Skip to content

Commit 11da407

Browse files
committed
CU-8696kd70p improve the eval mode and add plot on metrics
1 parent e0e015b commit 11da407

File tree

6 files changed

+135
-65
lines changed

6 files changed

+135
-65
lines changed

app/api/routers/supervised_training.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ async def train_supervised(request: Request,
2929
trainer_export: Annotated[List[UploadFile], File(description="One or more trainer export files to be uploaded")],
3030
epochs: Annotated[int, Query(description="The number of training epochs", ge=0)] = 1,
3131
lr_override: Annotated[Union[float, None], Query(description="The override of the initial learning rate", gt=0.0)] = None,
32-
test_size: Annotated[Union[float, None], Query(description="The override of the test size in percentage", ge=0.0)] = 0.2,
32+
test_size: Annotated[Union[float, None], Query(description="The override of the test size in percentage. (For a 'huggingface-ner' model, a negative value can be used to apply the train-validation-test split if implicitly defined in trainer export: 'projects[0]' is used for training, 'projects[1]' for validation, and 'projects[2]' for testing)")] = 0.2,
3333
log_frequency: Annotated[int, Query(description="The number of processed documents after which training metrics will be logged", ge=1)] = 1,
3434
description: Annotated[Union[str, None], Form(description="The description of the training or change logs")] = None,
3535
model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> JSONResponse:

app/management/tracker_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def send_model_stats(stats: Dict, step: int) -> None:
7272
mlflow.log_metrics(metrics, step)
7373

7474
@staticmethod
75-
def send_hf_training_logs(logs: Dict, step: int) -> None:
75+
def send_hf_metrics_logs(logs: Dict, step: int) -> None:
7676
mlflow.log_metrics(logs, step)
7777

7878
@staticmethod

app/model_services/huggingface_ner_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ def from_model(cls, model: PreTrainedModel, tokenizer: PreTrainedTokenizerBase)
7979
model_service.tokenizer = tokenizer
8080
_pipeline = partial(pipeline,
8181
task="ner",
82-
model=model,
83-
tokenizer=tokenizer,
82+
model=model_service.model,
83+
tokenizer=model_service.tokenizer,
8484
stride=10,
8585
aggregation_strategy=get_settings().HF_PIPELINE_AGGREGATION_STRATEGY)
8686
if non_default_device_is_available(get_settings().DEVICE):
@@ -139,7 +139,7 @@ def annotate(self, text: str) -> Dict:
139139
df = pd.DataFrame(columns=["label_name", "label_id", "start", "end", "accuracy"])
140140
else:
141141
for idx, row in df.iterrows():
142-
df.loc[idx, "label_id"] = str(self._model.config.label2id[row["entity_group"]])
142+
df.loc[idx, "label_id"] = row["entity_group"]
143143
df.rename(columns={"entity_group": "label_name", "score": "accuracy"}, inplace=True)
144144
records = df.to_dict("records")
145145
return records

app/trainers/base.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,10 @@ def start_training(self,
7373
self._tracker_client.save_processed_artifact(data_file.name, self._model_name)
7474

7575
dataset = None
76-
if training_type == TrainingType.UNSUPERVISED.value and isinstance(data_file, TextIO):
76+
if training_type == TrainingType.UNSUPERVISED.value and isinstance(data_file, tempfile.TemporaryDirectory):
77+
dataset = datasets.load_from_disk(data_file.name)
78+
self._tracker_client.save_train_dataset(dataset)
79+
elif training_type == TrainingType.UNSUPERVISED.value:
7780
try:
7881
dataset = datasets.load_dataset(doc_dataset.__file__,
7982
data_files={"documents": data_file.name},
@@ -84,7 +87,7 @@ def start_training(self,
8487
finally:
8588
if dataset is not None:
8689
dataset.cleanup_cache_files()
87-
elif training_type == TrainingType.SUPERVISED.value and isinstance(data_file, TextIO):
90+
elif training_type == TrainingType.SUPERVISED.value:
8891
try:
8992
dataset = datasets.load_dataset(anno_dataset.__file__,
9093
data_files={"annotations": data_file.name},
@@ -95,9 +98,6 @@ def start_training(self,
9598
finally:
9699
if dataset is not None:
97100
dataset.cleanup_cache_files()
98-
elif training_type == TrainingType.UNSUPERVISED.value and isinstance(data_file, tempfile.TemporaryDirectory):
99-
dataset = datasets.load_from_disk(data_file.name)
100-
self._tracker_client.save_train_dataset(dataset)
101101
else:
102102
raise ValueError(f"Unknown training type: {training_type}")
103103

0 commit comments

Comments
 (0)