@@ -133,7 +133,7 @@ def clear_rwkv_state(self):
133133def Llama (model_path : str , strategy : str ) -> AbstractLlama :
134134 model_path = get_model_path (model_path )
135135
136- from llama_cpp import Llama
136+ from llama_cpp import Llama as LlamaCpp
137137
138138 filename , _ = os .path .splitext (os .path .basename (model_path ))
139139 n_ctx = 8192
@@ -142,18 +142,29 @@ def Llama(model_path: str, strategy: str) -> AbstractLlama:
142142 except :
143143 pass
144144
145- model = Llama (
146- model_path , n_gpu_layers = - 1 if "cpu" not in strategy else 0 , n_ctx = n_ctx
147- )
148-
149- # Only patch generate function if it is an RWKV model
150- if "rwkv" in filename .lower ():
151- original_generate = model .generate
152- def rwkv_generate (tokens , ** kwargs ):
153- kwargs ['reset' ] = False
154- return original_generate (tokens , ** kwargs )
155- model .generate = rwkv_generate
145+ # Check if this is an RWKV model
146+ is_rwkv = "rwkv" in filename .lower ()
156147
148+ if is_rwkv :
149+ # RWKV models need reset=False to maintain sequential RNN state
150+ class RWKVLlama (LlamaCpp ):
151+ """Llama wrapper that forces reset=False for RWKV's sequential state"""
152+ def generate (self , tokens , reset = False , ** kwargs ):
153+ # Always use reset=False for RWKV to avoid state position mismatches
154+ return super ().generate (tokens , reset = False , ** kwargs )
155+
156+ model = RWKVLlama (
157+ model_path ,
158+ n_gpu_layers = - 1 if "cpu" not in strategy else 0 ,
159+ n_ctx = n_ctx
160+ )
161+ else :
162+ model = LlamaCpp (
163+ model_path ,
164+ n_gpu_layers = - 1 if "cpu" not in strategy else 0 ,
165+ n_ctx = n_ctx
166+ )
167+
157168 llama : AbstractLlama
158169 llama = TextLlama (model )
159170 llama .name = filename
0 commit comments