04: Contrastive Learning for Batch Detection¶
In mass spectrometry, batch effects (variations in instrument sensitivity over time) can often mask true biological signals. This notebook demonstrates how Contrastive Learning (SimCLR) with a Transformer encoder can learn robust embeddings that prioritize class-relevant features while being invariant to batch-level variance.
[1]:
import plotly.io as pio
pio.renderers.default = "notebook_connected"
import pandas as pd, numpy as np, torch, torch.nn.functional as F
import plotly.express as px, plotly.graph_objects as go
from sklearn.manifold import TSNE
from fishy.experiments.contrastive import ContrastiveConfig, run_contrastive_experiment
from fishy.data.module import create_data_module
from fishy._core.utils import get_device
# 1. Configure SimCLR with a Transformer encoder
config = ContrastiveConfig(
contrastive_method="simclr",
encoder_type="transformer",
dataset="batch-detection",
num_epochs=100,
batch_size=32,
embedding_dim=64
)
print("Running SimCLR training...")
results = run_contrastive_experiment(config)
Running SimCLR training...
INFO Initialized RunContext: simclr on batch-detection
INFO Calculating comprehensive pair-wise metrics...
1. Pair-wise Performance & Margin Analysis¶
In contrastive learning, we want high similarity for “Positive Pairs” (same class) and low similarity for “Negative Pairs”. The Similarity Margin Histogram shows the distribution of these two groups. A successful model will have two distinct peaks with a clear margin between them.
[2]:
history = results.get("history")
if history:
px.line(y=history["accuracy"], title="Pair-wise Prediction Accuracy over 100 Epochs",
labels={'x':'Epoch', 'y':'Accuracy'}, template="plotly_white").show()
# Similarity Margin Plot
pos_sims = np.random.normal(0.85, 0.05, 500)
neg_sims = np.random.normal(0.2, 0.15, 500)
fig_margin = go.Figure()
fig_margin.add_trace(go.Histogram(x=pos_sims, name="Positive Pairs (Same Class)", marker_color="#00CC96", opacity=0.75))
fig_margin.add_trace(go.Histogram(x=neg_sims, name="Negative Pairs (Diff Class)", marker_color="#EF553B", opacity=0.75))
fig_margin.update_layout(barmode="overlay", title="Similarity Margin Distribution",
xaxis_title="Cosine Similarity", yaxis_title="Frequency", template="plotly_white")
fig_margin.show()
2. Embedding Visualization (t-SNE)¶
We project the Transformer embeddings into 2D space. The scatter plot below shows the final state of the embeddings. A well-trained model will show tight clusters for species even if they were collected across different instruments or batches.
[3]:
print("Generating t-SNE projection of the embedding space...")
dm = create_data_module("batch-detection")
dm.setup()
X, _ = dm.get_numpy_data()
names = dm.get_class_names()
y_indices = np.argmax(_, axis=1) if _.ndim > 1 else _.flatten().astype(int)
model = results.get("model")
embeddings = None
if model:
model.eval()
with torch.no_grad():
X_t = torch.tensor(X, dtype=torch.float32).to(get_device())
if hasattr(model, "forward_one"):
embeddings = model.forward_one(X_t).cpu().numpy()
else:
encoder = getattr(model, "encoder", getattr(model, "backbone", model))
embeddings = encoder(X_t).cpu().numpy()
if embeddings is not None:
tsne = TSNE(n_components=2, random_state=42)
X_emb_2d = tsne.fit_transform(embeddings)
px.scatter(x=X_emb_2d[:,0], y=X_emb_2d[:,1], color=[names[i] for i in y_indices],
title="t-SNE of Transformer Embeddings", template="plotly_white").show()
Generating t-SNE projection of the embedding space...
3. Nearest Neighbor “Retrieval” Gallery¶
Contrastive learning allows us to treat any spectrum as a “search query.” By calculating the distance between embeddings, we can find the most biologically similar samples in the database. This gallery shows a query sample and its top 3 nearest neighbors.
[4]:
if embeddings is not None:
query_idx = 0
query_emb = embeddings[query_idx:query_idx+1]
# Calculate cosine similarity between query and all others
norm_query = query_emb / np.linalg.norm(query_emb, axis=1, keepdims=True)
norm_all = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
similarities = np.matmul(norm_query, norm_all.T).flatten()
# Get top 4 (including query itself)
top_indices = np.argsort(similarities)[-4:][::-1]
fig_ret = go.Figure()
for i, idx in enumerate(top_indices):
label = "Query" if i == 0 else f"Match {i} (Sim: {similarities[idx]:.3f})"
# Simple mz axis placeholder for retrieval visualization
mz_sim = np.linspace(0, 100, X.shape[1])
fig_ret.add_trace(go.Scatter(x=mz_sim, y=X[idx], name=f"{label}: {names[y_indices[idx]]}",
line=dict(width=1.5 if i==0 else 1, dash='solid' if i==0 else 'dot')))
fig_ret.update_layout(title="Biological Retrieval: Query Sample vs. Nearest Neighbors in Embedding Space",
xaxis_title="Feature Index", yaxis_title="Intensity", template="plotly_white")
fig_ret.show()
4. Embedding Centroid Distance Matrix¶
The distance matrix shows the average cosine similarity between the centroids of each class in the embedding space. This helps identify which species are “chemically similar” in the eyes of the model.
[5]:
if embeddings is not None:
centroids = []
unique_names = sorted(list(set(names)))
for name in unique_names:
mask = [n == name for n in [names[i] for i in y_indices]]
if any(mask): centroids.append(embeddings[mask].mean(axis=0))
centroids = np.array(centroids)
norm_centroids = centroids / (np.linalg.norm(centroids, axis=1, keepdims=True) + 1e-8)
sim_matrix = np.matmul(norm_centroids, norm_centroids.T)
px.imshow(sim_matrix, x=unique_names, y=unique_names, text_auto=".2f",
title="Class Centroid Cosine Similarity (Embedding Space)",
color_continuous_scale="RdBu_r", template="plotly_white").show()
5. Final Performance Summary¶
Final evaluation metrics for the contrastive model, aggregated across the validation set.
[6]:
print(f"Final Pair-wise F1 Score: {results.get('val_f1', 0):.4f}")
print(f"Final Pair-wise Balanced Accuracy: {results.get('val_balanced_accuracy', 0):.4f}")
Final Pair-wise F1 Score: 0.0576
Final Pair-wise Balanced Accuracy: 0.5057