-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcli.py
More file actions
86 lines (79 loc) · 3.45 KB
/
cli.py
File metadata and controls
86 lines (79 loc) · 3.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import copy
import torch
import argparse
import torch.nn as nn
from loguru import logger
from modelutils import get_opt
from evaluation import opt_eval
from datautils import get_loaders
from save_and_load import save_lr_tensors, load_lr_tensors
from core_compression import opt_delta_lr
@torch.no_grad()
def quantize_with_lowrank(base_model, target_model, dataloader, rank, wbits, n_samples):
# first do low rank approximation
# then quantize
original_finetuned_model = copy.deepcopy(target_model)
for base_p, finetuned_p in zip(base_model.parameters(), target_model.parameters()):
finetuned_p.data = (finetuned_p.data - base_p.data).data.clone()
r_quantizers, l_quantizers, lr_tensors = opt_delta_lr(
original_finetuned_model,
target_model,
dataloader,
nsamples=n_samples,
wbits=wbits,
sym=True,
trits=False,
rank = rank,
args={
'percdamp': 0.01,
'groupsize': -1,
'actorder': False,
}
)
target_model.to(base_model.device)
# then we can restore the original model
for base_p, finetuned_p in zip(base_model.parameters(), target_model.parameters()):
finetuned_p.data = (finetuned_p.data + base_p.data).data.clone()
# target_model is now the approximation of base_model
# then we can estimate how many parameters we saved
return r_quantizers, l_quantizers, lr_tensors
if __name__=="__main__":
argparser = argparse.ArgumentParser()
argparser.add_argument('--base-model', type=str, required=True)
argparser.add_argument('--target-model', type=str, required=True)
argparser.add_argument('--dataset', type=str, required=True)
argparser.add_argument('--delta', action='store_true', default=True, help='Whether to use delta compression')
argparser.add_argument('--rank', type=int, required=False, default=-1, help='Rank to use for decomposing each matrices')
argparser.add_argument('--nsamples', type=int, required=False, default=128, help='Number of samples to use for quantization')
argparser.add_argument('--seed', type=int, required=False, default=42, help='Seed to use for quantization')
argparser.add_argument('--save', type=str, default='', help='Path to save the quantized model')
argparser.add_argument('--wbits', type=int, default=8, help='Number of bits to use for quantization')
argparser.add_argument('--sym', action='store_true', default=True, help='Whether to use symmetric quantization')
argparser.add_argument('--trits', action='store_true', default=False, help='Whether to use trits')
args = argparser.parse_args()
seed = args.seed
base_model = get_opt(args.base_model)
target_model = get_opt(args.target_model)
base_model.to('cuda')
target_model.to('cuda')
base_model.eval()
target_model.eval()
trainloader, loader_enc = get_loaders(
args.dataset,
nsamples = args.nsamples,
seed=seed,
model=args.target_model,
seqlen=base_model.seqlen,
)
r_quantizer, l_quantizer, lr_tensors = quantize_with_lowrank(
base_model,
target_model,
trainloader,
args.rank,
args.wbits,
args.nsamples
)
if args.save:
save_lr_tensors(lr_tensors, f"{args.save}/{args.target_model.replace('/', '.')}-r{args.rank}-w{args.wbits}-lr.safetensors")
ppl = opt_eval(target_model, loader_enc, args, target_model.device)
logger.info(f"Perplexity: {ppl}")