{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 02: Fish Species and Part Identification\n", "\n", "This notebook demonstrates the identification of fish species and parts using Transformers, including a deep dive into stable biomarkers." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import plotly.io as pio\n", "pio.renderers.default = \"notebook_connected\"\n", "import pandas as pd, numpy as np, torch, torch.nn as nn\n", "import plotly.express as px, plotly.graph_objects as go, networkx as nx\n", "from fishy import TrainingConfig, run_unified_training, display_final_summary\n", "from fishy.analysis.xai import ModelWrapper, GradCAM\n", "from fishy._core.utils import get_device\n", "from lime.lime_tabular import LimeTabularExplainer\n", "from sklearn.metrics import confusion_matrix\n", "\n", "# Train the model\n", "config = TrainingConfig(model=\"transformer\", dataset=\"species\", epochs=10, wandb_log=False)\n", "results = run_unified_training(config)\n", "display_final_summary(results)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Interactive Confusion Matrix\n", "\n", "The confusion matrix is the first step in understanding model performance. This interactive heatmap allows us to see exactly which classes are being confused. In REIMS data, confusion often happens between biologically similar species (e.g., different types of white fish)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if \"predictions\" in results:\n", " p = results[\"predictions\"]; cn = results[\"class_names\"]\n", " cm = confusion_matrix(p[\"labels\"], p[\"preds\"])\n", " fig_cm = px.imshow(cm, x=cn, y=cn, text_auto=True, \n", " title=\"Interactive Confusion Matrix\", \n", " color_continuous_scale=\"Viridis\", template=\"plotly_white\")\n", " fig_cm.update_layout(xaxis_title=\"Predicted Species\", yaxis_title=\"Actual Species\")\n", " fig_cm.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Feature Interaction Matrix\n", "\n", "Rather than looking at individual peaks in isolation, we visualize how the top 10 most important features (identified via saliency) interact with each other across the dataset. This correlation matrix reveals chemical dependencies and provides a 10x10 interaction map of the model's most critical decision-making features." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = results[\"model\"]; dm = results[\"data_module\"]\n", "try:\n", " model.eval()\n", " # Generating 10x10 Feature Interaction Matrix from Saliency\n", " print(\"Generating 10x10 Feature Interaction Matrix from Saliency...\")\n", " target_layer = model.layer_norm2\n", " gc = GradCAM(model, target_layer)\n", " \n", " # Use first sample for saliency\n", " X_all, _ = dm.get_numpy_data()\n", " input_t = torch.tensor(X_all[0:1], dtype=torch.float32).to(get_device())\n", " input_t.requires_grad = True \n", " \n", " cam = gc.generate_cam(input_t).cpu().numpy()[0]\n", " gc.remove_hooks()\n", " \n", " top_10_idx = cam.argsort()[-10:][::-1]\n", " top_10_idx.sort() # Keep spectral order\n", " top_10_names = [dm.get_filtered_dataframe().columns[i+1] for i in top_10_idx]\n", " \n", " # Compute correlation matrix for these top features across the dataset\n", " corr_matrix = np.corrcoef(X_all[:, top_10_idx].T)\n", " \n", " fig_inter = px.imshow(corr_matrix, x=top_10_names, y=top_10_names, \n", " title=\"Top 10 Feature Interaction Matrix (Saliency Correlations)\",\n", " color_continuous_scale=\"RdBu_r\", template=\"plotly_white\")\n", " fig_inter.show()\n", "except Exception as e:\n", " print(f\"Visualization failed: {e}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Grad-CAM Saliency Overlay\n", "\n", "While attention shows relationships, Grad-CAM shows **local importance**. By coloring the spectrum by its gradient-based importance, we can see exactly which chemical peaks \"triggered\" the model's final classification decision." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "try:\n", " # Select a layer that preserves the spectral feature dimension\n", " # layer_norm2 is the last layer before pooling and fc_out\n", " target_layer = model.layer_norm2\n", " gc = GradCAM(model, target_layer)\n", " \n", " X_raw, y_raw = dm.get_numpy_data(labels_as_indices=True)\n", " sample_idx = 0\n", " input_t = torch.tensor(X_raw[sample_idx:sample_idx+1], dtype=torch.float32).to(get_device())\n", " \n", " cam = gc.generate_cam(input_t).cpu().numpy()[0]\n", " gc.remove_hooks()\n", " \n", " feature_names = [f for f in dm.get_filtered_dataframe().columns if f not in [\"Class Name\", \"m/z\"]]\n", " mz_axis = np.array([float(f) for f in feature_names])\n", " spec = X_raw[sample_idx]\n", " \n", " fig_cam = go.Figure()\n", " fig_cam.add_trace(go.Scatter(x=mz_axis, y=spec, name=\"Spectrum\", line=dict(color=\"lightgray\", width=1)))\n", " fig_cam.add_trace(go.Scatter(x=mz_axis, y=spec, mode=\"markers\", \n", " marker=dict(color=cam, colorscale=\"Viridis\", size=8, showscale=True, \n", " colorbar=dict(title=\"Importance\")),\n", " name=\"Grad-CAM Importance\"))\n", " \n", " fig_cam.update_layout(title=f\"Grad-CAM Saliency: {results['class_names'][y_raw[sample_idx]]}\",\n", " xaxis_title=\"m/z\", yaxis_title=\"Intensity\", template=\"plotly_white\")\n", " fig_cam.show()\n", "except Exception as e:\n", " print(f\"Grad-CAM visualization failed: {e}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. Advanced Biomarker Analysis\n", "\n", "Beyond simple classification, we want to identify which chemical features (m/z peaks) are consistently used by the model to distinguish between classes. We use LIME across a subset of samples to find stable, class-specific diagnostic markers." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class_names = results[\"class_names\"]\n", "X_all, y_all = dm.get_numpy_data(labels_as_indices=True)\n", "feature_names = [f\"{c}\" for c in dm.get_filtered_dataframe().columns if c not in [\"Class Name\", \"m/z\"]]\n", "mz_axis = np.array([float(c) for c in feature_names])\n", "\n", "print(\"Analyzing stable biomarkers via LIME aggregate...\")\n", "wrapper = ModelWrapper(model, str(get_device()))\n", "explainer = LimeTabularExplainer(X_all, feature_names=feature_names, class_names=class_names, discretize_continuous=False)\n", "\n", "class_biomarkers = {}\n", "all_top_indices = []\n", "\n", "for c_idx in range(len(class_names)):\n", " c_indices = np.where(y_all == c_idx)[0]\n", " if len(c_indices) > 0:\n", " sub = np.random.choice(c_indices, min(10, len(c_indices)), replace=False)\n", " w_list = []\n", " for idx in sub:\n", " exp = explainer.explain_instance(X_all[idx], wrapper.predict_proba, num_features=20, labels=(c_idx,))\n", " w_list.append(dict(exp.as_list(label=c_idx)))\n", " \n", " avg_w = pd.DataFrame(w_list).mean().sort_values()\n", " top_features = avg_w.tail(10).index.tolist()\n", " class_biomarkers[c_idx] = [feature_names.index(f) for f in top_features]\n", " all_top_indices.extend(class_biomarkers[c_idx])\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 4.1 Distinct Class Biomarker Comparison\n", "\n", "This plot overlays the top 10 diagnostic peaks for each class on top of a representative spectrum. It highlights the features that the model has learned are unique to each class." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig_bio = go.Figure()\n", "colors = [\"gold\", \"cyan\", \"silver\"]\n", "for c_idx, c_name in enumerate(class_names):\n", " rep_idx = np.where(y_all == c_idx)[0][0]\n", " spec = X_all[rep_idx]\n", " bio_indices = class_biomarkers[c_idx]\n", " \n", " fig_bio.add_trace(go.Scatter(x=mz_axis, y=spec, name=f\"{c_name} Base\", line=dict(width=1), opacity=0.3))\n", " fig_bio.add_trace(go.Scatter(x=mz_axis[bio_indices], y=spec[bio_indices], mode=\"markers\", \n", " marker=dict(size=12, color=colors[c_idx % len(colors)], line=dict(width=2, color=\"black\")),\n", " name=f\"Diagnostic for {c_name}\"))\n", "\n", "fig_bio.update_layout(title=\"Diagnostic Peak Comparison\", template=\"plotly_white\", xaxis_title=\"m/z\", yaxis_title=\"Intensity\")\n", "fig_bio.show()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 4.2 Biomarker Stability\n", "\n", "This chart shows which m/z features were identified as important across multiple classes. Highly stable biomarkers (appearing in multiple lists) are likely capturing fundamental chemical differences between the fish species." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "feat_counts = pd.Series([mz_axis[i] for i in all_top_indices]).value_counts().head(15)\n", "px.bar(x=[f\"{x:.2f}\" for x in feat_counts.index], y=feat_counts.values, \n", " labels={'x':'m/z Feature', 'y':'Frequency in Top Lists'}, \n", " title=\"Biomarker Stability (Aggregate across classes)\", template=\"plotly_white\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 4.3 Biomarker Network\n", "\n", "The biomarker network visualizes the correlations between the top identified features. Peaks that are highly correlated ( > 0.8$) often belong to the same lipid or metabolic pathway. This provides a biochemical context to the model's features." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "top_indices = list(set(all_top_indices))\n", "if len(top_indices) > 1:\n", " subset_data = X_all[:, top_indices]\n", " corr_matrix = np.corrcoef(subset_data.T)\n", " G = nx.Graph()\n", " for i in range(len(top_indices)): G.add_node(i, label=f\"{mz_axis[top_indices[i]]:.1f}\")\n", " for i in range(len(top_indices)):\n", " for j in range(i + 1, len(top_indices)):\n", " if abs(corr_matrix[i, j]) > 0.8: G.add_edge(i, j, weight=abs(corr_matrix[i, j]))\n", " \n", " pos = nx.spring_layout(G, seed=42)\n", " edge_x, edge_y = [], []\n", " for edge in G.edges():\n", " x0, y0 = pos[edge[0]]; x1, y1 = pos[edge[1]]\n", " edge_x.extend([x0, x1, None]); edge_y.extend([y0, y1, None])\n", " \n", " edge_trace = go.Scatter(x=edge_x, y=edge_y, line=dict(width=0.5, color=\"#888\"), hoverinfo=\"none\", mode=\"lines\")\n", " node_trace = go.Scatter(x=[pos[n][0] for n in G.nodes()], y=[pos[n][1] for n in G.nodes()], \n", " mode=\"markers+text\", text=[G.nodes[n][\"label\"] for n in G.nodes()],\n", " marker=dict(showscale=True, colorscale=\"YlGnBu\", size=10, \n", " color=[len(list(G.neighbors(n))) for n in G.nodes()], line_width=2))\n", " \n", " fig_net = go.Figure(data=[edge_trace, node_trace], \n", " layout=go.Layout(title=\"Biomarker Correlation Network (r > 0.8)\", \n", " showlegend=False, hovermode=\"closest\", \n", " xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), \n", " yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)))\n", " fig_net.show()\n", "else:\n", " print(\"Not enough biomarkers to form a network.\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.10.0" } }, "nbformat": 4, "nbformat_minor": 5 }