@@ -68,10 +68,9 @@ inline void TwistNTT(Polynomial<P> &res, PolynomialNTT<P> &a)
6868 cuHEpp::TwistNTT<typename lvl1param::T, lvl1param::nbit>(
6969 res, a, (*ntttablelvl1)[0 ], (*ntttwistlvl1)[0 ]);
7070#endif
71- else if constexpr (std::is_same_v<typename P::T, uint64_t >) {
71+ else if constexpr (std::is_same_v<typename P::T, uint64_t >)
7272 cuHEpp::TwistNTT<typename lvl2param::T, lvl2param::nbit>(
7373 res, a, (*ntttablelvl2)[0 ], (*ntttwistlvl2)[0 ]);
74- }
7574 else
7675 static_assert (false_v<typename P::T>, " Undefined TwistNTT!" );
7776}
@@ -89,6 +88,14 @@ inline void TwistFFT(Polynomial<P> &res, const PolynomialInFD<P> &a)
8988 else if constexpr (std::is_same_v<typename P::T, uint64_t >)
9089 fftplvl1.execute_direct_torus64 (res.data (), a.data ());
9190 }
91+ else if constexpr (std::is_same_v<P, lvl3param>) {
92+ // For 128-bit lvl3param, use 64-bit FFT and shift result to top 64 bits
93+ // This preserves the Torus semantics (most significant bits)
94+ alignas (64 ) std::array<uint64_t , P::n> temp;
95+ fftplvl3.execute_direct_torus64 (temp.data (), a.data ());
96+ for (int i = 0 ; i < P::n; i++)
97+ res[i] = static_cast <__uint128_t >(temp[i]) << 64 ;
98+ }
9299 else if constexpr (std::is_same_v<typename P::T, uint64_t >)
93100 fftplvl2.execute_direct_torus64 (res.data (), a.data ());
94101 else
@@ -143,6 +150,14 @@ inline void TwistIFFT(PolynomialInFD<P> &res, const Polynomial<P> &a)
143150 if constexpr (std::is_same_v<typename P::T, uint64_t >)
144151 fftplvl1.execute_reverse_torus64 (res.data (), a.data ());
145152 }
153+ else if constexpr (std::is_same_v<P, lvl3param>) {
154+ // For 128-bit lvl3param, use top 64 bits for FFT
155+ // This preserves the Torus semantics (most significant bits)
156+ alignas (64 ) std::array<uint64_t , P::n> temp;
157+ for (int i = 0 ; i < P::n; i++)
158+ temp[i] = static_cast <uint64_t >(a[i] >> 64 );
159+ fftplvl3.execute_reverse_torus64 (res.data (), temp.data ());
160+ }
146161 else if constexpr (std::is_same_v<typename P::T, uint64_t >)
147162 fftplvl2.execute_reverse_torus64 (res.data (), a.data ());
148163 else
@@ -301,8 +316,21 @@ inline void PolyMul(Polynomial<P> &res, const Polynomial<P> &a,
301316 for (int i = 0 ; i < P::n; i++) ntta[i] *= nttb[i];
302317 TwistNTT<P>(res, ntta);
303318 }
319+ else if constexpr (std::is_same_v<typename P::T, __uint128_t >) {
320+ // Naive for 128-bit types (FFT/NTT don't support 128-bit precision)
321+ for (int i = 0 ; i < P::n; i++) {
322+ __uint128_t ri = 0 ;
323+ for (int j = 0 ; j <= i; j++)
324+ ri += static_cast <__int128_t >(a[j]) *
325+ static_cast <__int128_t >(b[i - j]);
326+ for (int j = i + 1 ; j < P::n; j++)
327+ ri -= static_cast <__int128_t >(a[j]) *
328+ static_cast <__int128_t >(b[P::n + i - j]);
329+ res[i] = ri;
330+ }
331+ }
304332 else {
305- // Naieve
333+ // Naive for other types
306334 for (int i = 0 ; i < P::n; i++) {
307335 typename P::T ri = 0 ;
308336 for (int j = 0 ; j <= i; j++)
@@ -339,17 +367,33 @@ template <class P>
339367inline void PolyMulNaive (Polynomial<P> &res, const Polynomial<P> &a,
340368 const Polynomial<P> &b)
341369{
342- for (int i = 0 ; i < P::n; i++) {
343- typename P::T ri = 0 ;
344- for (int j = 0 ; j <= i; j++)
345- ri += static_cast <typename std::make_signed<typename P::T>::type>(
346- a[j]) *
347- b[i - j];
348- for (int j = i + 1 ; j < P::n; j++)
349- ri -= static_cast <typename std::make_signed<typename P::T>::type>(
350- a[j]) *
351- b[P::n + i - j];
352- res[i] = ri;
370+ if constexpr (std::is_same_v<typename P::T, __uint128_t >) {
371+ for (int i = 0 ; i < P::n; i++) {
372+ __uint128_t ri = 0 ;
373+ for (int j = 0 ; j <= i; j++)
374+ ri += static_cast <__int128_t >(a[j]) *
375+ static_cast <__int128_t >(b[i - j]);
376+ for (int j = i + 1 ; j < P::n; j++)
377+ ri -= static_cast <__int128_t >(a[j]) *
378+ static_cast <__int128_t >(b[P::n + i - j]);
379+ res[i] = ri;
380+ }
381+ }
382+ else {
383+ for (int i = 0 ; i < P::n; i++) {
384+ typename P::T ri = 0 ;
385+ for (int j = 0 ; j <= i; j++)
386+ ri +=
387+ static_cast <typename std::make_signed<typename P::T>::type>(
388+ a[j]) *
389+ b[i - j];
390+ for (int j = i + 1 ; j < P::n; j++)
391+ ri -=
392+ static_cast <typename std::make_signed<typename P::T>::type>(
393+ a[j]) *
394+ b[P::n + i - j];
395+ res[i] = ri;
396+ }
353397 }
354398}
355399
0 commit comments