Skip to content

Commit 2a861c1

Browse files
nindanaotoclaude
andcommitted
Fix h̅gen shift formula bug in Double Decomposition
The h̅gen and nonceh̅gen functions had an incorrect shift formula that used (i+1)*B̅gbit instead of i*B̅gbit. This caused ExternalProductDD to fail with ~50% error rate because the gadget values h[i]*h̅[j] were off by a factor of 2^B̅gbit. The correct formula is: - h̅[0] = 1 (j=0 means no auxiliary shift) - h̅[j] = 2^(width - j*B̅gbit) for j > 0 This matches the decomposition shift formula: width - (i+1)*Bgbit - j*B̅gbit When l̅=1 (trivial auxiliary decomposition), h̅[0]=1 correctly reduces double decomposition to standard decomposition. Also adds externalproductdoubledecomposition test to verify correctness. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 894ed4c commit 2a861c1

File tree

10 files changed

+501
-11
lines changed

10 files changed

+501
-11
lines changed

include/params.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,10 @@ using DecomposedPolynomial = std::array<Polynomial<P>, P::l>;
118118
template <class P>
119119
using DecomposedNoncePolynomial = std::array<Polynomial<P>, P::lₐ>;
120120
template <class P>
121+
using DecomposedPolynomialDD = std::array<Polynomial<P>, P::l * P::l̅>;
122+
template <class P>
123+
using DecomposedNoncePolynomialDD = std::array<Polynomial<P>, P::lₐ * P::l̅ₐ>;
124+
template <class P>
121125
using DecomposedPolynomialNTT = std::array<PolynomialNTT<P>, P::l>;
122126
template <class P>
123127
using DecomposedNoncePolynomialNTT = std::array<PolynomialNTT<P>, P::lₐ>;

include/params/128bit.hpp

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,41 @@ struct AHlvl2param {
153153
static constexpr std::uint32_t B̅gₐbit = baseP::B̅gₐbit;
154154
};
155155

156+
// New lvl3param with 128-bit Torus and non-trivial Double Decomposition
157+
// Double decomposition constraint: l * Bgbit + l̅ * B̅gbit <= 128
158+
// Using l=4, Bgbit=16, l̅=4, B̅gbit=16: 4*16 + 4*16 = 128 bits (fully utilized)
156159
struct lvl3param {
160+
static constexpr int32_t key_value_max = 1;
161+
static constexpr int32_t key_value_min = -1;
162+
static const std::uint32_t nbit = 12; // dimension must be a power of 2 for
163+
// ease of polynomial multiplication.
164+
static constexpr std::uint32_t n = 1 << nbit; // dimension = 4096
165+
static constexpr std::uint32_t k = 1;
166+
static constexpr std::uint32_t lₐ = 4;
167+
static constexpr std::uint32_t l = 4;
168+
static constexpr std::uint32_t Bgbit = 16;
169+
static constexpr std::uint32_t Bgₐbit = 16;
170+
static constexpr uint32_t Bg = 1U << Bgbit;
171+
static constexpr uint32_t Bgₐ = 1U << Bgₐbit;
172+
static constexpr ErrorDistribution errordist =
173+
ErrorDistribution::ModularGaussian;
174+
static const inline double α = std::pow(2.0, -105); // fresh noise
175+
using T = __uint128_t; // Torus representation
176+
static constexpr T μ = static_cast<T>(1) << 125;
177+
static constexpr uint32_t plain_modulusbit = 31;
178+
static constexpr __uint128_t plain_modulus = static_cast<T>(1) << plain_modulusbit;
179+
static constexpr double Δ =
180+
static_cast<double>(static_cast<T>(1) << (128 - plain_modulusbit - 1));
181+
// Double Decomposition (bivariate representation) parameters
182+
// Non-trivial values for testing actual double decomposition
183+
// Constraint: l * Bgbit + l̅ * B̅gbit <= 128
184+
static constexpr std::uint32_t l̅ = 4; // auxiliary decomposition levels
185+
static constexpr std::uint32_t l̅ₐ = 4;
186+
static constexpr std::uint32_t B̅gbit = 16; // 2^16 base for auxiliary
187+
static constexpr std::uint32_t B̅gₐbit = 16;
188+
};
189+
190+
struct lvl4param {
157191
static constexpr int32_t key_value_max = 1;
158192
static constexpr int32_t key_value_min = -1;
159193
static const std::uint32_t nbit = 13; // dimension must be a power of 2 for
@@ -175,7 +209,7 @@ struct lvl3param {
175209
static constexpr uint64_t plain_modulus = 1ULL << plain_modulusbit;
176210
static constexpr double Δ = 1ULL << (64 - plain_modulusbit - 1);
177211
// Double Decomposition (bivariate representation) parameters
178-
// For now, set to trivial values (no actual second decomposition)
212+
// Trivial values (no actual second decomposition)
179213
static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels
180214
static constexpr std::uint32_t l̅ₐ = l̅;
181215
static constexpr std::uint32_t B̅gbit =
@@ -270,3 +304,12 @@ struct lvl31param {
270304
using domainP = lvl3param;
271305
using targetP = lvl1param;
272306
};
307+
308+
struct lvl41param {
309+
static constexpr std::uint32_t t = 7; // number of addition in keyswitching
310+
static constexpr std::uint32_t basebit =
311+
2; // how many bit should be encrypted in keyswitching key
312+
static const inline double α = lvl1param::α; // key noise
313+
using domainP = lvl4param;
314+
using targetP = lvl1param;
315+
};

include/params/CGGI16.hpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,9 @@ struct lvl3param {
127127
static constexpr std::uint32_t B̅gₐbit = B̅gbit;
128128
};
129129

130+
// Dummy
131+
using lvl4param = lvl3param;
132+
130133
struct lvl10param {
131134
static constexpr std::uint32_t t = 8;
132135
static constexpr std::uint32_t basebit = 2;
@@ -212,4 +215,7 @@ struct lvl31param {
212215
static const inline double α = lvl1param::α; // key noise
213216
using domainP = lvl3param;
214217
using targetP = lvl1param;
215-
};
218+
};
219+
220+
// Dummy
221+
using lvl41param = lvl31param;

include/params/CGGI19.hpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,9 @@ struct lvl3param {
125125
static constexpr std::uint32_t B̅gₐbit = B̅gbit;
126126
};
127127

128+
// Dummy
129+
using lvl4param = lvl3param;
130+
128131
// Dummy
129132
struct lvl11param {
130133
static constexpr std::uint32_t t = 0; // number of addition in keyswitching
@@ -211,4 +214,7 @@ struct lvl31param {
211214
static const inline double α = lvl1param::α; // key noise
212215
using domainP = lvl3param;
213216
using targetP = lvl1param;
214-
};
217+
};
218+
219+
// Dummy
220+
using lvl41param = lvl31param;

include/params/compress.hpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,9 @@ struct lvl3param {
140140
static constexpr std::uint32_t B̅gₐbit = B̅gbit;
141141
};
142142

143+
// Dummy
144+
using lvl4param = lvl3param;
145+
143146
// Key Switching parameters
144147
struct lvl10param {
145148
static constexpr std::uint32_t t = 5; // number of addition in keyswitching
@@ -213,4 +216,7 @@ struct lvl31param {
213216
2; // how many bit should be encrypted in keyswitching key
214217
using domainP = lvl3param;
215218
using targetP = lvl1param;
216-
};
219+
};
220+
221+
// Dummy
222+
using lvl41param = lvl31param;

include/params/concrete.hpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,9 @@ struct lvl3param {
201201
static constexpr std::uint32_t B̅gₐbit = B̅gbit;
202202
};
203203

204+
// Dummy
205+
using lvl4param = lvl3param;
206+
204207
// Key Switching parameters
205208
struct lvl10param {
206209
static constexpr std::uint32_t t = 5; // number of addition in keyswitching
@@ -290,4 +293,7 @@ struct lvl31param {
290293
static const inline double α = lvl1param::α; // key noise
291294
using domainP = lvl3param;
292295
using targetP = lvl1param;
293-
};
296+
};
297+
298+
// Dummy
299+
using lvl41param = lvl31param;

include/params/ternary.hpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,9 @@ struct lvl3param {
131131
static constexpr std::uint32_t B̅gₐbit = B̅gbit;
132132
};
133133

134+
// Dummy
135+
using lvl4param = lvl3param;
136+
134137
// Key Switching parameters
135138
struct lvl10param {
136139
static constexpr std::uint32_t t = 7; // number of addition in keyswitching
@@ -219,4 +222,7 @@ struct lvl31param {
219222
static const inline double α = lvl1param::α; // key noise
220223
using domainP = lvl3param;
221224
using targetP = lvl1param;
222-
};
225+
};
226+
227+
// Dummy
228+
using lvl41param = lvl31param;

include/params/tfhe-rs.hpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,9 @@ struct lvl3param {
137137
static constexpr std::uint32_t B̅gₐbit = B̅gbit;
138138
};
139139

140+
// Dummy
141+
using lvl4param = lvl3param;
142+
140143
// Key Switching parameters
141144
struct lvl10param {
142145
static constexpr std::uint32_t t = 3; // number of addition in keyswitching
@@ -225,4 +228,7 @@ struct lvl31param {
225228
static const inline double α = lvl1param::α; // key noise
226229
using domainP = lvl3param;
227230
using targetP = lvl1param;
228-
};
231+
};
232+
233+
// Dummy
234+
using lvl41param = lvl31param;

include/trgsw.hpp

Lines changed: 153 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,102 @@ inline void NonceDecomposition(DecomposedNoncePolynomial<P> &decpoly,
117117
}
118118
}
119119

120+
// Double Decomposition (bivariate representation) for external product
121+
// Decomposes each coefficient a into l*l̅ components such that:
122+
// a ≈ Σᵢ Σⱼ aᵢⱼ * Bg^(l-i) * B̅g^(l̅-j)
123+
// When l̅=1 (j=0 only), this reduces to standard decomposition.
124+
template <class P>
125+
constexpr typename P::T ddoffsetgen()
126+
{
127+
typename P::T offset = 0;
128+
for (int i = 1; i <= P::l; i++)
129+
for (int j = 0; j < P::l̅; j++)
130+
offset += (static_cast<typename P::T>(P::Bg) / 2) *
131+
(static_cast<typename P::T>(1)
132+
<< (std::numeric_limits<typename P::T>::digits -
133+
i * P::Bgbit - j * P::B̅gbit));
134+
return offset;
135+
}
136+
137+
template <class P>
138+
inline void DoubleDecomposition(DecomposedPolynomialDD<P> &decpoly,
139+
const Polynomial<P> &poly)
140+
{
141+
constexpr typename P::T offset = ddoffsetgen<P>();
142+
// Remaining bits after decomposition
143+
constexpr int remaining_bits = std::numeric_limits<typename P::T>::digits -
144+
P::l * P::Bgbit - P::l̅ * P::B̅gbit;
145+
// roundoffset is 0 if no remaining bits, otherwise 2^(remaining_bits-1)
146+
constexpr typename P::T roundoffset =
147+
remaining_bits > 0
148+
? (static_cast<typename P::T>(1) << (remaining_bits - 1))
149+
: static_cast<typename P::T>(0);
150+
constexpr typename P::T maskBg =
151+
static_cast<typename P::T>((1ULL << P::Bgbit) - 1);
152+
constexpr typename P::T halfBg = (1ULL << (P::Bgbit - 1));
153+
154+
for (int n = 0; n < P::n; n++) {
155+
typename P::T a = poly[n] + offset + roundoffset;
156+
for (int i = 0; i < P::l; i++) {
157+
for (int j = 0; j < P::l̅; j++) {
158+
// Shift to get the (i,j)-th digit in base Bg (after B̅g grouping)
159+
// When l̅=1 (j=0 only), this reduces to standard decomposition
160+
const int shift = std::numeric_limits<typename P::T>::digits -
161+
(i + 1) * P::Bgbit - j * P::B̅gbit;
162+
decpoly[i * P::l̅ + j][n] =
163+
static_cast<std::make_signed_t<typename P::T>>(
164+
((a >> shift) & maskBg) - halfBg);
165+
}
166+
}
167+
}
168+
}
169+
170+
template <class P>
171+
constexpr typename P::T nonceddoffsetgen()
172+
{
173+
typename P::T offset = 0;
174+
for (int i = 1; i <= P::lₐ; i++)
175+
for (int j = 0; j < P::l̅ₐ; j++)
176+
offset += (static_cast<typename P::T>(P::Bgₐ) / 2) *
177+
(static_cast<typename P::T>(1)
178+
<< (std::numeric_limits<typename P::T>::digits -
179+
i * P::Bgₐbit - j * P::B̅gₐbit));
180+
return offset;
181+
}
182+
183+
template <class P>
184+
inline void NonceDoubleDecomposition(DecomposedNoncePolynomialDD<P> &decpoly,
185+
const Polynomial<P> &poly)
186+
{
187+
constexpr typename P::T offset = nonceddoffsetgen<P>();
188+
// Remaining bits after decomposition
189+
constexpr int remaining_bits = std::numeric_limits<typename P::T>::digits -
190+
P::lₐ * P::Bgₐbit - P::l̅ₐ * P::B̅gₐbit;
191+
// roundoffset is 0 if no remaining bits, otherwise 2^(remaining_bits-1)
192+
constexpr typename P::T roundoffset =
193+
remaining_bits > 0
194+
? (static_cast<typename P::T>(1) << (remaining_bits - 1))
195+
: static_cast<typename P::T>(0);
196+
constexpr typename P::T maskBg =
197+
static_cast<typename P::T>((1ULL << P::Bgₐbit) - 1);
198+
constexpr typename P::T halfBg = (1ULL << (P::Bgₐbit - 1));
199+
200+
for (int n = 0; n < P::n; n++) {
201+
typename P::T a = poly[n] + offset + roundoffset;
202+
for (int i = 0; i < P::lₐ; i++) {
203+
for (int j = 0; j < P::l̅ₐ; j++) {
204+
// Shift to get the (i,j)-th digit
205+
// When l̅ₐ=1 (j=0 only), this reduces to standard decomposition
206+
const int shift = std::numeric_limits<typename P::T>::digits -
207+
(i + 1) * P::Bgₐbit - j * P::B̅gₐbit;
208+
decpoly[i * P::l̅ₐ + j][n] =
209+
static_cast<std::make_signed_t<typename P::T>>(
210+
((a >> shift) & maskBg) - halfBg);
211+
}
212+
}
213+
}
214+
}
215+
120216
template <class P>
121217
void Decomposition(DecomposedPolynomialNTT<P> &decpolyntt,
122218
const Polynomial<P> &poly)
@@ -195,6 +291,53 @@ void ExternalProduct(TRLWE<P> &res, const TRLWE<P> &trlwe,
195291
for (int k = 0; k < P::k + 1; k++) TwistFFT<P>(res[k], restrlwefft[k]);
196292
}
197293

294+
// External product with Double Decomposition (bivariate representation)
295+
// Uses the full TRGSW structure with l*l̅ rows for "b" block and k*lₐ*l̅ₐ rows
296+
// for "a" blocks
297+
template <class P>
298+
void ExternalProductDD(TRLWE<P> &res, const TRLWE<P> &trlwe,
299+
const TRGSWFFT<P> &trgswfft)
300+
{
301+
alignas(64) PolynomialInFD<P> decpolyfft;
302+
alignas(64) TRLWEInFD<P> restrlwefft;
303+
304+
// Handle "a" polynomials (indices 0 to k-1 in TRLWE)
305+
// Uses NonceDoubleDecomposition with lₐ*l̅ₐ levels
306+
{
307+
alignas(64) DecomposedNoncePolynomialDD<P> decpoly;
308+
NonceDoubleDecomposition<P>(decpoly, trlwe[0]);
309+
TwistIFFT<P>(decpolyfft, decpoly[0]);
310+
for (int m = 0; m < P::k + 1; m++)
311+
MulInFD<P::n>(restrlwefft[m], decpolyfft, trgswfft[0][m]);
312+
for (int i = 1; i < P::lₐ * P::l̅ₐ; i++) {
313+
TwistIFFT<P>(decpolyfft, decpoly[i]);
314+
for (int m = 0; m < P::k + 1; m++)
315+
FMAInFD<P::n>(restrlwefft[m], decpolyfft, trgswfft[i][m]);
316+
}
317+
for (int k = 1; k < P::k; k++) {
318+
NonceDoubleDecomposition<P>(decpoly, trlwe[k]);
319+
for (int i = 0; i < P::lₐ * P::l̅ₐ; i++) {
320+
TwistIFFT<P>(decpolyfft, decpoly[i]);
321+
for (int m = 0; m < P::k + 1; m++)
322+
FMAInFD<P::n>(restrlwefft[m], decpolyfft,
323+
trgswfft[i + k * P::lₐ * P::l̅ₐ][m]);
324+
}
325+
}
326+
}
327+
328+
// Handle "b" polynomial (index k in TRLWE)
329+
// Uses DoubleDecomposition with l*l̅ levels
330+
alignas(64) DecomposedPolynomialDD<P> decpoly;
331+
DoubleDecomposition<P>(decpoly, trlwe[P::k]);
332+
for (int i = 0; i < P::l * P::l̅; i++) {
333+
TwistIFFT<P>(decpolyfft, decpoly[i]);
334+
for (int m = 0; m < P::k + 1; m++)
335+
FMAInFD<P::n>(restrlwefft[m], decpolyfft,
336+
trgswfft[i + P::k * P::lₐ * P::l̅ₐ][m]);
337+
}
338+
for (int k = 0; k < P::k + 1; k++) TwistFFT<P>(res[k], restrlwefft[k]);
339+
}
340+
198341
template <class P>
199342
void ExternalProduct(TRLWE<P> &res, const Polynomial<P> &poly,
200343
const HalfTRGSWFFT<P> &halftrgswfft)
@@ -453,23 +596,29 @@ constexpr std::array<typename P::T, P::lₐ> noncehgen()
453596
}
454597

455598
// Auxiliary h generation for Double Decomposition (bivariate representation)
599+
// h̅[j] values are used to construct gadget values h[i] * h̅[j] = 2^(width - (i+1)*Bgbit - j*B̅gbit)
600+
// For j=0: no auxiliary shift, so h̅[0] = 1
601+
// For j>0: h̅[j] = 2^(width - j*B̅gbit) which combines with h[i] via modular multiplication
456602
template <class P>
457603
constexpr std::array<typename P::T, P::l̅> h̅gen()
458604
{
459605
std::array<typename P::T, P::l̅> h̅{};
460-
for (int i = 0; i < P::l̅; i++)
606+
h̅[0] = 1; // j=0 means no auxiliary shift
607+
for (int i = 1; i < P::l̅; i++)
461608
h̅[i] = 1ULL << (std::numeric_limits<typename P::T>::digits -
462-
(i + 1) * P::B̅gbit);
609+
i * P::B̅gbit);
463610
return h̅;
464611
}
465612

613+
// Auxiliary h generation for nonce part of TRGSW with Double Decomposition
466614
template <class P>
467615
constexpr std::array<typename P::T, P::l̅ₐ> nonceh̅gen()
468616
{
469617
std::array<typename P::T, P::l̅ₐ> h̅{};
470-
for (int i = 0; i < P::l̅ₐ; i++)
618+
h̅[0] = 1; // j=0 means no auxiliary shift
619+
for (int i = 1; i < P::l̅ₐ; i++)
471620
h̅[i] = 1ULL << (std::numeric_limits<typename P::T>::digits -
472-
(i + 1) * P::B̅gₐbit);
621+
i * P::B̅gₐbit);
473622
return h̅;
474623
}
475624

0 commit comments

Comments
 (0)