- [2025-11] Setup the GitHub project of SonoNexus!!!
SonoNexus is a foundation model-powered sensing system that acts as a hardware-agnostic Rosetta Stone for interpreting images across the entire sensor landscape. It is built upon two cornerstone contributions. First, we construct Sono-21M, the largest and most diverse ultrasound dataset to date, comprising 21.14 million images of 20 major organ types. Purposefully curated from 10 distinct mainstream sensor models across 17 hospitals. Second, we developed SonoNexus via a self-supervised learning strategy, enabling seamless performance across a broad spectrum of devices and downstream clinical applications.
Here, we provide the inference codes to show the effectivenss of the pre-trained models on reconstructing the masked US images and capturing the discriminative features.
Detailed feature visualization and image inference codes are define in test_model.py. To calculate the activation maps, we provide two query anchors, including max-pooled token and average-pooled token among patch tokens.
import torch
import torch.nn.functional as F
from load_model import VisionUlt
import os
import matplotlib.pyplot as plt
import numpy as np
import cv2
from dataset import get_data
# Set device
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ==========================================
# 1. Helper Functions
# ==========================================
def denormalize(img_tensor):
"""
Convert an ImageNet-standardized tensor back to a 0-255 numpy array (H, W, C)
"""
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1).to(img_tensor.device)
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1).to(img_tensor.device)
img = img_tensor * std + mean
img = torch.clamp(img, 0, 1)
img = img.permute(1, 2, 0).cpu().detach().numpy()
return (img * 255).astype(np.uint8)
def compute_similarity_heatmap(feats, img_size, pool_type='avg'):
"""
Compute a cosine similarity heatmap between feature maps and global features
Args:
feats: [B, H, W, C] Input features
img_size: (Target_H, Target_W) Original image size
pool_type: 'avg' (average pooling) or 'max' (max pooling)
"""
B, H, W, C = feats.shape
# 1. Compute global feature vector
# Perform pooling on spatial dimensions (H, W), i.e., dimensions 1 and 2
if pool_type == 'avg':
# [B, H, W, C] -> [B, 1, 1, C]
global_feat = feats.mean(dim=(1, 2), keepdim=True)
elif pool_type == 'max':
# [B, H, W, C] -> [B, C] -> [B, 1, 1, C]
# torch.amax supports multi-dimension max
global_feat = torch.amax(feats, dim=(1, 2), keepdim=True)
else:
raise ValueError("pool_type must be 'avg' or 'max'")
# 2. Compute cosine similarity
# feats: [B, H, W, C]
# global_feat: [B, 1, 1, C]
# F.cosine_similarity automatically broadcasts and computes along dim=-1 (channel)
similarity_map = F.cosine_similarity(feats, global_feat, dim=-1) # Result: [B, H, W]
# 3. Upsample to original image size
# Interpolation requires [B, C, H, W] format, where C=1
similarity_map = similarity_map.unsqueeze(1) # [B, 1, H, W]
similarity_map = F.interpolate(similarity_map, size=img_size, mode='bilinear', align_corners=False)
similarity_map = similarity_map.squeeze(1) # [B, Target_H, Target_W]
return similarity_map
def apply_heatmap_overlay(img_rgb, heatmap_tensor):
"""
Overlay a heatmap on the original image
"""
# Convert to numpy
heatmap_np = heatmap_tensor.cpu().detach().numpy()
# Normalize (Min-Max) to 0-1
# Cosine similarity range is typically [-1, 1], so we map it to the visualization range
heatmap_np = heatmap_np - np.min(heatmap_np)
heatmap_np = heatmap_np / (np.max(heatmap_np) + 1e-8)
# Convert to pseudo-color
heatmap_uint8 = (heatmap_np * 255).astype(np.uint8)
heatmap_color = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
heatmap_color = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB)
# Overlay
overlay = cv2.addWeighted(img_rgb, 0.6, heatmap_color, 0.4, 0)
return heatmap_color, overlay
# ==========================================
# 2. Model Loading
# ==========================================
path = "your_downloaded_pth_file"
model = VisionUlt().to(device)
checkpoint = torch.load(path, map_location=device)
state_dict = {k.replace("module.", ""): v for k, v in checkpoint.items()}
model.load_state_dict(state_dict, strict=False)
model.eval()
print("Model loaded successfully.")
dataloader = get_data(data_root="your_test_images", batch_size=4)
# ==========================================
# 3. Main Loop and Visualization
# ==========================================
for data in dataloader:
image, mask, _ = data
image = image.to(device)
mask = mask.to(device)
# Inference
with torch.no_grad():
image_recon = model(image, mask)
# Get features: [B, H, W, C]
feats = model.model(image * (1 - mask))[3]
feats = model.merge(feats)
print(f"Feats shape: {feats.shape}")
mse = ((image_recon - image) ** 2).mean()
print(f"mse is {mse}")
# Prepare data for visualization
batch_size = image.shape[0]
img_h, img_w = image.shape[2], image.shape[3]
# --- Core modification: compute similarity heatmap ---
# You can choose pool_type='avg' or 'max'
heatmaps_resized = compute_similarity_heatmap(feats, (img_h, img_w), pool_type='max')
# Create canvas
fig, axs = plt.subplots(batch_size, 4, figsize=(16, 4 * batch_size))
if batch_size == 1: axs = axs[None, :]
for i in range(batch_size):
# 1. Original Image
img_orig = denormalize(image[i])
img_recon = denormalize(image_recon[i])
# 2. Masked Image
mask_np = mask[i].permute(1, 2, 0).cpu().detach().numpy()
img_masked = img_orig * (1 - mask_np)
img_masked = img_masked.astype(np.uint8)
# 3. Similarity Heatmap & Overlay
heatmap_vis, overlay_vis = apply_heatmap_overlay(img_orig, heatmaps_resized[i])
# --- Plot ---
axs[i, 0].imshow(img_orig)
axs[i, 0].set_title("Original Image")
axs[i, 0].axis('off')
axs[i, 1].imshow(img_masked)
axs[i, 1].set_title("Masked Input")
axs[i, 1].axis('off')
axs[i, 2].imshow(img_recon)
axs[i, 2].set_title("Recon Image")
axs[i, 2].axis('off')
axs[i, 3].imshow(overlay_vis)
axs[i, 3].set_title("Overlay")
axs[i, 3].axis('off')
plt.tight_layout()
save_path = "visualization_similarity.png"
plt.savefig(save_path)
print(f"Visualization saved to {save_path}")
breakIf ones are willing to pre-train SonoNexus on in-house datasets, please refer:
The training and testing datasets are defined in ./dataset_mae_cnn.py, with the data pre-processing augmentation pipeline and masking strategy. Our in-house pre-trained data consists of a large-scale dataset of 21,140,761 covering 20 major organs, enabling comprehensive model training and evaluation, collected from 10 types of ultrasound equipment/sensors.
The model is in ./model/swin.py, including the model definition, masked image reconstruction loss and contrastive loss.
The training process is in ./train_mae_cnn.py and the running file is ./main_mae_cnn.py
- β Fetal ultrasound view classification
- β Organ segmentation
- β Anatomical structure detection
- β Disease classification
When pre-trained period is finised, ones can easily transfer the model into diverse down-stream tasks for US images. In the main paper, we focus on four tasks, inclduing fetal ultrasound view classification, organ segmentation, anatomical structure detection and disease classification.
This project is released under the Apache 2.0 License.







