The ROP gadget classification tutorial used tabular features extracted from disassembly: instruction counts, register usage, side effects, and termination type. That approach works well for short, linear gadgets. But tabular features discard structural information. Two functions with the same instruction counts and register usage can have completely different control flow: one might be a straight-line computation, the other a nested loop with error-handling branches. This tutorial introduces graph neural networks (GNNs), which operate directly on the control flow graph (CFG) of a function, preserving the branching and looping structure that makes each function unique.
The binary similarity problem
Given two functions from two different binaries, are they semantically equivalent (compiled from the same source) or unrelated? Four security applications drive interest in this problem.
Vulnerability discovery. A CVE is published for a function in an open-source library. Is that vulnerable function present in this IoT device’s firmware? The firmware was compiled with a different compiler, optimization level, and architecture, so string matching and byte-level hashing fail.
Patch diffing. A vendor releases a firmware update with vague release notes. By comparing function embeddings between versions, you identify functions whose structure shifted significantly. Those are the patched functions.
Malware family clustering. Two samples share command-and-control code but differ in payload. Function similarity reveals shared code components across variants.
Plagiarism detection. Similarity embeddings identify matching functions across binaries even when compilation settings differ at every level.
The core challenge: the same source code produces different binaries under different compilation settings.
Same source function compiled four ways:
gcc -O0 x86_64 -> 42 instructions, 12 basic blocks
gcc -O2 x86_64 -> 28 instructions, 7 basic blocks (inlining, loop unrolling)
clang -O1 x86_64 -> 31 instructions, 9 basic blocks (different register allocation)
gcc -O1 ARM -> 35 instructions, 10 basic blocks (different ISA entirely)
All four are semantically identical. Byte hashes: all different.
CFG structure: similar (branching patterns preserved).Control flow graphs as data
A CFG represents a function as a directed graph. Each node is a basic block (a straight-line instruction sequence with no internal branches) and each edge is a control flow transfer (conditional branch, unconditional jump, or fall-through).
Block 0 (entry) Block 0: cmp + conditional jump
| \ Block 1: second comparison
v v Block 2: error return
Block 1 Block 2 Block 3: call + success return
| \ | Block 4: second error return
v v | Block 5: common exit (ret)
Blk 3 Blk 4 |
| | |
v v v
Block 5 (exit)For GNN input, each node needs a feature vector derived from its basic block:
| Feature | Description |
|---|---|
| instruction_count | Number of instructions (normalized) |
| has_call | Block contains a CALL instruction |
| has_cmp | Block contains CMP or TEST |
| has_jump | Block ends with a conditional jump |
| arithmetic_ratio | Fraction of arithmetic instructions |
| memory_ratio | Fraction of memory access instructions |
Setting up the environment
python -m venv venv && source venv/bin/activate
pip install torch torch-geometric numpy pandas scikit-learn matplotlib networkx
mkdir -p binary-similarity/{data,models,plots}Note
PyTorch Geometric (PyG) is the standard library for GNNs in PyTorch. Installation may require matching your CUDA version. See the PyG installation guide for platform-specific instructions. CPU-only works fine for this tutorial’s synthetic dataset.
Generating a synthetic CFG dataset
We generate synthetic CFGs that capture the structural properties of real control flow graphs (sparse connectivity, single entry nodes, realistic branching) for reproducibility.
Warning
Real-world CFG extraction uses Ghidra’s Python API or r2pipe. The synthetic graphs capture structural properties but not the full complexity of real disassembly. A later section shows the Ghidra extraction approach.
import random
import numpy as np
import torch
from torch_geometric.data import Data
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
def generate_cfg(num_nodes):
"""Generate a synthetic CFG with realistic branching properties."""
edges_src, edges_dst = [], []
for i in range(num_nodes - 1):
edges_src.append(i)
edges_dst.append(i + 1)
if random.random() < 0.4 and i + 2 < num_nodes:
target = random.randint(i + 2, min(i + 5, num_nodes - 1))
edges_src.append(i)
edges_dst.append(target)
# Back edges (loops)
for i in range(num_nodes):
if random.random() < 0.15 and i >= 2:
edges_src.append(i)
edges_dst.append(random.randint(max(0, i - 4), i - 1))
edge_index = torch.tensor([edges_src, edges_dst], dtype=torch.long)
features = []
for _ in range(num_nodes):
features.append([
random.randint(1, 20) / 20.0,
float(random.random() < 0.2),
float(random.random() < 0.35),
float(random.random() < 0.4),
random.random() * 0.6,
random.random() * 0.8,
])
return Data(x=torch.tensor(features, dtype=torch.float),
edge_index=edge_index)
def perturb_cfg(graph, noise_level=0.1):
"""Perturb a CFG to simulate different compilation settings."""
x = torch.clamp(graph.x.clone() + torch.randn_like(graph.x) * noise_level,
0.0, 1.0)
edges_src = graph.edge_index[0].tolist()
edges_dst = graph.edge_index[1].tolist()
num_nodes = x.size(0)
if num_nodes > 2 and random.random() < 0.5:
src = random.randint(0, num_nodes - 2)
edges_src.append(src)
edges_dst.append(random.randint(src + 1, num_nodes - 1))
if len(edges_src) > num_nodes and random.random() < 0.3:
idx = random.randint(0, len(edges_src) - 1)
edges_src.pop(idx)
edges_dst.pop(idx)
return Data(x=x, edge_index=torch.tensor([edges_src, edges_dst],
dtype=torch.long))
num_pairs = 10000
graphs_a, graphs_b, labels = [], [], []
for i in range(num_pairs):
num_nodes = random.randint(5, 50)
g1 = generate_cfg(num_nodes)
if random.random() < 0.5:
g2 = perturb_cfg(g1, noise_level=random.uniform(0.05, 0.2))
labels.append(1.0)
else:
g2 = generate_cfg(random.randint(5, 50))
labels.append(0.0)
graphs_a.append(g1)
graphs_b.append(g2)
labels = torch.tensor(labels, dtype=torch.float)
print(f"Generated {num_pairs} pairs: {int(labels.sum())} positive, "
f"{num_pairs - int(labels.sum())} negative")Generated 10000 pairs: 4968 positive, 5032 negativeGraph representation with PyTorch Geometric
PyG represents each graph as a Data object with node features x and edge list edge_index. Batching concatenates graphs into one large disconnected graph with a batch vector tracking node ownership. Since we have pairs, we need a custom collation function.
from torch_geometric.data import Batch
from torch.utils.data import Dataset
class CFGPairDataset(Dataset):
"""Dataset of CFG pairs with binary similarity labels."""
def __init__(self, graphs_a, graphs_b, labels):
self.graphs_a = graphs_a
self.graphs_b = graphs_b
self.labels = labels
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
return self.graphs_a[idx], self.graphs_b[idx], self.labels[idx]
def collate_pairs(batch):
"""Collate (graph_a, graph_b, label) tuples into batched graphs."""
ga, gb, la = zip(*batch)
return (Batch.from_data_list(list(ga)), Batch.from_data_list(list(gb)),
torch.stack(list(la)))
n = len(labels)
n_train, n_val = int(0.7 * n), int(0.15 * n)
indices = torch.randperm(n).tolist()
train_a = [graphs_a[i] for i in indices[:n_train]]
train_b = [graphs_b[i] for i in indices[:n_train]]
train_labels = labels[indices[:n_train]]
val_a = [graphs_a[i] for i in indices[n_train:n_train + n_val]]
val_b = [graphs_b[i] for i in indices[n_train:n_train + n_val]]
val_labels = labels[indices[n_train:n_train + n_val]]
test_a = [graphs_a[i] for i in indices[n_train + n_val:]]
test_b = [graphs_b[i] for i in indices[n_train + n_val:]]
test_labels = labels[indices[n_train + n_val:]]
train_loader = torch.utils.data.DataLoader(
CFGPairDataset(train_a, train_b, train_labels),
batch_size=64, shuffle=True, collate_fn=collate_pairs)
val_loader = torch.utils.data.DataLoader(
CFGPairDataset(val_a, val_b, val_labels),
batch_size=64, shuffle=False, collate_fn=collate_pairs)
test_loader = torch.utils.data.DataLoader(
CFGPairDataset(test_a, test_b, test_labels),
batch_size=64, shuffle=False, collate_fn=collate_pairs)
print(f"Train: {len(train_a)}, Val: {len(val_a)}, Test: {len(test_a)}")Train: 7000, Val: 1500, Test: 1500Building the GNN
We use a Graph Isomorphism Network (GIN) encoder. GIN is as powerful as the Weisfeiler-Lehman graph isomorphism test, making it a strong choice when graph structure matters. Each GIN layer aggregates neighbor information with a learnable MLP, then global mean pooling collapses node embeddings into a graph-level vector. A Siamese wrapper feeds both graphs through the shared encoder and computes cosine similarity.
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GINConv, global_mean_pool
class CFGEncoder(nn.Module):
"""GIN-based encoder that maps a CFG to a fixed-size embedding."""
def __init__(self, input_dim, hidden_dim=64, num_layers=3):
super().__init__()
self.convs = nn.ModuleList()
self.bns = nn.ModuleList()
for i in range(num_layers):
in_dim = input_dim if i == 0 else hidden_dim
mlp = nn.Sequential(nn.Linear(in_dim, hidden_dim), nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim))
self.convs.append(GINConv(mlp))
self.bns.append(nn.BatchNorm1d(hidden_dim))
self.fc = nn.Linear(hidden_dim, hidden_dim)
def forward(self, data):
x, edge_index, batch = data.x, data.edge_index, data.batch
for conv, bn in zip(self.convs, self.bns):
x = torch.relu(bn(conv(x, edge_index)))
return self.fc(global_mean_pool(x, batch))
class SimilarityModel(nn.Module):
"""Siamese GNN that scores binary function similarity."""
def __init__(self, encoder):
super().__init__()
self.encoder = encoder
def forward(self, data1, data2):
emb1 = self.encoder(data1)
emb2 = self.encoder(data2)
return F.cosine_similarity(emb1, emb2, dim=1), emb1, emb2Note
The encoder weights are shared between both branches of the Siamese network. This forces the model to learn a general-purpose embedding space where similar functions land near each other regardless of which input slot they occupy.
Training
We train with binary cross-entropy on the rescaled cosine similarity (mapped from [-1, 1] to [0, 1]).
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder = CFGEncoder(input_dim=6, hidden_dim=64, num_layers=3)
model = SimilarityModel(encoder).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.BCELoss()
print(f"Device: {device}, Parameters: "
f"{sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
def train_epoch(model, loader, optimizer, criterion):
model.train()
total_loss, n = 0, 0
for batch_a, batch_b, batch_labels in loader:
batch_a, batch_b = batch_a.to(device), batch_b.to(device)
batch_labels = batch_labels.to(device)
optimizer.zero_grad()
sim, _, _ = model(batch_a, batch_b)
loss = criterion((sim + 1) / 2, batch_labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
n += 1
return total_loss / n
def evaluate(model, loader, criterion):
model.eval()
total_loss, n = 0, 0
all_preds, all_labels = [], []
with torch.no_grad():
for batch_a, batch_b, batch_labels in loader:
batch_a, batch_b = batch_a.to(device), batch_b.to(device)
batch_labels = batch_labels.to(device)
sim, _, _ = model(batch_a, batch_b)
sim_scaled = (sim + 1) / 2
total_loss += criterion(sim_scaled, batch_labels).item()
n += 1
all_preds.append(sim_scaled.cpu())
all_labels.append(batch_labels.cpu())
return total_loss / n, torch.cat(all_preds), torch.cat(all_labels)
best_val_loss = float("inf")
for epoch in range(30):
train_loss = train_epoch(model, train_loader, optimizer, criterion)
val_loss, _, _ = evaluate(model, val_loader, criterion)
if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save(model.state_dict(), "binary-similarity/models/best_model.pt")
if (epoch + 1) % 5 == 0:
print(f"Epoch {epoch+1:3d}/30 "
f"Train: {train_loss:.4f} Val: {val_loss:.4f}")Representative output:
Device: cpu, Parameters: 22,209
Epoch 5/30 Train: 0.5842 Val: 0.5601
Epoch 10/30 Train: 0.4731 Val: 0.4583
Epoch 15/30 Train: 0.3892 Val: 0.3954
Epoch 20/30 Train: 0.3214 Val: 0.3487
Epoch 25/30 Train: 0.2756 Val: 0.3195
Epoch 30/30 Train: 0.2412 Val: 0.3018Evaluation
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc
model.load_state_dict(torch.load("binary-similarity/models/best_model.pt",
map_location=device, weights_only=True))
test_loss, test_preds, test_labels_t = evaluate(model, test_loader, criterion)
preds_np = test_preds.numpy()
labels_np = test_labels_t.numpy()
roc = roc_auc_score(labels_np, preds_np)
pr_precision, pr_recall, _ = precision_recall_curve(labels_np, preds_np)
pr = auc(pr_recall, pr_precision)
acc = ((preds_np >= 0.5).astype(float) == labels_np).mean()
print(f"Test loss: {test_loss:.4f}")
print(f"ROC AUC: {roc:.4f}")
print(f"PR AUC: {pr:.4f}")
print(f"Accuracy: {acc:.4f}")Representative output:
Test loss: 0.3054
ROC AUC: 0.9142
PR AUC: 0.9078
Accuracy: 0.8413Retrieval metrics
In practice, you embed a query function and rank all candidates by embedding distance. Recall@k measures how often the true match appears in the top k results.
def retrieval_recall_at_k(model, graphs_a, graphs_b, labels, k_values):
"""For each positive pair, check if the true match ranks in top-k."""
model.eval()
with torch.no_grad():
embed = lambda gs: torch.cat([
model.encoder(Batch.from_data_list([g.to(device)])).cpu()
for g in gs], dim=0)
all_a, all_b = embed(graphs_a), embed(graphs_b)
results = {k: 0 for k in k_values}
pos_idx = (labels == 1.0).nonzero(as_tuple=True)[0]
for idx in pos_idx:
sims = F.cosine_similarity(all_a[idx].unsqueeze(0), all_b, dim=1)
rank = (sims.argsort(descending=True) == idx).nonzero(as_tuple=True)[0].item()
for k in k_values:
if rank < k:
results[k] += 1
return {k: v / len(pos_idx) for k, v in results.items()}
for k, v in retrieval_recall_at_k(
model, test_a, test_b, test_labels, [1, 5, 10, 20]).items():
print(f"Recall@{k:2d}: {v:.4f}")Representative output:
Recall@ 1: 0.5634
Recall@ 5: 0.8102
Recall@10: 0.8856
Recall@20: 0.9347Note
Recall@1 of 0.56 means the exact match is the top-ranked candidate 56% of the time. For vulnerability search across firmware with thousands of functions, recall@10 or recall@20 is more practical: present the analyst with a short list of candidates.
Embedding visualization
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
model.eval()
with torch.no_grad():
viz_n = 100
emb_a = np.array([model.encoder(
Batch.from_data_list([test_a[i].to(device)])).cpu().numpy()[0]
for i in range(viz_n)])
emb_b = np.array([model.encoder(
Batch.from_data_list([test_b[i].to(device)])).cpu().numpy()[0]
for i in range(viz_n)])
coords = TSNE(n_components=2, random_state=42, perplexity=30).fit_transform(
np.vstack([emb_a, emb_b]))
plt.figure(figsize=(8, 6))
plt.scatter(coords[:viz_n, 0], coords[:viz_n, 1], color="#2196F3",
alpha=0.7, s=20, label="Query functions")
plt.scatter(coords[viz_n:, 0], coords[viz_n:, 1], color="#FF5722",
alpha=0.7, s=20, label="Candidate functions")
for i in range(viz_n):
if test_labels[i].item() == 1.0:
plt.plot([coords[i, 0], coords[viz_n + i, 0]],
[coords[i, 1], coords[viz_n + i, 1]],
color="#9E9E9E", alpha=0.25, linewidth=0.8)
plt.title("t-SNE of function embeddings with positive-pair links")
plt.legend()
plt.tight_layout()
plt.savefig("binary-similarity/plots/tsne_embeddings.png", dpi=150)
plt.show()Extracting real CFGs with Ghidra
Note
This section requires Ghidra. If you don’t have it installed, skip ahead. The model works the same with synthetic or real graphs.
Save as extract_cfgs.py and run with analyzeHeadless /tmp/proj proj -import <binary> -postScript extract_cfgs.py -scriptPath .:
import json
from ghidra.program.model.block import BasicBlockModel
def extract_node_features(block, listing):
"""Extract a 6-element feature vector from a basic block."""
n, has_call, has_cmp, has_jump, arith, mem = 0, 0, 0, 0, 0, 0
instr_iter = listing.getInstructions(block, True)
while instr_iter.hasNext():
m = instr_iter.next().getMnemonicString().lower()
n += 1
if m.startswith("call"): has_call = 1
if m in ("cmp", "test"): has_cmp = 1
if m.startswith("j"): has_jump = 1
if m in ("add","sub","mul","imul","div","idiv","inc","dec",
"neg","shl","shr","sar","and","or","xor","not"): arith += 1
if m in ("mov","lea","push","pop","movzx","movsx"): mem += 1
if n == 0:
return [0.0] * 6
return [min(n / 20.0, 1.0), float(has_call), float(has_cmp),
float(has_jump), arith / float(n), mem / float(n)]
program = currentProgram
listing = program.getListing()
block_model = BasicBlockModel(program)
output = []
for func in program.getFunctionManager().getFunctions(True):
if func.isExternal() or func.isThunk():
continue
blocks, edges, addr_map = [], [], {}
block_iter = block_model.getCodeBlocksContaining(func.getBody(), None)
idx = 0
while block_iter.hasNext():
block = block_iter.next()
addr = block.getMinAddress().toString()
addr_map[addr] = idx
blocks.append(extract_node_features(block, listing))
idx += 1
block_iter = block_model.getCodeBlocksContaining(func.getBody(), None)
while block_iter.hasNext():
block = block_iter.next()
src = addr_map.get(block.getMinAddress().toString())
if src is None: continue
dest_iter = block.getDestinations(None)
while dest_iter.hasNext():
da = dest_iter.next().getDestinationAddress().toString()
if da in addr_map:
edges.append([src, addr_map[da]])
if len(blocks) >= 2:
output.append({"name": func.getName(), "nodes": blocks, "edges": edges})
with open(str(currentProgram.getExecutablePath()) + "_cfgs.json", "w") as f:
json.dump(output, f, indent=2)
print("Extracted %d functions" % len(output))To load extracted CFGs into PyG, parse the JSON and construct Data objects with x from the "nodes" arrays and edge_index from "edges".
Applications
Vulnerability search
Embed a known-vulnerable function and all functions from a firmware image, rank by cosine similarity, and present the top candidates to an analyst.
def vulnerability_search(model, vuln_graph, candidates, names, top_k=10):
"""Find firmware functions most similar to a known-vulnerable function."""
model.eval()
dev = next(model.parameters()).device
with torch.no_grad():
ve = model.encoder(Batch.from_data_list([vuln_graph]).to(dev))
sims = [F.cosine_similarity(
ve, model.encoder(Batch.from_data_list([g]).to(dev))).item()
for g in candidates]
ranked = sorted(zip(names, sims), key=lambda x: x[1], reverse=True)
print(f"{'Rank':<6}{'Function':<30}{'Similarity':<10}")
for i, (name, sim) in enumerate(ranked[:top_k]):
print(f"{i+1:<6}{name:<30}{sim:.4f}")
return ranked[:top_k]
# Demo with synthetic data
vulnerability_search(model, test_a[0], test_b[:100],
[f"func_{i:04d}" for i in range(100)])Representative output:
Rank Function Similarity
1 func_0000 0.9312
2 func_0047 0.7845
3 func_0023 0.7632
4 func_0091 0.7518
5 func_0012 0.7201Patch diffing
Embed functions from two firmware versions. Functions whose best-match similarity dropped significantly between versions are likely the patched ones.
def patch_diff(model, old_graphs, old_names, new_graphs, new_names, threshold=0.3):
"""Identify functions that changed significantly between firmware versions."""
model.eval()
dev = next(model.parameters()).device
with torch.no_grad():
embed = lambda gs: torch.cat([model.encoder(
Batch.from_data_list([g]).to(dev)).cpu() for g in gs], dim=0)
old_e, new_e = embed(old_graphs), embed(new_graphs)
changed = []
for i, name in enumerate(old_names):
sims = F.cosine_similarity(old_e[i].unsqueeze(0), new_e, dim=1)
best = sims.max().item()
if best < (1.0 - threshold):
changed.append((name, new_names[sims.argmax().item()], best))
changed.sort(key=lambda x: x[2])
print(f"{'Old Function':<25}{'Best Match':<25}{'Similarity':<10}")
for old, new, sim in changed[:5]:
print(f"{old:<25}{new:<25}{sim:.4f}")
return changed
patch_diff(model, test_a[:50], [f"old_{i:03d}" for i in range(50)],
test_b[:50], [f"new_{i:03d}" for i in range(50)])Representative output:
Old Function Best Match Similarity
old_037 new_012 0.4215
old_019 new_044 0.4587
old_042 new_031 0.5012
old_008 new_027 0.5234
old_028 new_003 0.5891Limitations
Scalability. Pairwise comparison across two firmware images is O(n*m). Approximate nearest neighbor search with FAISS reduces this to near-linear time.
Cross-architecture. A model trained on x86 CFGs won’t generalize to ARM or MIPS without architecture-agnostic features or multi-architecture training data.
Obfuscation. Control flow flattening, opaque predicates, and virtualization-based obfuscation deliberately destroy CFG structure. Heavily obfuscated code requires deobfuscation before CFG extraction produces useful graphs.
Feature engineering ceiling. Six per-node features are a simplification. Richer representations (instruction embeddings, data flow graphs, inter-procedural analysis) improve accuracy but increase complexity. Research systems like Gemini, SAFE, and jTrans use increasingly sophisticated representations.
Where this fits in the series
Graph neural networks round out the core modeling techniques in the ML for Security series. Each tutorial so far has paired a technique with a security problem where it has genuine value:
| Part | Technique | Security problem |
|---|---|---|
| 1 | Isolation Forest | Anomaly detection on audit logs |
| 2 | XGBoost | ROP gadget classification |
| 3 | Feedforward neural networks | Malware classification (PE files) |
| 4 | Autoencoders | Network intrusion detection |
| 5 | Fine-tuned transformers | Phishing URL detection |
| 6 | LSTMs (sequence models) | DNS exfiltration detection |
| 7 | Adversarial ML | Evading and hardening classifiers |
| 8 | CNNs on flow data | Encrypted traffic classification |
| 9 | NLP / topic modeling | Threat intelligence extraction |
| 10 | Graph neural networks | Binary function similarity |
The consistent theme is that technique selection should follow from the data structure. Tabular features call for tree-based models or feedforward networks. Sequences call for RNNs or transformers. Graphs call for GNNs. Starting with the data rather than the model avoids the trap of forcing every problem into the latest architecture.
Upcoming parts shift from modeling to operations: validating these models against time-based data leakage, and deploying them with drift detection and automated retraining.
For deeper coverage of transformers and large language models, see the Transformers and LLMs series. For the offensive security side, see the Linux exploitation series.