4242 "num_iterations" : 20 ,
4343 "use_pipeline" : False
4444 },
45+ "gost_example" : {
46+ "topic_count" : 50 ,
47+ "alg_name" : "ga" ,
48+ "num_iterations" : 10 ,
49+ "use_pipeline" : True ,
50+ "dataset" : {
51+ "lang" : "ru" ,
52+ "col_to_process" : 'paragraph' ,
53+ "dataset_path" : "data/sample_corpora/clean_docs_v17_gost_only.csv" ,
54+ "dataset_name" : "gost"
55+ },
56+ # "individual_type": "llm"
57+ },
4558 "surrogate" : {
4659 "alg_name" : "ga" ,
4760 "num_iterations" : 20 ,
6174}
6275
6376
64- def run (alg_name : str , alg_params : Dict [str , Any ], dataset : Optional [Dict [str , Any ]] = None ):
77+ def run (alg_name : str ,
78+ alg_params : Dict [str , Any ],
79+ dataset : Optional [Dict [str , Any ]] = None ,
80+ col_to_process : Optional [str ] = None ,
81+ topic_count : int = 20 ):
6582 if not dataset :
6683 dataset = {
6784 "lang" : "ru" ,
@@ -76,10 +93,11 @@ def run(alg_name: str, alg_params: Dict[str, Any], dataset: Optional[Dict[str, A
7693 model_path = os .path .join (working_dir_path , "autotm_model" )
7794
7895 autotm = AutoTM (
79- topic_count = 20 ,
96+ topic_count = topic_count ,
8097 preprocessing_params = {
81- "lang" : dataset ['lang' ],
98+ "lang" : dataset ['lang' ]
8299 },
100+ texts_column_name = col_to_process ,
83101 alg_name = alg_name ,
84102 alg_params = alg_params ,
85103 working_dir_path = working_dir_path ,
@@ -110,12 +128,25 @@ def main(conf_name: str = "base"):
110128 del conf ['alg_name' ]
111129
112130 dataset = None
131+ col_to_process = None
113132 if 'dataset' in conf :
114133 dataset = conf ['dataset' ]
134+ col_to_process = conf ['dataset' ].get ('col_to_process' , None )
115135 del conf ['dataset' ]
116136
117- run (alg_name = alg_name , alg_params = conf , dataset = dataset )
137+ topic_count = 20
138+ if 'topic_count' in conf :
139+ topic_count = conf ['topic_count' ]
140+ del conf ['topic_count' ]
141+
142+ run (
143+ alg_name = alg_name ,
144+ alg_params = conf ,
145+ dataset = dataset ,
146+ col_to_process = col_to_process ,
147+ topic_count = topic_count
148+ )
118149
119150
120151if __name__ == "__main__" :
121- main (conf_name = "base_en " )
152+ main (conf_name = "gost_example " )
0 commit comments