@@ -42,6 +42,9 @@ template <class P>
4242inline void relinKeySwitch (TRLWE<P> &res, const Polynomial<P> &poly,
4343 const relinKeyFFT<P> &relinkeyfft)
4444{
45+ static_assert (P::l̅ == 1 ,
46+ " relinKeySwitch only supports standard decomposition (l̅=1). "
47+ " Use relinKeySwitchDD for Double Decomposition." );
4548 DecomposedPolynomial<P> decvec;
4649 Decomposition<P>(decvec, poly);
4750 PolynomialInFD<P> decvecfft;
@@ -58,36 +61,259 @@ inline void relinKeySwitch(TRLWE<P> &res, const Polynomial<P> &poly,
5861 TwistFFT<P>(res[1 ], resfft[1 ]);
5962}
6063
64+ // Double Decomposition variant of relinKeySwitch
65+ // Uses l * l̅ rows and accumulates l̅ separate results before recombining
66+ template <class P >
67+ inline void relinKeySwitchDD (TRLWE<P> &res, const Polynomial<P> &poly,
68+ const relinKeyFFTDD<P> &relinkeyfft)
69+ {
70+ alignas (64 ) DecomposedPolynomial<P> decvec;
71+ Decomposition<P>(decvec, poly);
72+ alignas (64 ) PolynomialInFD<P> decvecfft;
73+
74+ // l̅ separate accumulators in FD domain
75+ alignas (64 ) std::array<TRLWEInFD<P>, P::l̅> resfft_dd;
76+
77+ // Initialize all accumulators to zero
78+ for (int j = 0 ; j < P::l̅; j++)
79+ for (int m = 0 ; m <= P::k; m++)
80+ for (int n = 0 ; n < P::n; n++)
81+ resfft_dd[j][m][n] = 0.0 ;
82+
83+ // Process with standard decomposition (l levels), accumulate into l̅ results
84+ for (int i = 0 ; i < P::l; i++) {
85+ TwistIFFT<P>(decvecfft, decvec[i]);
86+ // Each decomposition level i multiplies with l̅ relinkey rows
87+ for (int j = 0 ; j < P::l̅; j++) {
88+ const int row_idx = i * P::l̅ + j;
89+ for (int m = 0 ; m <= P::k; m++) {
90+ if (i == 0 && j == 0 )
91+ MulInFD<P::n>(resfft_dd[j][m], decvecfft,
92+ relinkeyfft[row_idx][m]);
93+ else
94+ FMAInFD<P::n>(resfft_dd[j][m], decvecfft,
95+ relinkeyfft[row_idx][m]);
96+ }
97+ }
98+ }
99+
100+ // FFT back to coefficient domain for each accumulator and recombine
101+ std::array<TRLWE<P>, P::l̅> results_dd;
102+ for (int j = 0 ; j < P::l̅; j++)
103+ for (int k = 0 ; k <= P::k; k++)
104+ TwistFFT<P>(results_dd[j][k], resfft_dd[j][k]);
105+
106+ // Recombine the l̅ TRLWEs back to single TRLWE
107+ RecombineTRLWEFromDD<P, false >(res, results_dd);
108+ }
109+
61110template <class P >
62111inline void Relinearization (TRLWE<P> &res, const TRLWE3<P> &mult,
63112 const relinKeyFFT<P> &relinkeyfft)
64113{
114+ static_assert (P::l̅ == 1 ,
115+ " Relinearization only supports standard decomposition (l̅=1). "
116+ " Use RelinearizationDD for Double Decomposition." );
65117 TRLWE<P> squareterm;
66118 relinKeySwitch<P>(squareterm, mult[2 ], relinkeyfft);
67119 for (int i = 0 ; i < P::n; i++) res[0 ][i] = mult[0 ][i] + squareterm[0 ][i];
68120 for (int i = 0 ; i < P::n; i++) res[1 ][i] = mult[1 ][i] + squareterm[1 ][i];
69121}
70122
123+ // Double Decomposition variant of Relinearization
124+ template <class P >
125+ inline void RelinearizationDD (TRLWE<P> &res, const TRLWE3<P> &mult,
126+ const relinKeyFFTDD<P> &relinkeyfft)
127+ {
128+ TRLWE<P> squareterm;
129+ relinKeySwitchDD<P>(squareterm, mult[2 ], relinkeyfft);
130+ for (int i = 0 ; i < P::n; i++) res[0 ][i] = mult[0 ][i] + squareterm[0 ][i];
131+ for (int i = 0 ; i < P::n; i++) res[1 ][i] = mult[1 ][i] + squareterm[1 ][i];
132+ }
133+
71134template <class P >
72135inline void TRLWEMult (TRLWE<P> &res, const TRLWE<P> &a, const TRLWE<P> &b,
73136 const relinKeyFFT<P> &relinkeyfft)
74137{
138+ static_assert (P::l̅ == 1 ,
139+ " TRLWEMult only supports standard decomposition (l̅=1). "
140+ " Use TRLWEMultDD for Double Decomposition." );
75141 TRLWE3<P> resmult;
76142 TRLWEMultWithoutRelinerization<P>(resmult, a, b);
77143 Relinearization<P>(res, resmult, relinkeyfft);
78144}
79145
146+ // Double Decomposition variant of TRLWEMult
147+ // Uses DD-based relinearization for improved noise management
148+ template <class P >
149+ inline void TRLWEMultDD (TRLWE<P> &res, const TRLWE<P> &a, const TRLWE<P> &b,
150+ const relinKeyFFTDD<P> &relinkeyfft)
151+ {
152+ TRLWE3<P> resmult;
153+ TRLWEMultWithoutRelinerization<P>(resmult, a, b);
154+ RelinearizationDD<P>(res, resmult, relinkeyfft);
155+ }
156+
157+ // Full Double Decomposition TRLWE Multiplication
158+ // Both TRLWEs are decomposed by l̅, multiplication is polynomial-like in decomposition indices
159+ // Algorithm:
160+ // 1. Decompose a[k] and b[k] into l̅ components each using base B̅g
161+ // 2. For polynomial product, compute convolution in decomposition index space:
162+ // (Σᵢ aᵢ·B̅g^i) × (Σⱼ bⱼ·B̅g^j) = Σₖ (Σᵢ₊ⱼ₌ₖ aᵢ·bⱼ)·B̅g^k
163+ // 3. Each aᵢ·bⱼ is computed via FFT polynomial multiplication
164+ // 4. IFFT to recover coefficients, rescale by Δ
165+ // 5. Recombine the 2l̅-1 terms back to proper scaling
166+ template <class P >
167+ void TRLWEMultWithoutRelinearizationDD (TRLWE3<P> &res, const TRLWE<P> &a,
168+ const TRLWE<P> &b)
169+ {
170+ constexpr int width = std::numeric_limits<typename P::T>::digits;
171+
172+ // Decompose all input polynomials into l̅ components
173+ // TRLWEBaseBbarDecompose gives: a = Σⱼ a_dec[j] * 2^(width - (j+1)*B̅gbit)
174+ // where a_dec[j] has coefficients in [-B̅g/2, B̅g/2)
175+ std::array<TRLWE<P>, P::l̅> a_dec, b_dec;
176+ TRLWEBaseBbarDecompose<P>(a_dec, a);
177+ TRLWEBaseBbarDecompose<P>(b_dec, b);
178+
179+ // FFT all decomposed components
180+ std::array<std::array<PolynomialInFD<P>, P::l̅>, P::k + 1 > a_fft, b_fft;
181+ for (int poly_idx = 0 ; poly_idx <= P::k; poly_idx++) {
182+ for (int j = 0 ; j < P::l̅; j++) {
183+ TwistIFFT<P>(a_fft[poly_idx][j], a_dec[j][poly_idx]);
184+ TwistIFFT<P>(b_fft[poly_idx][j], b_dec[j][poly_idx]);
185+ }
186+ }
187+
188+ // Compute c[0] = a[0]*b[1] + a[1]*b[0] using DD polynomial multiplication
189+ std::array<PolynomialInFD<P>, 2 * P::l̅ - 1 > c0_fft;
190+ for (int k = 0 ; k < 2 * P::l̅ - 1 ; k++) {
191+ for (int n = 0 ; n < P::n; n++) c0_fft[k][n] = 0.0 ;
192+ }
193+ for (int i = 0 ; i < P::l̅; i++) {
194+ for (int j = 0 ; j < P::l̅; j++) {
195+ FMAInFD<P::n>(c0_fft[i + j], a_fft[0 ][i], b_fft[1 ][j]);
196+ FMAInFD<P::n>(c0_fft[i + j], a_fft[1 ][i], b_fft[0 ][j]);
197+ }
198+ }
199+
200+ // Compute c[1] = a[1]*b[1]
201+ std::array<PolynomialInFD<P>, 2 * P::l̅ - 1 > c1_fft;
202+ for (int k = 0 ; k < 2 * P::l̅ - 1 ; k++) {
203+ for (int n = 0 ; n < P::n; n++) c1_fft[k][n] = 0.0 ;
204+ }
205+ for (int i = 0 ; i < P::l̅; i++) {
206+ for (int j = 0 ; j < P::l̅; j++) {
207+ FMAInFD<P::n>(c1_fft[i + j], a_fft[1 ][i], b_fft[1 ][j]);
208+ }
209+ }
210+
211+ // Compute c[2] = a[0]*b[0]
212+ std::array<PolynomialInFD<P>, 2 * P::l̅ - 1 > c2_fft;
213+ for (int k = 0 ; k < 2 * P::l̅ - 1 ; k++) {
214+ for (int n = 0 ; n < P::n; n++) c2_fft[k][n] = 0.0 ;
215+ }
216+ for (int i = 0 ; i < P::l̅; i++) {
217+ for (int j = 0 ; j < P::l̅; j++) {
218+ FMAInFD<P::n>(c2_fft[i + j], a_fft[0 ][i], b_fft[0 ][j]);
219+ }
220+ }
221+
222+ // Initialize results to zero
223+ for (int n = 0 ; n < P::n; n++) {
224+ res[0 ][n] = 0 ;
225+ res[1 ][n] = 0 ;
226+ res[2 ][n] = 0 ;
227+ }
228+
229+ // IFFT and recombine with Δ rescaling
230+ // The decomposition was: x = Σⱼ x_j * h̅[j] where h̅[j] = 2^(width - (j+1)*B̅gbit)
231+ // Product of positions i and j: scale = h̅[i] * h̅[j] = 2^(2*width - (i+j+2)*B̅gbit)
232+ // At convolution position k = i+j: scale = 2^(2*width - (k+2)*B̅gbit)
233+ //
234+ // The original ciphertext coefficients are scaled by Δ = 2^(width - plain_modulusbit).
235+ // After multiplying a × b = Δ² × plaintext_product.
236+ // We need to divide by Δ to get back to Δ scaling.
237+ //
238+ // For convolution position k:
239+ // raw_scale = 2^(2*width - (k+2)*B̅gbit)
240+ // after /Δ: scale = raw_scale / Δ = 2^(2*width - (k+2)*B̅gbit) / 2^(width - plain_modulusbit)
241+ // = 2^(width - (k+2)*B̅gbit + plain_modulusbit)
242+ //
243+ // So the shift for position k is: width - (k+2)*B̅gbit + plain_modulusbit
244+
245+ for (int k = 0 ; k < 2 * P::l̅ - 1 ; k++) {
246+ Polynomial<P> temp0, temp1, temp2;
247+ TwistFFT<P>(temp0, c0_fft[k]);
248+ TwistFFT<P>(temp1, c1_fft[k]);
249+ TwistFFT<P>(temp2, c2_fft[k]);
250+
251+ // Shift includes the Δ division: width - (k+2)*B̅gbit + plain_modulusbit
252+ // Positive shift means left shift, negative means right shift
253+ const int shift = width - (k + 2 ) * P::B̅gbit + P::plain_modulusbit;
254+
255+ if (shift >= 0 && shift < width) {
256+ // Left shift
257+ for (int n = 0 ; n < P::n; n++) {
258+ res[0 ][n] += temp0[n] << shift;
259+ res[1 ][n] += temp1[n] << shift;
260+ res[2 ][n] += temp2[n] << shift;
261+ }
262+ } else if (shift < 0 && -shift < width) {
263+ // Right shift
264+ const int right_shift = -shift;
265+ for (int n = 0 ; n < P::n; n++) {
266+ res[0 ][n] += static_cast <typename P::T>(
267+ static_cast <std::make_signed_t <typename P::T>>(temp0[n]) >> right_shift);
268+ res[1 ][n] += static_cast <typename P::T>(
269+ static_cast <std::make_signed_t <typename P::T>>(temp1[n]) >> right_shift);
270+ res[2 ][n] += static_cast <typename P::T>(
271+ static_cast <std::make_signed_t <typename P::T>>(temp2[n]) >> right_shift);
272+ }
273+ }
274+ // If |shift| >= width, contribution is zero
275+ }
276+ }
277+
278+ // Full DD TRLWE multiplication with relinearization
279+ template <class P >
280+ inline void TRLWEMultFullDD (TRLWE<P> &res, const TRLWE<P> &a, const TRLWE<P> &b,
281+ const relinKeyFFTDD<P> &relinkeyfft)
282+ {
283+ TRLWE3<P> resmult;
284+ TRLWEMultWithoutRelinearizationDD<P>(resmult, a, b);
285+ RelinearizationDD<P>(res, resmult, relinkeyfft);
286+ }
287+
80288template <class P >
81289inline void TLWEMult (TLWE<typename P::targetP> &res,
82290 const TLWE<typename P::domainP> &a,
83291 const TLWE<typename P::domainP> &b,
84292 const relinKeyFFT<typename P::targetP> &relinkeyfft,
85293 const PrivateKeySwitchingKey<P> &privksk)
86294{
295+ static_assert (P::targetP::l̅ == 1 ,
296+ " TLWEMult only supports standard decomposition (l̅=1). "
297+ " Use TLWEMultDD for Double Decomposition." );
87298 TRLWE<typename P::targetP> trlweres, trlwea, trlweb;
88299 PrivKeySwitch<P>(trlwea, a, privksk);
89300 PrivKeySwitch<P>(trlweb, b, privksk);
90301 TRLWEMult<typename P::targetP>(trlweres, trlwea, trlweb, relinkeyfft);
91302 SampleExtractIndex<typename P::targetP>(res, trlweres, 0 );
92303}
304+
305+ // Double Decomposition variant of TLWEMult
306+ template <class P >
307+ inline void TLWEMultDD (TLWE<typename P::targetP> &res,
308+ const TLWE<typename P::domainP> &a,
309+ const TLWE<typename P::domainP> &b,
310+ const relinKeyFFTDD<typename P::targetP> &relinkeyfft,
311+ const PrivateKeySwitchingKey<P> &privksk)
312+ {
313+ TRLWE<typename P::targetP> trlweres, trlwea, trlweb;
314+ PrivKeySwitch<P>(trlwea, a, privksk);
315+ PrivKeySwitch<P>(trlweb, b, privksk);
316+ TRLWEMultDD<typename P::targetP>(trlweres, trlwea, trlweb, relinkeyfft);
317+ SampleExtractIndex<typename P::targetP>(res, trlweres, 0 );
318+ }
93319} // namespace TFHEpp
0 commit comments