11"""Model training module."""
22
3+ import argparse
34import logging
45import os
56import sys
@@ -131,6 +132,20 @@ def validate(model: LLM, dataloader: DataLoader, device: torch.device) -> float:
131132
132133def main ():
133134 """Train and validate model."""
135+ parser = argparse .ArgumentParser (description = "Trigger an LLM training." )
136+
137+ cfg_path = Path .cwd () / "src/llm/configs/llm_3.55b.yaml"
138+ parser .add_argument (
139+ "--config-path" , type = Path , help = "Path of LLM config" , default = cfg_path
140+ )
141+ parser .add_argument (
142+ "--data-percentage" , type = int , help = "Use the first X% of data." , default = 5
143+ )
144+
145+ args = parser .parse_args ()
146+
147+ cfg_path , data_percentage = args .config_path , args .data_percentage
148+
134149 log .info ("Starting training process..." )
135150
136151 log .info ("Dataset is going to be downloaded..." )
@@ -140,7 +155,7 @@ def main():
140155 dataset = load_dataset (
141156 "bigcode/starcoderdata" ,
142157 data_dir = "python" ,
143- split = "train" ,
158+ split = f "train[: { data_percentage } %] " ,
144159 token = os .getenv ("HF_TOKEN" ),
145160 )
146161 log .info (f"Dataset loaded successfully with { len (dataset )} samples" )
@@ -258,4 +273,4 @@ def main():
258273
259274
260275if __name__ == "__main__" :
261- main ()
276+ main (cfg_path = args . config_path , data_percentage = args . data_percentage )
0 commit comments