Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,23 +96,27 @@ async def main():
questions = load_questions(args.dataset)[args.start_qid:args.end_qid]

target_tokenizer = AutoTokenizer.from_pretrained(args.model)
draft_tokenizer = AutoTokenizer.from_pretrained(args.draft_model)

target_config = get_model_config(args.model)
draft_config = get_model_config(args.draft_model)

os.environ['CUDA_VISIBLE_DEVICES'] = args.target_gpu_id
# Set environment variables for GPU usage
target_model = Targeter(args.model, eos_id=target_config['eos_id'], target_gpu_id=args.target_gpu_id,
enable_n_gram=args.enable_n_gram, vllm_config={'force_eager': False, 'num_speculative_tokens': args.num_speculative_tokens, 'prompt_lookup_max': args.prompt_lookup_max})

os.environ['CUDA_VISIBLE_DEVICES'] = args.draft_gpu_id
draft_model = Drafter(args.draft_model, eos_id=draft_config['eos_id'], draft_gpu_id=args.draft_gpu_id,
enable_n_gram=args.enable_n_gram, vllm_config={'force_eager': False, 'num_speculative_tokens': args.num_speculative_tokens, 'prompt_lookup_max': args.prompt_lookup_max})

if args.use_spec:
draft_tokenizer = AutoTokenizer.from_pretrained(args.draft_model)
draft_config = get_model_config(args.draft_model)

assert target_config['name'] == draft_config['name'], \
"Target and draft models must be of the same type (e.g., both Qwen3)."
os.environ['CUDA_VISIBLE_DEVICES'] = args.draft_gpu_id
draft_model = Drafter(args.draft_model, eos_id=draft_config['eos_id'], draft_gpu_id=args.draft_gpu_id,
enable_n_gram=args.enable_n_gram, vllm_config={'force_eager': False, 'num_speculative_tokens': args.num_speculative_tokens, 'prompt_lookup_max': args.prompt_lookup_max})

assert target_config['name'] == draft_config['name'], \
"Target and draft models must be of the same type (e.g., both Qwen3)."
else:
draft_tokenizer = None
draft_model = None
draft_config = None

target_config['judge_model'] = args.judge_model
print(f"Target Model Config: {target_config}")
Expand Down
43 changes: 25 additions & 18 deletions src/lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,27 +82,34 @@ async def run_problem(question, i, target_model, draft_model, \
tokenize=False,
)

draft_prompt = draft_tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
add_generation_prompt=True,
tokenize=False,
)


target_token_ids = target_tokenizer.encode(target_prompt, add_special_tokens=False)
draft_token_ids = draft_tokenizer.encode(draft_prompt, add_special_tokens=False)
# response = await target_model.target(prompt_token_ids)

if target_config['name'] == draft_config['name']:
target2draft = lambda x: x
draft2target = lambda x: x
if use_spec:
draft_prompt = draft_tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
add_generation_prompt=True,
tokenize=False,
)
draft_token_ids = draft_tokenizer.encode(draft_prompt, add_special_tokens=False)

if target_config['name'] == draft_config['name']:
target2draft = lambda x: x
draft2target = lambda x: x
else:
target2draft = lambda x: token_transform(x, target_tokenizer, draft_tokenizer)
draft2target = lambda x: token_transform(x, draft_tokenizer, target_tokenizer)

print('Running question:', question_id,
'Draft prompt:', [draft_prompt], 'Target prompt:', [target_prompt],
'Tokens: ', [target_token_ids, draft_token_ids])
else:
target2draft = lambda x: token_transform(x, target_tokenizer, draft_tokenizer)
draft2target = lambda x: token_transform(x, draft_tokenizer, target_tokenizer)

print('Running question:', question_id,
'Draft prompt:', [draft_prompt], 'Target prompt:', [target_prompt],
'Tokens: ', [target_token_ids, draft_token_ids])
draft_prompt = None
draft_token_ids = None
target2draft = None
draft2target = None
print('Running question:', question_id,
'Target prompt:', [target_prompt],
'Tokens: ', [target_token_ids])

t0 = time.time()
if use_spec:
Expand Down