@@ -117,6 +117,102 @@ inline void NonceDecomposition(DecomposedNoncePolynomial<P> &decpoly,
117117 }
118118}
119119
120+ // Double Decomposition (bivariate representation) for external product
121+ // Decomposes each coefficient a into l*l̅ components such that:
122+ // a ≈ Σᵢ Σⱼ aᵢⱼ * Bg^(l-i) * B̅g^(l̅-j)
123+ // When l̅=1 (j=0 only), this reduces to standard decomposition.
124+ template <class P >
125+ constexpr typename P::T ddoffsetgen ()
126+ {
127+ typename P::T offset = 0 ;
128+ for (int i = 1 ; i <= P::l; i++)
129+ for (int j = 0 ; j < P::l̅; j++)
130+ offset += (static_cast <typename P::T>(P::Bg) / 2 ) *
131+ (static_cast <typename P::T>(1 )
132+ << (std::numeric_limits<typename P::T>::digits -
133+ i * P::Bgbit - j * P::B̅gbit));
134+ return offset;
135+ }
136+
137+ template <class P >
138+ inline void DoubleDecomposition (DecomposedPolynomialDD<P> &decpoly,
139+ const Polynomial<P> &poly)
140+ {
141+ constexpr typename P::T offset = ddoffsetgen<P>();
142+ // Remaining bits after decomposition
143+ constexpr int remaining_bits = std::numeric_limits<typename P::T>::digits -
144+ P::l * P::Bgbit - P::l̅ * P::B̅gbit;
145+ // roundoffset is 0 if no remaining bits, otherwise 2^(remaining_bits-1)
146+ constexpr typename P::T roundoffset =
147+ remaining_bits > 0
148+ ? (static_cast <typename P::T>(1 ) << (remaining_bits - 1 ))
149+ : static_cast <typename P::T>(0 );
150+ constexpr typename P::T maskBg =
151+ static_cast <typename P::T>((1ULL << P::Bgbit) - 1 );
152+ constexpr typename P::T halfBg = (1ULL << (P::Bgbit - 1 ));
153+
154+ for (int n = 0 ; n < P::n; n++) {
155+ typename P::T a = poly[n] + offset + roundoffset;
156+ for (int i = 0 ; i < P::l; i++) {
157+ for (int j = 0 ; j < P::l̅; j++) {
158+ // Shift to get the (i,j)-th digit in base Bg (after B̅g grouping)
159+ // When l̅=1 (j=0 only), this reduces to standard decomposition
160+ const int shift = std::numeric_limits<typename P::T>::digits -
161+ (i + 1 ) * P::Bgbit - j * P::B̅gbit;
162+ decpoly[i * P::l̅ + j][n] =
163+ static_cast <std::make_signed_t <typename P::T>>(
164+ ((a >> shift) & maskBg) - halfBg);
165+ }
166+ }
167+ }
168+ }
169+
170+ template <class P >
171+ constexpr typename P::T nonceddoffsetgen ()
172+ {
173+ typename P::T offset = 0 ;
174+ for (int i = 1 ; i <= P::lₐ; i++)
175+ for (int j = 0 ; j < P::l̅ₐ; j++)
176+ offset += (static_cast <typename P::T>(P::Bgₐ) / 2 ) *
177+ (static_cast <typename P::T>(1 )
178+ << (std::numeric_limits<typename P::T>::digits -
179+ i * P::Bgₐbit - j * P::B̅gₐbit));
180+ return offset;
181+ }
182+
183+ template <class P >
184+ inline void NonceDoubleDecomposition (DecomposedNoncePolynomialDD<P> &decpoly,
185+ const Polynomial<P> &poly)
186+ {
187+ constexpr typename P::T offset = nonceddoffsetgen<P>();
188+ // Remaining bits after decomposition
189+ constexpr int remaining_bits = std::numeric_limits<typename P::T>::digits -
190+ P::lₐ * P::Bgₐbit - P::l̅ₐ * P::B̅gₐbit;
191+ // roundoffset is 0 if no remaining bits, otherwise 2^(remaining_bits-1)
192+ constexpr typename P::T roundoffset =
193+ remaining_bits > 0
194+ ? (static_cast <typename P::T>(1 ) << (remaining_bits - 1 ))
195+ : static_cast <typename P::T>(0 );
196+ constexpr typename P::T maskBg =
197+ static_cast <typename P::T>((1ULL << P::Bgₐbit) - 1 );
198+ constexpr typename P::T halfBg = (1ULL << (P::Bgₐbit - 1 ));
199+
200+ for (int n = 0 ; n < P::n; n++) {
201+ typename P::T a = poly[n] + offset + roundoffset;
202+ for (int i = 0 ; i < P::lₐ; i++) {
203+ for (int j = 0 ; j < P::l̅ₐ; j++) {
204+ // Shift to get the (i,j)-th digit
205+ // When l̅ₐ=1 (j=0 only), this reduces to standard decomposition
206+ const int shift = std::numeric_limits<typename P::T>::digits -
207+ (i + 1 ) * P::Bgₐbit - j * P::B̅gₐbit;
208+ decpoly[i * P::l̅ₐ + j][n] =
209+ static_cast <std::make_signed_t <typename P::T>>(
210+ ((a >> shift) & maskBg) - halfBg);
211+ }
212+ }
213+ }
214+ }
215+
120216template <class P >
121217void Decomposition (DecomposedPolynomialNTT<P> &decpolyntt,
122218 const Polynomial<P> &poly)
@@ -195,6 +291,53 @@ void ExternalProduct(TRLWE<P> &res, const TRLWE<P> &trlwe,
195291 for (int k = 0 ; k < P::k + 1 ; k++) TwistFFT<P>(res[k], restrlwefft[k]);
196292}
197293
294+ // External product with Double Decomposition (bivariate representation)
295+ // Uses the full TRGSW structure with l*l̅ rows for "b" block and k*lₐ*l̅ₐ rows
296+ // for "a" blocks
297+ template <class P >
298+ void ExternalProductDD (TRLWE<P> &res, const TRLWE<P> &trlwe,
299+ const TRGSWFFT<P> &trgswfft)
300+ {
301+ alignas (64 ) PolynomialInFD<P> decpolyfft;
302+ alignas (64 ) TRLWEInFD<P> restrlwefft;
303+
304+ // Handle "a" polynomials (indices 0 to k-1 in TRLWE)
305+ // Uses NonceDoubleDecomposition with lₐ*l̅ₐ levels
306+ {
307+ alignas (64 ) DecomposedNoncePolynomialDD<P> decpoly;
308+ NonceDoubleDecomposition<P>(decpoly, trlwe[0 ]);
309+ TwistIFFT<P>(decpolyfft, decpoly[0 ]);
310+ for (int m = 0 ; m < P::k + 1 ; m++)
311+ MulInFD<P::n>(restrlwefft[m], decpolyfft, trgswfft[0 ][m]);
312+ for (int i = 1 ; i < P::lₐ * P::l̅ₐ; i++) {
313+ TwistIFFT<P>(decpolyfft, decpoly[i]);
314+ for (int m = 0 ; m < P::k + 1 ; m++)
315+ FMAInFD<P::n>(restrlwefft[m], decpolyfft, trgswfft[i][m]);
316+ }
317+ for (int k = 1 ; k < P::k; k++) {
318+ NonceDoubleDecomposition<P>(decpoly, trlwe[k]);
319+ for (int i = 0 ; i < P::lₐ * P::l̅ₐ; i++) {
320+ TwistIFFT<P>(decpolyfft, decpoly[i]);
321+ for (int m = 0 ; m < P::k + 1 ; m++)
322+ FMAInFD<P::n>(restrlwefft[m], decpolyfft,
323+ trgswfft[i + k * P::lₐ * P::l̅ₐ][m]);
324+ }
325+ }
326+ }
327+
328+ // Handle "b" polynomial (index k in TRLWE)
329+ // Uses DoubleDecomposition with l*l̅ levels
330+ alignas (64 ) DecomposedPolynomialDD<P> decpoly;
331+ DoubleDecomposition<P>(decpoly, trlwe[P::k]);
332+ for (int i = 0 ; i < P::l * P::l̅; i++) {
333+ TwistIFFT<P>(decpolyfft, decpoly[i]);
334+ for (int m = 0 ; m < P::k + 1 ; m++)
335+ FMAInFD<P::n>(restrlwefft[m], decpolyfft,
336+ trgswfft[i + P::k * P::lₐ * P::l̅ₐ][m]);
337+ }
338+ for (int k = 0 ; k < P::k + 1 ; k++) TwistFFT<P>(res[k], restrlwefft[k]);
339+ }
340+
198341template <class P >
199342void ExternalProduct (TRLWE<P> &res, const Polynomial<P> &poly,
200343 const HalfTRGSWFFT<P> &halftrgswfft)
@@ -453,23 +596,29 @@ constexpr std::array<typename P::T, P::lₐ> noncehgen()
453596}
454597
455598// Auxiliary h generation for Double Decomposition (bivariate representation)
599+ // h̅[j] values are used to construct gadget values h[i] * h̅[j] = 2^(width - (i+1)*Bgbit - j*B̅gbit)
600+ // For j=0: no auxiliary shift, so h̅[0] = 1
601+ // For j>0: h̅[j] = 2^(width - j*B̅gbit) which combines with h[i] via modular multiplication
456602template <class P >
457603constexpr std::array<typename P::T, P::l̅> h̅gen()
458604{
459605 std::array<typename P::T, P::l̅> h̅{};
460- for (int i = 0 ; i < P::l̅; i++)
606+ h̅[0 ] = 1 ; // j=0 means no auxiliary shift
607+ for (int i = 1 ; i < P::l̅; i++)
461608 h̅[i] = 1ULL << (std::numeric_limits<typename P::T>::digits -
462- (i + 1 ) * P::B̅gbit);
609+ i * P::B̅gbit);
463610 return h̅;
464611}
465612
613+ // Auxiliary h generation for nonce part of TRGSW with Double Decomposition
466614template <class P >
467615constexpr std::array<typename P::T, P::l̅ₐ> nonceh̅gen()
468616{
469617 std::array<typename P::T, P::l̅ₐ> h̅{};
470- for (int i = 0 ; i < P::l̅ₐ; i++)
618+ h̅[0 ] = 1 ; // j=0 means no auxiliary shift
619+ for (int i = 1 ; i < P::l̅ₐ; i++)
471620 h̅[i] = 1ULL << (std::numeric_limits<typename P::T>::digits -
472- (i + 1 ) * P::B̅gₐbit);
621+ i * P::B̅gₐbit);
473622 return h̅;
474623}
475624
0 commit comments