Skip to content

Commit e2548c2

Browse files
nindanaotoclaude
andcommitted
Extend Double Decomposition support to TRLWE multiplication
Add full DD support for TRLWE multiplication (BFV-style): - Add relinKeyDD and relinKeyFFTDD types for DD relinearization keys - Add relinKeygenDD and relinKeyFFTgenDD for key generation - Add relinKeySwitchDD and RelinearizationDD for DD key switching - Add TRLWEMultWithoutRelinearizationDD for full DD multiplication - Decomposes both TRLWEs by l̅ using TRLWEBaseBbarDecompose - Computes polynomial convolution in decomposition index space - Rescales by Δ using shift: width - (k+2)*B̅gbit + plain_modulusbit - Add TRLWEMultFullDD and TRLWEMultDD convenience wrappers - Add TLWEMultDD for TLWE multiplication with DD Also fix memory issues for large parameter sets (128-bit): - Change stack allocations to heap in trgsw.hpp trgswSymEncrypt functions - Fix gatebootstrappingtlwe2tlwedoubledecomposition test allocations Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 5f3088b commit e2548c2

File tree

6 files changed

+581
-29
lines changed

6 files changed

+581
-29
lines changed

include/bfv++.hpp

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ template <class P>
4242
inline void relinKeySwitch(TRLWE<P> &res, const Polynomial<P> &poly,
4343
const relinKeyFFT<P> &relinkeyfft)
4444
{
45+
static_assert(P::l̅ == 1,
46+
"relinKeySwitch only supports standard decomposition (l̅=1). "
47+
"Use relinKeySwitchDD for Double Decomposition.");
4548
DecomposedPolynomial<P> decvec;
4649
Decomposition<P>(decvec, poly);
4750
PolynomialInFD<P> decvecfft;
@@ -58,36 +61,259 @@ inline void relinKeySwitch(TRLWE<P> &res, const Polynomial<P> &poly,
5861
TwistFFT<P>(res[1], resfft[1]);
5962
}
6063

64+
// Double Decomposition variant of relinKeySwitch
65+
// Uses l * l̅ rows and accumulates l̅ separate results before recombining
66+
template <class P>
67+
inline void relinKeySwitchDD(TRLWE<P> &res, const Polynomial<P> &poly,
68+
const relinKeyFFTDD<P> &relinkeyfft)
69+
{
70+
alignas(64) DecomposedPolynomial<P> decvec;
71+
Decomposition<P>(decvec, poly);
72+
alignas(64) PolynomialInFD<P> decvecfft;
73+
74+
// l̅ separate accumulators in FD domain
75+
alignas(64) std::array<TRLWEInFD<P>, P::l̅> resfft_dd;
76+
77+
// Initialize all accumulators to zero
78+
for (int j = 0; j < P::l̅; j++)
79+
for (int m = 0; m <= P::k; m++)
80+
for (int n = 0; n < P::n; n++)
81+
resfft_dd[j][m][n] = 0.0;
82+
83+
// Process with standard decomposition (l levels), accumulate into l̅ results
84+
for (int i = 0; i < P::l; i++) {
85+
TwistIFFT<P>(decvecfft, decvec[i]);
86+
// Each decomposition level i multiplies with l̅ relinkey rows
87+
for (int j = 0; j < P::l̅; j++) {
88+
const int row_idx = i * P::l̅ + j;
89+
for (int m = 0; m <= P::k; m++) {
90+
if (i == 0 && j == 0)
91+
MulInFD<P::n>(resfft_dd[j][m], decvecfft,
92+
relinkeyfft[row_idx][m]);
93+
else
94+
FMAInFD<P::n>(resfft_dd[j][m], decvecfft,
95+
relinkeyfft[row_idx][m]);
96+
}
97+
}
98+
}
99+
100+
// FFT back to coefficient domain for each accumulator and recombine
101+
std::array<TRLWE<P>, P::l̅> results_dd;
102+
for (int j = 0; j < P::l̅; j++)
103+
for (int k = 0; k <= P::k; k++)
104+
TwistFFT<P>(results_dd[j][k], resfft_dd[j][k]);
105+
106+
// Recombine the l̅ TRLWEs back to single TRLWE
107+
RecombineTRLWEFromDD<P, false>(res, results_dd);
108+
}
109+
61110
template <class P>
62111
inline void Relinearization(TRLWE<P> &res, const TRLWE3<P> &mult,
63112
const relinKeyFFT<P> &relinkeyfft)
64113
{
114+
static_assert(P::l̅ == 1,
115+
"Relinearization only supports standard decomposition (l̅=1). "
116+
"Use RelinearizationDD for Double Decomposition.");
65117
TRLWE<P> squareterm;
66118
relinKeySwitch<P>(squareterm, mult[2], relinkeyfft);
67119
for (int i = 0; i < P::n; i++) res[0][i] = mult[0][i] + squareterm[0][i];
68120
for (int i = 0; i < P::n; i++) res[1][i] = mult[1][i] + squareterm[1][i];
69121
}
70122

123+
// Double Decomposition variant of Relinearization
124+
template <class P>
125+
inline void RelinearizationDD(TRLWE<P> &res, const TRLWE3<P> &mult,
126+
const relinKeyFFTDD<P> &relinkeyfft)
127+
{
128+
TRLWE<P> squareterm;
129+
relinKeySwitchDD<P>(squareterm, mult[2], relinkeyfft);
130+
for (int i = 0; i < P::n; i++) res[0][i] = mult[0][i] + squareterm[0][i];
131+
for (int i = 0; i < P::n; i++) res[1][i] = mult[1][i] + squareterm[1][i];
132+
}
133+
71134
template <class P>
72135
inline void TRLWEMult(TRLWE<P> &res, const TRLWE<P> &a, const TRLWE<P> &b,
73136
const relinKeyFFT<P> &relinkeyfft)
74137
{
138+
static_assert(P::l̅ == 1,
139+
"TRLWEMult only supports standard decomposition (l̅=1). "
140+
"Use TRLWEMultDD for Double Decomposition.");
75141
TRLWE3<P> resmult;
76142
TRLWEMultWithoutRelinerization<P>(resmult, a, b);
77143
Relinearization<P>(res, resmult, relinkeyfft);
78144
}
79145

146+
// Double Decomposition variant of TRLWEMult
147+
// Uses DD-based relinearization for improved noise management
148+
template <class P>
149+
inline void TRLWEMultDD(TRLWE<P> &res, const TRLWE<P> &a, const TRLWE<P> &b,
150+
const relinKeyFFTDD<P> &relinkeyfft)
151+
{
152+
TRLWE3<P> resmult;
153+
TRLWEMultWithoutRelinerization<P>(resmult, a, b);
154+
RelinearizationDD<P>(res, resmult, relinkeyfft);
155+
}
156+
157+
// Full Double Decomposition TRLWE Multiplication
158+
// Both TRLWEs are decomposed by l̅, multiplication is polynomial-like in decomposition indices
159+
// Algorithm:
160+
// 1. Decompose a[k] and b[k] into l̅ components each using base B̅g
161+
// 2. For polynomial product, compute convolution in decomposition index space:
162+
// (Σᵢ aᵢ·B̅g^i) × (Σⱼ bⱼ·B̅g^j) = Σₖ (Σᵢ₊ⱼ₌ₖ aᵢ·bⱼ)·B̅g^k
163+
// 3. Each aᵢ·bⱼ is computed via FFT polynomial multiplication
164+
// 4. IFFT to recover coefficients, rescale by Δ
165+
// 5. Recombine the 2l̅-1 terms back to proper scaling
166+
template <class P>
167+
void TRLWEMultWithoutRelinearizationDD(TRLWE3<P> &res, const TRLWE<P> &a,
168+
const TRLWE<P> &b)
169+
{
170+
constexpr int width = std::numeric_limits<typename P::T>::digits;
171+
172+
// Decompose all input polynomials into l̅ components
173+
// TRLWEBaseBbarDecompose gives: a = Σⱼ a_dec[j] * 2^(width - (j+1)*B̅gbit)
174+
// where a_dec[j] has coefficients in [-B̅g/2, B̅g/2)
175+
std::array<TRLWE<P>, P::l̅> a_dec, b_dec;
176+
TRLWEBaseBbarDecompose<P>(a_dec, a);
177+
TRLWEBaseBbarDecompose<P>(b_dec, b);
178+
179+
// FFT all decomposed components
180+
std::array<std::array<PolynomialInFD<P>, P::l̅>, P::k + 1> a_fft, b_fft;
181+
for (int poly_idx = 0; poly_idx <= P::k; poly_idx++) {
182+
for (int j = 0; j < P::l̅; j++) {
183+
TwistIFFT<P>(a_fft[poly_idx][j], a_dec[j][poly_idx]);
184+
TwistIFFT<P>(b_fft[poly_idx][j], b_dec[j][poly_idx]);
185+
}
186+
}
187+
188+
// Compute c[0] = a[0]*b[1] + a[1]*b[0] using DD polynomial multiplication
189+
std::array<PolynomialInFD<P>, 2 * P::l̅ - 1> c0_fft;
190+
for (int k = 0; k < 2 * P::l̅ - 1; k++) {
191+
for (int n = 0; n < P::n; n++) c0_fft[k][n] = 0.0;
192+
}
193+
for (int i = 0; i < P::l̅; i++) {
194+
for (int j = 0; j < P::l̅; j++) {
195+
FMAInFD<P::n>(c0_fft[i + j], a_fft[0][i], b_fft[1][j]);
196+
FMAInFD<P::n>(c0_fft[i + j], a_fft[1][i], b_fft[0][j]);
197+
}
198+
}
199+
200+
// Compute c[1] = a[1]*b[1]
201+
std::array<PolynomialInFD<P>, 2 * P::l̅ - 1> c1_fft;
202+
for (int k = 0; k < 2 * P::l̅ - 1; k++) {
203+
for (int n = 0; n < P::n; n++) c1_fft[k][n] = 0.0;
204+
}
205+
for (int i = 0; i < P::l̅; i++) {
206+
for (int j = 0; j < P::l̅; j++) {
207+
FMAInFD<P::n>(c1_fft[i + j], a_fft[1][i], b_fft[1][j]);
208+
}
209+
}
210+
211+
// Compute c[2] = a[0]*b[0]
212+
std::array<PolynomialInFD<P>, 2 * P::l̅ - 1> c2_fft;
213+
for (int k = 0; k < 2 * P::l̅ - 1; k++) {
214+
for (int n = 0; n < P::n; n++) c2_fft[k][n] = 0.0;
215+
}
216+
for (int i = 0; i < P::l̅; i++) {
217+
for (int j = 0; j < P::l̅; j++) {
218+
FMAInFD<P::n>(c2_fft[i + j], a_fft[0][i], b_fft[0][j]);
219+
}
220+
}
221+
222+
// Initialize results to zero
223+
for (int n = 0; n < P::n; n++) {
224+
res[0][n] = 0;
225+
res[1][n] = 0;
226+
res[2][n] = 0;
227+
}
228+
229+
// IFFT and recombine with Δ rescaling
230+
// The decomposition was: x = Σⱼ x_j * h̅[j] where h̅[j] = 2^(width - (j+1)*B̅gbit)
231+
// Product of positions i and j: scale = h̅[i] * h̅[j] = 2^(2*width - (i+j+2)*B̅gbit)
232+
// At convolution position k = i+j: scale = 2^(2*width - (k+2)*B̅gbit)
233+
//
234+
// The original ciphertext coefficients are scaled by Δ = 2^(width - plain_modulusbit).
235+
// After multiplying a × b = Δ² × plaintext_product.
236+
// We need to divide by Δ to get back to Δ scaling.
237+
//
238+
// For convolution position k:
239+
// raw_scale = 2^(2*width - (k+2)*B̅gbit)
240+
// after /Δ: scale = raw_scale / Δ = 2^(2*width - (k+2)*B̅gbit) / 2^(width - plain_modulusbit)
241+
// = 2^(width - (k+2)*B̅gbit + plain_modulusbit)
242+
//
243+
// So the shift for position k is: width - (k+2)*B̅gbit + plain_modulusbit
244+
245+
for (int k = 0; k < 2 * P::l̅ - 1; k++) {
246+
Polynomial<P> temp0, temp1, temp2;
247+
TwistFFT<P>(temp0, c0_fft[k]);
248+
TwistFFT<P>(temp1, c1_fft[k]);
249+
TwistFFT<P>(temp2, c2_fft[k]);
250+
251+
// Shift includes the Δ division: width - (k+2)*B̅gbit + plain_modulusbit
252+
// Positive shift means left shift, negative means right shift
253+
const int shift = width - (k + 2) * P::B̅gbit + P::plain_modulusbit;
254+
255+
if (shift >= 0 && shift < width) {
256+
// Left shift
257+
for (int n = 0; n < P::n; n++) {
258+
res[0][n] += temp0[n] << shift;
259+
res[1][n] += temp1[n] << shift;
260+
res[2][n] += temp2[n] << shift;
261+
}
262+
} else if (shift < 0 && -shift < width) {
263+
// Right shift
264+
const int right_shift = -shift;
265+
for (int n = 0; n < P::n; n++) {
266+
res[0][n] += static_cast<typename P::T>(
267+
static_cast<std::make_signed_t<typename P::T>>(temp0[n]) >> right_shift);
268+
res[1][n] += static_cast<typename P::T>(
269+
static_cast<std::make_signed_t<typename P::T>>(temp1[n]) >> right_shift);
270+
res[2][n] += static_cast<typename P::T>(
271+
static_cast<std::make_signed_t<typename P::T>>(temp2[n]) >> right_shift);
272+
}
273+
}
274+
// If |shift| >= width, contribution is zero
275+
}
276+
}
277+
278+
// Full DD TRLWE multiplication with relinearization
279+
template <class P>
280+
inline void TRLWEMultFullDD(TRLWE<P> &res, const TRLWE<P> &a, const TRLWE<P> &b,
281+
const relinKeyFFTDD<P> &relinkeyfft)
282+
{
283+
TRLWE3<P> resmult;
284+
TRLWEMultWithoutRelinearizationDD<P>(resmult, a, b);
285+
RelinearizationDD<P>(res, resmult, relinkeyfft);
286+
}
287+
80288
template <class P>
81289
inline void TLWEMult(TLWE<typename P::targetP> &res,
82290
const TLWE<typename P::domainP> &a,
83291
const TLWE<typename P::domainP> &b,
84292
const relinKeyFFT<typename P::targetP> &relinkeyfft,
85293
const PrivateKeySwitchingKey<P> &privksk)
86294
{
295+
static_assert(P::targetP::l̅ == 1,
296+
"TLWEMult only supports standard decomposition (l̅=1). "
297+
"Use TLWEMultDD for Double Decomposition.");
87298
TRLWE<typename P::targetP> trlweres, trlwea, trlweb;
88299
PrivKeySwitch<P>(trlwea, a, privksk);
89300
PrivKeySwitch<P>(trlweb, b, privksk);
90301
TRLWEMult<typename P::targetP>(trlweres, trlwea, trlweb, relinkeyfft);
91302
SampleExtractIndex<typename P::targetP>(res, trlweres, 0);
92303
}
304+
305+
// Double Decomposition variant of TLWEMult
306+
template <class P>
307+
inline void TLWEMultDD(TLWE<typename P::targetP> &res,
308+
const TLWE<typename P::domainP> &a,
309+
const TLWE<typename P::domainP> &b,
310+
const relinKeyFFTDD<typename P::targetP> &relinkeyfft,
311+
const PrivateKeySwitchingKey<P> &privksk)
312+
{
313+
TRLWE<typename P::targetP> trlweres, trlwea, trlweb;
314+
PrivKeySwitch<P>(trlwea, a, privksk);
315+
PrivKeySwitch<P>(trlweb, b, privksk);
316+
TRLWEMultDD<typename P::targetP>(trlweres, trlwea, trlweb, relinkeyfft);
317+
SampleExtractIndex<typename P::targetP>(res, trlweres, 0);
318+
}
93319
} // namespace TFHEpp

include/evalkeygens.hpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,9 @@ void subikskgen(SubsetKeySwitchingKey<P>& ksk, const SecretKey& sk)
272272
template <class P>
273273
relinKey<P> relinKeygen(const Key<P>& key)
274274
{
275+
static_assert(P::l̅ == 1,
276+
"relinKeygen only supports standard decomposition (l̅=1). "
277+
"Use relinKeygenDD for Double Decomposition.");
275278
Polynomial<P> keysquare;
276279
std::array<typename P::T, P::n> partkey;
277280
for (int i = 0; i < P::n; i++) partkey[i] = key[0 * P::n + i];
@@ -282,6 +285,28 @@ relinKey<P> relinKeygen(const Key<P>& key)
282285
return relinkey;
283286
}
284287

288+
// Double Decomposition variant of relinKeygen
289+
// Creates relinearization key with l * l̅ rows for DD-based TRLWE multiplication
290+
template <class P>
291+
relinKeyDD<P> relinKeygenDD(const Key<P>& key)
292+
{
293+
Polynomial<P> keysquare;
294+
std::array<typename P::T, P::n> partkey;
295+
for (int i = 0; i < P::n; i++) partkey[i] = key[0 * P::n + i];
296+
PolyMulNaive<P>(keysquare, partkey, partkey);
297+
298+
// Use halftrgswSymEncrypt which properly handles DD
299+
HalfTRGSW<P> halftrgsw;
300+
halftrgswSymEncrypt<P>(halftrgsw, keysquare, key);
301+
302+
// Copy to relinKeyDD (they have the same structure: l * l̅ TRLWEs)
303+
relinKeyDD<P> relinkey;
304+
for (int i = 0; i < P::l * P::l̅; i++)
305+
relinkey[i] = halftrgsw[i];
306+
307+
return relinkey;
308+
}
309+
285310
template <class P>
286311
void subprivkskgen(SubsetPrivateKeySwitchingKey<P>& privksk,
287312
const Polynomial<typename P::targetP>& func,
@@ -318,11 +343,27 @@ void subprivkskgen(SubsetPrivateKeySwitchingKey<P>& privksk,
318343
template <class P>
319344
relinKeyFFT<P> relinKeyFFTgen(const Key<P>& key)
320345
{
346+
static_assert(P::l̅ == 1,
347+
"relinKeyFFTgen only supports standard decomposition (l̅=1). "
348+
"Use relinKeyFFTgenDD for Double Decomposition.");
321349
relinKey<P> relinkey = relinKeygen<P>(key);
322350
relinKeyFFT<P> relinkeyfft;
323351
for (int i = 0; i < P::l; i++)
324352
for (int j = 0; j < 2; j++)
325353
TwistIFFT<P>(relinkeyfft[i][j], relinkey[i][j]);
326354
return relinkeyfft;
327355
}
356+
357+
// Double Decomposition variant of relinKeyFFTgen
358+
// Creates FFT relinearization key with l * l̅ rows
359+
template <class P>
360+
relinKeyFFTDD<P> relinKeyFFTgenDD(const Key<P>& key)
361+
{
362+
relinKeyDD<P> relinkey = relinKeygenDD<P>(key);
363+
relinKeyFFTDD<P> relinkeyfft;
364+
for (int i = 0; i < P::l * P::l̅; i++)
365+
for (int j = 0; j <= P::k; j++)
366+
TwistIFFT<P>(relinkeyfft[i][j], relinkey[i][j]);
367+
return relinkeyfft;
368+
}
328369
} // namespace TFHEpp

include/params.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,11 @@ template <class P>
229229
using relinKey = std::array<TRLWE<P>, P::l>;
230230
template <class P>
231231
using relinKeyFFT = std::array<TRLWEInFD<P>, P::l>;
232+
// Double Decomposition variants with l * l̅ rows
233+
template <class P>
234+
using relinKeyDD = std::array<TRLWE<P>, P::l * P::l̅>;
235+
template <class P>
236+
using relinKeyFFTDD = aligned_array<TRLWEInFD<P>, P::l * P::l̅>;
232237

233238
#define TFHEPP_EXPLICIT_INSTANTIATION_TLWE(fun) \
234239
fun(lvl0param); \

0 commit comments

Comments
 (0)