-
Notifications
You must be signed in to change notification settings - Fork 437
Model: DiagVI (new) #3575
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
ori-kron-wis
wants to merge
120
commits into
scverse:main
Choose a base branch
from
quadbio:feature/spatial
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Model: DiagVI (new) #3575
Changes from all commits
Commits
Show all changes
120 commits
Select commit
Hold shift + click to select a range
f853803
multi-modal batch VAE
WinterHannah d49fdea
tests
WinterHannah fa358c5
batch encoding dual mod VAE
WinterHannah ab7de85
Merge pull request #9 from quadbio/main
Marius1311 ab0e4b4
Re-enable some workflows
Marius1311 5407a45
Merge pull request #10 from quadbio/chore/gha-spatial-branch-updates
Marius1311 92a7975
ci: comment out schedule trigger in test_linux workflow (resource sav…
Marius1311 b10b081
Set up simple linting workflow
Marius1311 23282c9
Fix linter
Marius1311 ec9b293
Modify the linux test workflow, depend on label
Marius1311 9b53764
Merge pull request #11 from quadbio/chore/ci-disable-schedule-test-linux
Marius1311 42fb0e3
Merge branch 'feature/spatial' into 2-basic-batch-encoding-like-in-glue
Marius1311 c94b551
updated tests
WinterHannah 4812cb7
Merge branch '2-basic-batch-encoding-like-in-glue' of https://github.…
WinterHannah 4526520
basemodelclass change
WinterHannah ef923d1
guidance graph and tensorboard
WinterHannah 5f3d226
add init file
WinterHannah bcb085e
logging update
WinterHannah 8d7dcc1
incorporated comments
WinterHannah 043f445
different parameters for test
WinterHannah 9875704
Merge pull request #8 from quadbio/2-basic-batch-encoding-like-in-glue
WinterHannah c44b70a
guidance graph
WinterHannah a693020
Merge branch '3-link-decoders-with-guidance-graph' into 3-link-decode…
WinterHannah 9a06794
Merge pull request #12 from quadbio/3-link-decoders-with-guidance-gra…
WinterHannah a4274aa
ruff formatting
WinterHannah 7820ea6
Update test_linux_custom_dataloader.yml
WinterHannah fe9bfaa
Merge pull request #14 from quadbio/feature/spatial
WinterHannah 33c9b2e
update workflow env
WinterHannah 194eea1
fixed tests
WinterHannah 0d0dc84
improved edge sampling
WinterHannah 89fbf30
standardized logging
WinterHannah ba3f110
comments on PR
WinterHannah 27e69c9
function to impute values
WinterHannah 4cdb7c6
graph consistency checks
WinterHannah 8650789
clean up unnecessary variables
WinterHannah 669ce03
custom target library/batch
WinterHannah 895af16
input dictionary instead of indices
WinterHannah 42385e4
delete comments
WinterHannah bb5f861
tests
WinterHannah ff5e0c0
logging
WinterHannah 229264a
Merge pull request #15 from quadbio/3-link-decoders-with-guidance-graph
WinterHannah 06aa591
unbalanced OT
WinterHannah 1fcca64
geomloss sinkhorn
WinterHannah 2fce25a
val step update
WinterHannah 2d6f6c2
automatic optimization and epoch heuristic
WinterHannah fcff3a9
quick fix
WinterHannah ce5b8a8
log name fix
WinterHannah 115e8d3
logging per epoch
WinterHannah 2c746fb
save/load model
WinterHannah 230b991
generalized load/save
WinterHannah 878a91a
deleted comments
WinterHannah 96d9f06
losses
WinterHannah 3b030ba
gmm
WinterHannah f2f13d7
save load test
WinterHannah 4f635b4
removed unneded function
WinterHannah 46ed0ba
Merge pull request #23 from quadbio/22-uot-for-modality-mixing
WinterHannah 664a28c
gmm
WinterHannah 04d2938
Merge pull request #24 from quadbio/feature/spatial
WinterHannah e86e6f9
classifier
WinterHannah 7b7f2d6
class loss param
WinterHannah 41a40a0
gmm fix
WinterHannah 11bf7bf
src/scvi/external/spaglue/_base_components.py
WinterHannah 7f5abb0
normal likelihood
WinterHannah d89e5c0
removed comments
WinterHannah 5f47710
Merge pull request #25 from quadbio/22-uot-for-modality-mixing-mixtur…
WinterHannah bdc3ae6
feature imputation score
WinterHannah 710722d
extended guidance graph
WinterHannah 4a142e3
ft embedding corr
WinterHannah d1d8584
uot loss annealing
WinterHannah 9e4ccd9
confidence score tests
WinterHannah 9795bd5
Merge pull request #27 from quadbio/26-provide-scoring-for-feature-im…
WinterHannah 0c27cc2
Merge branch 'feature/spatial' into main
WinterHannah d823575
Merge pull request #28 from quadbio/main
WinterHannah 56431b0
diagvi and nb mixture
WinterHannah 5ed2763
renamed folder
WinterHannah dd6fee6
docstrings
WinterHannah dce3aff
loss fix
WinterHannah 7042c37
shuffle batches
WinterHannah 6b88ad4
loss annealing
WinterHannah ac9b7bf
protein decoder
WinterHannah cd44f32
loss ifx
WinterHannah 65aeaec
shuffling
WinterHannah 71e0264
multi-gpu-test
WinterHannah 650caf4
optional common scale
WinterHannah e8dd964
fix sinkhorn loss
WinterHannah 0057d56
remove print
WinterHannah 226ca83
updated dependencies
WinterHannah 218cf0b
add muon to dependencies
WinterHannah 7f6098f
decorator and readme
WinterHannah 394f021
tutorial
WinterHannah b63aa50
tutorial update
WinterHannah fc60e22
guidance graph utility
WinterHannah bea713b
glue decoder
WinterHannah 2c38f87
protein decoder glue
WinterHannah ec0f159
remove prints
WinterHannah 3ac6b30
delete adv task
WinterHannah 4c41ab3
comments
WinterHannah 0ab97ce
Merge pull request #30 from quadbio/29-totalvi-decoder
WinterHannah e8b51ef
Merge remote-tracking branch 'scvi-tools/main' into feature/spatial
ori-kron-wis 9e73cc1
merge with main of scvi-tools
ori-kron-wis c4b9ee0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] e736702
updates
ori-kron-wis 76e1d91
temp fix pyproj for diagvi
ori-kron-wis 84f0cd7
test fix
ori-kron-wis 7c9f420
Merge remote-tracking branch 'scvi-tools/main' into feature/spatial
ori-kron-wis 79a4576
merge with main
ori-kron-wis 1e11cfe
Merge branch 'main' into feature/spatial
ori-kron-wis 0c6bb71
Merge branch 'main' into feature/spatial
ori-kron-wis d9a1700
Merge branch 'main' into feature/spatial
ori-kron-wis 8adfc42
classification parameter
WinterHannah e69b4b0
merge
WinterHannah 29bc4fa
Merge remote-tracking branch 'scvi-tools/main' into feature/spatial
ori-kron-wis e2f5dee
merge with main, fix precommits
ori-kron-wis 46bc1ed
documentation and cleanup
WinterHannah 4d19cae
full test coverage
WinterHannah 2937f43
removed unused arguments
WinterHannah 8123db9
make imputation deterministic
WinterHannah 100c652
remove confidence score for now
WinterHannah 0939d41
fix tests
WinterHannah 012343e
Merge pull request #2 from WinterHannah/diagvi-cleanup
WinterHannah File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| """DIAGVI model for multi-modal integration with guidance graphs.""" | ||
|
|
||
| from ._model import DIAGVI | ||
| from ._module import DIAGVAE | ||
| from ._task import DiagTrainingPlan | ||
|
|
||
| __all__ = ["DIAGVI", "DIAGVAE", "DiagTrainingPlan"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,234 @@ | ||
| """Base neural network components for DIAGVI model.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| import torch.nn.functional as F | ||
| from torch.distributions import Normal | ||
|
|
||
| from scvi.utils import dependencies | ||
|
|
||
| EPS = 1e-8 | ||
|
|
||
|
|
||
| class DecoderRNA(nn.Module): | ||
| """Decoder for RNA modality using feature embeddings. | ||
|
|
||
| Decodes latent representations to RNA expression using batch-specific | ||
| scale and bias parameters combined with feature embeddings from the | ||
| guidance graph. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| n_output | ||
| Number of output features (genes). | ||
| n_batches | ||
| Number of batches in the data. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| n_output: int, | ||
| n_batches: int, | ||
| ): | ||
| super().__init__() | ||
| self.n_output = n_output | ||
| self.n_batches = n_batches | ||
|
|
||
| self.scale_lin = nn.Parameter(torch.zeros(n_batches, n_output)) | ||
| self.bias = nn.Parameter(torch.zeros(n_batches, n_output)) | ||
| self.log_theta = nn.Parameter(torch.zeros(n_batches, n_output)) | ||
|
|
||
| self.px_dropout_param = nn.Parameter(torch.randn(n_output) * 0.01) | ||
|
|
||
| def forward( | ||
| self, | ||
| u: torch.Tensor, | ||
| l: torch.Tensor, | ||
| batch_index: torch.Tensor, | ||
| v: torch.Tensor, | ||
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | ||
| """Decode latent representation to RNA expression parameters. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| u | ||
| Latent representation tensor of shape (n_cells, n_latent). | ||
| l | ||
| Log library size tensor of shape (n_cells, 1). | ||
| batch_index | ||
| Batch indices tensor of shape (n_cells,) or (n_cells, 1). | ||
| v | ||
| Feature embedding tensor from graph encoder. | ||
|
|
||
| Returns | ||
| ------- | ||
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] | ||
| Tuple of (px_scale, px_r, px_rate, px_dropout) for negative binomial | ||
| distribution parameters. | ||
| """ | ||
| if batch_index.dim() > 1: | ||
| batch_index = batch_index.squeeze(-1) | ||
|
|
||
| if (batch_index.max() >= self.bias.shape[0]) or (batch_index.min() < 0): | ||
| raise IndexError( | ||
| f"Batch index out of bounds: valid range is [0, {self.bias.shape[0] - 1}]" | ||
| ) | ||
|
|
||
| scale = F.softplus(self.scale_lin[batch_index]) | ||
| bias = self.bias[batch_index] | ||
| log_theta = self.log_theta[batch_index] | ||
|
|
||
| raw_px_scale = scale * (u @ v.T) + bias | ||
| px_scale = torch.softmax(raw_px_scale, dim=-1) | ||
| px_rate = torch.exp(l) * px_scale | ||
|
|
||
| px_dropout = F.softplus(self.px_dropout_param) | ||
|
|
||
| px_r = log_theta | ||
|
|
||
| return px_scale, px_r, px_rate, px_dropout | ||
|
|
||
|
|
||
| class DecoderProteinGLUE(nn.Module): | ||
| """Decoder for protein modality using GLUE-style mixture model. | ||
|
|
||
| Decodes latent representations to protein expression using a mixture | ||
| of two components with batch-specific parameters. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| n_output | ||
| Number of output features (proteins). | ||
| n_batches | ||
| Number of batches in the data. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| n_output: int, | ||
| n_batches: int, | ||
| ): | ||
| super().__init__() | ||
| self.n_output = n_output | ||
| self.n_batches = n_batches | ||
|
|
||
| self.scale_lin = nn.Parameter(torch.zeros(n_batches, n_output)) | ||
|
|
||
| self.bias1 = nn.Parameter(torch.zeros(n_batches, n_output)) | ||
| self.bias2 = nn.Parameter(torch.zeros(n_batches, n_output)) | ||
|
|
||
| self.log_theta = nn.Parameter(torch.zeros(n_batches, n_output)) | ||
|
|
||
| def forward( | ||
| self, | ||
| u: torch.Tensor, | ||
| l: torch.Tensor, | ||
| batch_index: torch.Tensor, | ||
| v: torch.Tensor, | ||
| ) -> tuple[ | ||
| tuple[torch.Tensor, torch.Tensor], | ||
| torch.Tensor, | ||
| tuple[torch.Tensor, torch.Tensor], | ||
| torch.Tensor, | ||
| ]: | ||
| """Decode latent representation to protein expression parameters. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| u | ||
| Latent representation tensor of shape (n_cells, n_latent). | ||
| l | ||
| Log library size tensor of shape (n_cells, 1). | ||
| batch_index | ||
| Batch indices tensor of shape (n_cells,) or (n_cells, 1). | ||
| v | ||
| Feature embedding tensor from graph encoder. | ||
|
|
||
| Returns | ||
| ------- | ||
| tuple | ||
| Tuple of ((px_scale_1, px_scale_2), px_r, (px_rate_1, px_rate_2), | ||
| mixture_logits) for negative binomial mixture distribution. | ||
| """ | ||
| if batch_index.dim() > 1: | ||
| batch_index = batch_index.squeeze(-1) | ||
|
|
||
| if (batch_index.max() >= self.bias1.shape[0]) or (batch_index.min() < 0): | ||
| raise IndexError( | ||
| f"Batch index out of bounds: valid range is [0, {self.bias1.shape[0] - 1}]" | ||
| ) | ||
|
|
||
| scale = F.softplus(self.scale_lin[batch_index]) | ||
|
|
||
| bias1 = self.bias1[batch_index] | ||
| bias2 = self.bias2[batch_index] | ||
|
|
||
| log_theta = self.log_theta[batch_index] | ||
|
|
||
| raw_px_scale_1 = scale * (u @ v.T) + bias1 | ||
| raw_px_scale_2 = scale * (u @ v.T) + bias2 | ||
|
|
||
| px_scale_1 = torch.softmax(raw_px_scale_1, dim=-1) | ||
| px_scale_2 = torch.softmax(raw_px_scale_2, dim=-1) | ||
|
|
||
| px_rate_1 = torch.exp(l) * px_scale_1 | ||
| px_rate_2 = torch.exp(l) * px_scale_2 | ||
|
|
||
| mixture_logits = raw_px_scale_1 - raw_px_scale_2 | ||
|
|
||
| px_r = log_theta | ||
|
|
||
| return (px_scale_1, px_scale_2), px_r, (px_rate_1, px_rate_2), mixture_logits | ||
|
|
||
|
|
||
| class GraphEncoder(nn.Module): | ||
| """Graph convolutional encoder for feature embeddings. | ||
|
|
||
| Encodes feature nodes in the guidance graph to learn feature | ||
| embeddings that capture cross-modality relationships. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| vnum | ||
| Number of nodes (features) in the graph. | ||
| out_features | ||
| Dimensionality of the output feature embeddings. | ||
| """ | ||
|
|
||
| @dependencies("torch_geometric") | ||
| def __init__(self, vnum: int, out_features: int): | ||
| import torch_geometric | ||
|
|
||
| super().__init__() | ||
| self.vrepr = nn.Parameter(torch.zeros(vnum, out_features)) | ||
| self.conv = torch_geometric.nn.GCNConv(out_features, out_features) | ||
| self.loc = nn.Linear(out_features, out_features) | ||
| self.std_lin = nn.Linear(out_features, out_features) | ||
|
|
||
| def forward(self, edge_index: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
| """Encode graph to feature embeddings. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| edge_index | ||
| Edge index tensor of shape (2, n_edges). | ||
|
|
||
| Returns | ||
| ------- | ||
| tuple[torch.Tensor, torch.Tensor, torch.Tensor] | ||
| Tuple of (z, mu, logvar) where z is the sampled embedding, | ||
| mu is the mean, and logvar is the log variance. | ||
| """ | ||
| h = self.conv(self.vrepr, edge_index) | ||
| loc = self.loc(h) | ||
| std = F.softplus(self.std_lin(h)) + EPS | ||
|
|
||
| dist = Normal(loc, std) | ||
| mu = dist.loc | ||
| std = dist.scale | ||
| logvar = torch.log(std**2) | ||
| z = dist.rsample() | ||
|
|
||
| return z, mu, logvar |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need geomloss? It doesn't seem to be actively maintained.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Can, at the moment we rely on geomloss to compute the unbalanced optimal transport loss component. However, we are working on another approach based on POT, which is more actively maintained.