{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 03: Oil and Cross-species Adulteration\n", "\n", "This notebook focuses on **Regression** tasks within the REIMS framework. We specifically look at predicting oil concentration levels and identifying adulteration in cross-species samples. Unlike simple classification, regression models must learn the continuous linear or non-linear response of specific biomarkers to varying concentrations of a substance." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os, sys, warnings\n", "warnings.filterwarnings(\"ignore\")\n", "import pandas as pd, numpy as np\n", "import plotly.io as pio\n", "import plotly.express as px\n", "import plotly.graph_objects as go\n", "from sklearn.metrics import confusion_matrix, precision_recall_fscore_support, precision_recall_curve, average_precision_score, mean_absolute_error\n", "from fishy import TrainingConfig, run_unified_training, display_final_summary, create_data_module\n", "pio.renderers.default = \"notebook_connected\"\n", "try:\n", " import torch\n", " import torch.nn as nn\n", " from fishy.analysis.xai import GradCAM, ModelWrapper\n", " from lime.lime_tabular import LimeTabularExplainer\n", " from fishy._core.utils import get_device\n", " HAS_XAI = True\n", "except ImportError:\n", " HAS_XAI = False\n", " print(\"XAI dependencies (lime, torch) not fully available.\")\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Run regression training for the oil dataset\n", "config = TrainingConfig(model=\"rf\", dataset=\"oil\", regression=True, wandb_log=False)\n", "results = run_unified_training(config)\n", "display_final_summary(results)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Regression Calibration (Predicted vs. Actual)\n", "\n", "The predicted vs. actual plot is the gold standard for evaluating regression performance. A perfect model would have all points on the 45-degree dashed line. Deviations from this line indicate where the model is systematically over- or under-estimating the adulteration levels." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if \"predictions\" in results:\n", " p = results[\"predictions\"]\n", " y_true, y_pred = p[\"labels\"], p[\"preds\"]\n", " \n", " fig_cal = px.scatter(x=y_true, y=y_pred, labels={'x': 'Actual Concentration', 'y': 'Predicted Concentration'},\n", " title=\"Regression Calibration: Predicted vs. Actual\",\n", " opacity=0.6, template=\"plotly_white\")\n", " fig_cal.add_shape(type=\"line\", x0=min(y_true), y0=min(y_true), x1=max(y_true), y1=max(y_true), \n", " line=dict(color=\"Red\", dash=\"dash\"))\n", " fig_cal.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. MAE by Concentration (Limit of Detection)\n", "\n", "In chemical adulteration, it is critical to know at what concentration the model starts failing. This bar chart shows the Mean Absolute Error (MAE) for each concentration level. A high error at low concentrations (e.g., 1%) defines the model's **Limit of Detection**." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if \"predictions\" in results:\n", " errors = np.abs(y_true - y_pred)\n", " err_df = pd.DataFrame({\"Actual\": y_true, \"MAE\": errors})\n", " # Group by actual concentration and calculate mean error\n", " mae_by_conc = err_df.groupby(\"Actual\").mean().reset_index()\n", " \n", " px.bar(mae_by_conc, x=\"Actual\", y=\"MAE\", \n", " title=\"Prediction Error (MAE) by Concentration Level\",\n", " labels={'Actual': 'Actual Concentration (%)', 'MAE': 'Mean Absolute Error'},\n", " template=\"plotly_white\", color=\"MAE\", color_continuous_scale=\"Reds\").show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Residual Distribution Analysis\n", "\n", "Residuals (Error = Actual - Predicted) should ideally be normally distributed around zero. If the histogram is skewed or has multiple peaks, it suggests that the model is failing to capture specific chemical features related to certain concentration levels." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if \"predictions\" in results:\n", " residuals = y_true - y_pred\n", " px.histogram(x=residuals, nbins=30, labels={'x': 'Residual Error'}, \n", " title=\"Residual Distribution (Error Analysis)\", \n", " template=\"plotly_white\", color_discrete_sequence=['#636EFA']).show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. Biomarker Drift over Concentration\n", "\n", "In a regression task, specific m/z peaks should show a linear (or monotonic) response to concentration. This plot tracks the intensity of the top identified biomarkers across the sorted classes, proving that the model is learning biological gradients rather than memorizing samples." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if \"data_module\" in results:\n", " dm = results[\"data_module\"]\n", " X_raw, y_raw = dm.get_numpy_data(labels_as_indices=True)\n", " # Simple correlation to find linear biomarkers\n", " corrs = [np.corrcoef(X_raw[:, i], y_raw)[0, 1] for i in range(X_raw.shape[1])]\n", " top_bio_idx = np.argsort(np.abs(corrs))[-3:]\n", " \n", " feat_names = [f for f in dm.get_filtered_dataframe().columns if f not in [\"Class Name\", \"m/z\"]]\n", " \n", " drift_df = pd.DataFrame({\"Concentration\": y_raw})\n", " for idx in top_bio_idx:\n", " drift_df[f\"m/z {feat_names[idx]}\"] = X_raw[:, idx]\n", " \n", " melted_drift = drift_df.melt(id_vars=\"Concentration\", var_name=\"Biomarker\", value_name=\"Intensity\")\n", " px.scatter(melted_drift, x=\"Concentration\", y=\"Intensity\", color=\"Biomarker\",\n", " title=\"Biomarker Intensity Drift vs. Concentration\", template=\"plotly_white\").show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. Performance & Interpretability\n", "\n", "For tasks treated as ordinal classification, we use a confusion matrix. For interpretability, LIME reveals which spectral peaks are pushing the prediction toward higher or lower concentration levels." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if not config.regression and \"predictions\" in results:\n", " p = results[\"predictions\"]; cn = results[\"class_names\"]\n", " cm = confusion_matrix(p[\"labels\"], p[\"preds\"])\n", " px.imshow(cm, x=cn, y=cn, text_auto=True, title=\"Confusion Matrix\", color_continuous_scale=\"Blues\").show()\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if HAS_XAI and \"model\" in results and \"data_module\" in results:\n", " try:\n", " m = results[\"model\"]; dm = results[\"data_module\"]; X_x, y_x = dm.get_numpy_data(labels_as_indices=True)\n", " feature_names = [f\"{c}\" for c in dm.get_filtered_dataframe().columns if c not in [\"Class Name\", \"m/z\"]]\n", " explainer = LimeTabularExplainer(X_x, feature_names=feature_names, class_names=results[\"class_names\"], discretize_continuous=True)\n", " # For regression, we explain the single output value\n", " exp = explainer.explain_instance(X_x[0], ModelWrapper(m, str(get_device())).predict_proba if not config.regression else m.predict, num_features=10)\n", " el = exp.as_list()\n", " px.bar(x=[x[1] for x in el], y=[x[0] for x in el], orientation=\"h\", title=\"LIME Explanation (Sample 0)\").show()\n", " except Exception as e: print(f\"XAI Visualization failed: {e}\")\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.10.0" } }, "nbformat": 4, "nbformat_minor": 5 }