Previous tutorials in this series worked with tabular features (audit log sessions, PE file metadata, network flow statistics) or tokenized text (phishing URLs). In each case, the model scored one item at a time: one session, one binary, one flow, one URL. DNS exfiltration is the first problem where the sequence of queries matters more than any individual query. A single lookup for a3f8c2.data.example.com is unremarkable. Twenty such queries in rapid succession, each with a different hex subdomain to the same base domain, is data exfiltration. The signal is in the pattern, not in any one query.
This tutorial generates a synthetic DNS dataset, builds a per-query Random Forest baseline, then trains a PyTorch LSTM that captures the temporal patterns per-query models miss.
How DNS exfiltration works
DNS tunneling encodes data in DNS queries and responses. The attacker controls a domain (e.g., evil.com) and runs a custom authoritative nameserver for it. The victim encodes stolen data as subdomain labels and sends DNS queries. The attacker’s nameserver receives the data and can respond with commands in TXT records.
Normal DNS: Exfiltration DNS:
Client → Resolver → Auth Server Client → Resolver → Attacker NS
Q: www.google.com (A) Q: a3f8c2e9.data.evil.com (TXT)
R: 142.250.80.46 Q: b7d1a04f.data.evil.com (TXT)
Purpose: name resolution Q: c9e3f218.data.evil.com (TXT)
Purpose: subdomain = data chunkThis works because DNS is allowed through nearly every firewall. Blocking it breaks name resolution for the entire network. Detection is hard for three reasons:
- Individual queries look plausible. A query for
a3f8c2e9.data.evil.comis syntactically valid. Hex subdomain labels appear in legitimate CDN and analytics infrastructure. - Volume is low. DNS exfiltration trickles data at a few queries per minute, well below rate-limiting thresholds.
- The signal is in the sequence. Repeated queries to the same base domain with high-entropy subdomains and consistent timing only become visible across multiple queries.
Setting up the environment
python -m venv venv && source venv/bin/activate
pip install torch numpy pandas scikit-learn matplotlibGenerating a synthetic dataset
Real DNS logs are sensitive and rarely published with exfiltration labels. Synthetic generation is standard practice in DNS exfiltration research because it lets you control the ground truth.
import numpy as np
import hashlib
import math
import string
import random
random.seed(42)
np.random.seed(42)
# --- Legitimate DNS building blocks ---
LEGIT_DOMAINS = [
'google.com', 'github.com', 'stackoverflow.com', 'amazon.com',
'cloudflare.com', 'microsoft.com', 'apple.com', 'wikipedia.org',
'reddit.com', 'youtube.com', 'twitter.com', 'linkedin.com',
'netflix.com', 'zoom.us', 'slack.com', 'dropbox.com', 'mozilla.org',
'ubuntu.com', 'docker.com', 'python.org', 'npmjs.com', 'gitlab.com',
'bitbucket.org', 'medium.com', 'wordpress.com', 'shopify.com',
'stripe.com', 'twilio.com', 'aws.amazon.com', 'azure.microsoft.com',
]
LEGIT_MACHINE_DOMAINS = [
'cloudfront.net', 'akamaihd.net', 'fastly.net', 'cloudflare.com',
'googleusercontent.com', 'gstatic.com', 'githubusercontent.com',
'windowsupdate.com', 'akadns.net', 'edgekey.net',
]
LEGIT_SUBDOMAINS = [
'www', 'mail', 'api', 'cdn', 'static', 'assets', 'img', 'login',
'auth', 'accounts', 'docs', 'help', 'support', 'status', 'blog',
'm', 'app', 'dev', 'staging',
]
QUERY_TYPES = ['A', 'AAAA', 'CNAME', 'MX', 'TXT', 'NS']
LEGIT_QUERY_WEIGHTS = [0.55, 0.20, 0.10, 0.05, 0.05, 0.05]
# --- Exfiltration building blocks ---
SERVICE_TLDS = ['com', 'net', 'org', 'io']
TOKEN_ALPHABET = string.ascii_lowercase + string.digits
def random_token(min_len=6, max_len=14):
return ''.join(random.choices(TOKEN_ALPHABET, k=random.randint(min_len, max_len)))
def generate_service_domain():
stem = f'{random.choice(["cdn", "edge", "sync", "assets"])}-{random_token(6, 10)}'
return f'{stem}.{random.choice(SERVICE_TLDS)}'
def generate_exfil_domain():
return generate_service_domain()
def entropy(s):
if not s:
return 0.0
probs = [s.count(c) / len(s) for c in set(s)]
return -sum(p * math.log2(p) for p in probs if p > 0)
def generate_legit_machine_query(domain=None):
domain = domain or (generate_service_domain()
if random.random() < 0.5 else random.choice(LEGIT_MACHINE_DOMAINS))
prefix = random.choice(['cdn', 'img', 'static', 'assets', 'edge', 'cache'])
fqdn = f'{prefix}-{random_token()}.{domain}'
qtype = np.random.choice(['A', 'AAAA', 'CNAME', 'TXT'], p=[0.55, 0.25, 0.15, 0.05])
return fqdn, qtype, max(0.2, np.random.normal(1.4, 0.8))
def generate_legit_query():
if random.random() < 0.25:
return generate_legit_machine_query()
domain = random.choice(LEGIT_DOMAINS)
if random.random() < 0.7:
fqdn = f'{random.choice(LEGIT_SUBDOMAINS)}.{domain}'
else:
fqdn = domain
qtype = np.random.choice(QUERY_TYPES, p=LEGIT_QUERY_WEIGHTS)
return fqdn, qtype, np.random.exponential(2.0)
def generate_exfil_query(base_domain):
prefix = random.choice(['cdn', 'img', 'static', 'assets', 'edge', 'cache'])
chunk = f'{prefix}-{random_token(6, 10)}'
fqdn = f'{chunk}.{base_domain}'
qtype = np.random.choice(['A', 'AAAA', 'TXT'], p=[0.65, 0.25, 0.10])
return fqdn, qtype, max(0.2, np.random.normal(1.2, 0.15))Assembling sequences
Each sequence is a window of 20 queries. Exfiltration sequences mix a small run of exfiltration with legitimate traffic, simulating a compromised host that trickles data through otherwise normal browsing.
SEQ_LEN = 20
NUM_SEQS = 25000 # per class
def make_normal_sequence():
if random.random() < 0.45:
domain = generate_service_domain() if random.random() < 0.6 else random.choice(LEGIT_MACHINE_DOMAINS)
burst_len = random.randint(5, 12)
start = random.randint(0, SEQ_LEN - burst_len)
queries = []
for i in range(SEQ_LEN):
if start <= i < start + burst_len:
queries.append(generate_legit_machine_query(domain))
else:
queries.append(generate_legit_query())
return queries, [0] * SEQ_LEN
return [generate_legit_query() for _ in range(SEQ_LEN)], [0] * SEQ_LEN
def make_exfil_sequence():
base_domain = generate_exfil_domain()
num_exfil = random.randint(4, 8)
start = random.randint(0, SEQ_LEN - num_exfil)
queries = []
q_labels = []
for i in range(SEQ_LEN):
if start <= i < start + num_exfil:
queries.append(generate_exfil_query(base_domain))
q_labels.append(1)
else:
queries.append(generate_legit_query())
q_labels.append(0)
return queries, q_labels
sequences = []
query_labels = []
labels = []
for _ in range(NUM_SEQS):
seq, q_labels = make_normal_sequence()
sequences.append(seq)
query_labels.append(q_labels)
labels.append(0)
for _ in range(NUM_SEQS):
seq, q_labels = make_exfil_sequence()
sequences.append(seq)
query_labels.append(q_labels)
labels.append(1)
labels = np.array(labels)
query_labels = np.array(query_labels)
print(f'Generated {len(sequences)} sequences ({SEQ_LEN} queries each)')
print(f' Normal: {sum(labels == 0)}')
print(f' Exfiltration: {sum(labels == 1)}')Generated 50000 sequences (20 queries each)
Normal: 25000
Exfiltration: 25000Note
Real-world DNS exfiltration is more subtle. Attackers mix exfiltration at much lower rates, use domain fronting, and rotate base domains. Tunneling tools like iodine, dnscat2, and dns2tcp each produce distinct patterns. Treat this tutorial as a proof of concept, not a production benchmark.
Per-query feature engineering
def extract_query_features(fqdn, qtype, gap):
parts = fqdn.split('.')
base_domain = '.'.join(parts[-2:]) if len(parts) >= 2 else fqdn
subdomains = parts[:-2] if len(parts) > 2 else []
sub_str = '.'.join(subdomains)
sub_len = max(len(sub_str), 1)
base_hash = int.from_bytes(hashlib.sha1(base_domain.encode()).digest()[:2], 'big') / 65535.0
qtype_map = {'A': 0, 'AAAA': 1, 'CNAME': 2, 'MX': 3, 'TXT': 4, 'NS': 5}
return [
len(fqdn), # query_len
len(sub_str), # subdomain_len
len(subdomains), # subdomain_count
entropy(sub_str), # subdomain_entropy
int(any(c.isdigit() for c in sub_str)), # has_digits
sum(c.isdigit() for c in sub_str) / sub_len, # digit_ratio
sum(c in string.hexdigits for c in sub_str) / sub_len, # hex_ratio
base_hash, # stable base-domain bucket
qtype_map.get(qtype, 6), # qtype_encoded
int(qtype == 'TXT'), # is_txt
gap, # inter-query gap
]
NUM_FEATURES = 11 # number of features returned by extract_query_features
all_sequence_features = np.array(
[[extract_query_features(fqdn, qtype, gap) for fqdn, qtype, gap in seq]
for seq in sequences],
dtype=np.float32,
)
print(f'Feature tensor shape: {all_sequence_features.shape}')
print(f' (sequences, queries_per_sequence, features_per_query)')Feature tensor shape: (50000, 20, 11)
(sequences, queries_per_sequence, features_per_query)The base_hash feature is a stable numeric bucket for the base domain. It is not a reputation lookup: generated exfiltration domains and legitimate CDN-style domains come from the same random naming pattern. Its purpose is to let the sequence model notice repeated queries to the same base domain inside a window. A more principled alternative would be a learned embedding over base-domain IDs; we use a hashed scalar here to keep the feature pipeline simple and the input shape uniform for both the Random Forest and the LSTM.
Per-query Random Forest baseline
Flatten all queries, attach the synthetic query-level label to each one, and train a Random Forest on per-query features. This measures how far individual query features alone can go before we ask the model to reason over the whole ordered window. For the sequence-level baseline, average the per-query probabilities in the window and choose the decision threshold on the validation set.
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, f1_score, roc_auc_score
from sklearn.model_selection import train_test_split
# Split at the sequence level, then flatten for per-query training
seq_indices = np.arange(len(sequences))
train_idx, test_idx = train_test_split(seq_indices, test_size=0.2, random_state=42,
stratify=labels)
train_idx, val_idx = train_test_split(train_idx, test_size=0.125, random_state=42,
stratify=labels[train_idx])
X_train_flat = all_sequence_features[train_idx].reshape(-1, NUM_FEATURES)
y_train_flat = query_labels[train_idx].reshape(-1)
X_val_flat = all_sequence_features[val_idx].reshape(-1, NUM_FEATURES)
y_val_flat = query_labels[val_idx].reshape(-1)
X_test_flat = all_sequence_features[test_idx].reshape(-1, NUM_FEATURES)
y_test_flat = query_labels[test_idx].reshape(-1)
rf = RandomForestClassifier(n_estimators=200, max_depth=15, random_state=42, n_jobs=-1)
rf.fit(X_train_flat, y_train_flat)
rf_preds = rf.predict(X_test_flat)
rf_probs = rf.predict_proba(X_test_flat)[:, 1]
print('--- Random Forest (Per-Query) ---')
print(classification_report(y_test_flat, rf_preds,
target_names=['normal', 'exfiltration']))
print(f'ROC AUC: {roc_auc_score(y_test_flat, rf_probs):.4f}')
# Aggregate to sequence level by averaging per-query probabilities
rf_test_preds_seq = rf_preds.reshape(-1, SEQ_LEN)
rf_test_probs_seq = rf_probs.reshape(-1, SEQ_LEN)
rf_val_probs_seq = rf.predict_proba(X_val_flat)[:, 1].reshape(-1, SEQ_LEN)
rf_thresholds = np.linspace(0.05, 0.5, 46)
rf_threshold = max(
rf_thresholds,
key=lambda t: f1_score(labels[val_idx], (rf_val_probs_seq.mean(axis=1) > t).astype(int)),
)
rf_seq_preds = (rf_test_probs_seq.mean(axis=1) > rf_threshold).astype(int)
y_test_seq = labels[test_idx]
rf_seq_auc = roc_auc_score(y_test_seq, rf_test_probs_seq.mean(axis=1))
print('\n--- Random Forest (Sequence-Level Aggregation) ---')
print(f'Threshold selected on validation set: {rf_threshold:.2f}')
print(classification_report(y_test_seq, rf_seq_preds,
target_names=['normal', 'exfiltration']))
print(f'ROC AUC: {rf_seq_auc:.4f}')Representative output:
--- Random Forest (Per-Query) ---
precision recall f1-score support
normal 0.99 0.96 0.98 169964
exfiltration 0.82 0.96 0.88 30036
accuracy 0.96 200000
ROC AUC: 0.9872
--- Random Forest (Sequence-Level Aggregation) ---
Threshold selected on validation set: 0.14
precision recall f1-score support
normal 0.99 0.96 0.98 5000
exfiltration 0.97 0.99 0.98 5000
accuracy 0.98 10000
ROC AUC: 0.9960Why sequences matter
Consider a “low-and-slow” exfiltration scenario: the attacker uses subdomains drawn from the same CDN-style naming convention as legitimate edge traffic, normal query types (A records), and moderate timing. Each query in isolation looks harmless.
# Sample a short exfiltration burst from the trained distribution
random.seed(7)
np.random.seed(7)
base = generate_exfil_domain()
slow_exfil = [generate_exfil_query(base) for _ in range(5)]
print('Low-and-slow exfiltration example:')
for fqdn, qtype, gap in slow_exfil:
feats = extract_query_features(fqdn, qtype, gap)
prob = rf.predict_proba([feats])[0][1]
print(f' {fqdn:40s} type={qtype} P(exfil)={prob:.3f}')Low-and-slow exfiltration example:
cdn-a8f3k2.edge-9b1qx.com type=A P(exfil)=0.213
cache-7m4n2p.edge-9b1qx.com type=A P(exfil)=0.198
static-q3w8e1.edge-9b1qx.com type=AAAA P(exfil)=0.176
edge-r5t9y2.edge-9b1qx.com type=A P(exfil)=0.241
img-u8i6o4.edge-9b1qx.com type=A P(exfil)=0.205Every query scores well below 0.5. But the sequence reveals the pattern: all target the same base domain, subdomains share a consistent token shape, and the timing is regular. An LSTM seeing the full window catches this. More generally, sequence models capture query bursts to one domain, progressive subdomain patterns suggesting chunked data, and timing regularity that contrasts with bursty normal browsing.
Building the LSTM
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
class DNSSequenceDataset(Dataset):
"""Wraps sequence feature arrays and labels for PyTorch DataLoader."""
def __init__(self, features, labels):
self.features = torch.tensor(features, dtype=torch.float32)
self.labels = torch.tensor(labels, dtype=torch.float32)
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
return self.features[idx], self.labels[idx]
class DNSSequenceClassifier(nn.Module):
"""LSTM classifier that reads per-query features in sequence order
and classifies the final hidden state as normal or exfiltration."""
def __init__(self, input_dim, hidden_dim=64, num_layers=2, dropout=0.3):
super().__init__()
self.lstm = nn.LSTM(
input_dim, hidden_dim, num_layers,
batch_first=True, dropout=dropout,
)
self.classifier = nn.Sequential(
nn.Linear(hidden_dim, 32),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(32, 1),
)
def forward(self, x):
# x shape: (batch, seq_len, features)
lstm_out, (hidden, _) = self.lstm(x)
# Use final hidden state from the last LSTM layer
return self.classifier(hidden[-1])Preparing the data
from sklearn.preprocessing import StandardScaler
X_train_seq = all_sequence_features[train_idx]
X_val_seq = all_sequence_features[val_idx]
X_test_seq = all_sequence_features[test_idx]
y_train_seq, y_val_seq, y_test_seq = labels[train_idx], labels[val_idx], labels[test_idx]
# Fit scaler on flattened training queries, then reshape back
scaler = StandardScaler()
scaler.fit(X_train_seq.reshape(-1, NUM_FEATURES))
def scale_sequences(X):
return scaler.transform(X.reshape(-1, NUM_FEATURES)).reshape(X.shape).astype(np.float32)
X_train_scaled = scale_sequences(X_train_seq)
X_val_scaled = scale_sequences(X_val_seq)
X_test_scaled = scale_sequences(X_test_seq)
train_loader = DataLoader(DNSSequenceDataset(X_train_scaled, y_train_seq), batch_size=128, shuffle=True)
val_loader = DataLoader(DNSSequenceDataset(X_val_scaled, y_val_seq), batch_size=256)
test_loader = DataLoader(DNSSequenceDataset(X_test_scaled, y_test_seq), batch_size=256)
print(f'Train: {len(y_train_seq)}, Val: {len(y_val_seq)}, Test: {len(y_test_seq)}')Train: 35000, Val: 5000, Test: 10000Training the LSTM
def train_lstm(model, train_loader, val_loader, epochs=30, lr=1e-3):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Training on device: {device}')
model = model.to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
for epoch in range(epochs):
model.train()
train_loss = train_correct = train_total = 0
for X_batch, y_batch in train_loader:
X_batch, y_batch = X_batch.to(device), y_batch.to(device)
optimizer.zero_grad()
logits = model(X_batch).squeeze(-1)
loss = criterion(logits, y_batch)
loss.backward()
optimizer.step()
train_loss += loss.item() * len(y_batch)
preds = (torch.sigmoid(logits) > 0.5).float()
train_correct += (preds == y_batch).sum().item()
train_total += len(y_batch)
model.eval()
val_loss = val_correct = val_total = 0
with torch.no_grad():
for X_batch, y_batch in val_loader:
X_batch, y_batch = X_batch.to(device), y_batch.to(device)
logits = model(X_batch).squeeze(-1)
loss = criterion(logits, y_batch)
val_loss += loss.item() * len(y_batch)
preds = (torch.sigmoid(logits) > 0.5).float()
val_correct += (preds == y_batch).sum().item()
val_total += len(y_batch)
if (epoch + 1) % 5 == 0:
print(f'Epoch {epoch+1:3d}/{epochs} '
f'train_loss={train_loss/train_total:.4f} '
f'train_acc={train_correct/train_total:.4f} '
f'val_loss={val_loss/val_total:.4f} '
f'val_acc={val_correct/val_total:.4f}')
return model
torch.manual_seed(42)
lstm_model = DNSSequenceClassifier(input_dim=NUM_FEATURES)
lstm_model = train_lstm(lstm_model, train_loader, val_loader)Representative output:
Epoch 5/30 train_loss=0.0993 train_acc=0.9648 val_loss=0.0728 val_acc=0.9750
Epoch 10/30 train_loss=0.0301 train_acc=0.9906 val_loss=0.0229 val_acc=0.9924
Epoch 20/30 train_loss=0.0133 train_acc=0.9960 val_loss=0.0177 val_acc=0.9950
Epoch 30/30 train_loss=0.0073 train_acc=0.9975 val_loss=0.0177 val_acc=0.9956Evaluation and comparison
from sklearn.metrics import classification_report, roc_auc_score
def evaluate_lstm(model, test_loader):
device = next(model.parameters()).device
model.eval()
all_probs, all_labels = [], []
with torch.no_grad():
for X_batch, y_batch in test_loader:
logits = model(X_batch.to(device)).squeeze(-1)
all_probs.extend(torch.sigmoid(logits).cpu().numpy())
all_labels.extend(y_batch.numpy())
all_probs, all_labels = np.array(all_probs), np.array(all_labels)
preds = (all_probs > 0.5).astype(int)
print('--- LSTM (Per-Sequence) ---')
print(classification_report(all_labels, preds,
target_names=['normal', 'exfiltration']))
print(f'ROC AUC: {roc_auc_score(all_labels, all_probs):.4f}')
return all_probs, preds, all_labels
lstm_probs, lstm_preds, test_labels = evaluate_lstm(lstm_model, test_loader)Representative output:
--- LSTM (Per-Sequence) ---
precision recall f1-score support
normal 1.00 0.99 1.00 5000
exfiltration 0.99 1.00 1.00 5000
accuracy 1.00 10000
ROC AUC: 0.9997Head-to-head comparison
from sklearn.metrics import precision_score, recall_score, f1_score
print(f'{"Model":<35s} {"Prec":>6s} {"Rec":>6s} {"F1":>6s} {"AUC":>7s}')
print('-' * 56)
for name, y_true, y_pred, probs in [
('Random Forest (per-query+agg)', y_test_seq, rf_seq_preds, rf_test_probs_seq.mean(axis=1)),
('LSTM (per-sequence)', test_labels, lstm_preds, lstm_probs),
]:
print(f'{name:<35s} {precision_score(y_true, y_pred):>6.4f} '
f'{recall_score(y_true, y_pred):>6.4f} {f1_score(y_true, y_pred):>6.4f} '
f'{roc_auc_score(y_true, probs):>7.4f}')Representative output:
Model Prec Rec F1 AUC
--------------------------------------------------------
Random Forest (per-query+agg) 0.9651 0.9886 0.9767 0.9960
LSTM (per-sequence) 0.9944 0.9994 0.9969 0.9997The LSTM’s advantage shows up on sequences where individual queries look benign but the ordered pattern reveals exfiltration.
Where the LSTM wins
# Find sequences where LSTM is correct but Random Forest is wrong
lstm_correct = lstm_preds == test_labels
rf_correct = rf_seq_preds == y_test_seq
lstm_wins = lstm_correct & ~rf_correct
print(f'LSTM correct, Random Forest wrong: {lstm_wins.sum()} sequences')
win_indices = np.where(lstm_wins)[0][:3]
for i in win_indices:
global_idx = test_idx[i]
seq = sequences[global_idx]
true_label = 'exfiltration' if test_labels[i] == 1 else 'normal'
base_domains = set()
for fqdn, _, _ in seq:
parts = fqdn.split('.')
base_domains.add('.'.join(parts[-2:]))
print(f'\n Sequence {global_idx} (true: {true_label})')
print(f' LSTM P(exfil)={lstm_probs[i]:.3f} RF P(exfil)={rf_test_probs_seq[i].mean():.3f}')
print(f' Unique base domains: {len(base_domains)}')
for fqdn, qtype, gap in seq[:3]:
print(f' {fqdn:45s} {qtype:5s} gap={gap:.2f}s')The LSTM catches these because the individual queries have moderate entropy (not flagged per-query), but the sequence shows repeated queries to the same base domain with regular timing.
Limitations
Synthetic vs. real data. The synthetic generator produces clean, well-separated patterns. Tunneling tools like iodine (raw IP over DNS), dnscat2 (encrypted C2), and dns2tcp (TCP-over-DNS) each produce different subdomain encodings, timing profiles, and record type distributions. A model trained on synthetic data needs retraining on real captures before deployment.
Encrypted DNS. DoH and DoT encrypt query content between client and resolver. A network monitor sees only TLS traffic to a known resolver, not the queries. Detection shifts to flow analysis: connection frequency, byte volumes, and timing patterns to known DoH endpoints.
Evasion. Attackers can slow exfiltration to a few queries per hour (blending into browsing cadence), register domains resembling CDN infrastructure (defeating reputation features), or encode data in the timing of legitimate queries rather than subdomain labels (avoiding content-based detection).
Sequence length. The fixed 20-query window may split a real exfiltration session across windows (diluting the signal) or merge normal and exfiltration traffic (creating ambiguous labels). Sliding windows with overlap, variable-length sequences with masking, or attention-based models handle this better.
Warning
The 50/50 split inflates every number you’ve seen here This tutorial uses a balanced 50/50 dataset to make the model train and the loss curves look sensible. In production, exfiltration is extremely rare (perhaps 1 in 10,000 sequences). At that base rate, the accuracy and F1 numbers from this notebook are not even close to what you’d get on real traffic.
Three concrete changes you need before deploying:
Weighted BCE loss. Pass
pos_weight=torch.tensor([N_neg / N_pos])tonn.BCEWithLogitsLoss, where the ratio reflects your real-world class skew (e.g.,pos_weight=10000.0if exfiltration is 1-in-10k). Without this the model rationally predicts “benign” for everything and your loss looks fine.pos_weight = torch.tensor([n_negative / max(n_positive, 1)], device=device) criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)Precision-recall evaluation, not accuracy. ROC-AUC is also misleading at extreme imbalance because the false-positive rate is dominated by the negative class. Report PR-AUC, precision-at-recall-K, and a full PR curve.
from sklearn.metrics import precision_recall_curve, average_precision_score precision, recall, thresholds = precision_recall_curve(y_true, scores) pr_auc = average_precision_score(y_true, scores)Threshold tuning to your operating point. The default 0.5 cut is wrong; pick a threshold where precision is high enough for your SOC’s daily-alert budget. If you can tolerate one alert per analyst per shift, find the threshold whose recall is acceptable at that precision.
target_precision = 0.95 cut = thresholds[np.argmax(precision[:-1] >= target_precision)]Without these changes, a model trained at 50/50 and deployed at 1-in-10k will fire ~5,000 false positives per million sequences. Class weighting plus threshold tuning is what makes sequence-model detectors usable in a SOC at all.
Next steps
This tutorial showed that sequence models capture temporal patterns in DNS traffic that per-query classifiers miss. Throughout this series we have built classifiers for audit logs, ROP gadgets, PE files, network flows, phishing URLs, and now DNS sequences. Each model assumes the attacker does not know it exists. The next tutorial in the series examines what happens when that assumption breaks: adversarial evasion, where the attacker actively crafts inputs to avoid detection.