@@ -145,21 +145,36 @@ def forward(self, x):
145145 return result_dict
146146
147147 def _forward_type_2 (self , x ):
148- output = self .lstm_model (x ["ag" ])
148+ output = x ["ag" ]
149+ for layer in self .lstm_model :
150+ output = layer (output )
151+ if isinstance (output , tuple ):
152+ output = output [0 ]
153+
149154 eta_pred = output [:, :, 0 : self .output_size ]
150155 eta_dot_pred = output [:, :, self .output_size : 2 * self .output_size ]
151156 g_pred = output [:, :, 2 * self .output_size :]
152157
153158 # for ag_c
154- output_c = self .lstm_model (x ["ag_c" ])
159+ output_c = x ["ag_c" ]
160+ for layer in self .lstm_model :
161+ output_c = layer (output_c )
162+ if isinstance (output_c , tuple ):
163+ output_c = output_c [0 ]
164+
155165 eta_pred_c = output_c [:, :, 0 : self .output_size ]
156166 eta_dot_pred_c = output_c [:, :, self .output_size : 2 * self .output_size ]
157167 g_pred_c = output_c [:, :, 2 * self .output_size :]
158168 eta_t_pred_c = paddle .matmul (x ["phi" ], eta_pred_c )
159169 eta_tt_pred_c = paddle .matmul (x ["phi" ], eta_dot_pred_c )
160170 eta_dot1_pred_c = eta_dot_pred_c [:, :, 0 :1 ]
161171 tmp = paddle .concat ([eta_pred_c , eta_dot1_pred_c , g_pred_c ], 2 )
162- f = self .lstm_model_f (tmp )
172+ f = tmp
173+ for layer in self .lstm_model_f :
174+ f = layer (f )
175+ if isinstance (f , tuple ):
176+ f = f [0 ]
177+
163178 lift_pred_c = eta_tt_pred_c + f
164179
165180 return {
@@ -173,12 +188,22 @@ def _forward_type_2(self, x):
173188
174189 def _forward_type_3 (self , x ):
175190 # physics informed neural networks
176- output = self .lstm_model (x ["ag" ])
191+ output = x ["ag" ]
192+ for layer in self .lstm_model :
193+ output = layer (output )
194+ if isinstance (output , tuple ):
195+ output = output [0 ]
196+
177197 eta_pred = output [:, :, 0 : self .output_size ]
178198 eta_dot_pred = output [:, :, self .output_size : 2 * self .output_size ]
179199 g_pred = output [:, :, 2 * self .output_size :]
180200
181- output_c = self .lstm_model (x ["ag_c" ])
201+ output_c = x ["ag_c" ]
202+ for layer in self .lstm_model :
203+ output_c = layer (output_c )
204+ if isinstance (output_c , tuple ):
205+ output_c = output_c [0 ]
206+
182207 eta_pred_c = output_c [:, :, 0 : self .output_size ]
183208 eta_dot_pred_c = output_c [:, :, self .output_size : 2 * self .output_size ]
184209 g_pred_c = output_c [:, :, 2 * self .output_size :]
@@ -187,11 +212,20 @@ def _forward_type_3(self, x):
187212 eta_tt_pred_c = paddle .matmul (x ["phi" ], eta_dot_pred_c )
188213 g_t_pred_c = paddle .matmul (x ["phi" ], g_pred_c )
189214
190- f = self .lstm_model_f (paddle .concat ([eta_pred_c , eta_dot_pred_c , g_pred_c ], 2 ))
215+ f = paddle .concat ([eta_pred_c , eta_dot_pred_c , g_pred_c ], 2 )
216+ for layer in self .lstm_model_f :
217+ f = layer (f )
218+ if isinstance (f , tuple ):
219+ f = f [0 ]
220+
191221 lift_pred_c = eta_tt_pred_c + f
192222
193223 eta_dot1_pred_c = eta_dot_pred_c [:, :, 0 :1 ]
194- g_dot_pred_c = self .lstm_model_g (paddle .concat ([eta_dot1_pred_c , g_pred_c ], 2 ))
224+ g_dot_pred_c = paddle .concat ([eta_dot1_pred_c , g_pred_c ], 2 )
225+ for layer in self .lstm_model_g :
226+ g_dot_pred_c = layer (g_dot_pred_c )
227+ if isinstance (g_dot_pred_c , tuple ):
228+ g_dot_pred_c = g_dot_pred_c [0 ]
195229
196230 return {
197231 "eta_pred" : eta_pred ,
0 commit comments