02: Fish Species and Part Identificationยถ

This notebook demonstrates the identification of fish species and parts using Transformers, including a deep dive into stable biomarkers.

[1]:
import plotly.io as pio
pio.renderers.default = "notebook_connected"
import pandas as pd, numpy as np, torch, torch.nn as nn
import plotly.express as px, plotly.graph_objects as go, networkx as nx
from fishy import TrainingConfig, run_unified_training, display_final_summary
from fishy.analysis.xai import ModelWrapper, GradCAM
from fishy._core.utils import get_device
from lime.lime_tabular import LimeTabularExplainer
from sklearn.metrics import confusion_matrix

# Train the model
config = TrainingConfig(model="transformer", dataset="species", epochs=10, wandb_log=False)
results = run_unified_training(config)
display_final_summary(results)

INFO     Initialized RunContext: transformer on species
INFO     Evaluating pre-training phase
INFO     --- Fold 1/3 ---
INFO     --- Fold 2/3 ---
INFO     --- Fold 3/3 ---

โ•ญโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฎ
โ”‚ Training Complete - Results Summary โ”‚
โ”‚  Metric              Train     Val  โ”‚
โ”‚  Accuracy           0.9954  0.6296  โ”‚
โ”‚  Balanced Accuracy  0.9948  0.6250  โ”‚
โ”‚  MAE                0.0046  0.3704  โ”‚
โ”‚  MSE                0.0046  0.3704  โ”‚
โ”‚  Precision          0.9955  0.6297  โ”‚
โ”‚  Recall             0.9954  0.6296  โ”‚
โ”‚  F1 Score           0.9954  0.6294  โ”‚
โ•ฐโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฏ
Elapsed training time: 67.9778 seconds

1. Interactive Confusion Matrixยถ

The confusion matrix is the first step in understanding model performance. This interactive heatmap allows us to see exactly which classes are being confused. In REIMS data, confusion often happens between biologically similar species (e.g., different types of white fish).

[2]:
if "predictions" in results:
    p = results["predictions"]; cn = results["class_names"]
    cm = confusion_matrix(p["labels"], p["preds"])
    fig_cm = px.imshow(cm, x=cn, y=cn, text_auto=True,
                       title="Interactive Confusion Matrix",
                       color_continuous_scale="Viridis", template="plotly_white")
    fig_cm.update_layout(xaxis_title="Predicted Species", yaxis_title="Actual Species")
    fig_cm.show()

2. Feature Interaction Matrixยถ

Rather than looking at individual peaks in isolation, we visualize how the top 10 most important features (identified via saliency) interact with each other across the dataset. This correlation matrix reveals chemical dependencies and provides a 10x10 interaction map of the modelโ€™s most critical decision-making features.

[3]:
model = results["model"]; dm = results["data_module"]
try:
    model.eval()
    # Generating 10x10 Feature Interaction Matrix from Saliency
    print("Generating 10x10 Feature Interaction Matrix from Saliency...")
    target_layer = model.layer_norm2
    gc = GradCAM(model, target_layer)

    # Use first sample for saliency
    X_all, _ = dm.get_numpy_data()
    input_t = torch.tensor(X_all[0:1], dtype=torch.float32).to(get_device())
    input_t.requires_grad = True

    cam = gc.generate_cam(input_t).cpu().numpy()[0]
    gc.remove_hooks()

    top_10_idx = cam.argsort()[-10:][::-1]
    top_10_idx.sort() # Keep spectral order
    top_10_names = [dm.get_filtered_dataframe().columns[i+1] for i in top_10_idx]

    # Compute correlation matrix for these top features across the dataset
    corr_matrix = np.corrcoef(X_all[:, top_10_idx].T)

    fig_inter = px.imshow(corr_matrix, x=top_10_names, y=top_10_names,
                         title="Top 10 Feature Interaction Matrix (Saliency Correlations)",
                         color_continuous_scale="RdBu_r", template="plotly_white")
    fig_inter.show()
except Exception as e:
    print(f"Visualization failed: {e}")
Generating 10x10 Feature Interaction Matrix from Saliency...

3. Grad-CAM Saliency Overlayยถ

While attention shows relationships, Grad-CAM shows local importance. By coloring the spectrum by its gradient-based importance, we can see exactly which chemical peaks โ€œtriggeredโ€ the modelโ€™s final classification decision.

[4]:
try:
    # Select a layer that preserves the spectral feature dimension
    # layer_norm2 is the last layer before pooling and fc_out
    target_layer = model.layer_norm2
    gc = GradCAM(model, target_layer)

    X_raw, y_raw = dm.get_numpy_data(labels_as_indices=True)
    sample_idx = 0
    input_t = torch.tensor(X_raw[sample_idx:sample_idx+1], dtype=torch.float32).to(get_device())

    cam = gc.generate_cam(input_t).cpu().numpy()[0]
    gc.remove_hooks()

    feature_names = [f for f in dm.get_filtered_dataframe().columns if f not in ["Class Name", "m/z"]]
    mz_axis = np.array([float(f) for f in feature_names])
    spec = X_raw[sample_idx]

    fig_cam = go.Figure()
    fig_cam.add_trace(go.Scatter(x=mz_axis, y=spec, name="Spectrum", line=dict(color="lightgray", width=1)))
    fig_cam.add_trace(go.Scatter(x=mz_axis, y=spec, mode="markers",
                                 marker=dict(color=cam, colorscale="Viridis", size=8, showscale=True,
                                             colorbar=dict(title="Importance")),
                                 name="Grad-CAM Importance"))

    fig_cam.update_layout(title=f"Grad-CAM Saliency: {results['class_names'][y_raw[sample_idx]]}",
                          xaxis_title="m/z", yaxis_title="Intensity", template="plotly_white")
    fig_cam.show()
except Exception as e:
    print(f"Grad-CAM visualization failed: {e}")

4. Advanced Biomarker Analysisยถ

Beyond simple classification, we want to identify which chemical features (m/z peaks) are consistently used by the model to distinguish between classes. We use LIME across a subset of samples to find stable, class-specific diagnostic markers.

[5]:
class_names = results["class_names"]
X_all, y_all = dm.get_numpy_data(labels_as_indices=True)
feature_names = [f"{c}" for c in dm.get_filtered_dataframe().columns if c not in ["Class Name", "m/z"]]
mz_axis = np.array([float(c) for c in feature_names])

print("Analyzing stable biomarkers via LIME aggregate...")
wrapper = ModelWrapper(model, str(get_device()))
explainer = LimeTabularExplainer(X_all, feature_names=feature_names, class_names=class_names, discretize_continuous=False)

class_biomarkers = {}
all_top_indices = []

for c_idx in range(len(class_names)):
    c_indices = np.where(y_all == c_idx)[0]
    if len(c_indices) > 0:
        sub = np.random.choice(c_indices, min(10, len(c_indices)), replace=False)
        w_list = []
        for idx in sub:
            exp = explainer.explain_instance(X_all[idx], wrapper.predict_proba, num_features=20, labels=(c_idx,))
            w_list.append(dict(exp.as_list(label=c_idx)))

        avg_w = pd.DataFrame(w_list).mean().sort_values()
        top_features = avg_w.tail(10).index.tolist()
        class_biomarkers[c_idx] = [feature_names.index(f) for f in top_features]
        all_top_indices.extend(class_biomarkers[c_idx])

Analyzing stable biomarkers via LIME aggregate...

4.1 Distinct Class Biomarker Comparisonยถ

This plot overlays the top 10 diagnostic peaks for each class on top of a representative spectrum. It highlights the features that the model has learned are unique to each class.

[6]:
fig_bio = go.Figure()
colors = ["gold", "cyan", "silver"]
for c_idx, c_name in enumerate(class_names):
    rep_idx = np.where(y_all == c_idx)[0][0]
    spec = X_all[rep_idx]
    bio_indices = class_biomarkers[c_idx]

    fig_bio.add_trace(go.Scatter(x=mz_axis, y=spec, name=f"{c_name} Base", line=dict(width=1), opacity=0.3))
    fig_bio.add_trace(go.Scatter(x=mz_axis[bio_indices], y=spec[bio_indices], mode="markers",
                                 marker=dict(size=12, color=colors[c_idx % len(colors)], line=dict(width=2, color="black")),
                                 name=f"Diagnostic for {c_name}"))

fig_bio.update_layout(title="Diagnostic Peak Comparison", template="plotly_white", xaxis_title="m/z", yaxis_title="Intensity")
fig_bio.show()

4.2 Biomarker Stabilityยถ

This chart shows which m/z features were identified as important across multiple classes. Highly stable biomarkers (appearing in multiple lists) are likely capturing fundamental chemical differences between the fish species.

[7]:
feat_counts = pd.Series([mz_axis[i] for i in all_top_indices]).value_counts().head(15)
px.bar(x=[f"{x:.2f}" for x in feat_counts.index], y=feat_counts.values,
       labels={'x':'m/z Feature', 'y':'Frequency in Top Lists'},
       title="Biomarker Stability (Aggregate across classes)", template="plotly_white")

4.3 Biomarker Networkยถ

The biomarker network visualizes the correlations between the top identified features. Peaks that are highly correlated ( > 0.8$) often belong to the same lipid or metabolic pathway. This provides a biochemical context to the modelโ€™s features.

[8]:
top_indices = list(set(all_top_indices))
if len(top_indices) > 1:
    subset_data = X_all[:, top_indices]
    corr_matrix = np.corrcoef(subset_data.T)
    G = nx.Graph()
    for i in range(len(top_indices)): G.add_node(i, label=f"{mz_axis[top_indices[i]]:.1f}")
    for i in range(len(top_indices)):
        for j in range(i + 1, len(top_indices)):
            if abs(corr_matrix[i, j]) > 0.8: G.add_edge(i, j, weight=abs(corr_matrix[i, j]))

    pos = nx.spring_layout(G, seed=42)
    edge_x, edge_y = [], []
    for edge in G.edges():
        x0, y0 = pos[edge[0]]; x1, y1 = pos[edge[1]]
        edge_x.extend([x0, x1, None]); edge_y.extend([y0, y1, None])

    edge_trace = go.Scatter(x=edge_x, y=edge_y, line=dict(width=0.5, color="#888"), hoverinfo="none", mode="lines")
    node_trace = go.Scatter(x=[pos[n][0] for n in G.nodes()], y=[pos[n][1] for n in G.nodes()],
                            mode="markers+text", text=[G.nodes[n]["label"] for n in G.nodes()],
                            marker=dict(showscale=True, colorscale="YlGnBu", size=10,
                                        color=[len(list(G.neighbors(n))) for n in G.nodes()], line_width=2))

    fig_net = go.Figure(data=[edge_trace, node_trace],
                        layout=go.Layout(title="Biomarker Correlation Network (r > 0.8)",
                                         showlegend=False, hovermode="closest",
                                         xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                                         yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)))
    fig_net.show()
else:
    print("Not enough biomarkers to form a network.")