Skip to content

Commit 2b6a833

Browse files
A190nuxjosStorer
authored andcommitted
Cleaner fix for the state issue with llama.cpp
1 parent d3d171a commit 2b6a833

File tree

1 file changed

+23
-12
lines changed

1 file changed

+23
-12
lines changed

backend-python/utils/llama.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def clear_rwkv_state(self):
133133
def Llama(model_path: str, strategy: str) -> AbstractLlama:
134134
model_path = get_model_path(model_path)
135135

136-
from llama_cpp import Llama
136+
from llama_cpp import Llama as LlamaCpp
137137

138138
filename, _ = os.path.splitext(os.path.basename(model_path))
139139
n_ctx = 8192
@@ -142,18 +142,29 @@ def Llama(model_path: str, strategy: str) -> AbstractLlama:
142142
except:
143143
pass
144144

145-
model = Llama(
146-
model_path, n_gpu_layers=-1 if "cpu" not in strategy else 0, n_ctx=n_ctx
147-
)
148-
149-
# Only patch generate function if it is an RWKV model
150-
if "rwkv" in filename.lower():
151-
original_generate = model.generate
152-
def rwkv_generate(tokens, **kwargs):
153-
kwargs['reset'] = False
154-
return original_generate(tokens, **kwargs)
155-
model.generate = rwkv_generate
145+
# Check if this is an RWKV model
146+
is_rwkv = "rwkv" in filename.lower()
156147

148+
if is_rwkv:
149+
# RWKV models need reset=False to maintain sequential RNN state
150+
class RWKVLlama(LlamaCpp):
151+
"""Llama wrapper that forces reset=False for RWKV's sequential state"""
152+
def generate(self, tokens, reset=False, **kwargs):
153+
# Always use reset=False for RWKV to avoid state position mismatches
154+
return super().generate(tokens, reset=False, **kwargs)
155+
156+
model = RWKVLlama(
157+
model_path,
158+
n_gpu_layers=-1 if "cpu" not in strategy else 0,
159+
n_ctx=n_ctx
160+
)
161+
else:
162+
model = LlamaCpp(
163+
model_path,
164+
n_gpu_layers=-1 if "cpu" not in strategy else 0,
165+
n_ctx=n_ctx
166+
)
167+
157168
llama: AbstractLlama
158169
llama = TextLlama(model)
159170
llama.name = filename

0 commit comments

Comments
 (0)