02: Fish Species and Part Identificationยถ
This notebook demonstrates the identification of fish species and parts using Transformers, including a deep dive into stable biomarkers.
[1]:
import plotly.io as pio
pio.renderers.default = "notebook_connected"
import pandas as pd, numpy as np, torch, torch.nn as nn
import plotly.express as px, plotly.graph_objects as go, networkx as nx
from fishy import TrainingConfig, run_unified_training, display_final_summary
from fishy.analysis.xai import ModelWrapper, GradCAM
from fishy._core.utils import get_device
from lime.lime_tabular import LimeTabularExplainer
from sklearn.metrics import confusion_matrix
# Train the model
config = TrainingConfig(model="transformer", dataset="species", epochs=10, wandb_log=False)
results = run_unified_training(config)
display_final_summary(results)
INFO Initialized RunContext: transformer on species
INFO Evaluating pre-training phase
INFO --- Fold 1/3 ---
INFO --- Fold 2/3 ---
INFO --- Fold 3/3 ---
โญโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฎ โ Training Complete - Results Summary โ โ Metric Train Val โ โ Accuracy 0.9954 0.6296 โ โ Balanced Accuracy 0.9948 0.6250 โ โ MAE 0.0046 0.3704 โ โ MSE 0.0046 0.3704 โ โ Precision 0.9955 0.6297 โ โ Recall 0.9954 0.6296 โ โ F1 Score 0.9954 0.6294 โ โฐโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฏ
Elapsed training time: 67.9778 seconds
1. Interactive Confusion Matrixยถ
The confusion matrix is the first step in understanding model performance. This interactive heatmap allows us to see exactly which classes are being confused. In REIMS data, confusion often happens between biologically similar species (e.g., different types of white fish).
[2]:
if "predictions" in results:
p = results["predictions"]; cn = results["class_names"]
cm = confusion_matrix(p["labels"], p["preds"])
fig_cm = px.imshow(cm, x=cn, y=cn, text_auto=True,
title="Interactive Confusion Matrix",
color_continuous_scale="Viridis", template="plotly_white")
fig_cm.update_layout(xaxis_title="Predicted Species", yaxis_title="Actual Species")
fig_cm.show()
2. Feature Interaction Matrixยถ
Rather than looking at individual peaks in isolation, we visualize how the top 10 most important features (identified via saliency) interact with each other across the dataset. This correlation matrix reveals chemical dependencies and provides a 10x10 interaction map of the modelโs most critical decision-making features.
[3]:
model = results["model"]; dm = results["data_module"]
try:
model.eval()
# Generating 10x10 Feature Interaction Matrix from Saliency
print("Generating 10x10 Feature Interaction Matrix from Saliency...")
target_layer = model.layer_norm2
gc = GradCAM(model, target_layer)
# Use first sample for saliency
X_all, _ = dm.get_numpy_data()
input_t = torch.tensor(X_all[0:1], dtype=torch.float32).to(get_device())
input_t.requires_grad = True
cam = gc.generate_cam(input_t).cpu().numpy()[0]
gc.remove_hooks()
top_10_idx = cam.argsort()[-10:][::-1]
top_10_idx.sort() # Keep spectral order
top_10_names = [dm.get_filtered_dataframe().columns[i+1] for i in top_10_idx]
# Compute correlation matrix for these top features across the dataset
corr_matrix = np.corrcoef(X_all[:, top_10_idx].T)
fig_inter = px.imshow(corr_matrix, x=top_10_names, y=top_10_names,
title="Top 10 Feature Interaction Matrix (Saliency Correlations)",
color_continuous_scale="RdBu_r", template="plotly_white")
fig_inter.show()
except Exception as e:
print(f"Visualization failed: {e}")
Generating 10x10 Feature Interaction Matrix from Saliency...
3. Grad-CAM Saliency Overlayยถ
While attention shows relationships, Grad-CAM shows local importance. By coloring the spectrum by its gradient-based importance, we can see exactly which chemical peaks โtriggeredโ the modelโs final classification decision.
[4]:
try:
# Select a layer that preserves the spectral feature dimension
# layer_norm2 is the last layer before pooling and fc_out
target_layer = model.layer_norm2
gc = GradCAM(model, target_layer)
X_raw, y_raw = dm.get_numpy_data(labels_as_indices=True)
sample_idx = 0
input_t = torch.tensor(X_raw[sample_idx:sample_idx+1], dtype=torch.float32).to(get_device())
cam = gc.generate_cam(input_t).cpu().numpy()[0]
gc.remove_hooks()
feature_names = [f for f in dm.get_filtered_dataframe().columns if f not in ["Class Name", "m/z"]]
mz_axis = np.array([float(f) for f in feature_names])
spec = X_raw[sample_idx]
fig_cam = go.Figure()
fig_cam.add_trace(go.Scatter(x=mz_axis, y=spec, name="Spectrum", line=dict(color="lightgray", width=1)))
fig_cam.add_trace(go.Scatter(x=mz_axis, y=spec, mode="markers",
marker=dict(color=cam, colorscale="Viridis", size=8, showscale=True,
colorbar=dict(title="Importance")),
name="Grad-CAM Importance"))
fig_cam.update_layout(title=f"Grad-CAM Saliency: {results['class_names'][y_raw[sample_idx]]}",
xaxis_title="m/z", yaxis_title="Intensity", template="plotly_white")
fig_cam.show()
except Exception as e:
print(f"Grad-CAM visualization failed: {e}")
4. Advanced Biomarker Analysisยถ
Beyond simple classification, we want to identify which chemical features (m/z peaks) are consistently used by the model to distinguish between classes. We use LIME across a subset of samples to find stable, class-specific diagnostic markers.
[5]:
class_names = results["class_names"]
X_all, y_all = dm.get_numpy_data(labels_as_indices=True)
feature_names = [f"{c}" for c in dm.get_filtered_dataframe().columns if c not in ["Class Name", "m/z"]]
mz_axis = np.array([float(c) for c in feature_names])
print("Analyzing stable biomarkers via LIME aggregate...")
wrapper = ModelWrapper(model, str(get_device()))
explainer = LimeTabularExplainer(X_all, feature_names=feature_names, class_names=class_names, discretize_continuous=False)
class_biomarkers = {}
all_top_indices = []
for c_idx in range(len(class_names)):
c_indices = np.where(y_all == c_idx)[0]
if len(c_indices) > 0:
sub = np.random.choice(c_indices, min(10, len(c_indices)), replace=False)
w_list = []
for idx in sub:
exp = explainer.explain_instance(X_all[idx], wrapper.predict_proba, num_features=20, labels=(c_idx,))
w_list.append(dict(exp.as_list(label=c_idx)))
avg_w = pd.DataFrame(w_list).mean().sort_values()
top_features = avg_w.tail(10).index.tolist()
class_biomarkers[c_idx] = [feature_names.index(f) for f in top_features]
all_top_indices.extend(class_biomarkers[c_idx])
Analyzing stable biomarkers via LIME aggregate...
4.1 Distinct Class Biomarker Comparisonยถ
This plot overlays the top 10 diagnostic peaks for each class on top of a representative spectrum. It highlights the features that the model has learned are unique to each class.
[6]:
fig_bio = go.Figure()
colors = ["gold", "cyan", "silver"]
for c_idx, c_name in enumerate(class_names):
rep_idx = np.where(y_all == c_idx)[0][0]
spec = X_all[rep_idx]
bio_indices = class_biomarkers[c_idx]
fig_bio.add_trace(go.Scatter(x=mz_axis, y=spec, name=f"{c_name} Base", line=dict(width=1), opacity=0.3))
fig_bio.add_trace(go.Scatter(x=mz_axis[bio_indices], y=spec[bio_indices], mode="markers",
marker=dict(size=12, color=colors[c_idx % len(colors)], line=dict(width=2, color="black")),
name=f"Diagnostic for {c_name}"))
fig_bio.update_layout(title="Diagnostic Peak Comparison", template="plotly_white", xaxis_title="m/z", yaxis_title="Intensity")
fig_bio.show()
4.2 Biomarker Stabilityยถ
This chart shows which m/z features were identified as important across multiple classes. Highly stable biomarkers (appearing in multiple lists) are likely capturing fundamental chemical differences between the fish species.
[7]:
feat_counts = pd.Series([mz_axis[i] for i in all_top_indices]).value_counts().head(15)
px.bar(x=[f"{x:.2f}" for x in feat_counts.index], y=feat_counts.values,
labels={'x':'m/z Feature', 'y':'Frequency in Top Lists'},
title="Biomarker Stability (Aggregate across classes)", template="plotly_white")
4.3 Biomarker Networkยถ
The biomarker network visualizes the correlations between the top identified features. Peaks that are highly correlated ( > 0.8$) often belong to the same lipid or metabolic pathway. This provides a biochemical context to the modelโs features.
[8]:
top_indices = list(set(all_top_indices))
if len(top_indices) > 1:
subset_data = X_all[:, top_indices]
corr_matrix = np.corrcoef(subset_data.T)
G = nx.Graph()
for i in range(len(top_indices)): G.add_node(i, label=f"{mz_axis[top_indices[i]]:.1f}")
for i in range(len(top_indices)):
for j in range(i + 1, len(top_indices)):
if abs(corr_matrix[i, j]) > 0.8: G.add_edge(i, j, weight=abs(corr_matrix[i, j]))
pos = nx.spring_layout(G, seed=42)
edge_x, edge_y = [], []
for edge in G.edges():
x0, y0 = pos[edge[0]]; x1, y1 = pos[edge[1]]
edge_x.extend([x0, x1, None]); edge_y.extend([y0, y1, None])
edge_trace = go.Scatter(x=edge_x, y=edge_y, line=dict(width=0.5, color="#888"), hoverinfo="none", mode="lines")
node_trace = go.Scatter(x=[pos[n][0] for n in G.nodes()], y=[pos[n][1] for n in G.nodes()],
mode="markers+text", text=[G.nodes[n]["label"] for n in G.nodes()],
marker=dict(showscale=True, colorscale="YlGnBu", size=10,
color=[len(list(G.neighbors(n))) for n in G.nodes()], line_width=2))
fig_net = go.Figure(data=[edge_trace, node_trace],
layout=go.Layout(title="Biomarker Correlation Network (r > 0.8)",
showlegend=False, hovermode="closest",
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)))
fig_net.show()
else:
print("Not enough biomarkers to form a network.")