Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
120 commits
Select commit Hold shift + click to select a range
f853803
multi-modal batch VAE
WinterHannah May 13, 2025
d49fdea
tests
WinterHannah May 13, 2025
fa358c5
batch encoding dual mod VAE
WinterHannah May 23, 2025
ab7de85
Merge pull request #9 from quadbio/main
Marius1311 May 23, 2025
ab0e4b4
Re-enable some workflows
Marius1311 May 23, 2025
5407a45
Merge pull request #10 from quadbio/chore/gha-spatial-branch-updates
Marius1311 May 23, 2025
92a7975
ci: comment out schedule trigger in test_linux workflow (resource sav…
Marius1311 May 23, 2025
b10b081
Set up simple linting workflow
Marius1311 May 23, 2025
23282c9
Fix linter
Marius1311 May 23, 2025
ec9b293
Modify the linux test workflow, depend on label
Marius1311 May 23, 2025
9b53764
Merge pull request #11 from quadbio/chore/ci-disable-schedule-test-linux
Marius1311 May 23, 2025
42fb0e3
Merge branch 'feature/spatial' into 2-basic-batch-encoding-like-in-glue
Marius1311 May 23, 2025
c94b551
updated tests
WinterHannah May 23, 2025
4812cb7
Merge branch '2-basic-batch-encoding-like-in-glue' of https://github.…
WinterHannah May 23, 2025
4526520
basemodelclass change
WinterHannah May 23, 2025
ef923d1
guidance graph and tensorboard
WinterHannah May 26, 2025
5f3d226
add init file
WinterHannah May 26, 2025
bcb085e
logging update
WinterHannah May 27, 2025
8d7dcc1
incorporated comments
WinterHannah May 27, 2025
043f445
different parameters for test
WinterHannah May 27, 2025
9875704
Merge pull request #8 from quadbio/2-basic-batch-encoding-like-in-glue
WinterHannah May 27, 2025
c44b70a
guidance graph
WinterHannah May 28, 2025
a693020
Merge branch '3-link-decoders-with-guidance-graph' into 3-link-decode…
WinterHannah May 28, 2025
9a06794
Merge pull request #12 from quadbio/3-link-decoders-with-guidance-gra…
WinterHannah May 28, 2025
a4274aa
ruff formatting
WinterHannah May 28, 2025
7820ea6
Update test_linux_custom_dataloader.yml
WinterHannah May 28, 2025
fe9bfaa
Merge pull request #14 from quadbio/feature/spatial
WinterHannah May 28, 2025
33c9b2e
update workflow env
WinterHannah May 28, 2025
194eea1
fixed tests
WinterHannah May 28, 2025
0d0dc84
improved edge sampling
WinterHannah May 31, 2025
89fbf30
standardized logging
WinterHannah Jun 2, 2025
ba3f110
comments on PR
WinterHannah Jun 3, 2025
27e69c9
function to impute values
WinterHannah Jun 4, 2025
4cdb7c6
graph consistency checks
WinterHannah Jun 4, 2025
8650789
clean up unnecessary variables
WinterHannah Jun 4, 2025
669ce03
custom target library/batch
WinterHannah Jun 5, 2025
895af16
input dictionary instead of indices
WinterHannah Jun 6, 2025
42385e4
delete comments
WinterHannah Jun 6, 2025
bb5f861
tests
WinterHannah Jun 6, 2025
ff5e0c0
logging
WinterHannah Jun 9, 2025
229264a
Merge pull request #15 from quadbio/3-link-decoders-with-guidance-graph
WinterHannah Jun 9, 2025
06aa591
unbalanced OT
WinterHannah Jun 11, 2025
1fcca64
geomloss sinkhorn
WinterHannah Jun 16, 2025
2fce25a
val step update
WinterHannah Jun 17, 2025
2d6f6c2
automatic optimization and epoch heuristic
WinterHannah Jun 18, 2025
fcff3a9
quick fix
WinterHannah Jun 18, 2025
ce5b8a8
log name fix
WinterHannah Jun 18, 2025
115e8d3
logging per epoch
WinterHannah Jun 19, 2025
2c746fb
save/load model
WinterHannah Jun 25, 2025
230b991
generalized load/save
WinterHannah Jun 25, 2025
878a91a
deleted comments
WinterHannah Jun 25, 2025
96d9f06
losses
WinterHannah Jun 27, 2025
3b030ba
gmm
WinterHannah Jun 28, 2025
f2f13d7
save load test
WinterHannah Jun 30, 2025
4f635b4
removed unneded function
WinterHannah Jun 30, 2025
46ed0ba
Merge pull request #23 from quadbio/22-uot-for-modality-mixing
WinterHannah Jul 1, 2025
664a28c
gmm
WinterHannah Jul 1, 2025
04d2938
Merge pull request #24 from quadbio/feature/spatial
WinterHannah Jul 1, 2025
e86e6f9
classifier
WinterHannah Jul 1, 2025
7b7f2d6
class loss param
WinterHannah Jul 2, 2025
41a40a0
gmm fix
WinterHannah Jul 8, 2025
11bf7bf
src/scvi/external/spaglue/_base_components.py
WinterHannah Jul 11, 2025
7f5abb0
normal likelihood
WinterHannah Jul 15, 2025
d89e5c0
removed comments
WinterHannah Jul 15, 2025
5f47710
Merge pull request #25 from quadbio/22-uot-for-modality-mixing-mixtur…
WinterHannah Jul 16, 2025
bdc3ae6
feature imputation score
WinterHannah Jul 17, 2025
710722d
extended guidance graph
WinterHannah Jul 17, 2025
4a142e3
ft embedding corr
WinterHannah Jul 31, 2025
d1d8584
uot loss annealing
WinterHannah Aug 3, 2025
9e4ccd9
confidence score tests
WinterHannah Aug 4, 2025
9795bd5
Merge pull request #27 from quadbio/26-provide-scoring-for-feature-im…
WinterHannah Aug 4, 2025
0c27cc2
Merge branch 'feature/spatial' into main
WinterHannah Aug 4, 2025
d823575
Merge pull request #28 from quadbio/main
WinterHannah Aug 4, 2025
56431b0
diagvi and nb mixture
WinterHannah Aug 10, 2025
5ed2763
renamed folder
WinterHannah Aug 10, 2025
dd6fee6
docstrings
WinterHannah Aug 10, 2025
dce3aff
loss fix
WinterHannah Aug 12, 2025
7042c37
shuffle batches
WinterHannah Aug 12, 2025
6b88ad4
loss annealing
WinterHannah Aug 12, 2025
ac9b7bf
protein decoder
WinterHannah Aug 12, 2025
cd44f32
loss ifx
WinterHannah Aug 12, 2025
65aeaec
shuffling
WinterHannah Aug 12, 2025
71e0264
multi-gpu-test
WinterHannah Aug 15, 2025
650caf4
optional common scale
WinterHannah Aug 29, 2025
e8dd964
fix sinkhorn loss
WinterHannah Aug 29, 2025
0057d56
remove print
WinterHannah Sep 3, 2025
226ca83
updated dependencies
WinterHannah Sep 3, 2025
218cf0b
add muon to dependencies
WinterHannah Sep 3, 2025
7f6098f
decorator and readme
WinterHannah Sep 3, 2025
394f021
tutorial
WinterHannah Sep 3, 2025
b63aa50
tutorial update
WinterHannah Sep 4, 2025
fc60e22
guidance graph utility
WinterHannah Sep 4, 2025
bea713b
glue decoder
WinterHannah Sep 12, 2025
2c38f87
protein decoder glue
WinterHannah Sep 13, 2025
ec0f159
remove prints
WinterHannah Sep 13, 2025
3ac6b30
delete adv task
WinterHannah Sep 14, 2025
4c41ab3
comments
WinterHannah Sep 21, 2025
0ab97ce
Merge pull request #30 from quadbio/29-totalvi-decoder
WinterHannah Sep 26, 2025
e8b51ef
Merge remote-tracking branch 'scvi-tools/main' into feature/spatial
ori-kron-wis Oct 21, 2025
9e73cc1
merge with main of scvi-tools
ori-kron-wis Oct 21, 2025
c4b9ee0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 21, 2025
e736702
updates
ori-kron-wis Oct 21, 2025
76e1d91
temp fix pyproj for diagvi
ori-kron-wis Oct 21, 2025
84f0cd7
test fix
ori-kron-wis Oct 21, 2025
7c9f420
Merge remote-tracking branch 'scvi-tools/main' into feature/spatial
ori-kron-wis Oct 30, 2025
79a4576
merge with main
ori-kron-wis Oct 30, 2025
1e11cfe
Merge branch 'main' into feature/spatial
ori-kron-wis Nov 11, 2025
0c6bb71
Merge branch 'main' into feature/spatial
ori-kron-wis Nov 24, 2025
d9a1700
Merge branch 'main' into feature/spatial
ori-kron-wis Dec 10, 2025
8adfc42
classification parameter
WinterHannah Dec 28, 2025
e69b4b0
merge
WinterHannah Dec 28, 2025
29bc4fa
Merge remote-tracking branch 'scvi-tools/main' into feature/spatial
ori-kron-wis Dec 29, 2025
e2f5dee
merge with main, fix precommits
ori-kron-wis Dec 29, 2025
46bc1ed
documentation and cleanup
WinterHannah Jan 24, 2026
4d19cae
full test coverage
WinterHannah Jan 24, 2026
2937f43
removed unused arguments
WinterHannah Jan 26, 2026
8123db9
make imputation deterministic
WinterHannah Jan 26, 2026
100c652
remove confidence score for now
WinterHannah Jan 26, 2026
0939d41
fix tests
WinterHannah Jan 26, 2026
012343e
Merge pull request #2 from WinterHannah/diagvi-cleanup
WinterHannah Jan 26, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ to [Semantic Versioning]. The full commit history is available in the [commit lo

#### Added

- Add {class}`scvi.external.DIAGVI` for integrating spatial and dissociated single-cell datasets {pr}`3575`.
- Add MLFlow support, {pr}`3573`.
- Add support for MuData during Ray autotune {pr}`3545`.
- Add {meth}`~scvi.external.TorchMRVI.get_normalized_expression`
Expand Down
8 changes: 8 additions & 0 deletions docs/tutorials/index_spatial.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
notebooks/spatial/resolVI_tutorial
notebooks/spatial/scVIVA_tutorial
notebooks/spatial/DestVI_tutorial
notebooks/spatial/DiagVI_tutorial
notebooks/spatial/gimvi_tutorial
notebooks/spatial/tangram_scvi_tools
notebooks/spatial/stereoscope_heart_LV_tutorial
Expand Down Expand Up @@ -33,6 +34,13 @@ Stratify spatial transcriptomics data into niche-aware cell states with scVIVA
Perform multi-resolution analysis on spatial transcriptomics data with DestVI
```

```{customcard}
:path: notebooks/spatial/DiagVI_tutorial
:tags: Analysis, Integration, Modality-imputation, Differential-comparison

Perform integration of spatial and dissociated single-cell multi omics dsata with DiagVI
```

```{customcard}
:path: notebooks/spatial/gimvi_tutorial
:tags: Modality-imputation, Integration
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,13 @@ interpretability = ["captum", "shap", "decoupler"]
jax = ["jax", "jaxlib", "optax", "numpyro", "flax"]
# for custom dataloders
dataloaders = ["lamindb>=1.12.1", "cellxgene-census", "tiledbsoma", "tiledbsoma_ml", "torchdata"]
# for diagvi
diagvi = ["torch_geometric", "geomloss"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need geomloss? It doesn't seem to be actively maintained.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Can, at the moment we rely on geomloss to compute the unbalanced optimal transport loss component. However, we are working on another approach based on POT, which is more actively maintained.

# for mlflow
mlflow = ["mlflow","psutil","GPUtil"]

optional = [
"scvi-tools[autotune,mlflow,hub,jax,file_sharing,regseq,parallel,interpretability]",
"scvi-tools[autotune,mlflow,hub,jax,file_sharing,regseq,parallel,interpretability,diagvi]",
"igraph","leidenalg","pynndescent",
]
tutorials = [
Expand Down
2 changes: 2 additions & 0 deletions src/scvi/external/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .contrastivevi import ContrastiveVI
from .cytovi import CYTOVI
from .decipher import Decipher
from .diagvi import DIAGVI
from .gimvi import GIMVI
from .methylvi import METHYLANVI, METHYLVI
from .mrvi import MRVI
Expand Down Expand Up @@ -43,6 +44,7 @@
"RESOLVI",
"SCVIVA",
"CYTOVI",
"DIAGVI",
]


Expand Down
7 changes: 7 additions & 0 deletions src/scvi/external/diagvi/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""DIAGVI model for multi-modal integration with guidance graphs."""

from ._model import DIAGVI
from ._module import DIAGVAE
from ._task import DiagTrainingPlan

__all__ = ["DIAGVI", "DIAGVAE", "DiagTrainingPlan"]
234 changes: 234 additions & 0 deletions src/scvi/external/diagvi/_base_components.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
"""Base neural network components for DIAGVI model."""

from __future__ import annotations

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal

from scvi.utils import dependencies

EPS = 1e-8


class DecoderRNA(nn.Module):
"""Decoder for RNA modality using feature embeddings.

Decodes latent representations to RNA expression using batch-specific
scale and bias parameters combined with feature embeddings from the
guidance graph.

Parameters
----------
n_output
Number of output features (genes).
n_batches
Number of batches in the data.
"""

def __init__(
self,
n_output: int,
n_batches: int,
):
super().__init__()
self.n_output = n_output
self.n_batches = n_batches

self.scale_lin = nn.Parameter(torch.zeros(n_batches, n_output))
self.bias = nn.Parameter(torch.zeros(n_batches, n_output))
self.log_theta = nn.Parameter(torch.zeros(n_batches, n_output))

self.px_dropout_param = nn.Parameter(torch.randn(n_output) * 0.01)

def forward(
self,
u: torch.Tensor,
l: torch.Tensor,
batch_index: torch.Tensor,
v: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Decode latent representation to RNA expression parameters.

Parameters
----------
u
Latent representation tensor of shape (n_cells, n_latent).
l
Log library size tensor of shape (n_cells, 1).
batch_index
Batch indices tensor of shape (n_cells,) or (n_cells, 1).
v
Feature embedding tensor from graph encoder.

Returns
-------
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
Tuple of (px_scale, px_r, px_rate, px_dropout) for negative binomial
distribution parameters.
"""
if batch_index.dim() > 1:
batch_index = batch_index.squeeze(-1)

if (batch_index.max() >= self.bias.shape[0]) or (batch_index.min() < 0):
raise IndexError(
f"Batch index out of bounds: valid range is [0, {self.bias.shape[0] - 1}]"
)

scale = F.softplus(self.scale_lin[batch_index])
bias = self.bias[batch_index]
log_theta = self.log_theta[batch_index]

raw_px_scale = scale * (u @ v.T) + bias
px_scale = torch.softmax(raw_px_scale, dim=-1)
px_rate = torch.exp(l) * px_scale

px_dropout = F.softplus(self.px_dropout_param)

px_r = log_theta

return px_scale, px_r, px_rate, px_dropout


class DecoderProteinGLUE(nn.Module):
"""Decoder for protein modality using GLUE-style mixture model.

Decodes latent representations to protein expression using a mixture
of two components with batch-specific parameters.

Parameters
----------
n_output
Number of output features (proteins).
n_batches
Number of batches in the data.
"""

def __init__(
self,
n_output: int,
n_batches: int,
):
super().__init__()
self.n_output = n_output
self.n_batches = n_batches

self.scale_lin = nn.Parameter(torch.zeros(n_batches, n_output))

self.bias1 = nn.Parameter(torch.zeros(n_batches, n_output))
self.bias2 = nn.Parameter(torch.zeros(n_batches, n_output))

self.log_theta = nn.Parameter(torch.zeros(n_batches, n_output))

def forward(
self,
u: torch.Tensor,
l: torch.Tensor,
batch_index: torch.Tensor,
v: torch.Tensor,
) -> tuple[
tuple[torch.Tensor, torch.Tensor],
torch.Tensor,
tuple[torch.Tensor, torch.Tensor],
torch.Tensor,
]:
"""Decode latent representation to protein expression parameters.

Parameters
----------
u
Latent representation tensor of shape (n_cells, n_latent).
l
Log library size tensor of shape (n_cells, 1).
batch_index
Batch indices tensor of shape (n_cells,) or (n_cells, 1).
v
Feature embedding tensor from graph encoder.

Returns
-------
tuple
Tuple of ((px_scale_1, px_scale_2), px_r, (px_rate_1, px_rate_2),
mixture_logits) for negative binomial mixture distribution.
"""
if batch_index.dim() > 1:
batch_index = batch_index.squeeze(-1)

if (batch_index.max() >= self.bias1.shape[0]) or (batch_index.min() < 0):
raise IndexError(
f"Batch index out of bounds: valid range is [0, {self.bias1.shape[0] - 1}]"
)

scale = F.softplus(self.scale_lin[batch_index])

bias1 = self.bias1[batch_index]
bias2 = self.bias2[batch_index]

log_theta = self.log_theta[batch_index]

raw_px_scale_1 = scale * (u @ v.T) + bias1
raw_px_scale_2 = scale * (u @ v.T) + bias2

px_scale_1 = torch.softmax(raw_px_scale_1, dim=-1)
px_scale_2 = torch.softmax(raw_px_scale_2, dim=-1)

px_rate_1 = torch.exp(l) * px_scale_1
px_rate_2 = torch.exp(l) * px_scale_2

mixture_logits = raw_px_scale_1 - raw_px_scale_2

px_r = log_theta

return (px_scale_1, px_scale_2), px_r, (px_rate_1, px_rate_2), mixture_logits


class GraphEncoder(nn.Module):
"""Graph convolutional encoder for feature embeddings.

Encodes feature nodes in the guidance graph to learn feature
embeddings that capture cross-modality relationships.

Parameters
----------
vnum
Number of nodes (features) in the graph.
out_features
Dimensionality of the output feature embeddings.
"""

@dependencies("torch_geometric")
def __init__(self, vnum: int, out_features: int):
import torch_geometric

super().__init__()
self.vrepr = nn.Parameter(torch.zeros(vnum, out_features))
self.conv = torch_geometric.nn.GCNConv(out_features, out_features)
self.loc = nn.Linear(out_features, out_features)
self.std_lin = nn.Linear(out_features, out_features)

def forward(self, edge_index: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Encode graph to feature embeddings.

Parameters
----------
edge_index
Edge index tensor of shape (2, n_edges).

Returns
-------
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
Tuple of (z, mu, logvar) where z is the sampled embedding,
mu is the mean, and logvar is the log variance.
"""
h = self.conv(self.vrepr, edge_index)
loc = self.loc(h)
std = F.softplus(self.std_lin(h)) + EPS

dist = Normal(loc, std)
mu = dist.loc
std = dist.scale
logvar = torch.log(std**2)
z = dist.rsample()

return z, mu, logvar
Loading