This project implements a two-stage neural network model to simulate and predict the dynamics of the Drosophila (fruit fly) whole brain activity.
This project requires two main datasets:
- connectome data from the FlyWire project (version 630/783): https://codex.flywire.ai/
- neural activity recordings from the Drosophila brain: https://doi.org/10.6084/m9.figshare.13349282
We have also provided preprocessed data files in the data/ directory for convenience.
Please download the datasets (https://drive.google.com/file/d/1TPuHJ-IC1yQtL5TMngAjGJa_JAMdusnu/view?usp=drive_link) and place them in the appropriate directories (data/) as specified in the code.
- Loads and processes Drosophila brain connectome data
- Simulates neural activity with biologically plausible dynamics
- Predicts firing rates across brain regions (neuropils)
- Evaluates prediction accuracy using bin classification and MSE metrics
- Generates comprehensive visualizations (heatmaps, time series, correlation plots, error analysis)
- Automatic experiment organization with filepath management
- Smart checkpoint resumption
The model follows a complete workflow with 4 stages:
- First Round Training (train1): Train the spiking neural network to capture brain dynamics
- Generate Training Data (eval1): Generate training data from the trained spiking network
- Second Round Training (train2): Train the RNN encoder/decoder to process input signals
- Evaluate: Evaluate the complete model and save predictions with comprehensive visualizations
Run the complete workflow from start to finish:
python drosophila_whole_brain_fitting.py --mode all --epoch_round1 50 --epoch_round2 50You can run each stage separately for more control:
# Stage 1: Train spiking neural network
python drosophila_whole_brain_fitting.py --mode train1 --epoch_round1 500
# Stage 2: Generate training data (requires checkpoint from stage 1)
python drosophila_whole_brain_fitting.py --mode eval1 --filepath results/v4_2/630#2017-10-26_1#...
# Stage 3: Train RNN (requires generated data)
python drosophila_whole_brain_fitting.py --mode train2 --epoch_round2 1000 --filepath results/v4_2/630#2017-10-26_1#...
# Stage 4: Evaluate (requires both checkpoints)
python drosophila_whole_brain_fitting.py --mode evaluate --filepath results/v4_2/630#2017-10-26_1#...--flywire_version: Version of the FlyWire connectome data (630or783, default:630)--neural_activity_id: ID of the neural activity recording (default:2017-10-26_1)--bin_size: Bin size for discretizing firing rates in Hz (default:0.25)--devices: GPU device IDs, e.g.,"0"or"0,1"(default:0)
-
--mode: Which stage(s) to run (default:all)all: Run complete pipeline (train1 → eval1 → train2 → evaluate)train1: First round training onlyeval1: Generate training data (evaluate stage 1)train2: Second round training onlyevaluate: Evaluation only
-
--filepath: Base directory path for checkpoints and results (default: auto-generated)- If not provided: Automatically generates a unique directory path based on training parameters
- If provided: Uses the specified directory and loads settings from
first-round-losses.txt - Examples:
- Auto-generated:
results/v4_2/630#2017-10-26_1#100.0Hz#...#2025-12-11-15-30-45 - Custom:
results/my_experiment
- Auto-generated:
- All checkpoints, logs, and results are saved to this directory
- Enables easy experiment resumption and organization
--epoch_round1: Number of epochs for first-round training (default:500)--epoch_round2: Number of epochs for second-round training (default:1000)--batch_size: Batch size for training (default:128)--lr: Learning rate for first-round training (default:0.01)--lr_round2: Learning rate for second-round training (default:0.001)
--etrace_decay: Decay factor for eligibility traces,0for non-temporal (default:0.99)--scale_factor: Scale factor for synaptic connections in mV (default:0.000825)--n_rank: LoRA rank for low-rank adaptation (default:20)--n_hidden: RNN hidden size for second-round training (default:256)--sim_before_train: Fraction of simulation steps before training (default:0.1)--noise_sigma: Noise sigma for data augmentation (default:0.05)--max_firing_rate: Maximum firing rate for neural activity in Hz (default:100.0)--loss_fn: Loss function (mse,mae,huber,cosine_distance,log_cosh, default:mse)--grad_clip: Gradient clipping value (default:1.0)
--dt: Time step for simulation in ms (default:0.2)--seed: Random seed for reproducibility (default:2025)--input_style: Input style for second-round training (v1orv2, default:v1)--split: Train/test split ratio (informational, default:0.6)
Quick test run with reduced epochs:
python drosophila_whole_brain_fitting.py --mode all --epoch_round1 2 --epoch_round2 2Full training with custom hyperparameters:
python drosophila_whole_brain_fitting.py \
--mode all \
--flywire_version 630 \
--neural_activity_id 2017-10-26_1 \
--epoch_round1 500 \
--epoch_round2 1000 \
--lr 0.01 \
--batch_size 128 \
--devices 0Using custom filepath for experiment organization:
python drosophila_whole_brain_fitting.py \
--mode all \
--filepath results/my_experiment \
--epoch_round1 500 \
--epoch_round2 1000Resuming from auto-generated filepath:
# First run creates auto-generated path
python drosophila_whole_brain_fitting.py --mode train1 --epoch_round1 100
# Resume by providing the generated path
python drosophila_whole_brain_fitting.py \
--filepath results/v4_2/630#2017-10-26_1#...#2025-12-11-15-30-45 \
--mode allThe workflow creates the following outputs in the results directory:
Checkpoints:
first-round-checkpoint.msgpack: Best checkpoint from first-round training (spiking network)second-round-rnn-checkpoint-v1.msgpack: Best checkpoint from second-round training (RNN)
Training Data and Logs:
simulated_neuropil_fr.npy: Generated training data from stage 1 evaluationfirst-round-losses.txt: Training logs and hyperparameter settings for first roundevaluation_stats.txt: Evaluation metrics summary (bin accuracy, MSE loss)
Predictions:
neuropil_fr_predictions.npy: Final predictions on test data
Visualizations:
images/: Training visualizations (neuropil firing rate comparisons during training)evaluation_plots/: Comprehensive evaluation visualizations including:heatmap_comparison.png: Side-by-side heatmaps of ground truth vs simulated activitytime_series_comparison.png: Temporal evolution for selected neuropilscorrelation_scatter.png: Predicted vs actual firing rates with correlationbarplot_comparison_t*.png: Bar plot comparisons at 3 sample time pointserror_analysis.png: 4-panel error diagnostics (heatmap, distribution, per-neuropil, over-time)
The model automatically evaluates performance when running in all or evaluate mode:
Metrics:
- Bin accuracy: Percentage of correctly predicted firing rate bins
- MSE loss: Mean squared error between predicted and actual firing rates
- Correlation: Pearson correlation between predictions and ground truth
Outputs:
- Numerical results saved to
evaluation_stats.txt - Predictions saved to
neuropil_fr_predictions.npy - Comprehensive visualizations saved to
evaluation_plots/directory
The workflow generates two types of visualizations:
Generated during first-round training and saved to images/ directory:
- Neuropil firing rate comparisons (simulated vs ground truth)
- Bar plots for each batch showing model progress
Generated during final evaluation and saved to evaluation_plots/ directory:
- Heatmap Comparison: Side-by-side visualization of ground truth and simulated firing rates across all neuropils and time steps
- Time Series Plots: Detailed temporal evolution for 6 neuropils with highest variance
- Correlation Scatter Plot: Overall prediction accuracy with correlation coefficient
- Sample Bar Plots: Neuropil-level comparisons at 25%, 50%, and 75% time points
- Error Analysis: Four-panel comprehensive error diagnostics:
- Absolute error heatmap across time and neuropils
- Relative error distribution histogram
- Mean error per neuropil bar plot
- Mean error over time line plot
All visualizations are publication-ready (150 DPI, proper labels and legends).
If you use this code or data, please cite:
@Article{Wang2026,
author={Wang, Chaoming
and Dong, Xingsi
and Ji, Zilong
and Xiao, Mingqing
and Jiang, Jiedong
and Liu, Xiao
and Huan, Yuxiang
and Wu, Si},
title={Model-agnostic linear-memory online learning in spiking neural networks},
journal={Nature Communications},
year={2026},
month={Jan},
day={19},
abstract={Spiking neural networks (SNNs) offer a promising paradigm for modeling brain dynamics and developing neuromorphic intelligence, yet an online learning system capable of training rich spiking dynamics over long horizons with low memory footprints has been missing. Existing online approaches either incur quadratic memory growth, sacrifice biological fidelity through oversimplified models, or lack end-to-end automated tooling. Here, we introduce BrainTrace, a model-agnostic, linear-memory, and automated online learning system for spiking neural networks. BrainTrace standardizes model specification to encompass diverse neuronal and synaptic dynamics; implements a linear-memory online learning rule by exploiting intrinsic properties of spiking dynamics; and provides a compiler that automatically generates optimized online-learning code for arbitrary user-defined models. Across diverse dynamics and tasks, BrainTrace achieves strong learning performance with a low memory footprint and high computational throughput. Critically, these properties enable online fitting of a whole-brain-scale Drosophila SNN that recapitulates region-level functional activity. By reconciling generality, efficiency, and usability, BrainTrace establishes a foundation for spiking network modeling at scale.},
issn={2041-1723},
doi={10.1038/s41467-026-68453-w},
url={https://doi.org/10.1038/s41467-026-68453-w}
}