Sicheng Mo1, Thao Nguyen2, Richard Zhang3, Nicholas Kolkin3, Siddharth Srinivasan Iyer3, Eli Shechtman3, Krishna Kumar Singh3, Yong Jae Lee2, Bolei Zhou1, Yuheng Li31UCLA 2UW-Madison 3Adobe
This is the official implementation of Group Diffusion, a Generative AI algorithm for enhancing image generation via cross-sample attention.
Download the code: (example, please update it)
git clone https://github.com/adobe-research/GroupDiff.git
cd GroupDiffCreate and activate conda environment:
conda create -n gdiff python=3.10 -y && conda activate gdiff
pip install -r requirements.txt
pip install 'tensorflow[and-cuda]'Download the ImageNet dataset, and place it in under data/imagenet/.
After that, run the following script to extract VAE latent and embedding for each images.
mkdir data
# Download / link imagenet train to data/imagenet
ln -s YOUR_IMAGENET_PATH data/imagenet
torchrun --nnodes=1 --nproc_per_node=8 dataset/extract_feats.py \
--data-path data/imagenet/imagenet/train \
--features-path data/imagenet_feats/train \
--models allThen, create FAISS index for fast similarity search during training (this will take around 30 minutes):
python dataset/create_fasiss.py \
--metadata data/imagenet_feats/train/metadata.json \
--feature_key dinov2-l \
--output_dir data/imagenet_feats/train_index/dinov2-l \
--index_type ivfpq \
--num_workers 32This will create a FAISS index from the extracted DINOv2-L features, which enables efficient nearest neighbor search during GroupDiff training. The ivfpq index type provides a good balance between search speed and memory usage.
| Model | Resume From | Pre-Training Iters | Fine-tuning Iters | Link |
|---|---|---|---|---|
| GroupDiff-l-4 (SiT-XL) | REPA-SiT | 4M | 500K | ckpt |
| GroupDiff-l-4 (DiT-XL) | REPA-SiT | 4M | 500K | ckpt |
Run the following script to download our pre-trained checkpoints and FID stats file (from ADM's TensorFlow evaluation suite).
huggingface-cli download Sichengmo/GroupDiff --local-dir released_model
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz -O data/VIRTUAL_imagenet256_labeled.npz
Detailed scripts for GroupDiff training can be found in scripts/.
Train GroupDiff-l with DiT-XL (800 epochs):
project=GroupDiff
exp_name=groupdiff-l-4-dit-xl-dinov2
batch_size=32 # per GPU batch size, global batch size = batch_size x num_gpus = 32 x 8 = 256
epochs=800
YOUR_WANDB_ENTITY="YourWandbEntity"
accelerate launch --num_processes 8 --multi_gpu --mixed_precision=bf16 train.py \
--project $project --exp_name $exp_name --auto_resume \
--model DiT_xl --patch_size 2 --num_max_sample 4 \
--batch_size $batch_size --epochs $epochs \
--lr 1e-4 --num_sampling_steps 250 \
--data_path data/imagenet/train \
--query_sim feat --use_cached_tokens \
--entity $YOUR_WANDB_ENTITYOptional: Using Custom Feature Paths
Then add these arguments to specify custom paths:
--metadata_path data/imagenet_feats/train/metadata.json \
--faiss_index_path data/imagenet_feats/train_index/dinov2-l/imagenet_index_ivfpq.faiss \
--faiss_feature_name "dinov2-l" \
--features_root data/imagenet_feats/train \
--load_latent \
--latent_feature_name "vae-256" \
--latent_root data/imagenet_feats/trainTrain GroupDiff-l with SiT-XL (800 epochs):
project=GroupDiff
exp_name=groupdiff-l-4-sit-xl-dinov2
batch_size=32 # per GPU batch size, global batch size = batch_size x num_gpus = 32 x 8 = 256
epochs=800
YOUR_WANDB_ENTITY="YourWandbEntity"
accelerate launch --num_processes 8 --multi_gpu --mixed_precision=bf16 train.py \
--project $project --exp_name $exp_name --auto_resume \
--model SiT_xl --patch_size 2 --num_max_sample 4 \
--batch_size $batch_size --epochs $epochs \
--lr 1e-4 --num_sampling_steps 250 \
--data_path data/imagenet/train \
--query_sim feat --use_cached_tokens \
--entity $YOUR_WANDB_ENTITYcheckpoint_path=work_dirs/${project}/${exp_name}/checkpoints/lastest.pth
accelerate launch --num_processes 8 --multi_gpu --mixed_precision=bf16 eval.py \
--project $project --exp_name $exp_name --auto_resume \
--model DiT_base --patch_size 2 --num_max_sample 4 \
--batch_size $batch_size --eval_bsz 128 \
--num_sampling_steps 250 --cfg 2.5 \
--guidance_low 0.0 --guidance_high 1.0 \
--cond_group_size 1 --uncond_group_size 4 \
--num_images 50000 --seed 0 \
--load_from ${checkpoint_path} --use_ema \
--fid_stats_path data/VIRTUAL_imagenet256_labeled.npz \
--entity $YOUR_WANDB_ENTITYproject=GroupDiff-l-pretrained
exp_name=gdiff-l-4-dit-xl-2-resume
checkpoint_path=released_model/${exp_name}.pth
batch_size=32 # per GPU batch size
YOUR_WANDB_ENTITY="YourWandbEntity"
accelerate launch --num_processes 8 --multi_gpu --mixed_precision=bf16 eval.py \
--project $project --exp_name $exp_name --auto_resume \
--model DiT_xl --patch_size 2 --num_max_sample 4 \
--batch_size $batch_size --eval_bsz 128 \
--num_sampling_steps 250 --cfg 1.65 \
--guidance_low 0.0 --guidance_high 1.0 \
--cond_group_size 1 --uncond_group_size 4 \
--num_images 50000 --seed 0 \
--load_from ${checkpoint_path} --use_ema \
--fid_stats_path data/VIRTUAL_imagenet256_labeled.npz \
--entity $YOUR_WANDB_ENTITYproject=GroupDiff-l-pretrained
exp_name=gdiff-l-4-sit-xl-2-repa-resume
checkpoint_path=released_model/${exp_name}.pth
batch_size=32 # per GPU batch size
YOUR_WANDB_ENTITY="YourWandbEntity"
accelerate launch --num_processes 8 --multi_gpu --mixed_precision=bf16 eval.py \
--project $project --exp_name $exp_name --auto_resume \
--model SiT_xl --patch_size 2 --num_max_sample 4 \
--batch_size $batch_size --eval_bsz 128 \
--num_sampling_steps 250 --cfg 2.585 \
--guidance_low 0.25 --guidance_high 0.75 \
--cond_group_size 1 --uncond_group_size 4 \
--num_images 50000 --seed 0 \
--load_from ${checkpoint_path} --use_ema \
--fid_stats_path data/VIRTUAL_imagenet256_labeled.npz \
--entity $YOUR_WANDB_ENTITYWe thank the authors of DiT, and SiT for their foundational work.
Our codebase builds upon several excellent open-source projects, including DeTok and MAR. We are grateful to the communities behind them.
Sicheng Mo ([email protected])
This codebase has been cleaned up but has not undergone extensive testing. If you encounter any issues or have questions, please open a GitHub issue. We appreciate your feedback!
