Skip to content

Commit c57befc

Browse files
nindanaotoclaude
andcommitted
Add 128-bit Torus support for Double Decomposition with FFT
- Update lvl3param to use __uint128_t with l=l̅=4, Bgbit=B̅gbit=16 - Extend FFT to support nbit=12 (n=4096) via fftplvl3 - Add TwistFFT/TwistIFFT handling for 128-bit types using 64-bit FFT - Add UniformTorusRandom<P>() helper for 128-bit random generation - Add ModularGaussian support for __uint128_t - Fix decomposition functions to scale values for 128-bit FFT compatibility - Add lvl03param bootstrapping parameter (lvl0 → lvl3) - Add GateBootstrappingTLWE2TLWEDD test for non-trivial DD (l̅=4) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent dd23b69 commit c57befc

File tree

12 files changed

+342
-59
lines changed

12 files changed

+342
-59
lines changed

include/detwfa.hpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,4 +116,52 @@ void CMUXwithPolynomialMulByXaiMinusOne(TRLWE<P> &acc,
116116
for (int i = 0; i < P::n; i++) acc[k][i] += temp[k][i];
117117
}
118118

119+
// Double Decomposition variants
120+
template <class P>
121+
void CMUXFFTDD(TRLWE<P> &res, const TRGSWFFT<P> &cs, const TRLWE<P> &c1,
122+
const TRLWE<P> &c0)
123+
{
124+
for (int k = 0; k < P::k + 1; k++)
125+
for (int i = 0; i < P::n; i++) res[k][i] = c1[k][i] - c0[k][i];
126+
ExternalProductDD<P>(res, res, cs);
127+
for (int k = 0; k < P::k + 1; k++)
128+
for (int i = 0; i < P::n; i++) res[k][i] += c0[k][i];
129+
}
130+
131+
template <class bkP>
132+
void CMUXwithPolynomialMulByXaiMinusOneDD(
133+
TRLWE<typename bkP::targetP> &acc,
134+
const BootstrappingKeyElementFFT<bkP> &cs, const int a)
135+
{
136+
if constexpr (bkP::domainP::key_value_diff == 1) {
137+
alignas(64) TRLWE<typename bkP::targetP> temp;
138+
for (int k = 0; k < bkP::targetP::k + 1; k++)
139+
PolynomialMulByXaiMinusOne<typename bkP::targetP>(temp[k], acc[k],
140+
a);
141+
ExternalProductDD<typename bkP::targetP>(temp, temp, cs[0]);
142+
for (int k = 0; k < bkP::targetP::k + 1; k++)
143+
for (int i = 0; i < bkP::targetP::n; i++) acc[k][i] += temp[k][i];
144+
}
145+
else {
146+
alignas(32) TRLWE<typename bkP::targetP> temp;
147+
int count = 0;
148+
for (int i = bkP::domainP::key_value_min;
149+
i <= bkP::domainP::key_value_max; i++) {
150+
if (i != 0) {
151+
const int mod = (a * i) % (2 * bkP::targetP::n);
152+
const int index = mod > 0 ? mod : mod + (2 * bkP::targetP::n);
153+
for (int k = 0; k < bkP::targetP::k + 1; k++)
154+
PolynomialMulByXaiMinusOne<typename bkP::targetP>(
155+
temp[k], acc[k], index);
156+
ExternalProductDD<typename bkP::targetP>(temp, temp,
157+
cs[count]);
158+
for (int k = 0; k < bkP::targetP::k + 1; k++)
159+
for (int i = 0; i < bkP::targetP::n; i++)
160+
acc[k][i] += temp[k][i];
161+
count++;
162+
}
163+
}
164+
}
165+
}
166+
119167
} // namespace TFHEpp

include/gatebootstrapping.hpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,46 @@ constexpr Polynomial<P> μpolygen()
263263
return poly;
264264
}
265265

266+
// Double Decomposition variants
267+
template <class P, uint32_t num_out = 1>
268+
void BlindRotateDD(TRLWE<typename P::targetP> &res,
269+
const TLWE<typename P::domainP> &tlwe,
270+
const BootstrappingKeyFFT<P> &bkfft,
271+
const Polynomial<typename P::targetP> &testvector)
272+
{
273+
constexpr uint32_t bitwidth = bits_needed<num_out - 1>();
274+
const uint32_t b̄ = 2 * P::targetP::n -
275+
((tlwe[P::domainP::k * P::domainP::n] >>
276+
(std::numeric_limits<typename P::domainP::T>::digits -
277+
1 - P::targetP::nbit + bitwidth))
278+
<< bitwidth);
279+
res = {};
280+
PolynomialMulByXai<typename P::targetP>(res[P::targetP::k], testvector, b̄);
281+
for (int i = 0; i < P::domainP::k * P::domainP::n; i++) {
282+
constexpr typename P::domainP::T roundoffset =
283+
1ULL << (std::numeric_limits<typename P::domainP::T>::digits - 2 -
284+
P::targetP::nbit + bitwidth);
285+
const uint32_t ā =
286+
(tlwe[i] + roundoffset) >>
287+
(std::numeric_limits<typename P::domainP::T>::digits - 1 -
288+
P::targetP::nbit + bitwidth)
289+
<< bitwidth;
290+
if (ā == 0) continue;
291+
CMUXwithPolynomialMulByXaiMinusOneDD<P>(res, bkfft[i], ā);
292+
}
293+
}
294+
295+
template <class P>
296+
void GateBootstrappingTLWE2TLWEDD(
297+
TLWE<typename P::targetP> &res, const TLWE<typename P::domainP> &tlwe,
298+
const BootstrappingKeyFFT<P> &bkfft,
299+
const Polynomial<typename P::targetP> &testvector)
300+
{
301+
alignas(64) TRLWE<typename P::targetP> acc;
302+
BlindRotateDD<P>(acc, tlwe, bkfft, testvector);
303+
SampleExtractIndex<typename P::targetP>(res, acc, 0);
304+
}
305+
266306
template <class bkP, typename bkP::targetP::T μ, class iksP>
267307
void GateBootstrapping(TLWE<typename iksP::targetP> &res,
268308
const TLWE<typename bkP::domainP> &tlwe,

include/mulfft.hpp

Lines changed: 58 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,9 @@ inline void TwistNTT(Polynomial<P> &res, PolynomialNTT<P> &a)
6868
cuHEpp::TwistNTT<typename lvl1param::T, lvl1param::nbit>(
6969
res, a, (*ntttablelvl1)[0], (*ntttwistlvl1)[0]);
7070
#endif
71-
else if constexpr (std::is_same_v<typename P::T, uint64_t>) {
71+
else if constexpr (std::is_same_v<typename P::T, uint64_t>)
7272
cuHEpp::TwistNTT<typename lvl2param::T, lvl2param::nbit>(
7373
res, a, (*ntttablelvl2)[0], (*ntttwistlvl2)[0]);
74-
}
7574
else
7675
static_assert(false_v<typename P::T>, "Undefined TwistNTT!");
7776
}
@@ -89,6 +88,14 @@ inline void TwistFFT(Polynomial<P> &res, const PolynomialInFD<P> &a)
8988
else if constexpr (std::is_same_v<typename P::T, uint64_t>)
9089
fftplvl1.execute_direct_torus64(res.data(), a.data());
9190
}
91+
else if constexpr (std::is_same_v<P, lvl3param>) {
92+
// For 128-bit lvl3param, use 64-bit FFT and shift result to top 64 bits
93+
// This preserves the Torus semantics (most significant bits)
94+
alignas(64) std::array<uint64_t, P::n> temp;
95+
fftplvl3.execute_direct_torus64(temp.data(), a.data());
96+
for (int i = 0; i < P::n; i++)
97+
res[i] = static_cast<__uint128_t>(temp[i]) << 64;
98+
}
9299
else if constexpr (std::is_same_v<typename P::T, uint64_t>)
93100
fftplvl2.execute_direct_torus64(res.data(), a.data());
94101
else
@@ -143,6 +150,14 @@ inline void TwistIFFT(PolynomialInFD<P> &res, const Polynomial<P> &a)
143150
if constexpr (std::is_same_v<typename P::T, uint64_t>)
144151
fftplvl1.execute_reverse_torus64(res.data(), a.data());
145152
}
153+
else if constexpr (std::is_same_v<P, lvl3param>) {
154+
// For 128-bit lvl3param, use top 64 bits for FFT
155+
// This preserves the Torus semantics (most significant bits)
156+
alignas(64) std::array<uint64_t, P::n> temp;
157+
for (int i = 0; i < P::n; i++)
158+
temp[i] = static_cast<uint64_t>(a[i] >> 64);
159+
fftplvl3.execute_reverse_torus64(res.data(), temp.data());
160+
}
146161
else if constexpr (std::is_same_v<typename P::T, uint64_t>)
147162
fftplvl2.execute_reverse_torus64(res.data(), a.data());
148163
else
@@ -301,8 +316,21 @@ inline void PolyMul(Polynomial<P> &res, const Polynomial<P> &a,
301316
for (int i = 0; i < P::n; i++) ntta[i] *= nttb[i];
302317
TwistNTT<P>(res, ntta);
303318
}
319+
else if constexpr (std::is_same_v<typename P::T, __uint128_t>) {
320+
// Naive for 128-bit types (FFT/NTT don't support 128-bit precision)
321+
for (int i = 0; i < P::n; i++) {
322+
__uint128_t ri = 0;
323+
for (int j = 0; j <= i; j++)
324+
ri += static_cast<__int128_t>(a[j]) *
325+
static_cast<__int128_t>(b[i - j]);
326+
for (int j = i + 1; j < P::n; j++)
327+
ri -= static_cast<__int128_t>(a[j]) *
328+
static_cast<__int128_t>(b[P::n + i - j]);
329+
res[i] = ri;
330+
}
331+
}
304332
else {
305-
// Naieve
333+
// Naive for other types
306334
for (int i = 0; i < P::n; i++) {
307335
typename P::T ri = 0;
308336
for (int j = 0; j <= i; j++)
@@ -339,17 +367,33 @@ template <class P>
339367
inline void PolyMulNaive(Polynomial<P> &res, const Polynomial<P> &a,
340368
const Polynomial<P> &b)
341369
{
342-
for (int i = 0; i < P::n; i++) {
343-
typename P::T ri = 0;
344-
for (int j = 0; j <= i; j++)
345-
ri += static_cast<typename std::make_signed<typename P::T>::type>(
346-
a[j]) *
347-
b[i - j];
348-
for (int j = i + 1; j < P::n; j++)
349-
ri -= static_cast<typename std::make_signed<typename P::T>::type>(
350-
a[j]) *
351-
b[P::n + i - j];
352-
res[i] = ri;
370+
if constexpr (std::is_same_v<typename P::T, __uint128_t>) {
371+
for (int i = 0; i < P::n; i++) {
372+
__uint128_t ri = 0;
373+
for (int j = 0; j <= i; j++)
374+
ri += static_cast<__int128_t>(a[j]) *
375+
static_cast<__int128_t>(b[i - j]);
376+
for (int j = i + 1; j < P::n; j++)
377+
ri -= static_cast<__int128_t>(a[j]) *
378+
static_cast<__int128_t>(b[P::n + i - j]);
379+
res[i] = ri;
380+
}
381+
}
382+
else {
383+
for (int i = 0; i < P::n; i++) {
384+
typename P::T ri = 0;
385+
for (int j = 0; j <= i; j++)
386+
ri +=
387+
static_cast<typename std::make_signed<typename P::T>::type>(
388+
a[j]) *
389+
b[i - j];
390+
for (int j = i + 1; j < P::n; j++)
391+
ri -=
392+
static_cast<typename std::make_signed<typename P::T>::type>(
393+
a[j]) *
394+
b[P::n + i - j];
395+
res[i] = ri;
396+
}
353397
}
354398
}
355399

include/params.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,16 @@ struct lvl02param {
6767
#endif
6868
};
6969

70+
struct lvl03param {
71+
using domainP = lvl0param;
72+
using targetP = lvl3param;
73+
#ifdef USE_KEY_BUNDLE
74+
static constexpr uint32_t Addends = 2;
75+
#else
76+
static constexpr uint32_t Addends = 1;
77+
#endif
78+
};
79+
7080
struct lvlh2param {
7181
using domainP = lvlhalfparam;
7282
using targetP = lvl2param;

include/params/128bit.hpp

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -153,35 +153,36 @@ struct AHlvl2param {
153153
static constexpr std::uint32_t B̅gₐbit = baseP::B̅gₐbit;
154154
};
155155

156-
// lvl3param with 64-bit Torus and non-trivial Double Decomposition
157-
// Double decomposition constraint: l * Bgbit + l̅ * B̅gbit <= 64
158-
// Using l=2, Bgbit=16, l̅=2, B̅gbit=16: 2*16 + 2*16 = 64 bits (fully utilized)
156+
// lvl3param with 128-bit Torus and non-trivial Double Decomposition
157+
// Double decomposition constraint: l * Bgbit + l̅ * B̅gbit <= 128
158+
// Using l=4, Bgbit=16, l̅=4, B̅gbit=16: 4*16 + 4*16 = 128 bits (fully utilized)
159159
struct lvl3param {
160160
static constexpr int32_t key_value_max = 1;
161161
static constexpr int32_t key_value_min = -1;
162162
static const std::uint32_t nbit = 12; // dimension must be a power of 2 for
163163
// ease of polynomial multiplication.
164164
static constexpr std::uint32_t n = 1 << nbit; // dimension = 4096
165165
static constexpr std::uint32_t k = 1;
166-
static constexpr std::uint32_t lₐ = 2;
167-
static constexpr std::uint32_t l = 2;
166+
static constexpr std::uint32_t lₐ = 4;
167+
static constexpr std::uint32_t l = 4;
168168
static constexpr std::uint32_t Bgbit = 16;
169169
static constexpr std::uint32_t Bgₐbit = 16;
170170
static constexpr uint32_t Bg = 1U << Bgbit;
171171
static constexpr uint32_t Bgₐ = 1U << Bgₐbit;
172172
static constexpr ErrorDistribution errordist =
173173
ErrorDistribution::ModularGaussian;
174-
static const inline double α = std::pow(2.0, -51); // fresh noise
175-
using T = uint64_t; // Torus representation
176-
static constexpr T μ = 1ULL << 61;
174+
static const inline double α = std::pow(2.0, -105); // fresh noise
175+
using T = __uint128_t; // Torus representation
176+
static constexpr T μ = static_cast<T>(1) << 125;
177177
static constexpr uint32_t plain_modulusbit = 31;
178-
static constexpr uint64_t plain_modulus = 1ULL << plain_modulusbit;
179-
static constexpr double Δ = 1ULL << (64 - plain_modulusbit - 1);
178+
static constexpr T plain_modulus = static_cast<T>(1) << plain_modulusbit;
179+
static constexpr double Δ =
180+
static_cast<double>(static_cast<T>(1) << (128 - plain_modulusbit - 1));
180181
// Double Decomposition (bivariate representation) parameters
181182
// Non-trivial values for testing actual double decomposition
182-
// Constraint: l * Bgbit + l̅ * B̅gbit <= 64
183-
static constexpr std::uint32_t l̅ = 2; // auxiliary decomposition levels
184-
static constexpr std::uint32_t l̅ₐ = 2;
183+
// Constraint: l * Bgbit + l̅ * B̅gbit <= 128
184+
static constexpr std::uint32_t l̅ = 4; // auxiliary decomposition levels
185+
static constexpr std::uint32_t l̅ₐ = 4;
185186
static constexpr std::uint32_t B̅gbit = 16; // 2^16 base for auxiliary
186187
static constexpr std::uint32_t B̅gₐbit = 16;
187188
};

include/tlwe.hpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,11 @@ template <class P>
1313
void tlweSymEncrypt(TLWE<P> &res, const typename P::T p, const double α,
1414
const Key<P> &key)
1515
{
16-
std::uniform_int_distribution<typename P::T> Torusdist(
17-
0, std::numeric_limits<typename P::T>::max());
1816
res = {};
1917
res[P::k * P::n] = ModularGaussian<P>(p, α);
2018
for (int k = 0; k < P::k; k++)
2119
for (int i = 0; i < P::n; i++) {
22-
res[k * P::n + i] = Torusdist(generator);
20+
res[k * P::n + i] = UniformTorusRandom<P>();
2321
res[P::k * P::n] += res[k * P::n + i] * key[k * P::n + i];
2422
}
2523
}
@@ -122,7 +120,7 @@ typename P::T tlweSymIntDecrypt(const TLWE<P> &c, const Key<P> &key)
122120
constexpr double Δ =
123121
2 *
124122
static_cast<double>(
125-
1ULL << (std::numeric_limits<typename P::T>::digits - 1)) /
123+
static_cast<typename P::T>(1) << (std::numeric_limits<typename P::T>::digits - 1)) /
126124
plain_modulus;
127125
const typename P::T phase = tlweSymPhase<P>(c, key);
128126
typename P::T res = static_cast<typename P::T>(std::round(phase / Δ));

0 commit comments

Comments
 (0)