Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 30 additions & 40 deletions rfdiffusion/Attention_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,20 +60,14 @@ def forward(self, query, key, value):
B, Q = query.shape[:2]
B, K = key.shape[:2]
#
query = self.to_q(query).reshape(B, Q, self.h, self.dim)
key = self.to_k(key).reshape(B, K, self.h, self.dim)
value = self.to_v(value).reshape(B, K, self.h, self.dim)
#
query = query * self.scaling
attn = einsum('bqhd,bkhd->bhqk', query, key)
attn = F.softmax(attn, dim=-1)
#
out = einsum('bhqk,bkhd->bqhd', attn, value)
out = out.reshape(B, Q, self.h*self.dim)
#
out = self.to_out(out)

return out
# (B, seq, h, d) -> (B, h, seq, d) for scaled_dot_product_attention
query = self.to_q(query).reshape(B, Q, self.h, self.dim).transpose(1, 2)
key = self.to_k(key ).reshape(B, K, self.h, self.dim).transpose(1, 2)
value = self.to_v(value).reshape(B, K, self.h, self.dim).transpose(1, 2)
# scaling and softmax handled internally; uses Flash Attention when available
out = F.scaled_dot_product_attention(query, key, value) # (B, h, Q, d)
out = out.transpose(1, 2).reshape(B, Q, self.h * self.dim)
return self.to_out(out)

class AttentionWithBias(nn.Module):
def __init__(self, d_in=256, d_bias=128, n_head=8, d_hidden=32):
Expand Down Expand Up @@ -117,22 +111,17 @@ def forward(self, x, bias):
x = self.norm_in(x)
bias = self.norm_bias(bias)
#
query = self.to_q(x).reshape(B, L, self.h, self.dim)
key = self.to_k(x).reshape(B, L, self.h, self.dim)
value = self.to_v(x).reshape(B, L, self.h, self.dim)
bias = self.to_b(bias) # (B, L, L, h)
gate = torch.sigmoid(self.to_g(x))
#
key = key * self.scaling
attn = einsum('bqhd,bkhd->bqkh', query, key)
attn = attn + bias
attn = F.softmax(attn, dim=-2)
#
out = einsum('bqkh,bkhd->bqhd', attn, value).reshape(B, L, -1)
# (B, L, h, d) -> (B, h, L, d); bias (B, L, L, h) -> (B, h, L, L)
query = self.to_q(x).reshape(B, L, self.h, self.dim).transpose(1, 2)
key = self.to_k(x).reshape(B, L, self.h, self.dim).transpose(1, 2)
value = self.to_v(x).reshape(B, L, self.h, self.dim).transpose(1, 2)
bias = self.to_b(bias).permute(0, 3, 1, 2) # (B, h, L, L)
gate = torch.sigmoid(self.to_g(x))
# bias added to logits before softmax; Flash Attention used when available
out = F.scaled_dot_product_attention(query, key, value, attn_mask=bias)
out = out.transpose(1, 2).reshape(B, L, -1) # (B, L, h*d)
out = gate * out
#
out = self.to_out(out)
return out
return self.to_out(out)

# MSA Attention (row/column) from AlphaFold architecture
class SequenceWeight(nn.Module):
Expand Down Expand Up @@ -265,19 +254,20 @@ def forward(self, msa):
msa = self.norm_msa(msa)
#
query = self.to_q(msa).reshape(B, N, L, self.h, self.dim)
key = self.to_k(msa).reshape(B, N, L, self.h, self.dim)
key = self.to_k(msa).reshape(B, N, L, self.h, self.dim)
value = self.to_v(msa).reshape(B, N, L, self.h, self.dim)
gate = torch.sigmoid(self.to_g(msa))
#
query = query * self.scaling
attn = einsum('bqihd,bkihd->bihqk', query, key)
attn = F.softmax(attn, dim=-1)
#
out = einsum('bihqk,bkihd->bqihd', attn, value).reshape(B, N, L, -1)
gate = torch.sigmoid(self.to_g(msa))
# Column attention: for each residue position, attend across N sequences.
# Reshape to (B*L, h, N, d) so scaled_dot_product_attention operates over N.
q = query.permute(0, 2, 3, 1, 4).reshape(B * L, self.h, N, self.dim)
k = key .permute(0, 2, 3, 1, 4).reshape(B * L, self.h, N, self.dim)
v = value.permute(0, 2, 3, 1, 4).reshape(B * L, self.h, N, self.dim)
out = F.scaled_dot_product_attention(q, k, v) # (B*L, h, N, d)
out = (out.reshape(B, L, self.h, N, self.dim)
.permute(0, 3, 1, 2, 4)
.reshape(B, N, L, -1))
out = gate * out
#
out = self.to_out(out)
return out
return self.to_out(out)

class MSAColGlobalAttention(nn.Module):
def __init__(self, d_msa=64, n_head=8, d_hidden=8):
Expand Down
100 changes: 59 additions & 41 deletions rfdiffusion/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,46 +3,56 @@
import pickle
import numpy as np
import os
import math
import logging

from scipy.spatial.transform import Rotation as scipy_R

from rfdiffusion.util import rigid_from_3_points

from rfdiffusion.util_module import ComputeAllAtomCoords

from rfdiffusion import igso3
import time

# Module-level cache so IGSO3 lookup tables survive across Diffuser instantiations
# (avoids redundant disk I/O when generating batches of designs).
_igso3_cache: dict = {}

torch.set_printoptions(sci_mode=False)


def get_beta_schedule(T, b0, bT, schedule_type, schedule_params={}, inference=False):
"""
Given a noise schedule type, create the beta schedule
"""
assert schedule_type in ["linear"]
Given a noise schedule type, create the beta schedule.

# Adjust b0 and bT if T is not 200
# This is a good approximation, with the beta correction below, unless T is very small
schedule_type options:
"linear" — Ho et al. (2020) linear schedule, scaled to T steps.
"cosine" — Nichol & Dhariwal (2021) cosine schedule; b0/bT ignored.
"""
assert schedule_type in ["linear", "cosine"], (
f"Unknown schedule type '{schedule_type}'. Choose 'linear' or 'cosine'."
)
assert T >= 15, "With discrete time and T < 15, the schedule is badly approximated"
b0 *= 200 / T
bT *= 200 / T

# linear noise schedule
if schedule_type == "linear":
# Scale endpoints to be equivalent to a 200-step schedule
b0 *= 200 / T
bT *= 200 / T
schedule = torch.linspace(b0, bT, T)

else:
raise NotImplementedError(f"Schedule of type {schedule_type} not implemented.")
elif schedule_type == "cosine":
# Cosine schedule from Nichol & Dhariwal (2021), Improved DDPM
s = schedule_params.get("s", 0.008)
steps = torch.arange(T + 1, dtype=torch.float64)
f = torch.cos((steps / T + s) / (1.0 + s) * math.pi / 2.0) ** 2
alphabar = (f / f[0]).float()
schedule = torch.clamp(1.0 - alphabar[1:] / alphabar[:-1], max=0.999)

# get alphabar_t for convenience
alpha_schedule = 1 - schedule
alpha_schedule = 1.0 - schedule
alphabar_t_schedule = torch.cumprod(alpha_schedule, dim=0)

if inference:
print(
f"With this beta schedule ({schedule_type} schedule, beta_0 = {round(b0, 3)}, beta_T = {round(bT,3)}), alpha_bar_T = {alphabar_t_schedule[-1]}"
f"Beta schedule: {schedule_type}, "
f"beta_0={schedule[0].item():.5f}, beta_T={schedule[-1].item():.5f}, "
f"alpha_bar_T={alphabar_t_schedule[-1].item():.5f}"
)

return schedule, alpha_schedule, alphabar_t_schedule
Expand Down Expand Up @@ -228,6 +238,10 @@ def _calc_igso3_vals(self, L=2000):
if not os.path.isdir(self.cache_dir):
os.makedirs(self.cache_dir)

if cache_fname in _igso3_cache:
self._log.info("Using in-memory IGSO3 cache.")
return _igso3_cache[cache_fname]

if os.path.exists(cache_fname):
self._log.info("Using cached IGSO3.")
igso3_vals = read_pkl(cache_fname)
Expand All @@ -241,6 +255,7 @@ def _calc_igso3_vals(self, L=2000):
)
write_pkl(cache_fname, igso3_vals)

_igso3_cache[cache_fname] = igso3_vals
return igso3_vals

@property
Expand Down Expand Up @@ -288,23 +303,29 @@ def sigma(self, t: torch.tensor):

def g(self, t):
"""
g returns the drift coefficient at time t
g returns the drift coefficient at time t.

since
sigma(t)^2 := \int_0^t g(s)^2 ds,
for arbitrary sigma(t) we invert this relationship to compute
g(t) = sqrt(d/dt sigma(t)^2).
g(t) = sqrt(d/dt sigma(t)^2)

Args:
t: scalar time between 0 and 1
For the linear schedule sigma(t) = min_sigma + t*min_b + 0.5*t^2*(max_b - min_b),
we derive analytically:
d/dt sigma(t)^2 = 2*sigma(t) * (min_b + t*(max_b - min_b))
which avoids a per-step autograd call.

Returns:
drift cooeficient as a scalar.
For the exponential schedule, autograd is still used as a fallback.
"""
t = torch.tensor(t, requires_grad=True)
sigma_sqr = self.sigma(t) ** 2
grads = torch.autograd.grad(sigma_sqr.sum(), t)[0]
return torch.sqrt(grads)
if not torch.is_tensor(t):
t = torch.tensor(t, dtype=torch.float32)

if self.schedule == "linear":
sigma_t = self.sigma(t)
dsigma_dt = self.min_b + t * (self.max_b - self.min_b)
return torch.sqrt(2.0 * sigma_t * dsigma_dt)
else:
t = t.requires_grad_(True)
sigma_sqr = self.sigma(t) ** 2
grads = torch.autograd.grad(sigma_sqr.sum(), t)[0]
return torch.sqrt(grads)

def sample(self, ts, n_samples=1):
"""
Expand Down Expand Up @@ -427,12 +448,9 @@ def diffuse_frames(self, xyz, t_list, diffusion_mask=None):
non_diffusion_mask = 1 - diffusion_mask[None, :, None]
sampled_rots = sampled_rots * non_diffusion_mask

# Apply sampled rot.
R_sampled = (
scipy_R.from_rotvec(sampled_rots.reshape(-1, 3))
.as_matrix()
.reshape(self.T, num_res, 3, 3)
)
# Apply sampled rot — torch-native Exp map avoids scipy/CPU roundtrip.
sampled_rots_t = torch.from_numpy(sampled_rots.reshape(-1, 3)).float()
R_sampled = igso3.Exp_torch(sampled_rots_t).numpy().reshape(self.T, num_res, 3, 3)
R_perturbed = np.einsum("tnij,njk->tnik", R_sampled, R_true)
perturbed_crds = (
np.einsum(
Expand Down Expand Up @@ -494,11 +512,10 @@ def reverse_sample_vectorized(
differential equations. arXiv preprint arXiv:2011.13456.
"""
# compute rotation vector corresponding to prediction of how r_t goes to r_0
R_0, R_t = torch.tensor(R_0), torch.tensor(R_t)
R_0, R_t = torch.as_tensor(R_0), torch.as_tensor(R_t)
R_0t = torch.einsum("...ij,...kj->...ik", R_t, R_0)
R_0t_rotvec = torch.tensor(
scipy_R.from_matrix(R_0t.cpu().numpy()).as_rotvec()
).to(R_0.device)
# torch-native Log map: stays on-device, no CPU/scipy roundtrip
R_0t_rotvec = igso3.Log_torch(R_0t).to(dtype=torch.float32, device=R_0.device)

# Approximate the score based on the prediction of R0.
# R_t @ hat(Score_approx) is the score approximation in the Lie algebra
Expand Down Expand Up @@ -527,7 +544,8 @@ def reverse_sample_vectorized(
Perturb_tangent = Delta_r + rot_g * np.sqrt(self.step_size) * Z
if mask is not None:
Perturb_tangent *= (1 - mask.long())[:, None, None]
Perturb = igso3.Exp(Perturb_tangent)
# torch-native Exp map: stays on-device, no scipy roundtrip
Perturb = igso3.Exp_torch(Perturb_tangent)

if return_perturb:
return Perturb
Expand Down
47 changes: 43 additions & 4 deletions rfdiffusion/igso3.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,53 @@ def hat(v):
hat_v[:, 0, 1], hat_v[:, 0, 2], hat_v[:, 1, 2] = -v[:, 2], v[:, 1], -v[:, 0]
return hat_v + -hat_v.transpose(2, 1)

# Logarithmic map from SO(3) to R^3 (i.e. rotation vector)
def hat_batch(v):
"""Batch hat map: [..., 3] -> [..., 3, 3] (cross-product / skew-symmetric matrix)."""
bshape = v.shape[:-1]
h = torch.zeros(*bshape, 3, 3, device=v.device, dtype=v.dtype)
h[..., 0, 1] = -v[..., 2]
h[..., 0, 2] = v[..., 1]
h[..., 1, 0] = v[..., 2]
h[..., 1, 2] = -v[..., 0]
h[..., 2, 0] = -v[..., 1]
h[..., 2, 1] = v[..., 0]
return h

def Log_torch(R):
"""On-device rotation matrix -> rotation vector. R: [..., 3, 3] -> [..., 3].
Stays on the original device/dtype — no scipy or CPU transfers."""
trace = R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2]
theta = torch.acos(torch.clamp((trace - 1.0) / 2.0, -1.0, 1.0))
skew = torch.stack([
R[..., 2, 1] - R[..., 1, 2],
R[..., 0, 2] - R[..., 2, 0],
R[..., 1, 0] - R[..., 0, 1],
], dim=-1)
sin_theta = torch.clamp(torch.sin(theta), min=1e-7)
axis = skew / (2.0 * sin_theta[..., None])
rotvec = axis * theta[..., None]
return torch.where(theta[..., None] < 1e-6, torch.zeros_like(rotvec), rotvec)

def Exp_torch(v):
"""On-device rotation vector -> rotation matrix. v: [..., 3] -> [..., 3, 3].
Rodrigues formula. Stays on the original device/dtype."""
theta = torch.norm(v, dim=-1)
theta_safe = torch.clamp(theta, min=1e-7)
axis = v / theta_safe[..., None]
K = hat_batch(axis)
I = torch.eye(3, device=v.device, dtype=v.dtype).expand(*v.shape[:-1], 3, 3)
sin_t = torch.sin(theta)[..., None, None]
cos_t = torch.cos(theta)[..., None, None]
R = I + sin_t * K + (1.0 - cos_t) * (K @ K)
return torch.where(theta[..., None, None] < 1e-7, I, R)

# Logarithmic map from SO(3) to R^3 (i.e. rotation vector) — legacy CPU version
def Log(R): return torch.tensor(Rotation.from_matrix(R.numpy()).as_rotvec())

# logarithmic map from SO(3) to so(3), this is the matrix logarithm
def log(R): return hat(Log(R))

# Exponential map from vector space of so(3) to SO(3), this is the matrix
# exponential combined with the "hat" map
# Exponential map from vector space of so(3) to SO(3) — legacy CPU version
def Exp(A): return torch.tensor(Rotation.from_rotvec(A.numpy()).as_matrix())

# Angle of rotation SO(3) to R^+
Expand Down
Loading