88from typing import Any
99
1010from torch import nn
11+ from trainer .io import get_user_data_dir
1112
1213from TTS .config import load_config
1314from TTS .utils .manage import ModelManager
@@ -81,6 +82,7 @@ def __init__(
8182 self .synthesizer : Synthesizer | None = None
8283 self .voice_converter : Synthesizer | None = None
8384 self .model_name = ""
85+ self .voice_dir = None
8486
8587 self .vocoder_path = vocoder_path
8688 self .vocoder_config_path = vocoder_config_path
@@ -93,6 +95,7 @@ def __init__(
9395 warnings .warn ("`gpu` will be deprecated. Please use `tts.to(device)` instead." )
9496
9597 if model_name is not None and len (model_name ) > 0 :
98+ self .voice_dir = get_user_data_dir ("tts" ) / model_name / "voices"
9699 if "tts_models" in model_name :
97100 self .load_tts_model_by_name (model_name , vocoder_name , gpu = gpu )
98101 elif "voice_conversion_models" in model_name :
@@ -158,22 +161,10 @@ def list_models() -> list[str]:
158161
159162 def download_model_by_name (
160163 self , model_name : str , vocoder_name : str | None = None
161- ) -> tuple [Path | None , Path | None , Path | None , Path | None , Path | None ]:
164+ ) -> tuple [Path | None , Path | None , Path | None , Path | None ]:
162165 model_path , config_path , model_item = self .manager .download_model (model_name )
163- if (
164- "fairseq" in model_name
165- or "openvoice" in model_name
166- or (
167- model_item is not None
168- and isinstance (model_item ["model_url" ], list )
169- and len (model_item ["model_url" ]) > 2
170- )
171- ):
172- # return model directory if there are multiple files
173- # we assume that the model knows how to load itself
174- return None , None , None , None , model_path
175166 if model_item .get ("default_vocoder" ) is None :
176- return model_path , config_path , None , None , None
167+ return model_path , config_path , None , None
177168 if vocoder_name is None :
178169 vocoder_name = model_item ["default_vocoder" ]
179170 vocoder_path , vocoder_config_path = None , None
@@ -183,7 +174,7 @@ def download_model_by_name(
183174 vocoder_config_path = self .vocoder_config_path
184175 if vocoder_path is None or vocoder_config_path is None :
185176 vocoder_path , vocoder_config_path , _ = self .manager .download_model (vocoder_name )
186- return model_path , config_path , vocoder_path , vocoder_config_path , None
177+ return model_path , config_path , vocoder_path , vocoder_config_path
187178
188179 def load_model_by_name (self , model_name : str , vocoder_name : str | None = None , * , gpu : bool = False ) -> None :
189180 """Load one of the 🐸TTS models by name.
@@ -202,15 +193,15 @@ def load_vc_model_by_name(self, model_name: str, vocoder_name: str | None = None
202193 gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
203194 """
204195 self .model_name = model_name
205- model_path , config_path , vocoder_path , vocoder_config_path , model_dir = self .download_model_by_name (
196+ model_path , config_path , vocoder_path , vocoder_config_path = self .download_model_by_name (
206197 model_name , vocoder_name
207198 )
208199 self .voice_converter = Synthesizer (
209200 vc_checkpoint = model_path ,
210201 vc_config = config_path ,
211202 vocoder_checkpoint = vocoder_path ,
212203 vocoder_config = vocoder_config_path ,
213- model_dir = model_dir ,
204+ voice_dir = self . voice_dir ,
214205 use_cuda = gpu ,
215206 )
216207
@@ -225,7 +216,7 @@ def load_tts_model_by_name(self, model_name: str, vocoder_name: str | None = Non
225216 """
226217 self .model_name = model_name
227218
228- model_path , config_path , vocoder_path , vocoder_config_path , model_dir = self .download_model_by_name (
219+ model_path , config_path , vocoder_path , vocoder_config_path = self .download_model_by_name (
229220 model_name , vocoder_name
230221 )
231222
@@ -240,7 +231,7 @@ def load_tts_model_by_name(self, model_name: str, vocoder_name: str | None = Non
240231 vocoder_config = vocoder_config_path ,
241232 encoder_checkpoint = self .encoder_path ,
242233 encoder_config = self .encoder_config_path ,
243- model_dir = model_dir ,
234+ voice_dir = self . voice_dir ,
244235 use_cuda = gpu ,
245236 )
246237
@@ -266,6 +257,7 @@ def load_tts_model_by_path(self, model_path: str, config_path: str, *, gpu: bool
266257 encoder_config = self .encoder_config_path ,
267258 use_cuda = gpu ,
268259 )
260+ self .voice_dir = self .synthesizer .voice_dir
269261
270262 def _check_arguments (
271263 self ,
0 commit comments