Skip to content

Commit 76837b3

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent a45300a commit 76837b3

File tree

50 files changed

+188
-214
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+188
-214
lines changed

applications/Colossal-LLaMA/colossal_llama/model/init_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""
55
Initialize new model with updated tokenizer by calculating the mean values from original model
66
"""
7+
78
import argparse
89

910
import numpy as np

applications/ColossalChat/coati/dataset/tokenization_utils.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,8 @@ def tokenize_sft(
5656
template.messages = []
5757
for idx, mess in enumerate(messages):
5858
if mess["from"] != template.roles[idx % 2]:
59-
raise ValueError(
60-
f"Message should iterate between user and assistant and starts with a \
61-
line from the user. Got the following data:\n{messages}"
62-
)
59+
raise ValueError(f"Message should iterate between user and assistant and starts with a \
60+
line from the user. Got the following data:\n{messages}")
6361
template.append_message(mess["from"], mess["content"])
6462

6563
if len(template.messages) % 2 != 0:
@@ -245,10 +243,8 @@ def tokenize_rlhf(
245243

246244
for idx, mess in enumerate(context):
247245
if mess["from"] != template.roles[idx % 2]:
248-
raise ValueError(
249-
f"Message should iterate between user and assistant and starts with a \
250-
line from the user. Got the following data:\n{context}"
251-
)
246+
raise ValueError(f"Message should iterate between user and assistant and starts with a \
247+
line from the user. Got the following data:\n{context}")
252248
template.append_message(mess["from"], mess["content"])
253249

254250
if len(template.messages) % 2 != 1:
@@ -272,18 +268,14 @@ def tokenize_rlhf(
272268
rejected_continuation = data_point["rejected"]
273269
for round in range(len(chosen_continuation)):
274270
if chosen_continuation[round]["from"] != template.roles[(round + 1) % 2]:
275-
raise ValueError(
276-
f"Message should iterate between user and assistant and starts with a \
277-
line from the user. Got the following data:\n{chosen_continuation}"
278-
)
271+
raise ValueError(f"Message should iterate between user and assistant and starts with a \
272+
line from the user. Got the following data:\n{chosen_continuation}")
279273
chosen.append_message(chosen_continuation[round]["from"], chosen_continuation[round]["content"])
280274

281275
for round in range(len(rejected_continuation)):
282276
if rejected_continuation[round]["from"] != template.roles[(round + 1) % 2]:
283-
raise ValueError(
284-
f"Message should iterate between user and assistant and starts with a \
285-
line from the user. Got the following data:\n{rejected_continuation}"
286-
)
277+
raise ValueError(f"Message should iterate between user and assistant and starts with a \
278+
line from the user. Got the following data:\n{rejected_continuation}")
287279
rejected.append_message(rejected_continuation[round]["from"], rejected_continuation[round]["content"])
288280

289281
(
@@ -296,14 +288,14 @@ def tokenize_rlhf(
296288
) = (None, None, None, None, None, None)
297289

298290
chosen_data_packed = apply_rlhf_data_format(chosen, tokenizer)
299-
(chosen_input_ids, chosen_loss_mask, chosen_label_decode) = (
291+
chosen_input_ids, chosen_loss_mask, chosen_label_decode = (
300292
chosen_data_packed["input_ids"],
301293
chosen_data_packed["loss_mask"],
302294
chosen_data_packed["label_decode"],
303295
)
304296

305297
rejected_data_packed = apply_rlhf_data_format(rejected, tokenizer)
306-
(rejected_input_ids, rejected_loss_mask, rejected_label_decode) = (
298+
rejected_input_ids, rejected_loss_mask, rejected_label_decode = (
307299
rejected_data_packed["input_ids"],
308300
rejected_data_packed["loss_mask"],
309301
rejected_data_packed["label_decode"],

applications/ColossalChat/coati/distributed/reward/reward_fn.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
https://github.com/volcengine/verl
1818
"""
1919

20-
2120
import json
2221

2322
import torch

applications/ColossalChat/coati/trainer/kto.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def _train(self, epoch: int):
130130
)
131131
for i, batch in enumerate(self.train_dataloader):
132132
batch = to_device(batch, self.device)
133-
(input_ids, attention_mask, loss_mask, label, kl_input_ids, kl_attention_mask, kl_loss_mask) = (
133+
input_ids, attention_mask, loss_mask, label, kl_input_ids, kl_attention_mask, kl_loss_mask = (
134134
batch["input_ids"],
135135
batch["attention_mask"],
136136
batch["loss_mask"],
@@ -279,7 +279,7 @@ def _eval(self, epoch: int):
279279
)
280280
for i, batch in enumerate(self.train_dataloader):
281281
batch = to_device(batch, self.device)
282-
(input_ids, attention_mask, loss_mask, label, kl_input_ids, kl_attention_mask, kl_loss_mask) = (
282+
input_ids, attention_mask, loss_mask, label, kl_input_ids, kl_attention_mask, kl_loss_mask = (
283283
batch["input_ids"],
284284
batch["attention_mask"],
285285
batch["loss_mask"],

applications/ColossalChat/examples/community/ray/train_prompts_on_ray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def _init_optimizer(self):
120120
def _prepare_model_with_strategy(self, has_optimizer: bool):
121121
if has_optimizer:
122122
self._init_optimizer()
123-
(self._model, self._optimizer) = self._strategy.prepare((self._model, self._optimizer))
123+
self._model, self._optimizer = self._strategy.prepare((self._model, self._optimizer))
124124
else:
125125
self._model = self._strategy.prepare(self._model)
126126

applications/ColossalQA/examples/webui_demo/webui.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,11 @@ def restart(chatbot, txt):
8181
)
8282
with gr.Row():
8383
btn = gr.UploadButton("📁", file_types=["file"], file_count="multiple", size="sm")
84-
restart_btn = gr.Button(str("\u21BB"), elem_id="restart-btn", scale=1)
84+
restart_btn = gr.Button(str("\u21bb"), elem_id="restart-btn", scale=1)
8585
txt = gr.Textbox(
8686
scale=8,
8787
show_label=False,
88-
placeholder="Enter text and press enter, or use 📁 to upload files, click \u21BB to clear loaded files and restart chat",
88+
placeholder="Enter text and press enter, or use 📁 to upload files, click \u21bb to clear loaded files and restart chat",
8989
container=True,
9090
autofocus=True,
9191
)

colossalai/auto_parallel/tensor_shard/solver/solver.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""This code is adapted from Alpa
2-
https://github.com/alpa-projects/alpa/
3-
with some changes. """
2+
https://github.com/alpa-projects/alpa/
3+
with some changes."""
44

55
import multiprocessing
66
import time

colossalai/autochunk/select_chunk.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def _select_min_memory_chunk_region(self, possible_chunk_regions, chunk_infos):
176176
return best_region
177177

178178
def _is_legal_region(self, cur_chunk_info, chunk_infos):
179-
(chunk_region_start, chunk_region_end) = cur_chunk_info["region"]
179+
chunk_region_start, chunk_region_end = cur_chunk_info["region"]
180180
if cur_chunk_info in chunk_infos:
181181
return False
182182
if chunk_region_end < chunk_region_start:

colossalai/booster/plugin/gemini_plugin.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -338,10 +338,8 @@ def load_sharded_optimizer(
338338
# Load param_groups.
339339
param_group_path = ckpt_index_file.get_param_group_filename()
340340
if param_group_path is None:
341-
raise RuntimeError(
342-
f"Invalid index file path {checkpoint_index_file} for an optimizer. \
343-
Lacking param group file under current directory."
344-
)
341+
raise RuntimeError(f"Invalid index file path {checkpoint_index_file} for an optimizer. \
342+
Lacking param group file under current directory.")
345343
saved_param_groups = torch.load(param_group_path)
346344
optimizer.load_param_groups(saved_param_groups)
347345

colossalai/booster/plugin/low_level_zero_plugin.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -268,10 +268,8 @@ def load_sharded_optimizer(
268268
# Load param_groups
269269
param_group_path = ckpt_index_file.get_param_group_filename()
270270
if param_group_path is None:
271-
raise RuntimeError(
272-
f"Invalid index file path {index_file_path} for an optimizer. \
273-
Lacking param group file under current directory."
274-
)
271+
raise RuntimeError(f"Invalid index file path {index_file_path} for an optimizer. \
272+
Lacking param group file under current directory.")
275273
id_map = load_param_groups_into_optimizer(optimizer, param_group_path)
276274

277275
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()

0 commit comments

Comments
 (0)