diff --git a/rfdiffusion/Attention_module.py b/rfdiffusion/Attention_module.py index f8868fc2..0e345733 100644 --- a/rfdiffusion/Attention_module.py +++ b/rfdiffusion/Attention_module.py @@ -60,20 +60,14 @@ def forward(self, query, key, value): B, Q = query.shape[:2] B, K = key.shape[:2] # - query = self.to_q(query).reshape(B, Q, self.h, self.dim) - key = self.to_k(key).reshape(B, K, self.h, self.dim) - value = self.to_v(value).reshape(B, K, self.h, self.dim) - # - query = query * self.scaling - attn = einsum('bqhd,bkhd->bhqk', query, key) - attn = F.softmax(attn, dim=-1) - # - out = einsum('bhqk,bkhd->bqhd', attn, value) - out = out.reshape(B, Q, self.h*self.dim) - # - out = self.to_out(out) - - return out + # (B, seq, h, d) -> (B, h, seq, d) for scaled_dot_product_attention + query = self.to_q(query).reshape(B, Q, self.h, self.dim).transpose(1, 2) + key = self.to_k(key ).reshape(B, K, self.h, self.dim).transpose(1, 2) + value = self.to_v(value).reshape(B, K, self.h, self.dim).transpose(1, 2) + # scaling and softmax handled internally; uses Flash Attention when available + out = F.scaled_dot_product_attention(query, key, value) # (B, h, Q, d) + out = out.transpose(1, 2).reshape(B, Q, self.h * self.dim) + return self.to_out(out) class AttentionWithBias(nn.Module): def __init__(self, d_in=256, d_bias=128, n_head=8, d_hidden=32): @@ -117,22 +111,17 @@ def forward(self, x, bias): x = self.norm_in(x) bias = self.norm_bias(bias) # - query = self.to_q(x).reshape(B, L, self.h, self.dim) - key = self.to_k(x).reshape(B, L, self.h, self.dim) - value = self.to_v(x).reshape(B, L, self.h, self.dim) - bias = self.to_b(bias) # (B, L, L, h) - gate = torch.sigmoid(self.to_g(x)) - # - key = key * self.scaling - attn = einsum('bqhd,bkhd->bqkh', query, key) - attn = attn + bias - attn = F.softmax(attn, dim=-2) - # - out = einsum('bqkh,bkhd->bqhd', attn, value).reshape(B, L, -1) + # (B, L, h, d) -> (B, h, L, d); bias (B, L, L, h) -> (B, h, L, L) + query = self.to_q(x).reshape(B, L, self.h, self.dim).transpose(1, 2) + key = self.to_k(x).reshape(B, L, self.h, self.dim).transpose(1, 2) + value = self.to_v(x).reshape(B, L, self.h, self.dim).transpose(1, 2) + bias = self.to_b(bias).permute(0, 3, 1, 2) # (B, h, L, L) + gate = torch.sigmoid(self.to_g(x)) + # bias added to logits before softmax; Flash Attention used when available + out = F.scaled_dot_product_attention(query, key, value, attn_mask=bias) + out = out.transpose(1, 2).reshape(B, L, -1) # (B, L, h*d) out = gate * out - # - out = self.to_out(out) - return out + return self.to_out(out) # MSA Attention (row/column) from AlphaFold architecture class SequenceWeight(nn.Module): @@ -265,19 +254,20 @@ def forward(self, msa): msa = self.norm_msa(msa) # query = self.to_q(msa).reshape(B, N, L, self.h, self.dim) - key = self.to_k(msa).reshape(B, N, L, self.h, self.dim) + key = self.to_k(msa).reshape(B, N, L, self.h, self.dim) value = self.to_v(msa).reshape(B, N, L, self.h, self.dim) - gate = torch.sigmoid(self.to_g(msa)) - # - query = query * self.scaling - attn = einsum('bqihd,bkihd->bihqk', query, key) - attn = F.softmax(attn, dim=-1) - # - out = einsum('bihqk,bkihd->bqihd', attn, value).reshape(B, N, L, -1) + gate = torch.sigmoid(self.to_g(msa)) + # Column attention: for each residue position, attend across N sequences. + # Reshape to (B*L, h, N, d) so scaled_dot_product_attention operates over N. + q = query.permute(0, 2, 3, 1, 4).reshape(B * L, self.h, N, self.dim) + k = key .permute(0, 2, 3, 1, 4).reshape(B * L, self.h, N, self.dim) + v = value.permute(0, 2, 3, 1, 4).reshape(B * L, self.h, N, self.dim) + out = F.scaled_dot_product_attention(q, k, v) # (B*L, h, N, d) + out = (out.reshape(B, L, self.h, N, self.dim) + .permute(0, 3, 1, 2, 4) + .reshape(B, N, L, -1)) out = gate * out - # - out = self.to_out(out) - return out + return self.to_out(out) class MSAColGlobalAttention(nn.Module): def __init__(self, d_msa=64, n_head=8, d_hidden=8): diff --git a/rfdiffusion/diffusion.py b/rfdiffusion/diffusion.py index a67e5794..14261492 100644 --- a/rfdiffusion/diffusion.py +++ b/rfdiffusion/diffusion.py @@ -3,46 +3,56 @@ import pickle import numpy as np import os +import math import logging -from scipy.spatial.transform import Rotation as scipy_R - from rfdiffusion.util import rigid_from_3_points - from rfdiffusion.util_module import ComputeAllAtomCoords - from rfdiffusion import igso3 import time +# Module-level cache so IGSO3 lookup tables survive across Diffuser instantiations +# (avoids redundant disk I/O when generating batches of designs). +_igso3_cache: dict = {} + torch.set_printoptions(sci_mode=False) def get_beta_schedule(T, b0, bT, schedule_type, schedule_params={}, inference=False): """ - Given a noise schedule type, create the beta schedule - """ - assert schedule_type in ["linear"] + Given a noise schedule type, create the beta schedule. - # Adjust b0 and bT if T is not 200 - # This is a good approximation, with the beta correction below, unless T is very small + schedule_type options: + "linear" — Ho et al. (2020) linear schedule, scaled to T steps. + "cosine" — Nichol & Dhariwal (2021) cosine schedule; b0/bT ignored. + """ + assert schedule_type in ["linear", "cosine"], ( + f"Unknown schedule type '{schedule_type}'. Choose 'linear' or 'cosine'." + ) assert T >= 15, "With discrete time and T < 15, the schedule is badly approximated" - b0 *= 200 / T - bT *= 200 / T - # linear noise schedule if schedule_type == "linear": + # Scale endpoints to be equivalent to a 200-step schedule + b0 *= 200 / T + bT *= 200 / T schedule = torch.linspace(b0, bT, T) - else: - raise NotImplementedError(f"Schedule of type {schedule_type} not implemented.") + elif schedule_type == "cosine": + # Cosine schedule from Nichol & Dhariwal (2021), Improved DDPM + s = schedule_params.get("s", 0.008) + steps = torch.arange(T + 1, dtype=torch.float64) + f = torch.cos((steps / T + s) / (1.0 + s) * math.pi / 2.0) ** 2 + alphabar = (f / f[0]).float() + schedule = torch.clamp(1.0 - alphabar[1:] / alphabar[:-1], max=0.999) - # get alphabar_t for convenience - alpha_schedule = 1 - schedule + alpha_schedule = 1.0 - schedule alphabar_t_schedule = torch.cumprod(alpha_schedule, dim=0) if inference: print( - f"With this beta schedule ({schedule_type} schedule, beta_0 = {round(b0, 3)}, beta_T = {round(bT,3)}), alpha_bar_T = {alphabar_t_schedule[-1]}" + f"Beta schedule: {schedule_type}, " + f"beta_0={schedule[0].item():.5f}, beta_T={schedule[-1].item():.5f}, " + f"alpha_bar_T={alphabar_t_schedule[-1].item():.5f}" ) return schedule, alpha_schedule, alphabar_t_schedule @@ -228,6 +238,10 @@ def _calc_igso3_vals(self, L=2000): if not os.path.isdir(self.cache_dir): os.makedirs(self.cache_dir) + if cache_fname in _igso3_cache: + self._log.info("Using in-memory IGSO3 cache.") + return _igso3_cache[cache_fname] + if os.path.exists(cache_fname): self._log.info("Using cached IGSO3.") igso3_vals = read_pkl(cache_fname) @@ -241,6 +255,7 @@ def _calc_igso3_vals(self, L=2000): ) write_pkl(cache_fname, igso3_vals) + _igso3_cache[cache_fname] = igso3_vals return igso3_vals @property @@ -288,23 +303,29 @@ def sigma(self, t: torch.tensor): def g(self, t): """ - g returns the drift coefficient at time t + g returns the drift coefficient at time t. - since - sigma(t)^2 := \int_0^t g(s)^2 ds, - for arbitrary sigma(t) we invert this relationship to compute - g(t) = sqrt(d/dt sigma(t)^2). + g(t) = sqrt(d/dt sigma(t)^2) - Args: - t: scalar time between 0 and 1 + For the linear schedule sigma(t) = min_sigma + t*min_b + 0.5*t^2*(max_b - min_b), + we derive analytically: + d/dt sigma(t)^2 = 2*sigma(t) * (min_b + t*(max_b - min_b)) + which avoids a per-step autograd call. - Returns: - drift cooeficient as a scalar. + For the exponential schedule, autograd is still used as a fallback. """ - t = torch.tensor(t, requires_grad=True) - sigma_sqr = self.sigma(t) ** 2 - grads = torch.autograd.grad(sigma_sqr.sum(), t)[0] - return torch.sqrt(grads) + if not torch.is_tensor(t): + t = torch.tensor(t, dtype=torch.float32) + + if self.schedule == "linear": + sigma_t = self.sigma(t) + dsigma_dt = self.min_b + t * (self.max_b - self.min_b) + return torch.sqrt(2.0 * sigma_t * dsigma_dt) + else: + t = t.requires_grad_(True) + sigma_sqr = self.sigma(t) ** 2 + grads = torch.autograd.grad(sigma_sqr.sum(), t)[0] + return torch.sqrt(grads) def sample(self, ts, n_samples=1): """ @@ -427,12 +448,9 @@ def diffuse_frames(self, xyz, t_list, diffusion_mask=None): non_diffusion_mask = 1 - diffusion_mask[None, :, None] sampled_rots = sampled_rots * non_diffusion_mask - # Apply sampled rot. - R_sampled = ( - scipy_R.from_rotvec(sampled_rots.reshape(-1, 3)) - .as_matrix() - .reshape(self.T, num_res, 3, 3) - ) + # Apply sampled rot — torch-native Exp map avoids scipy/CPU roundtrip. + sampled_rots_t = torch.from_numpy(sampled_rots.reshape(-1, 3)).float() + R_sampled = igso3.Exp_torch(sampled_rots_t).numpy().reshape(self.T, num_res, 3, 3) R_perturbed = np.einsum("tnij,njk->tnik", R_sampled, R_true) perturbed_crds = ( np.einsum( @@ -494,11 +512,10 @@ def reverse_sample_vectorized( differential equations. arXiv preprint arXiv:2011.13456. """ # compute rotation vector corresponding to prediction of how r_t goes to r_0 - R_0, R_t = torch.tensor(R_0), torch.tensor(R_t) + R_0, R_t = torch.as_tensor(R_0), torch.as_tensor(R_t) R_0t = torch.einsum("...ij,...kj->...ik", R_t, R_0) - R_0t_rotvec = torch.tensor( - scipy_R.from_matrix(R_0t.cpu().numpy()).as_rotvec() - ).to(R_0.device) + # torch-native Log map: stays on-device, no CPU/scipy roundtrip + R_0t_rotvec = igso3.Log_torch(R_0t).to(dtype=torch.float32, device=R_0.device) # Approximate the score based on the prediction of R0. # R_t @ hat(Score_approx) is the score approximation in the Lie algebra @@ -527,7 +544,8 @@ def reverse_sample_vectorized( Perturb_tangent = Delta_r + rot_g * np.sqrt(self.step_size) * Z if mask is not None: Perturb_tangent *= (1 - mask.long())[:, None, None] - Perturb = igso3.Exp(Perturb_tangent) + # torch-native Exp map: stays on-device, no scipy roundtrip + Perturb = igso3.Exp_torch(Perturb_tangent) if return_perturb: return Perturb diff --git a/rfdiffusion/igso3.py b/rfdiffusion/igso3.py index 6d90bdb2..e10fa4c6 100644 --- a/rfdiffusion/igso3.py +++ b/rfdiffusion/igso3.py @@ -15,14 +15,53 @@ def hat(v): hat_v[:, 0, 1], hat_v[:, 0, 2], hat_v[:, 1, 2] = -v[:, 2], v[:, 1], -v[:, 0] return hat_v + -hat_v.transpose(2, 1) -# Logarithmic map from SO(3) to R^3 (i.e. rotation vector) +def hat_batch(v): + """Batch hat map: [..., 3] -> [..., 3, 3] (cross-product / skew-symmetric matrix).""" + bshape = v.shape[:-1] + h = torch.zeros(*bshape, 3, 3, device=v.device, dtype=v.dtype) + h[..., 0, 1] = -v[..., 2] + h[..., 0, 2] = v[..., 1] + h[..., 1, 0] = v[..., 2] + h[..., 1, 2] = -v[..., 0] + h[..., 2, 0] = -v[..., 1] + h[..., 2, 1] = v[..., 0] + return h + +def Log_torch(R): + """On-device rotation matrix -> rotation vector. R: [..., 3, 3] -> [..., 3]. + Stays on the original device/dtype — no scipy or CPU transfers.""" + trace = R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2] + theta = torch.acos(torch.clamp((trace - 1.0) / 2.0, -1.0, 1.0)) + skew = torch.stack([ + R[..., 2, 1] - R[..., 1, 2], + R[..., 0, 2] - R[..., 2, 0], + R[..., 1, 0] - R[..., 0, 1], + ], dim=-1) + sin_theta = torch.clamp(torch.sin(theta), min=1e-7) + axis = skew / (2.0 * sin_theta[..., None]) + rotvec = axis * theta[..., None] + return torch.where(theta[..., None] < 1e-6, torch.zeros_like(rotvec), rotvec) + +def Exp_torch(v): + """On-device rotation vector -> rotation matrix. v: [..., 3] -> [..., 3, 3]. + Rodrigues formula. Stays on the original device/dtype.""" + theta = torch.norm(v, dim=-1) + theta_safe = torch.clamp(theta, min=1e-7) + axis = v / theta_safe[..., None] + K = hat_batch(axis) + I = torch.eye(3, device=v.device, dtype=v.dtype).expand(*v.shape[:-1], 3, 3) + sin_t = torch.sin(theta)[..., None, None] + cos_t = torch.cos(theta)[..., None, None] + R = I + sin_t * K + (1.0 - cos_t) * (K @ K) + return torch.where(theta[..., None, None] < 1e-7, I, R) + +# Logarithmic map from SO(3) to R^3 (i.e. rotation vector) — legacy CPU version def Log(R): return torch.tensor(Rotation.from_matrix(R.numpy()).as_rotvec()) - + # logarithmic map from SO(3) to so(3), this is the matrix logarithm def log(R): return hat(Log(R)) -# Exponential map from vector space of so(3) to SO(3), this is the matrix -# exponential combined with the "hat" map +# Exponential map from vector space of so(3) to SO(3) — legacy CPU version def Exp(A): return torch.tensor(Rotation.from_rotvec(A.numpy()).as_matrix()) # Angle of rotation SO(3) to R^+ diff --git a/rfdiffusion/inference/utils.py b/rfdiffusion/inference/utils.py index 3fb14112..9c72e713 100644 --- a/rfdiffusion/inference/utils.py +++ b/rfdiffusion/inference/utils.py @@ -4,7 +4,6 @@ import torch import torch.nn.functional as nn from rfdiffusion.diffusion import get_beta_schedule -from scipy.spatial.transform import Rotation as scipy_R from rfdiffusion.util import rigid_from_3_points from rfdiffusion.util_module import ComputeAllAtomCoords from rfdiffusion import util @@ -53,9 +52,9 @@ def get_next_frames(xt, px0, t, diffuser, so3_type, diffusion_mask, noise_scale= R_t, Ca_t = rigid_from_3_points(N_t, Ca_t, C_t) - # this must be to normalize them or something - R_0 = scipy_R.from_matrix(R_0.squeeze().numpy()).as_matrix() - R_t = scipy_R.from_matrix(R_t.squeeze().numpy()).as_matrix() + # rigid_from_3_points already returns proper rotation matrices; convert to numpy. + R_0 = R_0.squeeze().numpy() + R_t = R_t.squeeze().numpy() L = R_t.shape[0] all_rot_transitions = np.broadcast_to(np.identity(3), (L, 3, 3)).copy() @@ -122,6 +121,33 @@ def get_mu_xt_x0(xt, px0, t, beta_schedule, alphabar_schedule, eps=1e-6): return mu, sigma +def get_mu_xt_x0_ddim(xt, px0, t, alphabar_schedule, eps=1e-8): + """ + Deterministic DDIM update for Cα coordinates (Song et al., 2021). + + Unlike DDPM, DDIM skips the stochastic noise term and uses: + x_{t-1} = sqrt(alpha_bar_{t-1}) * x̂_0 + + sqrt(1 - alpha_bar_{t-1}) * epsilon_theta(x_t, t) + where epsilon_theta is the implied noise direction derived from x_t and x̂_0. + + Setting noise_scale=0 in DDPM is not equivalent — DDIM uses a different mean. + """ + t_idx = t - 1 + xt_ca = xt[:, 1, :] + px0_ca = px0[:, 1, :] + + alphabar_t = alphabar_schedule[t_idx] + alphabar_tm1 = alphabar_schedule[t_idx - 1] if t_idx > 0 else torch.ones(1, dtype=xt.dtype, device=xt.device) + + # Implied noise direction + eps_theta = (xt_ca - torch.sqrt(alphabar_t + eps) * px0_ca) / torch.sqrt(1.0 - alphabar_t + eps) + + # DDIM deterministic update + x_tm1 = torch.sqrt(alphabar_tm1) * px0_ca + torch.sqrt(1.0 - alphabar_tm1) * eps_theta + delta = x_tm1 - xt_ca + return delta + + def get_next_ca( xt, px0, @@ -131,6 +157,7 @@ def get_next_ca( beta_schedule, alphabar_schedule, noise_scale=1.0, + ddim=False, ): """ Given full atom x0 prediction (xyz coordinates), diffuse to x(t-1) @@ -155,24 +182,24 @@ def get_next_ca( get_allatom = ComputeAllAtomCoords().to(device=xt.device) L = len(xt) - # bring to origin after global alignment (when don't have a motif) or replace input motif and bring to origin, and then scale px0 = px0 * crd_scale - xt = xt * crd_scale + xt = xt * crd_scale - # get mu(xt, x0) - mu, sigma = get_mu_xt_x0( - xt, px0, t, beta_schedule=beta_schedule, alphabar_schedule=alphabar_schedule - ) - - sampled_crds = torch.normal(mu, torch.sqrt(sigma * noise_scale)) - delta = sampled_crds - xt[:, 1, :] # check sign of this is correct + if ddim: + # Deterministic DDIM update — faster convergence, no stochastic noise + delta = get_mu_xt_x0_ddim(xt, px0, t, alphabar_schedule=alphabar_schedule) + else: + # Stochastic DDPM update + mu, sigma = get_mu_xt_x0( + xt, px0, t, beta_schedule=beta_schedule, alphabar_schedule=alphabar_schedule + ) + sampled_crds = torch.normal(mu, torch.sqrt(sigma * noise_scale)) + delta = sampled_crds - xt[:, 1, :] if not diffusion_mask is None: - # Don't move motif delta[diffusion_mask, ...] = 0 out_crds = xt + delta[:, None, :] - return out_crds / crd_scale, delta / crd_scale @@ -243,13 +270,14 @@ def __init__( crd_scale=1 / 15, potential_manager=None, partial_T=None, + ddim=False, ): """ - Parameters: noise_level: scaling on the noise added (set to 0 to use no noise, to 1 to have full noise) - + ddim: use deterministic DDIM update for Cα coordinates instead of + stochastic DDPM. Enables fewer-step inference at equivalent quality. """ self.T = T self.L = L @@ -267,6 +295,7 @@ def __init__( self.final_noise_scale_frame = final_noise_scale_frame self.frame_noise_schedule_type = frame_noise_schedule_type self.potential_manager = potential_manager + self.ddim = ddim self._log = logging.getLogger(__name__) self.schedule, self.alpha_schedule, self.alphabar_schedule = get_beta_schedule( @@ -464,6 +493,7 @@ def get_next_pose( beta_schedule=self.schedule, alphabar_schedule=self.alphabar_schedule, noise_scale=noise_scale_ca, + ddim=self.ddim, ) # get the next set of backbone frames (coordinates) diff --git a/rfdiffusion/kinematics.py b/rfdiffusion/kinematics.py index 8d548394..f67cf372 100644 --- a/rfdiffusion/kinematics.py +++ b/rfdiffusion/kinematics.py @@ -47,7 +47,7 @@ def get_ang(a, b, c): w /= torch.norm(w, dim=-1, keepdim=True) vw = torch.sum(v*w, dim=-1) - return torch.acos(vw) + return torch.acos(torch.clamp(vw, -1.0, 1.0)) # ============================================================ def get_dih(a, b, c, d):