Skip to content

Commit 75f5c66

Browse files
committed
Update training CLI.
1 parent 2e3b973 commit 75f5c66

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

src/llm/train.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Model training module."""
22

3+
import argparse
34
import logging
45
import os
56
import sys
@@ -131,6 +132,20 @@ def validate(model: LLM, dataloader: DataLoader, device: torch.device) -> float:
131132

132133
def main():
133134
"""Train and validate model."""
135+
parser = argparse.ArgumentParser(description="Trigger an LLM training.")
136+
137+
cfg_path = Path.cwd() / "src/llm/configs/llm_3.55b.yaml"
138+
parser.add_argument(
139+
"--config-path", type=Path, help="Path of LLM config", default=cfg_path
140+
)
141+
parser.add_argument(
142+
"--data-percentage", type=int, help="Use the first X% of data.", default=5
143+
)
144+
145+
args = parser.parse_args()
146+
147+
cfg_path, data_percentage = args.config_path, args.data_percentage
148+
134149
log.info("Starting training process...")
135150

136151
log.info("Dataset is going to be downloaded...")
@@ -140,7 +155,7 @@ def main():
140155
dataset = load_dataset(
141156
"bigcode/starcoderdata",
142157
data_dir="python",
143-
split="train",
158+
split=f"train[:{data_percentage}%]",
144159
token=os.getenv("HF_TOKEN"),
145160
)
146161
log.info(f"Dataset loaded successfully with {len(dataset)} samples")
@@ -258,4 +273,4 @@ def main():
258273

259274

260275
if __name__ == "__main__":
261-
main()
276+
main(cfg_path=args.config_path, data_percentage=args.data_percentage)

0 commit comments

Comments
 (0)