Skip to content

Commit d0b524a

Browse files
committed
add new example
1 parent d8af632 commit d0b524a

File tree

1 file changed

+36
-5
lines changed

1 file changed

+36
-5
lines changed

examples/examples_autotm_fit_predict.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,19 @@
4242
"num_iterations": 20,
4343
"use_pipeline": False
4444
},
45+
"gost_example": {
46+
"topic_count": 50,
47+
"alg_name": "ga",
48+
"num_iterations": 10,
49+
"use_pipeline": True,
50+
"dataset": {
51+
"lang": "ru",
52+
"col_to_process": 'paragraph',
53+
"dataset_path": "data/sample_corpora/clean_docs_v17_gost_only.csv",
54+
"dataset_name": "gost"
55+
},
56+
# "individual_type": "llm"
57+
},
4558
"surrogate": {
4659
"alg_name": "ga",
4760
"num_iterations": 20,
@@ -61,7 +74,11 @@
6174
}
6275

6376

64-
def run(alg_name: str, alg_params: Dict[str, Any], dataset: Optional[Dict[str, Any]] = None):
77+
def run(alg_name: str,
78+
alg_params: Dict[str, Any],
79+
dataset: Optional[Dict[str, Any]] = None,
80+
col_to_process: Optional[str] = None,
81+
topic_count: int = 20):
6582
if not dataset:
6683
dataset = {
6784
"lang": "ru",
@@ -76,10 +93,11 @@ def run(alg_name: str, alg_params: Dict[str, Any], dataset: Optional[Dict[str, A
7693
model_path = os.path.join(working_dir_path, "autotm_model")
7794

7895
autotm = AutoTM(
79-
topic_count=20,
96+
topic_count=topic_count,
8097
preprocessing_params={
81-
"lang": dataset['lang'],
98+
"lang": dataset['lang']
8299
},
100+
texts_column_name=col_to_process,
83101
alg_name=alg_name,
84102
alg_params=alg_params,
85103
working_dir_path=working_dir_path,
@@ -110,12 +128,25 @@ def main(conf_name: str = "base"):
110128
del conf['alg_name']
111129

112130
dataset = None
131+
col_to_process = None
113132
if 'dataset' in conf:
114133
dataset = conf['dataset']
134+
col_to_process = conf['dataset'].get('col_to_process', None)
115135
del conf['dataset']
116136

117-
run(alg_name=alg_name, alg_params=conf, dataset=dataset)
137+
topic_count = 20
138+
if 'topic_count' in conf:
139+
topic_count = conf['topic_count']
140+
del conf['topic_count']
141+
142+
run(
143+
alg_name=alg_name,
144+
alg_params=conf,
145+
dataset=dataset,
146+
col_to_process=col_to_process,
147+
topic_count=topic_count
148+
)
118149

119150

120151
if __name__ == "__main__":
121-
main(conf_name="base_en")
152+
main(conf_name="gost_example")

0 commit comments

Comments
 (0)