Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 41 additions & 94 deletions dingo/exec/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ def __init__(self, input_args: InputArgs):
self.input_args: InputArgs = input_args
self.llm: Optional[BaseLLM] = None
self.summary: SummaryModel = SummaryModel()
self.bad_info_list: List[ResultInfo] = []
self.good_info_list: List[ResultInfo] = []

def load_data(self) -> Generator[MetaData, None, None]:
"""
Expand Down Expand Up @@ -68,19 +66,11 @@ def execute(self) -> List[SummaryModel]:
eval_group=group_name,
input_path=input_path,
output_path=output_path if self.input_args.save_data else '',
create_time=create_time,
score=0,
num_good=0,
num_bad=0,
total=0,
type_ratio={},
name_ratio={}
create_time=create_time
)
self.evaluate()
self.summary = self.summarize(self.summary)
self.summary.finish_time = time.strftime('%Y%m%d_%H%M%S', time.localtime())
if self.input_args.save_data:
self.save_data(output_path, self.input_args, self.bad_info_list, self.good_info_list, self.summary)
self.write_summary(self.summary.output_path, self.input_args, self.summary)

return [self.summary]

Expand All @@ -98,8 +88,6 @@ def evaluate(self):
pbar = tqdm(total=None, unit='items')

def process_batch(batch: List):
save_flag = False

futures=[]
for group_type, group in Model.get_group(self.input_args.eval_group).items():
if group_type == 'rule':
Expand All @@ -111,46 +99,19 @@ def process_batch(batch: List):

for future in concurrent.futures.as_completed(futures):
result_info = future.result()
# calculate summary ratio
for t in result_info.type_list:
self.summary.type_ratio[t] += 1
for n in result_info.name_list:
self.summary.name_ratio[n] += 1
if result_info.error_status:
self.bad_info_list.append(result_info)
self.summary.num_bad += 1
for t in result_info.type_list:
if t not in self.summary.type_ratio:
self.summary.type_ratio[t] = 1
else:
self.summary.type_ratio[t] += 1
for n in result_info.name_list:
if n not in self.summary.name_ratio:
self.summary.name_ratio[n] = 1
else:
self.summary.name_ratio[n] += 1
else:
if self.input_args.save_correct:
self.good_info_list.append(result_info)
for t in result_info.type_list:
if t not in self.summary.type_ratio:
self.summary.type_ratio[t] = 1
else:
self.summary.type_ratio[t] += 1
for n in result_info.name_list:
if n not in self.summary.name_ratio:
self.summary.name_ratio[n] = 1
else:
self.summary.name_ratio[n] += 1
self.summary.num_good += 1
self.summary.total += 1
if self.summary.total % self.input_args.interval_size == 0:
save_flag = True

self.write_single_data(self.summary.output_path, self.input_args, result_info)
pbar.update()
# save data in file
if self.input_args.save_data:
if save_flag:
tmp_summary = self.summarize(self.summary)
tmp_summary.finish_time = time.strftime('%Y%m%d_%H%M%S', time.localtime())
tmp_output_path = self.summary.output_path
self.save_data(tmp_output_path, self.input_args, self.bad_info_list, self.good_info_list, tmp_summary)
self.bad_info_list = []
self.good_info_list = []
self.write_summary(self.summary.output_path, self.input_args, self.summarize(self.summary))
while True:
batch = list(itertools.islice(data_iter, self.input_args.batch_size))
if not batch:
Expand Down Expand Up @@ -270,9 +231,9 @@ def evaluate_prompt(self, group: List[BasePrompt], d: MetaData) -> ResultInfo:

def summarize(self, summary: SummaryModel) -> SummaryModel:
new_summary = copy.deepcopy(summary)
new_summary.finish_time = time.strftime('%Y%m%d_%H%M%S', time.localtime())
if new_summary.total == 0:
return new_summary
new_summary.num_good = new_summary.total - new_summary.num_bad
new_summary.score = round(new_summary.num_good / new_summary.total * 100, 2)
for t in new_summary.type_ratio:
new_summary.type_ratio[t] = round(new_summary.type_ratio[t] / new_summary.total, 6)
Expand All @@ -282,52 +243,38 @@ def summarize(self, summary: SummaryModel) -> SummaryModel:
new_summary.name_ratio = dict(sorted(new_summary.name_ratio.items()))
return new_summary

def get_summary(self):
return self.summary
def write_single_data(self, path: str, input_args: InputArgs, result_info: ResultInfo):
if not input_args.save_data:
return

def get_bad_info_list(self):
return self.bad_info_list

def get_good_info_list(self):
return self.good_info_list
if not input_args.save_correct and not result_info.error_status:
return

def save_data(
self,
path: str,
input_args: InputArgs,
bad_info_list: List[ResultInfo],
good_info_list: List[ResultInfo],
summary: SummaryModel,
):
for result_info in bad_info_list:
for new_name in result_info.name_list:
t = str(new_name).split('-')[0]
n = str(new_name).split('-')[1]
p_t = os.path.join(path, t)
if not os.path.exists(p_t):
os.makedirs(p_t)
f_n = os.path.join(path, t, n) + ".jsonl"
with open(f_n, 'a', encoding='utf-8') as f:
if input_args.save_raw:
str_json = json.dumps(result_info.to_raw_dict(), ensure_ascii=False)
else:
str_json = json.dumps(result_info.to_dict(), ensure_ascii=False)
f.write(str_json + '\n')
if input_args.save_correct:
for result_info in good_info_list:
for new_name in result_info.name_list:
t = str(new_name).split('-')[0]
n = str(new_name).split('-')[1]
p_t = os.path.join(path, t)
if not os.path.exists(p_t):
os.makedirs(p_t)
f_n = os.path.join(path, t, n) + ".jsonl"
with open(f_n, 'a', encoding='utf-8') as f:
if input_args.save_raw:
str_json = json.dumps(result_info.to_raw_dict(), ensure_ascii=False)
else:
str_json = json.dumps(result_info.to_dict(), ensure_ascii=False)
f.write(str_json + '\n')
for new_name in result_info.name_list:
t = str(new_name).split('-')[0]
n = str(new_name).split('-')[1]
p_t = os.path.join(path, t)
if not os.path.exists(p_t):
os.makedirs(p_t)
f_n = os.path.join(path, t, n) + ".jsonl"
with open(f_n, 'a', encoding='utf-8') as f:
if input_args.save_raw:
str_json = json.dumps(result_info.to_raw_dict(), ensure_ascii=False)
else:
str_json = json.dumps(result_info.to_dict(), ensure_ascii=False)
f.write(str_json + '\n')

def write_summary(self, path: str, input_args: InputArgs, summary: SummaryModel):
if not input_args.save_data:
return
with open(path + '/summary.json', 'w', encoding='utf-8') as f:
json.dump(summary.to_dict(), f, indent=4, ensure_ascii=False)

def get_summary(self):
pass

def get_bad_info_list(self):
pass

def get_good_info_list(self):
pass
5 changes: 0 additions & 5 deletions dingo/io/input/InputArgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ class InputArgs(BaseModel):
# Resume settings
start_index: int = 0
end_index: int = -1
interval_size: int = 1000

# Concurrent settings
max_workers: int = 1
Expand Down Expand Up @@ -89,10 +88,6 @@ def check_args(self):
if self.end_index >= 0 and self.end_index < self.start_index:
raise ValueError("if end_index is non negative, end_index must be greater than start_index")

# check interval size
if self.interval_size <= 0:
raise ValueError("interval_size must be positive.")

# check max workers
if self.max_workers <= 0:
raise ValueError("max_workers must be a positive integer.")
Expand Down
7 changes: 4 additions & 3 deletions dingo/io/output/SummaryModel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections import defaultdict
from typing import Dict, List

from pydantic import BaseModel
from pydantic import BaseModel, Field


class SummaryModel(BaseModel):
Expand All @@ -15,8 +16,8 @@ class SummaryModel(BaseModel):
num_good: int = 0
num_bad: int = 0
total: int = 0
type_ratio: Dict[str, float] = {}
name_ratio: Dict[str, float] = {}
type_ratio: Dict[str, int] = Field(default_factory=lambda: defaultdict(int))
name_ratio: Dict[str, int] = Field(default_factory=lambda: defaultdict(int))

def to_dict(self):
return {
Expand Down
4 changes: 0 additions & 4 deletions dingo/run/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ def parse_args():
default=None, help="The number of data start to check.")
parser.add_argument("--end_index", type=int,
default=None, help="The number of data end to check.")
parser.add_argument("--interval_size", type=int,
default=None, help="The number of size to save while checking.")
parser.add_argument("--max_workers", type=int,
default=None, help="The number of max workers to concurrent check. ")
parser.add_argument("--batch_size", type=int,
Expand Down Expand Up @@ -112,8 +110,6 @@ def parse_args():
input_data['start_index'] = args.start_index
if args.end_index:
input_data['end_index'] = args.end_index
if args.interval_size:
input_data['interval_size'] = args.interval_size
if args.max_workers:
input_data['max_workers'] = args.max_workers
if args.batch_size:
Expand Down
2 changes: 0 additions & 2 deletions docs/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
| --save_raw | bool | False | No | whether save raw data. |
| --start_index | int | 0 | No | the number of data start to check. |
| --end_index | int | -1 | No | the number of data end to check. if it's negative, include the data from start_index to end. |
| --interval_size | int | 1000 | No | the number of size to save while checking. |
| --max_workers | int | 1 | No | the number of max workers to concurrent check. |
| --batch_size | int | 1 | No | the number of max data for concurrent check. |
| --dataset | str | "hugging_face" | Yes | dataset type, in ['hugging_face', 'local'] |
Expand Down Expand Up @@ -46,7 +45,6 @@
| save_raw | bool | False | No | whether save raw data. |
| start_index | int | 0 | No | the number of data start to check. |
| end_index | int | -1 | No | the number of data end to check. if it's negative, include the data from start_index to end. |
| interval_size | int | 1000 | No | the number of size to save while checking. |
| max_workers | int | 1 | No | the number of max workers to concurrent check. |
| batch_size | int | 1 | No | the number of max data for concurrent check. |
| dataset | str | "hugging_face" | Yes | dataset type, in ['hugging_face', 'local'] |
Expand Down
30 changes: 30 additions & 0 deletions test/scripts/test_write.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import os
import shutil

import pytest
from dingo.exec import Executor
from dingo.io import InputArgs


class TestWrite:
def test_write_local_jsonl(self):
input_args = InputArgs(**{
"eval_group": "qa_standard_v1",
"input_path": "../data/test_local_jsonl.jsonl",
"save_data": True,
"save_correct": True,
"dataset": "local",
"data_format": "jsonl",
"column_id": "id",
"column_content": "content",
})
executor = Executor.exec_map["local"](input_args)
result = executor.execute()
# print(result)
output_path = result[0].output_path
assert os.path.exists(output_path)
shutil.rmtree('outputs')


if __name__ == '__main__':
pytest.main(["-s", "-q"])