-
-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathconfig.py
More file actions
115 lines (91 loc) · 3.7 KB
/
config.py
File metadata and controls
115 lines (91 loc) · 3.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import logging
from dotenv import load_dotenv
load_dotenv()
logger = logging.getLogger(__name__)
def _env_float(key, default):
"""Read a float from env, falling back to default on invalid values."""
try:
return float(os.environ.get(key, default))
except (ValueError, TypeError):
logger.warning("Invalid value for %s, using default %s", key, default)
return float(default)
def _env_int(key, default):
"""Read an int from env, falling back to default on invalid values."""
try:
return int(os.environ.get(key, default))
except (ValueError, TypeError):
logger.warning("Invalid value for %s, using default %s", key, default)
return int(default)
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
CKPT_DIR = os.environ.get("CKPT_DIR", os.path.join(BASE_DIR, "ckpt"))
OUTPUT_DIR = os.environ.get("OUTPUT_DIR", os.path.join(BASE_DIR, "output"))
HEARTMULGEN_REPO = "HeartMuLa/HeartMuLaGen"
HEARTMULGEN_FILES = ["tokenizer.json", "gen_config.json"]
# --- Model Variants ---
# Each variant has its own HuggingFace repo and local directory.
# The "version" string is passed to heartlib's from_pretrained() which constructs
# the model path as: CKPT_DIR/HeartMuLa-oss-{version}/
MODEL_VARIANTS = {
"hny": {
"name": "HeartMuLa 3B HNY",
"repo_id": "HeartMuLa/HeartMuLa-oss-3B-happy-new-year",
"local_dir": "HeartMuLa-oss-3B-happy-new-year",
"version": "3B-happy-new-year",
},
"rl": {
"name": "HeartMuLa 3B RL",
"repo_id": "HeartMuLa/HeartMuLa-RL-oss-3B-20260123",
"local_dir": "HeartMuLa-oss-3B-RL",
"version": "3B-RL",
},
"base": {
"name": "HeartMuLa 3B",
"repo_id": "HeartMuLa/HeartMuLa-oss-3B",
"local_dir": "HeartMuLa-oss-3B",
"version": "3B",
},
}
CODEC_MODEL = {
"name": "HeartCodec",
"repo_id": "HeartMuLa/HeartCodec-oss-20260123",
"local_dir": "HeartCodec-oss",
}
MODEL_VARIANT_LABELS = {
"hny": "HeartMuLa 3B HNY (Recommended)",
"rl": "HeartMuLa 3B RL",
"base": "HeartMuLa 3B (Base)",
}
DEFAULT_MODEL_VARIANT = os.environ.get("MODEL_VARIANT", "hny")
DEFAULT_GENERATION_PARAMS = {
"temperature": _env_float("MUSIC_TEMPERATURE", "1.0"),
"cfg_scale": _env_float("MUSIC_CFG_SCALE", "1.5"),
"topk": _env_int("MUSIC_TOPK", "50"),
"max_audio_length_ms": _env_int("MUSIC_MAX_LENGTH_SEC", "240") * 1000,
}
DEFAULT_NUM_VARIANTS = _env_int("MUSIC_NUM_VARIANTS", "1")
DEFAULT_LAZY_LOAD = os.environ.get("LAZY_LOAD", "true").lower() in ("1", "true", "yes")
DEFAULT_OLLAMA_URL = os.environ.get("OLLAMA_URL", "http://localhost:11434")
DEFAULT_OLLAMA_MODEL = os.environ.get("OLLAMA_MODEL", "glm-4.7-flash")
DEFAULT_OPENAI_MODEL = os.environ.get("OPENAI_MODEL", "gpt-4o-mini")
DEFAULT_OPENAI_URL = os.environ.get("OPENAI_URL", "https://api.openai.com/v1")
DEFAULT_OPENAI_KEY = os.environ.get("OPENAI_API_KEY", "")
DEFAULT_OPENAI_MODELS = [
m.strip()
for m in os.environ.get(
"OPENAI_MODELS", "gpt-4o-mini,gpt-4o,gpt-4.1-mini,gpt-4.1,o4-mini,gpt-5-mini,gpt-5,gpt-5.2"
).split(",")
if m.strip()
]
DEFAULT_LLM_BACKEND = os.environ.get("LLM_BACKEND", "Ollama")
DEFAULT_LLM_TEMPERATURE = _env_float("LLM_TEMPERATURE", "0.7")
DEFAULT_LLM_TIMEOUT = _env_int("LLM_TIMEOUT", "120")
STYLE_TRANSFER_ENABLED = os.environ.get("STYLE_TRANSFER", "true").lower() in ("1", "true", "yes")
TRANSCRIPTION_ENABLED = os.environ.get("TRANSCRIPTION", "true").lower() in ("1", "true", "yes")
TRANSCRIPTOR_MODEL = {
"name": "HeartTranscriptor",
"repo_id": "HeartMuLa/HeartTranscriptor-oss",
"local_dir": "HeartTranscriptor-oss",
}
SERVER_HOST = os.environ.get("SERVER_HOST", "127.0.0.1")
SERVER_PORT = _env_int("SERVER_PORT", "7860")