A modular, GPU-accelerated Physics-Informed Neural Network framework built on JAX + Flax + Optax
underPINN is a research-grade PINN engine that combines classical collocation-based PINNs with Finite Basis decomposition (FBPINN), attention-augmented networks, residual-based adaptive weighting/resampling, transfer learning (including windowed time-marching for long-horizon unsteady flows), shock capturing with learnable artificial viscosity, non-Newtonian (Carreau) blood rheology, inverse problems, and a full restart/resume system — all JIT-compiled and differentiable via XLA on CPU, GPU, and TPU.
- MLP — standard multi-layer perceptron with tanh activations; configurable depth and width via a layer list
- GatedMLP — modified MLP (Wang et al. 2022): two input encoders U/V are gate-blended into every hidden layer, curing the pathological gradient flow of plain tanh MLPs on stiff PDEs; selectable in any flow example via
network.type: gated_mlp - FourierMLP — random Fourier feature embeddings (trainable σ) prepended to a standard MLP; essential for oscillatory solutions (Helmholtz, wave, high-Re flows) where plain MLPs fail to represent high spatial frequencies
- FBPINN — overlapping subdomain decomposition with sigmoid partition-of-unity windows; each subdomain gets its own network so training is never dominated by one region
- HybridAttention + SimpleGate — gated residual blocks inside each FBPINN subdomain; SimpleGate multiplies the hidden state element-wise by a learnable gate for compact, expressive feature modulation
lax.scanfused kernels — fuse N gradient steps into a single XLA kernel, eliminating Python dispatch between epochs; delivers 50–500× less overhead on GPU compared to a Python for-loop- Cosine LR decay — via
optax.cosine_decay_schedule; integrates seamlessly withTrainingConfig - RAR-D adaptive collocation resampling — periodically replaces a fraction of collocation points with samples drawn proportional to
|residual|^k(Lu et al., 2021); focuses compute on high-error regions without changing the total batch size - RAR/RAD shock-focused resampling —
rad_resample(Wu et al., 2023;p ∝ r^k/E[r^k] + c) refreshes the interior pool toward shocks/contacts in the compressible cases (ramp,sod_shock,toro3); config knobsrar_period,rar_candidates,rar_k,rar_c - Artificial viscosity for shocks — global Laplacian dissipation
−ε∇²Uon the conserved variables in the compressible Euler cases; ε can be fixed (art_visc) or learned asε = softplus(log_av)jointly with the network (trainable_visc: true) - Time-marching transfer learning — long-horizon unsteady problems split into windows; each window warm-starts from the previous one and chains its end-state as the next initial condition (pulsatile pipe flow), with per-window checkpoints and window-level restart
- RBA element-wise loss weighting — residual-based adaptivity assigns per-point weights so that boundary and collocation losses are automatically balanced during training
- EarlyStopping — monitors a metric (default: total loss) and halts training after
patienceepochs without improvement TrainingConfigdataclass — centralises all hyperparameters with runtime validation; a single object is passed to every solver- Callbacks —
ConsoleLogger(prints loss every N epochs),EarlyStopping(halts on plateau),ModelCheckpoint(saves best model during training); all callbacks fire correctly even insidelax.scanloops
RestartManagersavesparams.msgpack,opt_state.msgpack, loss histories, and ameta.jsonto<out_dir>/restart/everysave_restart_everyepochs- On re-run,
RestartManagerchecks the"done"flag; ifdone: false(interrupted run), training resumes exactly from the last snapshot — epoch counter, optimizer state, and loss histories are all restored done()marker — once training completes normally or via early stopping, the snapshot is marked"done": true; the nextrunstarts fresh. An interrupted run leavesdone: falseand auto-resumes on the nextrunresumeCLI command —python -m underPINN resume <config.yaml>verifies the MD5 config hash and resetsdoneso a completed run can be continued; warns if any field (lr, epochs, layers, …) changed since the last snapshot
- JAX's XLA BFC allocator pre-reserves ~90% of all free VRAM the moment
import jaxexecutes; on an 80 GB A100 this shows as ~73 GB reserved even for a 3-layer MLP - underPINN sets
XLA_PYTHON_CLIENT_PREALLOCATE=falseautomatically — inunderPINN/__main__.pyfor CLI runs and at the top of every example script for directpython examples/…runs — so on-demand GPU memory growth is the default behaviour out of the box
- 1-D Burgers equation (
u_t + uu_x = νu_xx) - 1-D / 2-D diffusion / heat — forward and inverse (recover thermal diffusivity α)
- 1-D wave equation (
u_tt = c²u_xx) - 2-D Helmholtz (
Δu + k²u = f, manufactured source) - 2-D steady incompressible Navier-Stokes (lid-driven cavity Re=100, NACA airfoil, circular cylinder Re=40)
- 2-D RANS k-ε turbulence model (turbulent channel, Re=10 000)
- 3-D steady incompressible Navier-Stokes (Hagen-Poiseuille pipe flow, AAA bulge)
- 3-D unsteady incompressible Navier-Stokes (
(x,y,z,t) → (u,v,w,p); pulsatile pipe via time-marching transfer) - 3-D generalized-Newtonian (Carreau) Navier-Stokes — shear-thinning blood rheology (
μ(γ̇) = μ∞ + (μ0−μ∞)[1+(λγ̇)²]^((n−1)/2)); pipe and AAA cases - 2-D steady compressible Euler — conservative flux-divergence form with optional artificial viscosity (oblique-shock ramp, Mach 3, θ=10°)
- 1-D unsteady compressible Euler — Sod shock tube with learnable artificial viscosity + exact Riemann reference
- 1-D unsteady compressible Euler — Toro test 3 (Woodward–Colella blast wave, 5-decade pressure jump) with exp/log positivity + reference-state non-dimensionalisation for the extreme dynamic range
- Unsteady pipe cross-section (
(y, z, t) → u) - Harmonic oscillator (
d²u/dt² + ω²u = 0) - Exponential decay ODE (
du/dt + λu = 0) - FBPINN ODE (
du/dx = cos(ω x)) — domain decomposition into overlapping subdomains; defeats the spectral bias that stalls a single PINN at high ω
- Interval — 1-D uniform / stratified sampler
- Rectangle — 2-D interior + boundary-aware sampling
- NACA 4-digit airfoil — symmetric and cambered profiles; angle of attack imposed by rotating the airfoil about the quarter-chord (inflow stays horizontal); SDF-weighted near-surface sampling; camber-aware upper/lower surface split
- Circular cylinder (2-D) — exterior cross-flow domain with analytic SDF (
Cylinder2D) - Cylindrical Pipe — interior, wall, inlet, and outlet face samplers
- AAA bulge (
BulgeGeometry) — axisymmetric vessel with a cosine-squared AAA bulgeR(x); interior / curved-wall / inlet / outlet samplers - Ramp — trapezoidal domain above a wedge surface for compressible shock problems
- Composite — boolean combinations of any geometry objects
- Shapely-backed polygon — arbitrary 2-D polygon sampler backed by Shapely 2.x
FBPINNSolver— space-time PDE training withlax.scan, RAR-D, and RestartManager integrationODESolver— lightweight ODE training with callbacks and checkpointingSteadySolver— stationary (no time dimension) PDE trainingLDCSolver— lid-driven cavity / FBPINN variantRANSSolver— k-ε turbulence model with RBA loss weighting
- Every runner saves
params.msgpack+params_meta.jsonto the output directory after training ModelPredictor.from_meta(path)rebuilds the exact model architecture from the JSON sidecar and loads weights — zero boilerplate, no need to re-specify layersModelCheckpointcallback saves the best model (by monitored metric) during training
- Parameter transfer — warm-start from a trained model when changing ν, Re, or diffusivity; converges 2–3× faster than training from scratch
- Temporal transfer — extend the time horizon by fine-tuning on a new time interval starting from a previously trained checkpoint
- Both modes use
solver.load_params(src_params)orsolver.restore_checkpoint(path)
- Joint optimisation of network weights + physics parameters (e.g. recover thermal diffusivity α from 50 sparse noisy observations)
- Log-parameterisation (
log_alpha = log(α)) ensures positivity without constraints - Gradient flows simultaneously through the PDE residual and the observation loss
- One command (
python -m underPINN bench) runs all registered problems across multiple epoch budgets - Outputs include PNG accuracy plots, convergence grids, CSV tables, wall-time charts, and a Markdown summary report
--from-jsonreplays plotting from a previous JSON result without re-training
run— run a single problem from a YAML configresume— verify the config MD5 hash against a stored snapshot and resetdonefor continuation; warns if any config field changedsweep— Cartesian product hyperparameter sweep; each combination gets its own sub-directorybench— full benchmark suitelist— list all registered runnersshow— print the resolved config without trainingversion— print the framework version
Calendar versioning (CalVer YYMM) — May 2026 → underPINN-v2605
# CPU / development
pip install jax flax optax matplotlib scipy shapely pandas pyyaml# GPU (CUDA 12)
pip install -U "jax[cuda12]" && pip install -r requirements-gpu.txt# TPU (Google Colab TPU runtime / Cloud TPU VM)
pip install -r requirements-tpu.txt
pip install -e . --no-deps # --no-deps so pip can't replace the TPU jax with the cpu pin
# underPINN auto-detects the TPU and forces full-float32 matmuls
# (the bf16 MXU default corrupts second-order PDE residuals)# From source (editable install — recommended)
git clone https://github.com/Aeroscience-Computations-Analysis-Lab/underPINN.git
cd underPINN-v2605
pip install -e .python -c "import jax; print(jax.devices())"
# Expected on GPU: [CudaDevice(id=0)]JAX's XLA BFC (Best-Fit with Coalescing) allocator pre-reserves approximately 90% of all free VRAM the moment import jax executes — before any tensor is created — to avoid memory fragmentation during training. On an 80 GB A100 this appears as ~73 GB reserved even for a tiny 3-layer MLP that actually uses only ~200 MB of active arrays.
This is a deliberate XLA design choice: by owning the memory pool upfront, it can coalesce and reuse buffers without ever calling cudaMalloc again during training. The downside is that two JAX processes cannot share a GPU gracefully unless you set explicit limits.
The environment variable is set in underPINN/__main__.py (for CLI runs) and at the top of every example script (for direct python examples/… runs) before import jax, so you get on-demand allocation out of the box:
# This is already done for you — shown here for transparency
import os
os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")
import jax # now allocates only what it actually needs# On-demand growth (default in underPINN) — frees all unreserved VRAM for other jobs
export XLA_PYTHON_CLIENT_PREALLOCATE=false
# Hard cap — useful when sharing a node; limits to e.g. 20% of VRAM
export XLA_PYTHON_CLIENT_PREALLOCATE=false
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.20
# Platform allocator — no XLA pool at all (slowest, minimal fragmentation)
export XLA_PYTHON_CLIENT_ALLOCATOR=platform
# Multi-GPU: restrict to a single device (e.g. GPU 1)
export CUDA_VISIBLE_DEVICES=1import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.15"
import jax # now uses at most 15% of VRAM| Problem | Network | VRAM (approx) |
|---|---|---|
| Burgers 1-D | [2,64,64,64,1] | ~200 MB |
| Wave 1-D | FourierMLP | ~300 MB |
| Helmholtz 2-D | FourierMLP | ~400 MB |
| LDC 2-D | FBPINN | ~800 MB |
| Airfoil 2-D | [2,128,128,128,3] | ~1.2 GB |
| Pipe Flow 3-D | [3,64,64,64,64,4] | ~2.0 GB |
| Compressible Ramp | [2,80,80,80,80,80,4] | ~1.8 GB |
| k-ε Turbulence | FBPINN | ~3.0 GB |
python -m underPINN run examples/burgers/config.yaml
python -m underPINN run examples/wave/config.yaml
python -m underPINN run examples/pipe_flow/pipe_flow.yaml
python -m underPINN run examples/ramp/config.yaml
python -m underPINN run examples/cylinder/config.yaml # flow over a cylinder (Re=40)
python -m underPINN run examples/sod_shock/config.yaml # Sod tube, learnable ε
python -m underPINN run examples/AAA/config.yaml # 3-D AAA bulge
python -m underPINN run examples/pipe_flow_rheology/config.yaml # Carreau blood, pipe
python -m underPINN run examples/AAA_rheology/config.yaml # Carreau blood, AAA
python -m underPINN run examples/pipe_flow/pipe_flow_pulsatile_transfer.yaml # pulsatile, time-marching TLimport os
os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")
import jax, optax, jax.numpy as jnp
from underPINN.nn.mlp import MLP
from underPINN.pde.burgers import BurgersPDE
from underPINN.losses.loss import PINNLoss
from underPINN.solver.fbpinn import FBPINNSolver
from underPINN.core.config import TrainingConfig
from underPINN.callbacks.logging import ConsoleLogger
from underPINN.callbacks.early_stopping import EarlyStopping
model = MLP(layers=[2, 64, 64, 64, 1])
pde = BurgersPDE(model, nu=0.01)
loss = PINNLoss(model, pde, ic_weight=100.0, bc_weight=10.0, rba=True)
solver = FBPINNSolver(model, pde, loss=loss)
solver.init(jax.random.PRNGKey(0))
config = TrainingConfig(
epochs = 5000,
lr = 1e-3,
lr_schedule = optax.cosine_decay_schedule(1e-3, 5000, alpha=1e-2),
batch_r = 2048,
log_every = 500,
out_dir = "outputs/burgers", # enables auto-restart
save_restart_every = 500,
callbacks = [
ConsoleLogger(log_every=500),
EarlyStopping(patience=400),
],
)
solver.train(*data, config=config)Each example folder is self-contained — script + YAML live together. Run any problem directly:
python examples/burgers/burgers.py
python examples/wave/wave.py
python examples/helmholtz/helmholtz.py
python examples/heat/forward.py
python examples/heat/inverse.py
python examples/LDC/run_ldc.py
python examples/airfoil/airfoil_flow.py
python examples/cylinder/cylinder_flow.py
python examples/ode/ode_test.py
python examples/pipe_flow/pipe_flow.py
python examples/pipe_flow/pipe_flow_unsteady_transfer.py
python examples/pipe_flow/pipe_flow_pulsatile_transfer.py
python examples/pipe_flow_rheology/pipe_flow_rheology.py
python examples/AAA/AAA_flow.py
python examples/AAA_rheology/AAA_rheology.py
python examples/ramp/ramp.py
python examples/sod_shock/sod_shock.py
python examples/transfer/burgers_transfer.py
python examples/inverse/inverse_diffusion.py
# Pass a custom config as the first argument:
python examples/burgers/burgers.py my_custom.yaml# Steady pipe & AAA (Newtonian or Carreau) — axial-plane u contour + streamlines,
# pressure contour & line plots, wall shear stress, and an NPZ of the solution:
python examples/predict_steady.py outputs/pipe_flow
python examples/predict_steady.py outputs/AAA_rheology
# Pulsatile pipe (time-marching) — point queries, snapshot/spacetime plots, GIF:
python examples/pipe_flow/predict_pulsatile.py outputs/pipe_flow_pulsatile_transfer --t 2.7 --plot
python examples/pipe_flow/predict_pulsatile.py outputs/pipe_flow_pulsatile_transfer --spacetime --animate# Single run
python -m underPINN run examples/burgers/config.yaml
# Hyperparameter sweep (Cartesian product)
python -m underPINN sweep examples/burgers/burgers_nu_sweep.yaml
python -m underPINN sweep examples/pipe_flow/pipe_flow_re_sweep.yaml
# Benchmark all problems
python -m underPINN bench
python -m underPINN bench --problems burgers wave helmholtz --epochs 500 1000 2000 5000
python -m underPINN bench --all
python -m underPINN bench --from-json outputs/bench/results.json
# Utilities
python -m underPINN list # list all registered runners
python -m underPINN show examples/wave/config.yaml # inspect resolved config
python -m underPINN version # print version stringproblem: burgers # selects the runner (one of the registered problems)
network:
type : mlp # mlp | fourier_mlp
layers: [2, 64, 64, 64, 1]
physics:
nu: 0.01 # PDE parameters (problem-specific)
data:
T: 2.0 # time horizon
n_collocation: 6000 # interior collocation points
n_ic: 200 # initial-condition points
n_bc: 200 # boundary-condition points
training:
epochs : 5000
lr : 1.0e-3
early_stopping_patience : 400 # omit to disable
save_restart_every : 500 # snapshot every 500 epochs (0 to disable)
loss:
ic_weight: 100.0 # weight on IC loss term
bc_weight: 10.0 # weight on BC loss term
rba : true # residual-based adaptivity
output:
dir : outputs/burgers # predictions, loss, config, model saved here
save_params: true # write params.msgpack + params_meta.jsonbase: # shared config for all runs
problem: burgers
network:
type: mlp
layers: [2, 64, 64, 64, 1]
training:
epochs: 5000
sweep: # dot-separated key → list of values
physics.nu : [0.1, 0.05, 0.025, 0.01]
training.epochs : [3000, 5000]Each run gets its own sub-directory (outputs/…/run_000, run_001, …) with a saved config.yaml for full reproducibility.
1. Create examples/<mycase>/mycase.py — define run_mycase(cfg) -> dict
2. Create examples/<mycase>/config.yaml — set problem: mycase
3. Add ONE line to underPINN/runner/dispatch.py:
"mycase": ("examples/mycase/mycase.py", "run_mycase"),
No other files need to change.
| Field | Type | Default | Description |
|---|---|---|---|
epochs |
int | 1000 | Total training epochs |
lr |
float | 1e-3 | Base learning rate |
lr_schedule |
optax schedule | None | Overrides lr when set; use optax.cosine_decay_schedule |
batch_r |
int | 4096 | Collocation mini-batch size |
batch_i |
int | 512 | Initial-condition mini-batch |
batch_b |
int | 512 | Boundary-condition mini-batch |
log_every |
int | 100 | Print interval (used by ConsoleLogger) |
seed |
int | 0 | PRNG seed |
callbacks |
list | [] | List of Callback objects |
n_scan_steps |
int | 1 | Fuse N steps into one XLA kernel (1 = Python loop) |
resample_period |
int | 0 | RAR-D resampling every N outer steps (0 = off) |
resample_candidates |
int | 0 | Candidate pool size (0 → 5 × batch_r) |
resample_k |
float | 1.0 | Exponent in p ∝ |residual|^k |
out_dir |
str | "" | Output directory; enables auto-restart when non-empty |
save_restart_every |
int | 500 | Snapshot interval in epochs (0 = off) |
ConsoleLogger
ConsoleLogger(log_every=500)
# Prints: [epoch / total] loss=X.XXe-04 pde=X.XXe-04 ic=X.XXe-03 ...EarlyStopping
EarlyStopping(patience=400, monitor="loss", min_delta=1e-8)
# Raises StopIteration (caught by the solver) after `patience` epochs without improvement.
# Works correctly inside lax.scan loops — fires at the outer-step boundary.ModelCheckpoint
ModelCheckpoint(
out_dir="outputs/burgers/",
monitor="loss", # metric key from the loss aux dict
mode="min", # "min" or "max"
save_best_only=True, # skip non-improving epochs
metadata={"problem": "burgers", "network": {"type": "mlp", "layers": [2,64,64,64,1]}},
)
# Writes params.msgpack + params_meta.json whenever a new best is reached.Instead of a Python for loop that calls back into Python every epoch, lax.scan unrolls N gradient steps into a single compiled XLA program. The Python interpreter only touches the computation once per n_scan_steps iterations, dramatically reducing dispatch overhead:
config = TrainingConfig(
epochs = 5000,
lr = 1e-3,
n_scan_steps = 100, # 50 outer Python calls instead of 5000
callbacks = [ConsoleLogger(log_every=500)],
)
solver.train(*data, config=config)n_scan_steps |
Python calls / 5 000 epochs | Callback granularity | Use case |
|---|---|---|---|
| 1 (default) | 5 000 | every epoch | Development / debugging |
| 100 | 50 | every 100 epochs | GPU training, medium runs |
| 500 | 10 | every 500 epochs | Long GPU runs, production |
At every resample_period outer steps, the solver:
- Evaluates the PDE residual
r(x)at a pool ofresample_candidatescandidate points - Computes sampling probabilities
p(x) ∝ |r(x)|^k - Replaces the lowest-residual collocation points with new draws from this distribution
This concentrates compute on high-error regions without changing total batch size or requiring any geometry change.
training:
n_scan_steps : 100
resample_period : 5 # every 5 outer steps = every 500 epochs
resample_k : 1.0 # linear in |residual|The restart system lets you safely interrupt and resume any training run without losing progress. It is fully automatic — just set save_restart_every in your config.
-
Every
save_restart_everyepochs, a snapshot is written to<out_dir>/restart/:params.msgpack— Flax-serialised model parameters at that epochopt_state.msgpack— Flax-serialised optimizer state (Adam moments, step count)hists.npz— all loss history arrays accumulated so far (loss_hist,pde_hist, etc.)meta.json—{"epoch": N, "cfg_hash": "...", "done": false}
-
On re-run,
RestartManagerdetects the snapshot directory and readsmeta.json. Ifdone: false(the run was interrupted), params, optimizer state, and loss histories are restored and training continues from the saved epoch. Plots are continuous across restarts. -
Config-change safety — solvers do not hash-check configs automatically. An interrupted run will resume even if you changed
lr,epochs, or other fields. Usepython -m underPINN resume <config.yaml>to verify the config hash before resuming. To force a fresh start without the command, delete<out_dir>/restart/manually. -
After training finishes — either normally or via early stopping —
done()marks the snapshot with"done": true. The next run with the same config file starts fresh rather than re-resuming a completed run.
training:
save_restart_every: 500 # snapshot every 500 epochs; 0 to disableconfig = TrainingConfig(
epochs = 10000,
out_dir = "outputs/burgers",
save_restart_every = 500,
)
solver.train(*data, config=config)That's all. If the process is killed at epoch 3 700, the next run resumes from epoch 3 500 (the last snapshot) automatically.
| File | Contents |
|---|---|
params.msgpack |
Flax-serialised model parameters |
opt_state.msgpack |
Flax-serialised optimizer state |
hists.npz |
Loss history arrays (loss_hist, pde_hist, etc.) |
meta.json |
{"epoch": N, "cfg_hash": null, "done": false} — cfg_hash is null when run via solver; populated by the resume CLI |
Solvers gate resumption on the done flag only — not on a config hash. This means an interrupted run will auto-resume even if you changed lr, epochs, layers, or physics parameters. To check for config changes before resuming, use the resume CLI command:
python -m underPINN resume examples/burgers/config.yamlresume computes the MD5 of the current YAML, compares it against the hash stored in meta.json, and warns you if any field changed since the last snapshot. If everything is consistent, it resets done to false so the next run will resume.
To force a fresh start without the command, simply delete <out_dir>/restart/ or set save_restart_every: 0 temporarily.
Every runner writes two files to the output directory after training:
outputs/burgers/
params.msgpack ← exact Flax/msgpack serialization of all weights
params_meta.json ← {"problem": "burgers", "network": {"type": "mlp", "layers": [...]}, ...}
predictions.npz ← collocation-point predictions
config.yaml ← resolved training config (reproducibility)
loss_hist.npy
loss.png
from underPINN.callbacks.checkpoint import ModelCheckpoint
ModelCheckpoint(
out_dir="outputs/burgers/",
monitor="loss",
mode="min",
save_best_only=True,
metadata={"problem": "burgers", "network": {"type": "mlp", "layers": [2, 64, 64, 64, 1]}},
)from underPINN.utils.checkpoint import ModelPredictor
import jax.numpy as jnp
# Option A — auto-build model from saved metadata (zero boilerplate)
predictor = ModelPredictor.from_meta("outputs/burgers/")
# Option B — provide model explicitly
from underPINN.nn.mlp import MLP
predictor = ModelPredictor.from_checkpoint(
MLP(layers=[2, 64, 64, 64, 1]),
"outputs/burgers/",
)
# Run inference
x_new = jnp.linspace(-1.0, 1.0, 500)
t_new = jnp.full(500, 0.8)
u = predictor.predict(jnp.stack([x_new, t_new], axis=1))from underPINN.utils.checkpoint import save_checkpoint, load_checkpoint
# Save any param pytree
save_checkpoint(params, "my_dir/", metadata={"problem": "wave", "network": {"layers": [...]}})
# Load (model used as template for structure)
params = load_checkpoint(model, "my_dir/")underPINN supports two transfer learning modes, both using the same warm-start API.
# Phase 1: train source model (e.g. Burgers ν=0.1)
solver_src.train(*data_src, config=cfg_src)
solver_src.save_checkpoint("outputs/source/")
# Phase 2: warm-start target from source weights, then fine-tune (e.g. ν=0.01)
solver_tgt.load_params(solver_src.params) # or restore_checkpoint("outputs/source/")
solver_tgt.train(*data_tgt, config=cfg_tgt) # lower lr recommended (3e-4 instead of 1e-3)
# Converges 2-3× faster than training from scratch# Phase 1: train on t ∈ [0, T_1]
solver_phase1.train(*data_t1, config=cfg_phase1)
# Phase 2: extend to t ∈ [0, T_2], T_2 > T_1, warm-start from Phase 1
solver_phase2.load_params(solver_phase1.params)
solver_phase2.train(*data_t2, config=cfg_phase2)Both modes are demonstrated in examples/transfer/burgers_transfer.py and examples/pipe_flow/pipe_flow_unsteady_transfer.py.
The heat inverse problem (examples/heat/inverse.py) recovers the unknown thermal diffusivity α from 50 sparse noisy observations:
- Joint optimisation: the optimizer simultaneously updates network weights
θand the physics parameterlog_α = log(α)via a singlejax.gradcall - Log-parameterisation: optimising
log_αinstead ofαdirectly guarantees positivity without any constraints or projections; the true α is recovered asexp(log_α)after training - Observation loss: a separate MSE term penalises the discrepancy between model predictions at the 50 observation locations and the noisy measurements; the PDE residual loss is the regulariser
# Simplified view of the inverse problem setup
from underPINN.pde.diffusion import DiffusionInversePDE
pde = DiffusionInversePDE(model, log_alpha_init=jnp.log(0.5))
# pde.log_alpha is a trainable parameter alongside model weights
# After training: alpha_recovered = jnp.exp(pde.log_alpha)The 2-D diffusion inverse (examples/inverse/inverse_diffusion.py) follows the same pattern for a 2-D domain.
underPINN/
├── core/
│ ├── base.py # BasePDE, BaseLoss, BaseSolver (+ save/restore_checkpoint)
│ └── config.py # TrainingConfig dataclass with validation
│
├── nn/
│ ├── mlp.py # MLP, FourierMLP
│ ├── fbpinn.py # FBPINN (domain-decomposed network)
│ ├── attention.py # HybridAttention, SimpleGate
│ ├── embeddings.py # Fourier / positional embeddings
│ └── subdomain.py # SubdomainNetwork
│
├── pde/
│ ├── burgers.py # 1-D Burgers equation
│ ├── diffusion.py # 1-D unsteady diffusion / heat inverse
│ ├── heat.py # 2-D steady heat (Poisson)
│ ├── heat2d_unsteady.py # 2-D unsteady heat (x, y, t) → u
│ ├── helmholtz.py # 2-D Helmholtz Δu + k²u = f
│ ├── wave.py # 1-D wave equation u_tt = c²u_xx
│ ├── navier_stokes.py # 2-D steady incompressible N-S
│ ├── navier_stokes_3d.py# 3-D steady + UNSTEADY incompressible N-S
│ ├── carreau_ns_3d.py # 3-D Carreau (shear-thinning) N-S + 1-D exact profile
│ ├── compressible_euler.py # 2-D steady Euler — conservative form + artificial viscosity
│ ├── euler_1d_unsteady.py # 1-D unsteady Euler (Sod) — learnable artificial viscosity
│ ├── pipe_flow_unsteady.py # Unsteady pipe cross-section (y, z, t) → u
│ ├── k_epsilon.py # RANS k-ε turbulence model
│ └── ode.py # Exponential decay, Harmonic oscillator
│
├── geometry/
│ ├── interval.py # 1-D interval sampler
│ ├── rectangle.py # 2-D rectangle sampler
│ ├── airfoil.py # NACA 4-digit (sym/cambered) + AoA rotation + SDF sampling
│ ├── cylinder.py # 2-D circular cylinder (cross-flow exterior)
│ ├── pipe.py # Cylindrical pipe (interior, wall, inlet, outlet)
│ ├── aaa.py # BulgeGeometry — axisymmetric AAA bulge R(x)
│ ├── ramp.py # Trapezoidal ramp domain above a wedge (compressible Euler)
│ ├── composite.py # Boolean combination of geometries
│ └── shapely_geom.py # Shapely-backed arbitrary polygon sampler
│
├── solver/
│ ├── fbpinn.py # FBPINNSolver (space-time PDE, lax.scan, RAR-D)
│ ├── ode_solver.py # ODESolver
│ ├── steady_solver.py # SteadySolver (no time dimension)
│ ├── ldc_solver.py # LDCSolver (lid-driven cavity / FBPINN)
│ └── rans_solver.py # RANSSolver (k-ε turbulence)
│
├── losses/
│ ├── loss.py # PINNLoss (with optional RBA)
│ ├── ode_loss.py # ODELoss
│ └── steady_loss.py # SteadyLoss
│
├── callbacks/
│ ├── base.py # Callback ABC
│ ├── logging.py # ConsoleLogger
│ ├── early_stopping.py # EarlyStopping
│ └── checkpoint.py # ModelCheckpoint (save best model during training)
│
├── runner/ # CLI dispatch only — runner logic lives in examples/
│ ├── dispatch.py # _REGISTRY: problem → (script path, fn name)
│ ├── pipe_flow.py # pipe_flow runner helper
│ ├── wave.py # wave runner helper
│ └── heat_forward.py # heat_forward runner helper
│
├── training/
│ └── resample.py # rar_d_resample (RAR-D adaptive collocation)
│
├── config/
│ └── loader.py # load_config, generate_sweep_configs, cfg_get
│
├── benchmark_utils/
│ ├── evaluators.py # per-problem evaluators with exact solutions
│ ├── benchmark_suite.py # BenchmarkResult, BenchmarkRunner
│ └── report.py # plots, CSV, Markdown report generation
│
├── utils/
│ ├── io.py # save_predictions (NPZ archives)
│ ├── sampling.py # safe_choice (replace-safe mini-batching)
│ ├── seed.py # set_seed (Python + NumPy + JAX)
│ ├── checkpoint.py # save_checkpoint, load_checkpoint, ModelPredictor
│ ├── restart.py # RestartManager (snapshot + resume + done marker)
│ ├── timing.py # fmt_train_time (JIT-aware training time reporting)
│ ├── metrics.py # rel_l2, mse helpers
│ └── plotting.py # plot_losses, plot_ode_result
│
└── __main__.py # CLI entry point (python -m underPINN)
# sets XLA_PYTHON_CLIENT_PREALLOCATE=false before import jax
examples/ # self-contained: each folder holds script + YAML
│ # Adding a new case = create folder + add 1 line to dispatch.py
├── burgers/ burgers.py + config.yaml (1-D Burgers FBPINN + RBA)
├── wave/ wave.py + config.yaml (1-D wave FourierMLP)
├── heat/ forward.py + heat_forward.yaml (2-D steady heat / Poisson)
│ inverse.py + heat_inverse.yaml (recover α from noisy data)
├── helmholtz/ helmholtz.py + config.yaml (2-D Helmholtz FourierMLP)
├── ode/ ode_test.py + config.yaml (exp decay + harmonic osc.)
├── fbpinn_ode/ fbpinn_ode.py + config.yaml (FBPINN subdomains, du/dx=cos ωx)
├── inverse/ inverse_diffusion.py + config.yaml (2-D diffusion inverse)
├── LDC/ run_ldc.py + config.yaml (2-D Lid-Driven Cavity Re=100)
├── K-Epsilon/ run_kepsilon.py + config.yaml (k-ε RANS turbulent channel)
├── airfoil/ airfoil_flow.py + config.yaml (NACA airfoil, AoA via rotation)
├── cylinder/ cylinder_flow.py + config.yaml (cylinder cross-flow, Re=40)
├── pipe_flow/ pipe_flow.py + pipe_flow.yaml (3-D Hagen-Poiseuille)
│ pipe_flow_unsteady_transfer.py + yaml (Re + temporal transfer)
│ pipe_flow_pulsatile_transfer.py + yaml (3-D pulsatile, time-marching TL)
│ predict_pulsatile.py (window predictor: plots, GIF, spacetime)
├── pipe_flow_rheology/ pipe_flow_rheology.py + config.yaml (Carreau blood, pipe)
├── AAA/ AAA_flow.py + config.yaml (3-D AAA bulge, Newtonian)
├── AAA_rheology/ AAA_rheology.py + config.yaml (Carreau blood, AAA bulge)
├── ramp/ ramp.py + config.yaml (2-D compressible Euler, M=3, AV + RAR)
├── sod_shock/ sod_shock.py + config.yaml (Sod tube, learnable ε + RAR)
├── toro3/ toro3.py + config.yaml (Toro-3 blast wave, exp positivity + non-dim)
├── predict_steady.py (post-process steady pipe/AAA: WSS, contours, NPZ)
└── transfer/ burgers_transfer.py + yaml (Burgers param + temp. TL)
heat2d_transfer.py + yaml (2-D heat transfer)
docs/
└── index.html # Static framework documentation website
| Problem | PDE | Network | Key Features | Config |
|---|---|---|---|---|
| Exponential Decay | du/dt + λu = 0 | MLP [1,32,32,1] | ODESolver, TrainingConfig, callbacks |
examples/ode/config.yaml |
| Harmonic Oscillator | d²u/dt² + ω²u = 0 | MLP [1,32,32,1] | ODESolver, IC derivative |
examples/ode/config.yaml |
| FBPINN ODE | du/dx = cos(ω x) | FBPINN — 15 subnets [1,16,16,1] | Overlapping subdomains + partition-of-unity windows, hard IC constraint | examples/fbpinn_ode/config.yaml |
| 1-D Burgers | u_t + uu_x = νu_xx | MLP [2,64,64,64,1] | FBPINN, RBA, cosine LR | examples/burgers/config.yaml |
| 1-D Heat — Forward | u_t = αu_xx | MLP [2,64,64,64,1] | FBPINNSolver, exact Gaussian IC |
examples/heat/heat_forward.yaml |
| 1-D Heat — Inverse | u_t = αu_xx | MLP [2,64,64,64,1] | Recover α from 50 noisy observations | examples/heat/heat_inverse.yaml |
| 1-D Wave | u_tt = c²u_xx | FourierMLP [2,128,128,1] | Dual IC (u and u_t), n_fourier=32 | examples/wave/config.yaml |
| 2-D Helmholtz | Δu + k²u = f | FourierMLP [2,128,128,1] | k=4, manufactured source term | examples/helmholtz/config.yaml |
| 2-D Diffusion Inverse | u_t = α∇²u | MLP [3,64,64,64,1] | Log-param joint optimisation | examples/inverse/config.yaml |
| 2-D Lid-Driven Cavity | Steady N-S, Re=100 | FBPINN + SimpleGate | LDCSolver, attention, Re=100 |
examples/LDC/config.yaml |
| 2-D RANS k-ε | Turbulent channel | FBPINN | RANSSolver, RBA, Re=10000 |
examples/K-Epsilon/config.yaml |
| 2-D Compressible Ramp | Steady Euler (conservative), M=3 | MLP [2,80,80,80,80,80,4] | Oblique shock θ=10°, artificial viscosity (fixed/learnable), RAR | examples/ramp/config.yaml |
| 1-D Sod Shock Tube | Unsteady Euler (conservative) | MLP [2,80×5,3] | Learnable ε = softplus(log_av), exact Riemann reference, RAR | examples/sod_shock/config.yaml |
| 1-D Toro Test 3 (blast wave) | Unsteady Euler (conservative) | MLP [2,128×4,3] | exp/log positivity, reference-state non-dimensionalisation, learnable ε, RAR | examples/toro3/config.yaml |
| NACA Airfoil | Steady N-S, Re=100 | MLP / GatedMLP [2,128×6,3] | Cambered profiles, AoA via airfoil rotation, surface pressure & Cp | examples/airfoil/config.yaml |
| Cylinder Cross-flow | Steady N-S, Re=40 | MLP [2,128×6,3] | Pure-PINN recipe, Cp(θ) vs inviscid reference, wake pool | examples/cylinder/config.yaml |
| 3-D Pipe Flow | Steady 3-D N-S | MLP / GatedMLP [3,…,4] | Double-jacfwd Hessian, Hagen-Poiseuille exact | examples/pipe_flow/pipe_flow.yaml |
| 3-D AAA Bulge | Steady 3-D N-S | GatedMLP [3,192×5,4] | Cosine² bulge R(x), flow-rate balance check |
examples/AAA/config.yaml |
| Carreau Pipe (blood) | Steady Carreau N-S | GatedMLP [3,128×4,4] | Shear-thinning μ(γ̇), 1-D Carreau exact profile, β=16; same domain/Re as Newtonian pipe | examples/pipe_flow_rheology/config.yaml |
| Carreau AAA (blood) | Steady Carreau N-S | GatedMLP [3,192×5,4] | Blood rheology in the bulge, apparent-viscosity maps | examples/AAA_rheology/config.yaml |
| 3-D Pulsatile Pipe | Unsteady 3-D N-S | GatedMLP [4,…,4] | Time-marching transfer (windowed), per-window ckpts, window restart | examples/pipe_flow/pipe_flow_pulsatile_transfer.yaml |
| 3-D Unsteady Pipe Transfer | u_t = G + ν∇²u | MLP [3,64,64,64,64,1] | Bessel exact solution, Re + temporal TL | examples/pipe_flow/pipe_flow_unsteady_transfer.yaml |
| Burgers Transfer | Burgers | MLP [2,64,64,64,1] | Parameter transfer (ν) + temporal transfer | examples/transfer/burgers_transfer.yaml |
| Heat 2-D Transfer | 2-D heat | MLP [3,64,64,64,1] | Cross-diffusivity transfer + temporal | examples/transfer/heat2d_transfer.yaml |
| PDE | Equation | Key method | Used in |
|---|---|---|---|
| Burgers (1-D) | u_t + uu_x = νu_xx | BurgersPDE.residual |
examples/burgers/, examples/transfer/ |
| Diffusion / Heat (1-D) | u_t = αu_xx | DiffusionPDE.residual |
examples/heat/ |
| Heat (2-D unsteady) | u_t = α(u_xx + u_yy) | Heat2DPDE.residual |
examples/inverse/, examples/transfer/ |
| Wave (1-D) | u_tt = c²u_xx | WavePDE.residual |
examples/wave/ |
| Helmholtz (2-D) | Δu + k²u = f | HelmholtzPDE.residual |
examples/helmholtz/ |
| Navier-Stokes (2-D steady) | ∇·u=0, u·∇u = -∇p + ν∇²u | NavierStokesPDE.residual |
examples/LDC/, examples/airfoil/, examples/cylinder/ |
| Navier-Stokes (3-D steady) | Same + z-momentum | SteadyNS3DPDE.residual |
examples/pipe_flow/, examples/AAA/ |
| Navier-Stokes (3-D unsteady) | u_t + (u·∇)u = −∇p + ν∇²u | UnsteadyNS3DPDE.residual |
examples/pipe_flow/ (pulsatile) |
| Carreau N-S (3-D steady) | ∇·[μ*(γ̇)(∇u+∇uᵀ)] stress | CarreauNS3DPDE.residual |
examples/pipe_flow_rheology/, examples/AAA_rheology/ |
| Pipe unsteady | u_t = G + ν(u_yy + u_zz) | PipeUnsteadyPDE.residual |
examples/pipe_flow/ |
| RANS k-ε | N-S + k + ε transport | KEpsilonPDE.residual |
examples/K-Epsilon/ |
| Compressible Euler (2-D steady) | ∂F/∂x + ∂G/∂y = ε∇²U (conservative) | CompressibleEulerPDE.residual |
examples/ramp/ |
| Compressible Euler (1-D unsteady) | ∂U/∂t + ∂F/∂x = ε∂²U/∂x² | Euler1DUnsteadyPDE.residual |
examples/sod_shock/ |
| Exponential Decay | du/dt + λu = 0 | ExpDecayODE.residual |
examples/ode/ |
| Harmonic Oscillator | d²u/dt² + ω²u = 0 | HarmonicODE.residual |
examples/ode/ |
| Class | What it samples | Example uses |
|---|---|---|
Interval |
1-D uniform or Sobol interior + boundary points | Burgers, wave, heat (1-D) |
Rectangle |
2-D interior (LHS / Sobol) + all four boundary edges | Helmholtz, LDC, diffusion inverse |
NACAAirfoil |
NACA 4-digit (symmetric & cambered) exterior domain, SDF-weighted near-surface, AoA via quarter-chord rotation | examples/airfoil/ |
Cylinder2D |
Circular cylinder exterior cross-flow domain, analytic SDF, surface points | examples/cylinder/ |
Pipe |
3-D cylindrical interior, lateral wall, circular inlet, circular outlet | examples/pipe_flow/, examples/pipe_flow_rheology/ |
BulgeGeometry |
Axisymmetric AAA bulge R(x) (cosine²): interior, curved wall, inlet, outlet |
examples/AAA/, examples/AAA_rheology/ |
Ramp |
Trapezoidal domain above a wedge surface at angle θ | examples/ramp/ |
Composite |
Boolean union / intersection / difference of any two geometry objects | LDC (cavity minus any obstacle) |
ShapelyGeom |
Arbitrary 2-D polygon backed by Shapely 2.x; rejection-samples interior | Custom geometries |
# Run all fast problems with default epoch budgets [500, 1000, 2000, 5000]
python -m underPINN bench
# Select specific problems and budgets
python -m underPINN bench \
--problems burgers wave helmholtz heat_steady ode_exp ode_harmonic \
--epochs 500 1000 2000 5000 \
--output outputs/bench
# Include slow problems (3-D pipe flow, k-ε)
python -m underPINN bench --all
# Regenerate plots from a previous run without re-training
python -m underPINN bench --from-json outputs/bench/results.json| File | Description |
|---|---|
accuracy_vs_epochs.png |
Log-log rel-L² vs epoch budget, one line per problem |
accuracy_summary_bar.png |
Grouped bar chart of rel-L² at each epoch budget |
wall_time_vs_epochs.png |
Training time vs epoch budget |
ms_per_epoch.png |
Bar chart of training throughput per problem |
loss_grid.png |
Convergence curves for each problem |
benchmark_results.csv |
Full raw data table |
benchmark_summary.md |
Markdown table (one row per problem at max epochs) |
results.json |
Reusable JSON for --from-json replays |
Set XLA_PYTHON_CLIENT_PREALLOCATE=false before importing JAX. underPINN does this automatically in every entry point, but if you are writing a new script, add it at the top before import jax.
On GPU, each Python→XLA dispatch has ~1 ms of overhead. With 5 000 epochs and n_scan_steps=1 that is ~5 s of pure dispatch. With n_scan_steps=100 it drops to ~50 ms. For long runs, use n_scan_steps=500.
Enable RAR-D (resample_period > 0) when the solution has sharp gradients or shocks (Burgers at low ν, Euler ramp). A typical setting is resample_period=5, resample_k=1.0.
- Fast ODEs:
patience=200 - Medium PDEs (Burgers, wave, Helmholtz):
patience=400–800 - Complex PDEs (LDC, airfoil, 3-D pipe):
patience=1000–2000
JAX defaults to float32, which is optimal on all GPUs. Do not enable jax.config.update("jax_enable_x64", True) unless you have a specific reason; it halves throughput on CUDA devices.
Use CUDA_VISIBLE_DEVICES=1 to restrict to a specific GPU. Full multi-GPU sharded training is not currently implemented; single-device training on the fastest GPU is the recommended approach. To use several GPUs, launch one run per device (e.g. a parameter sweep with a different CUDA_VISIBLE_DEVICES per process).
Install with requirements-tpu.txt (see Installation). On import, underPINN detects the TPU backend and sets jax_default_matmul_precision = "highest" automatically — the TPU's default bfloat16 matmuls corrupt second-order PDE residuals (Hessians), so full-float32 matmuls are required for PINN accuracy. Override via the JAX_DEFAULT_MATMUL_PRECISION env var.
Always prefer optax.cosine_decay_schedule over a fixed learning rate for runs longer than 2 000 epochs. It provides free accuracy improvement at no cost by reducing the LR smoothly toward a small alpha value (recommended: alpha=1e-2).
Every solver prints a timing summary at the end of training:
Training complete — final loss 1.23e-04 | 45.2s [JIT≈12s + 3.3ms/ep]
The JIT≈… component appears when the first epoch is ≥ 3 s and at least 4× slower than the average of subsequent epochs, separating XLA compilation overhead from actual training time. The ms/ep figure is the mean wall-clock cost per epoch after JIT warm-up — useful for benchmarking solver configurations.
# Every PDE implements one method
class BasePDE(ABC):
@abstractmethod
def residual(self, params, *args): ...
# Every loss is callable and returns (total, aux_tuple)
class BaseLoss(ABC):
@abstractmethod
def __call__(self, params, *args, **kwargs): ...
# Every solver has init + train + checkpoint helpers (inherited)
class BaseSolver(ABC):
@abstractmethod
def init(self, key): ...
@abstractmethod
def train(self, *args, **kwargs): ...
# Concrete — available on every solver:
def save_checkpoint(self, out_dir, stem="params", metadata=None): ...
def restore_checkpoint(self, path): ...If you use underPINN in research or publications, please cite:
@software{underPINN_v2605,
author = {Kumar Prashant, Senthilkumar Lohith, Ranjan Rajesh},
title = {underPINN-v2605: A Modular JAX Framework for Physics-Informed Neural Networks},
year = {2026},
version = {v2605},
url = {https://github.com/Aeroscience-Computations-Analysis-Lab/underPINN.git}
}underPINN is released under the GPL-3.0 License. See LICENSE.txt for the full text.