Skip to content

Commit e8a6001

Browse files
authored
Merge pull request #277 from michaeljabbour/feat/tensor-parallelism
feat: add tensor parallelism awareness for multi-GPU model fitting
2 parents 9ec729a + 68b6169 commit e8a6001

File tree

9 files changed

+408
-11
lines changed

9 files changed

+408
-11
lines changed

llmfit-core/src/fit.rs

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,11 @@ pub enum FitLevel {
7272
/// This is the "optimization" dimension, independent of memory fit.
7373
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize)]
7474
pub enum RunMode {
75-
Gpu, // Fully loaded into VRAM -- fast
76-
MoeOffload, // MoE: active experts in VRAM, inactive offloaded to RAM
77-
CpuOffload, // Partial GPU offload, spills to system RAM -- mixed
78-
CpuOnly, // Entirely in system RAM, no GPU -- slow
75+
Gpu, // Fully loaded into VRAM -- fast
76+
MoeOffload, // MoE: active experts in VRAM, inactive offloaded to RAM
77+
CpuOffload, // Partial GPU offload, spills to system RAM -- mixed
78+
CpuOnly, // Entirely in system RAM, no GPU -- slow
79+
TensorParallel, // Distributed via NCCL across cluster nodes
7980
}
8081

8182
/// Multi-dimensional score components (0-100 each).
@@ -165,6 +166,8 @@ impl ModelFit {
165166
// pre-quantized models default to vLLM, falling back to auto-detect.
166167
let runtime = if let Some(forced) = force_runtime {
167168
forced
169+
} else if system.cluster_mode {
170+
InferenceRuntime::Vllm
168171
} else if model.is_prequantized() {
169172
InferenceRuntime::Vllm
170173
} else if system.backend == GpuBackend::Metal && system.unified_memory {
@@ -177,7 +180,25 @@ impl ModelFit {
177180

178181
// Step 1: pick the best available execution path
179182
// Step 2: score memory fit purely on headroom in that path's memory pool
180-
let (run_mode, mem_required, mem_available) = if system.has_gpu {
183+
let (run_mode, mem_required, mem_available) = if system.cluster_mode {
184+
// Cluster mode: vLLM with tensor parallelism across multiple nodes.
185+
// Total VRAM is the sum across all nodes (NCCL handles distribution).
186+
let pool = system.total_gpu_vram_gb.unwrap_or(0.0);
187+
let tp_size = system.cluster_node_count;
188+
if let Some((_, best_mem)) = choose_quant(pool) {
189+
notes.push(format!(
190+
"Cluster: tensor-parallel across {} nodes via vLLM (TP={})",
191+
tp_size, tp_size
192+
));
193+
(RunMode::TensorParallel, best_mem, pool)
194+
} else {
195+
notes.push(format!(
196+
"Cluster: {} nodes but model exceeds aggregate VRAM ({:.1} GB)",
197+
tp_size, pool
198+
));
199+
(RunMode::TensorParallel, default_mem_required, pool)
200+
}
201+
} else if system.has_gpu {
181202
if system.unified_memory {
182203
// Unified memory (Apple Silicon or NVIDIA Tegra/Grace Blackwell):
183204
// GPU and CPU share the same memory pool.
@@ -391,6 +412,7 @@ impl ModelFit {
391412
pub fn run_mode_text(&self) -> &str {
392413
match self.run_mode {
393414
RunMode::Gpu => "GPU",
415+
RunMode::TensorParallel => "TP",
394416
RunMode::MoeOffload => "MoE",
395417
RunMode::CpuOffload => "CPU+GPU",
396418
RunMode::CpuOnly => "CPU",
@@ -413,7 +435,7 @@ fn score_fit(
413435
}
414436

415437
match run_mode {
416-
RunMode::Gpu => {
438+
RunMode::Gpu | RunMode::TensorParallel => {
417439
if recommended <= mem_available {
418440
FitLevel::Perfect
419441
} else if mem_available >= mem_required * 1.2 {
@@ -799,6 +821,7 @@ fn estimate_tps(
799821

800822
let mode_factor = match run_mode {
801823
RunMode::Gpu => 1.0,
824+
RunMode::TensorParallel => 0.9,
802825
RunMode::MoeOffload => 0.8,
803826
RunMode::CpuOffload => 0.5,
804827
RunMode::CpuOnly => unreachable!(),
@@ -835,10 +858,11 @@ fn estimate_tps(
835858

836859
// Run mode penalties
837860
match run_mode {
838-
RunMode::Gpu => {} // full speed
839-
RunMode::MoeOffload => base *= 0.8, // expert switching latency
840-
RunMode::CpuOffload => base *= 0.5, // significant penalty
841-
RunMode::CpuOnly => base *= 0.3, // worst case—override K to CPU
861+
RunMode::Gpu => {} // full speed
862+
RunMode::TensorParallel => base *= 0.9, // TP communication overhead
863+
RunMode::MoeOffload => base *= 0.8, // expert switching latency
864+
RunMode::CpuOffload => base *= 0.5, // significant penalty
865+
RunMode::CpuOnly => base *= 0.3, // worst case—override K to CPU
842866
}
843867

844868
// CPU-only should use CPU K regardless of detected GPU
@@ -1046,6 +1070,8 @@ mod tests {
10461070
gguf_sources: vec![],
10471071
capabilities: vec![],
10481072
format: models::ModelFormat::default(),
1073+
num_attention_heads: None,
1074+
num_key_value_heads: None,
10491075
}
10501076
}
10511077

@@ -1071,6 +1097,8 @@ mod tests {
10711097
GpuBackend::CpuX86
10721098
},
10731099
gpus: vec![],
1100+
cluster_mode: false,
1101+
cluster_node_count: 0,
10741102
}
10751103
}
10761104

@@ -1222,6 +1250,8 @@ mod tests {
12221250
gguf_sources: vec![],
12231251
capabilities: vec![],
12241252
format: models::ModelFormat::default(),
1253+
num_attention_heads: None,
1254+
num_key_value_heads: None,
12251255
};
12261256
let mut system = test_system(64.0, true, Some(8.0));
12271257
system.backend = GpuBackend::Cuda;
@@ -1255,6 +1285,8 @@ mod tests {
12551285
gguf_sources: vec![],
12561286
capabilities: vec![],
12571287
format: models::ModelFormat::default(),
1288+
num_attention_heads: None,
1289+
num_key_value_heads: None,
12581290
};
12591291
let system = test_system(12.0, true, Some(8.0));
12601292

@@ -1696,6 +1728,8 @@ mod tests {
16961728
unified_memory: false,
16971729
backend: GpuBackend::Cuda,
16981730
gpus: vec![],
1731+
cluster_mode: false,
1732+
cluster_node_count: 0,
16991733
}
17001734
}
17011735

llmfit-core/src/hardware.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ pub struct SystemSpecs {
5757
pub backend: GpuBackend,
5858
/// All detected GPUs (may span different vendors/backends).
5959
pub gpus: Vec<GpuInfo>,
60+
/// True when running in multi-node cluster mode (e.g. DGX Spark cluster).
61+
pub cluster_mode: bool,
62+
/// Number of nodes in the cluster (0 or 1 = single machine).
63+
pub cluster_node_count: u32,
6064
}
6165

6266
impl SystemSpecs {
@@ -112,6 +116,8 @@ impl SystemSpecs {
112116
unified_memory,
113117
backend,
114118
gpus,
119+
cluster_mode: false,
120+
cluster_node_count: 0,
115121
}
116122
}
117123

@@ -2552,6 +2558,8 @@ GPU id = 1 (NVIDIA GeForce RTX 4090)
25522558
unified_memory: false,
25532559
backend: super::GpuBackend::CpuX86,
25542560
gpus: vec![],
2561+
cluster_mode: false,
2562+
cluster_node_count: 0,
25552563
}
25562564
}
25572565

@@ -2575,6 +2583,8 @@ GPU id = 1 (NVIDIA GeForce RTX 4090)
25752583
count: 1,
25762584
unified_memory: false,
25772585
}],
2586+
cluster_mode: false,
2587+
cluster_node_count: 0,
25782588
}
25792589
}
25802590

0 commit comments

Comments
 (0)