@@ -54,6 +54,15 @@ class CPreconditioner;
5454 * Absolute tolerance, target residual is tol*||b||. ---*/
5555enum class LinearToleranceType { RELATIVE, ABSOLUTE };
5656
57+ /* !
58+ * \brief Modes of using FGCRODR.
59+ * \ingroup SpLinSys
60+ */
61+ enum class FgcrodrMode {
62+ NORMAL, /* !< \brief Solve the linear system. */
63+ SAME_MAT, /* !< \brief "NORMAL" but knowing the matrix did not change. */
64+ };
65+
5766/* !
5867 * \class CSysSolve
5968 * \ingroup SpLinSys
@@ -83,8 +92,8 @@ class CSysSolve {
8392 ScalarType Residual = 1e-20 ; /* !< \brief Residual at the end of a call to Solve or Solve_b. */
8493 unsigned long Iterations = 0 ; /* !< \brief Iterations done in Solve or Solve_b. */
8594
86- LINEAR_SOLVER_MODE
87- lin_sol_mode; /* !< \brief Type of operation for the linear system solver, changes the source of solver options. */
95+ /* !< \brief Type of operation for the linear system solver, changes the source of solver options. */
96+ LINEAR_SOLVER_MODE lin_sol_mode;
8897
8998 mutable bool cg_ready; /* !< \brief Indicate if memory used by CG is allocated. */
9099 mutable bool bcg_ready; /* !< \brief Indicate if memory used by BCGSTAB is allocated. */
@@ -98,23 +107,24 @@ class CSysSolve {
98107 mutable VectorType r_0; /* !< \brief The "arbitrary" vector in BCGSTAB. */
99108 mutable VectorType v; /* !< \brief BCGSTAB "v" vector (v = A * M^-1 * p). */
100109
101- mutable std::vector<VectorType> W; /* !< \brief Large matrix used by FGMRES, w^i+1 = A * z^i. */
102- mutable std::vector<VectorType> Z; /* !< \brief Large matrix used by FGMRES, preconditioned W. */
103-
104- VectorType
105- LinSysSol_tmp; /* !< \brief Temporary used when it is necessary to interface between active and passive types. */
106- VectorType
107- LinSysRes_tmp; /* !< \brief Temporary used when it is necessary to interface between active and passive types. */
108- VectorType*
109- LinSysSol_ptr; /* !< \brief Pointer to appropriate LinSysSol (set to original or temporary in call to Solve). */
110- const VectorType*
111- LinSysRes_ptr; /* !< \brief Pointer to appropriate LinSysRes (set to original or temporary in call to Solve). */
112-
113- LinearToleranceType tol_type =
114- LinearToleranceType::ABSOLUTE; /* !< \brief How the linear solvers interpret the tolerance. */
115- bool xIsZero = false ; /* !< \brief If true assume the initial solution is always 0. */
116- bool recomputeRes = false ; /* !< \brief Recompute the residual after inner iterations, if monitoring. */
117- unsigned long monitorFreq = 10 ; /* !< \brief Monitoring frequency. */
110+ mutable unsigned long k = 0 ;
111+ mutable std::vector<VectorType> Z, V; /* !< \brief Large matrices used by FGMRES, v^i+1 = A * z^i. */
112+ mutable std::vector<VectorType> W, T; /* !< \brief Large matrices used by FGCRODR for deflation vectors. */
113+
114+ /* !< \brief Temporary used when it is necessary to interface between active and passive types. */
115+ VectorType LinSysSol_tmp;
116+ /* !< \brief Temporary used when it is necessary to interface between active and passive types. */
117+ VectorType LinSysRes_tmp;
118+ /* !< \brief Pointer to appropriate LinSysSol (set to original or temporary in call to Solve). */
119+ VectorType* LinSysSol_ptr;
120+ /* !< \brief Pointer to appropriate LinSysRes (set to original or temporary in call to Solve). */
121+ const VectorType* LinSysRes_ptr;
122+
123+ /* !< \brief How the linear solvers interpret the tolerance. */
124+ mutable LinearToleranceType tol_type = LinearToleranceType::ABSOLUTE;
125+ mutable bool xIsZero = false ; /* !< \brief If true assume the initial solution is always 0. */
126+ bool recomputeRes = false ; /* !< \brief Recompute the residual after inner iterations, if monitoring. */
127+ unsigned long monitorFreq = 10 ; /* !< \brief Monitoring frequency. */
118128
119129 /* !< \brief Inner solver for nested preconditioning. */
120130 std::unique_ptr<CSysSolve<ScalarType>> inner_solver;
@@ -225,72 +235,65 @@ class CSysSolve {
225235
226236 /* !
227237 * \brief Used by Solve for compatibility between passive and active CSysVector.
228- * \note Same type specialization, temporary variables are not required.
229- * \param[in] LinSysRes - Linear system residual
230- * \param[in,out] LinSysSol - Linear system solution
231- */
232- template <class OtherType , su2enable_if<std::is_same<ScalarType, OtherType>::value> = 0 >
233- void HandleTemporariesIn (const CSysVector<OtherType>& LinSysRes, CSysVector<OtherType>& LinSysSol) {
234- /* --- Set the pointers. ---*/
235- BEGIN_SU2_OMP_SAFE_GLOBAL_ACCESS {
236- LinSysRes_ptr = &LinSysRes;
237- LinSysSol_ptr = &LinSysSol;
238- }
239- END_SU2_OMP_SAFE_GLOBAL_ACCESS
240- }
241-
242- /* !
243- * \brief Used by Solve for compatibility between passive and active CSysVector.
244- * \note Different type specialization, copy data into temporary solution and residual vectors.
245238 * \param[in] LinSysRes - Linear system residual
246239 * \param[in,out] LinSysSol - Linear system solution
247240 */
248- template <class OtherType , su2enable_if<!std::is_same<ScalarType, OtherType>::value> = 0 >
241+ template <class OtherType >
249242 void HandleTemporariesIn (const CSysVector<OtherType>& LinSysRes, CSysVector<OtherType>& LinSysSol) {
250- /* --- Copy data, the solution is also copied as it serves as initial condition. ---*/
251- LinSysRes_tmp.PassiveCopy (LinSysRes);
252- LinSysSol_tmp.PassiveCopy (LinSysSol);
253-
254- /* --- Set the pointers. ---*/
255- BEGIN_SU2_OMP_SAFE_GLOBAL_ACCESS {
256- LinSysRes_ptr = &LinSysRes_tmp;
257- LinSysSol_ptr = &LinSysSol_tmp;
243+ if constexpr (std::is_same_v<ScalarType, OtherType>) {
244+ /* --- Same type specialization, temporary variables are not required. ---*/
245+ BEGIN_SU2_OMP_SAFE_GLOBAL_ACCESS {
246+ LinSysRes_ptr = &LinSysRes;
247+ LinSysSol_ptr = &LinSysSol;
248+ }
249+ END_SU2_OMP_SAFE_GLOBAL_ACCESS
250+ } else {
251+ /* --- Copy data, the solution is also copied as it serves as initial condition. ---*/
252+ LinSysRes_tmp.PassiveCopy (LinSysRes);
253+ LinSysSol_tmp.PassiveCopy (LinSysSol);
254+
255+ /* --- Set the pointers. ---*/
256+ BEGIN_SU2_OMP_SAFE_GLOBAL_ACCESS {
257+ LinSysRes_ptr = &LinSysRes_tmp;
258+ LinSysSol_ptr = &LinSysSol_tmp;
259+ }
260+ END_SU2_OMP_SAFE_GLOBAL_ACCESS
258261 }
259- END_SU2_OMP_SAFE_GLOBAL_ACCESS
260262 }
261263
262264 /* !
263265 * \brief Used by Solve for compatibility between passive and active CSysVector.
264- * \note Same type specialization, temporary variables are not required.
265266 * \param[out] LinSysSol - Linear system solution
266267 */
267- template <class OtherType , su2enable_if<std::is_same<ScalarType, OtherType>::value> = 0 >
268+ template <class OtherType >
268269 void HandleTemporariesOut (CSysVector<OtherType>& LinSysSol) {
269- /* --- Reset the pointers. ---*/
270- BEGIN_SU2_OMP_SAFE_GLOBAL_ACCESS {
271- LinSysRes_ptr = nullptr ;
272- LinSysSol_ptr = nullptr ;
270+ if constexpr (std::is_same_v<ScalarType, OtherType>) {
271+ /* --- Same type specialization, temporary variables are not required. ---*/
272+ BEGIN_SU2_OMP_SAFE_GLOBAL_ACCESS {
273+ LinSysRes_ptr = nullptr ;
274+ LinSysSol_ptr = nullptr ;
275+ }
276+ END_SU2_OMP_SAFE_GLOBAL_ACCESS
277+ } else {
278+ /* --- Copy data, only the temporary solution needs to be copied. ---*/
279+ LinSysSol.PassiveCopy (LinSysSol_tmp);
280+
281+ /* --- Reset the pointers. ---*/
282+ BEGIN_SU2_OMP_SAFE_GLOBAL_ACCESS {
283+ LinSysRes_ptr = nullptr ;
284+ LinSysSol_ptr = nullptr ;
285+ }
286+ END_SU2_OMP_SAFE_GLOBAL_ACCESS
273287 }
274- END_SU2_OMP_SAFE_GLOBAL_ACCESS
275288 }
276289
277- /* !
278- * \brief Used by Solve for compatibility between passive and active CSysVector.
279- * \note Different type specialization, copy data from the temporary solution vector.
280- * \param[out] LinSysSol - Linear system solution
281- */
282- template <class OtherType , su2enable_if<!std::is_same<ScalarType, OtherType>::value> = 0 >
283- void HandleTemporariesOut (CSysVector<OtherType>& LinSysSol) {
284- /* --- Copy data, only the temporary solution needs to be copied. ---*/
285- LinSysSol.PassiveCopy (LinSysSol_tmp);
286-
287- /* --- Reset the pointers. ---*/
288- BEGIN_SU2_OMP_SAFE_GLOBAL_ACCESS {
289- LinSysRes_ptr = nullptr ;
290- LinSysSol_ptr = nullptr ;
291- }
292- END_SU2_OMP_SAFE_GLOBAL_ACCESS
293- }
290+ /* --- TODO(pedro): The deflation part using Eigen does not compile in forward AD mode.
291+ * So we need a dummy template to avoid instantiating this function for directdiff. ---*/
292+ template <class Dummy = int >
293+ unsigned long FGCRODR_LinSolverImpl (const VectorType& b, VectorType& x, const ProductType& mat_vec,
294+ const PrecondType& precond, ScalarType tol, unsigned long max_iter,
295+ ScalarType& residual, bool monitoring, const CConfig* config,
296+ FgcrodrMode mode) const ;
294297
295298 public:
296299 /* !
@@ -335,7 +338,25 @@ class CSysSolve {
335338 */
336339 unsigned long RFGMRES_LinSolver (const VectorType& b, VectorType& x, const ProductType& mat_vec,
337340 const PrecondType& precond, ScalarType tol, unsigned long m, ScalarType& residual,
338- bool monitoring, const CConfig* config);
341+ bool monitoring, const CConfig* config) const ;
342+
343+ /* !
344+ * \brief Flexible Generalized Conjugate Residual Method with Inner Orthogonalization and Deflated Restarting.
345+ * \param[in] b - the right hand size vector
346+ * \param[in,out] x - on entry the intial guess, on exit the solution
347+ * \param[in] mat_vec - object that defines matrix-vector product
348+ * \param[in] precond - object that defines preconditioner
349+ * \param[in] tol - tolerance with which to solve the system
350+ * \param[in] max_iter - maximum number of iterations
351+ * \param[out] residual - final normalized residual
352+ * \param[in] monitoring - turn on priting residuals from solver to screen.
353+ * \param[in] config - Definition of the particular problem.
354+ * \param[in] mode - See FgcrodrMode.
355+ */
356+ unsigned long FGCRODR_LinSolver (const VectorType& b, VectorType& x, const ProductType& mat_vec,
357+ const PrecondType& precond, ScalarType tol, unsigned long max_iter,
358+ ScalarType& residual, bool monitoring, const CConfig* config,
359+ FgcrodrMode mode = FgcrodrMode::NORMAL) const ;
339360
340361 /* !
341362 * \brief Biconjugate Gradient Stabilized Method (BCGSTAB)
@@ -390,7 +411,7 @@ class CSysSolve {
390411 * \param[in] directCall - If this method is called directly, or in AD context.
391412 */
392413 unsigned long Solve_b (MatrixType& Jacobian, const CSysVector<su2double>& LinSysRes, CSysVector<su2double>& LinSysSol,
393- CGeometry* geometry, const CConfig* config, const bool directCall = true );
414+ CGeometry* geometry, const CConfig* config, bool directCall = true );
394415
395416 /* !
396417 * \brief Get the number of iterations.
@@ -423,4 +444,9 @@ class CSysSolve {
423444 * \brief Set the screen output frequency during monitoring.
424445 */
425446 inline void SetMonitoringFrequency (bool frequency) { monitorFreq = frequency; }
447+
448+ /* !
449+ * \brief Discard FGCRODR's deflation vectors for the next solve.
450+ */
451+ inline void ResetDeflation () const { k = 0 ; }
426452};
0 commit comments