We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 56733b0 commit f14800aCopy full SHA for f14800a
models/loader.py
@@ -5,6 +5,8 @@
5
import warnings
6
from typing import Optional
7
8
+import torch
9
+
10
from configs.model_config import DEVICE
11
from models.adapter import get_llm_model_adapter
12
from models.compression import compress_module
@@ -52,7 +54,6 @@ class ModelLoader(metaclass=Singleton):
52
54
def __init__(self, model_path) -> None:
53
55
self.device = DEVICE
56
self.model_path = model_path
- import torch
57
self.kwargs = {
58
"torch_dtype": torch.float16,
59
"device_map": "auto",
@@ -67,7 +68,6 @@ def loader(
67
68
cpu_offloading=False,
69
max_gpu_memory: Optional[str] = None,
70
):
71
if self.device == "cpu":
72
kwargs = {"torch_dtype": torch.float32}
73
0 commit comments