Skip to content

Commit 5ddc7d8

Browse files
nindanaotoclaude
andcommitted
Add Double Decomposition (bivariate representation) infrastructure
Implement foundational support for the Double Decomposition technique from "Revisiting Key Decomposition Techniques for FHE" (ePrint 2023/771). Changes: - Add auxiliary decomposition parameters (l̅, l̅ₐ, B̅gbit, B̅gₐbit) to all parameter structs with trivial default values (l̅=1, B̅g=2^digits) - Update TRGSW type definitions to use k*lₐ*l̅ₐ + l*l̅ row structure - Add h̅gen() and nonceh̅gen() for auxiliary h value generation - Modify trgswhadd, halftrgswhadd, trgswhoneadd for double decomposition - Update ApplyFFT2trgsw, ApplyNTT2trgsw, ApplyRAINTT2trgsw loop bounds With trivial values, behavior is unchanged (h̅[0]=1, sizes identical). 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 251b531 commit 5ddc7d8

File tree

5 files changed

+138
-33
lines changed

5 files changed

+138
-33
lines changed

include/params.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -138,17 +138,17 @@ template <class P>
138138
using TRLWERAINTT = std::array<PolynomialRAINTT<P>, P::k + 1>;
139139

140140
template <class P>
141-
using TRGSW = std::array<TRLWE<P>, P::k * P::lₐ + P::l>;
141+
using TRGSW = std::array<TRLWE<P>, P::k * P::lₐ * P::l̅ₐ + P::l * P::l̅>;
142142
template <class P>
143-
using HalfTRGSW = std::array<TRLWE<P>, P::l>;
143+
using HalfTRGSW = std::array<TRLWE<P>, P::l * P::l̅>;
144144
template <class P>
145-
using TRGSWFFT = aligned_array<TRLWEInFD<P>, P::k * P::lₐ + P::l>;
145+
using TRGSWFFT = aligned_array<TRLWEInFD<P>, P::k * P::lₐ * P::l̅ₐ + P::l * P::l̅>;
146146
template <class P>
147-
using HalfTRGSWFFT = aligned_array<TRLWEInFD<P>, P::l>;
147+
using HalfTRGSWFFT = aligned_array<TRLWEInFD<P>, P::l * P::l̅>;
148148
template <class P>
149-
using TRGSWNTT = std::array<TRLWENTT<P>, P::k * P::lₐ + P::l>;
149+
using TRGSWNTT = std::array<TRLWENTT<P>, P::k * P::lₐ * P::l̅ₐ + P::l * P::l̅>;
150150
template <class P>
151-
using TRGSWRAINTT = std::array<TRLWERAINTT<P>, P::k * P::lₐ + P::l>;
151+
using TRGSWRAINTT = std::array<TRLWERAINTT<P>, P::k * P::lₐ * P::l̅ₐ + P::l * P::l̅>;
152152

153153
#ifdef USE_KEY_BUNDLE
154154
template <class P>

include/params/128bit.hpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,13 @@ struct lvl1param {
6262
static constexpr double Δ =
6363
static_cast<double>(1ULL << std::numeric_limits<T>::digits) /
6464
plain_modulus;
65+
// Double Decomposition (bivariate representation) parameters
66+
// For now, set to trivial values (no actual second decomposition)
67+
static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels
68+
static constexpr std::uint32_t l̅ₐ = l̅;
69+
static constexpr std::uint32_t B̅gbit =
70+
std::numeric_limits<T>::digits; // full coefficient width
71+
static constexpr std::uint32_t B̅gₐbit = B̅gbit;
6572
};
6673

6774
struct AHlvl1param {
@@ -83,6 +90,11 @@ struct AHlvl1param {
8390
static constexpr std::make_signed_t<T> μ = baseP::μ;
8491
static constexpr uint32_t plain_modulus = baseP::plain_modulus;
8592
static constexpr double Δ = baseP::Δ;
93+
// Double Decomposition parameters inherited from baseP
94+
static constexpr std::uint32_t l̅ = baseP::l̅;
95+
static constexpr std::uint32_t l̅ₐ = baseP::l̅ₐ;
96+
static constexpr std::uint32_t B̅gbit = baseP::B̅gbit;
97+
static constexpr std::uint32_t B̅gₐbit = baseP::B̅gₐbit;
8698
};
8799

88100
struct lvl2param {
@@ -106,6 +118,13 @@ struct lvl2param {
106118
static constexpr uint32_t plain_modulus = 8;
107119
static constexpr double Δ =
108120
static_cast<double>(1ULL << (std::numeric_limits<T>::digits - 4));
121+
// Double Decomposition (bivariate representation) parameters
122+
// For now, set to trivial values (no actual second decomposition)
123+
static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels
124+
static constexpr std::uint32_t l̅ₐ = l̅;
125+
static constexpr std::uint32_t B̅gbit =
126+
std::numeric_limits<T>::digits; // full coefficient width
127+
static constexpr std::uint32_t B̅gₐbit = B̅gbit;
109128
};
110129

111130
struct AHlvl2param {
@@ -127,6 +146,11 @@ struct AHlvl2param {
127146
static constexpr std::make_signed_t<T> μ = baseP::μ;
128147
static constexpr uint32_t plain_modulus = baseP::plain_modulus;
129148
static constexpr double Δ = baseP::Δ;
149+
// Double Decomposition parameters inherited from baseP
150+
static constexpr std::uint32_t l̅ = baseP::l̅;
151+
static constexpr std::uint32_t l̅ₐ = baseP::l̅ₐ;
152+
static constexpr std::uint32_t B̅gbit = baseP::B̅gbit;
153+
static constexpr std::uint32_t B̅gₐbit = baseP::B̅gₐbit;
130154
};
131155

132156
struct lvl3param {
@@ -150,6 +174,13 @@ struct lvl3param {
150174
static constexpr uint32_t plain_modulusbit = 31;
151175
static constexpr uint64_t plain_modulus = 1ULL << plain_modulusbit;
152176
static constexpr double Δ = 1ULL << (64 - plain_modulusbit - 1);
177+
// Double Decomposition (bivariate representation) parameters
178+
// For now, set to trivial values (no actual second decomposition)
179+
static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels
180+
static constexpr std::uint32_t l̅ₐ = l̅;
181+
static constexpr std::uint32_t B̅gbit =
182+
std::numeric_limits<T>::digits; // full coefficient width
183+
static constexpr std::uint32_t B̅gₐbit = B̅gbit;
153184
};
154185

155186
// Key Switching parameters

include/params/CGGI16.hpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,13 @@ struct lvl1param {
5959
static constexpr double Δ =
6060
static_cast<double>(1ULL << std::numeric_limits<T>::digits) /
6161
plain_modulus;
62+
// Double Decomposition (bivariate representation) parameters
63+
// For now, set to trivial values (no actual second decomposition)
64+
static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels
65+
static constexpr std::uint32_t l̅ₐ = l̅;
66+
static constexpr std::uint32_t B̅gbit =
67+
std::numeric_limits<T>::digits; // full coefficient width
68+
static constexpr std::uint32_t B̅gₐbit = B̅gbit;
6269
};
6370

6471
struct lvl2param {
@@ -80,6 +87,13 @@ struct lvl2param {
8087
static constexpr T μ = 1ULL << 61;
8188
static constexpr uint32_t plain_modulus = 8;
8289
static constexpr double Δ = μ;
90+
// Double Decomposition (bivariate representation) parameters
91+
// For now, set to trivial values (no actual second decomposition)
92+
static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels
93+
static constexpr std::uint32_t l̅ₐ = l̅;
94+
static constexpr std::uint32_t B̅gbit =
95+
std::numeric_limits<T>::digits; // full coefficient width
96+
static constexpr std::uint32_t B̅gₐbit = B̅gbit;
8397
};
8498

8599
// Dummy
@@ -104,6 +118,13 @@ struct lvl3param {
104118
static constexpr uint32_t plain_modulusbit = 31;
105119
static constexpr uint64_t plain_modulus = 1ULL << plain_modulusbit;
106120
static constexpr double Δ = 1ULL << (64 - plain_modulusbit - 1);
121+
// Double Decomposition (bivariate representation) parameters
122+
// For now, set to trivial values (no actual second decomposition)
123+
static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels
124+
static constexpr std::uint32_t l̅ₐ = l̅;
125+
static constexpr std::uint32_t B̅gbit =
126+
std::numeric_limits<T>::digits; // full coefficient width
127+
static constexpr std::uint32_t B̅gₐbit = B̅gbit;
107128
};
108129

109130
struct lvl10param {

include/params/CGGI19.hpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,13 @@ struct lvl1param {
5959
static constexpr double Δ =
6060
static_cast<double>(1ULL << std::numeric_limits<T>::digits) /
6161
plain_modulus;
62+
// Double Decomposition (bivariate representation) parameters
63+
// For now, set to trivial values (no actual second decomposition)
64+
static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels
65+
static constexpr std::uint32_t l̅ₐ = l̅;
66+
static constexpr std::uint32_t B̅gbit =
67+
std::numeric_limits<T>::digits; // full coefficient width
68+
static constexpr std::uint32_t B̅gₐbit = B̅gbit;
6269
};
6370

6471
struct lvl2param {
@@ -78,6 +85,13 @@ struct lvl2param {
7885
static constexpr T μ = 1ULL << 61;
7986
static constexpr uint32_t plain_modulus = 8;
8087
static constexpr double Δ = μ;
88+
// Double Decomposition (bivariate representation) parameters
89+
// For now, set to trivial values (no actual second decomposition)
90+
static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels
91+
static constexpr std::uint32_t l̅ₐ = l̅;
92+
static constexpr std::uint32_t B̅gbit =
93+
std::numeric_limits<T>::digits; // full coefficient width
94+
static constexpr std::uint32_t B̅gₐbit = B̅gbit;
8195
};
8296

8397
// Dummy
@@ -102,6 +116,13 @@ struct lvl3param {
102116
static constexpr uint32_t plain_modulusbit = 31;
103117
static constexpr uint64_t plain_modulus = 1ULL << plain_modulusbit;
104118
static constexpr double Δ = 1ULL << (64 - plain_modulusbit - 1);
119+
// Double Decomposition (bivariate representation) parameters
120+
// For now, set to trivial values (no actual second decomposition)
121+
static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels
122+
static constexpr std::uint32_t l̅ₐ = l̅;
123+
static constexpr std::uint32_t B̅gbit =
124+
std::numeric_limits<T>::digits; // full coefficient width
125+
static constexpr std::uint32_t B̅gₐbit = B̅gbit;
105126
};
106127

107128
// Dummy

include/trgsw.hpp

Lines changed: 59 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ template <class P>
353353
TRGSWFFT<P> ApplyFFT2trgsw(const TRGSW<P> &trgsw)
354354
{
355355
alignas(64) TRGSWFFT<P> trgswfft;
356-
for (int i = 0; i < P::k * P::lₐ + P::l; i++)
356+
for (int i = 0; i < P::k * P::lₐ * P::l̅ₐ + P::l * P::l̅; i++)
357357
for (int j = 0; j < (P::k + 1); j++)
358358
TwistIFFT<P>(trgswfft[i][j], trgsw[i][j]);
359359
return trgswfft;
@@ -362,7 +362,7 @@ TRGSWFFT<P> ApplyFFT2trgsw(const TRGSW<P> &trgsw)
362362
template <class P>
363363
void ApplyFFT2trgsw(TRGSWFFT<P> &trgswfft, const TRGSW<P> &trgsw)
364364
{
365-
for (int i = 0; i < P::k * P::lₐ + P::l; i++)
365+
for (int i = 0; i < P::k * P::lₐ * P::l̅ₐ + P::l * P::l̅; i++)
366366
for (int j = 0; j < (P::k + 1); j++)
367367
TwistIFFT<P>(trgswfft[i][j], trgsw[i][j]);
368368
}
@@ -371,7 +371,7 @@ template <class P>
371371
HalfTRGSWFFT<P> ApplyFFT2halftrgsw(const HalfTRGSW<P> &trgsw)
372372
{
373373
alignas(64) HalfTRGSWFFT<P> halftrgswfft;
374-
for (int i = 0; i < P::l; i++)
374+
for (int i = 0; i < P::l * P::l̅; i++)
375375
for (int j = 0; j < (P::k + 1); j++)
376376
TwistIFFT<P>(halftrgswfft[i][j], trgsw[i][j]);
377377
return halftrgswfft;
@@ -381,7 +381,7 @@ template <class P>
381381
TRGSWNTT<P> ApplyNTT2trgsw(const TRGSW<P> &trgsw)
382382
{
383383
TRGSWNTT<P> trgswntt;
384-
for (int i = 0; i < P::k * P::lₐ + P::l; i++)
384+
for (int i = 0; i < P::k * P::lₐ * P::l̅ₐ + P::l * P::l̅; i++)
385385
for (int j = 0; j < P::k + 1; j++)
386386
TwistINTT<P>(trgswntt[i][j], trgsw[i][j]);
387387
return trgswntt;
@@ -392,7 +392,7 @@ TRGSWRAINTT<P> ApplyRAINTT2trgsw(const TRGSW<P> &trgsw)
392392
{
393393
constexpr uint8_t remainder = ((P::nbit - 1) % 3) + 1;
394394
TRGSWRAINTT<P> trgswntt;
395-
for (int i = 0; i < P::k * P::lₐ + P::l; i++)
395+
for (int i = 0; i < P::k * P::lₐ * P::l̅ₐ + P::l * P::l̅; i++)
396396
for (int j = 0; j < P::k + 1; j++) {
397397
raintt::TwistINTT<typename P::T, P::nbit, true>(
398398
trgswntt[i][j], trgsw[i][j], (*raintttable)[1],
@@ -412,7 +412,7 @@ template <class P>
412412
TRGSWNTT<P> TRGSW2NTT(const TRGSW<P> &trgsw)
413413
{
414414
TRGSWNTT<P> trgswntt;
415-
for (int i = 0; i < P::k * P::lₐ + P::l; i++)
415+
for (int i = 0; i < P::k * P::lₐ * P::l̅ₐ + P::l * P::l̅; i++)
416416
for (int j = 0; j < P::k + 1; j++) {
417417
PolynomialNTT<P> temp;
418418
TwistINTT<P>(temp, trgsw[i][j]);
@@ -452,13 +452,38 @@ constexpr std::array<typename P::T, P::lₐ> noncehgen()
452452
return h;
453453
}
454454

455+
// Auxiliary h generation for Double Decomposition (bivariate representation)
456+
template <class P>
457+
constexpr std::array<typename P::T, P::l̅> h̅gen()
458+
{
459+
std::array<typename P::T, P::l̅> h̅{};
460+
for (int i = 0; i < P::l̅; i++)
461+
h̅[i] = 1ULL << (std::numeric_limits<typename P::T>::digits -
462+
(i + 1) * P::B̅gbit);
463+
return h̅;
464+
}
465+
466+
template <class P>
467+
constexpr std::array<typename P::T, P::l̅ₐ> nonceh̅gen()
468+
{
469+
std::array<typename P::T, P::l̅ₐ> h̅{};
470+
for (int i = 0; i < P::l̅ₐ; i++)
471+
h̅[i] = 1ULL << (std::numeric_limits<typename P::T>::digits -
472+
(i + 1) * P::B̅gₐbit);
473+
return h̅;
474+
}
475+
455476
template <class P>
456477
inline void halftrgswhadd(HalfTRGSW<P> &halftrgsw, const Polynomial<P> &p)
457478
{
458479
constexpr std::array<typename P::T, P::l> h = hgen<P>();
480+
constexpr std::array<typename P::T, P::l̅> h̅ = h̅gen<P>();
459481
for (int i = 0; i < P::l; i++) {
460-
for (int j = 0; j < P::n; j++) {
461-
halftrgsw[i][P::k][j] += static_cast<typename P::T>(p[j]) * h[i];
482+
for (int ī = 0; ī < P::l̅; ī++) {
483+
for (int j = 0; j < P::n; j++) {
484+
halftrgsw[i * P::l̅ + ī][P::k][j] +=
485+
static_cast<typename P::T>(p[j]) * h[i] * h̅[ī];
486+
}
462487
}
463488
}
464489
}
@@ -467,19 +492,26 @@ template <class P>
467492
inline void trgswhadd(TRGSW<P> &trgsw, const Polynomial<P> &p)
468493
{
469494
constexpr std::array<typename P::T, P::lₐ> nonceh = noncehgen<P>();
495+
constexpr std::array<typename P::T, P::l̅ₐ> nonceh̅ = nonceh̅gen<P>();
470496
for (int i = 0; i < P::lₐ; i++) {
471-
for (int k = 0; k < P::k; k++) {
472-
for (int j = 0; j < P::n; j++) {
473-
trgsw[i + k * P::lₐ][k][j] +=
474-
static_cast<typename P::T>(p[j]) * nonceh[i];
497+
for (int ī = 0; ī < P::l̅ₐ; ī++) {
498+
for (int k = 0; k < P::k; k++) {
499+
for (int j = 0; j < P::n; j++) {
500+
trgsw[(i * P::l̅ₐ + ī) + k * P::lₐ * P::l̅ₐ][k][j] +=
501+
static_cast<typename P::T>(p[j]) * nonceh[i] *
502+
nonceh̅[ī];
503+
}
475504
}
476505
}
477506
}
478507
constexpr std::array<typename P::T, P::l> h = hgen<P>();
508+
constexpr std::array<typename P::T, P::l̅> h̅ = h̅gen<P>();
479509
for (int i = 0; i < P::l; i++) {
480-
for (int j = 0; j < P::n; j++) {
481-
trgsw[i + P::k * P::lₐ][P::k][j] +=
482-
static_cast<typename P::T>(p[j]) * h[i];
510+
for (int ī = 0; ī < P::l̅; ī++) {
511+
for (int j = 0; j < P::n; j++) {
512+
trgsw[(i * P::l̅ + ī) + P::k * P::lₐ * P::l̅ₐ][P::k][j] +=
513+
static_cast<typename P::T>(p[j]) * h[i] * h̅[ī];
514+
}
483515
}
484516
}
485517
}
@@ -488,11 +520,19 @@ template <class P>
488520
inline void trgswhoneadd(TRGSW<P> &trgsw)
489521
{
490522
constexpr std::array<typename P::T, P::lₐ> nonceh = noncehgen<P>();
523+
constexpr std::array<typename P::T, P::l̅ₐ> nonceh̅ = nonceh̅gen<P>();
491524
for (int i = 0; i < P::lₐ; i++)
492-
for (int k = 0; k < P::k; k++) trgsw[i + k * P::lₐ][k][0] += nonceh[i];
525+
for (int ī = 0; ī < P::l̅ₐ; ī++)
526+
for (int k = 0; k < P::k; k++)
527+
trgsw[(i * P::l̅ₐ + ī) + k * P::lₐ * P::l̅ₐ][k][0] +=
528+
nonceh[i] * nonceh̅[ī];
493529

494530
constexpr std::array<typename P::T, P::l> h = hgen<P>();
495-
for (int i = 0; i < P::l; i++) trgsw[i + P::k * P::lₐ][P::k][0] += h[i];
531+
constexpr std::array<typename P::T, P::l̅> h̅ = h̅gen<P>();
532+
for (int i = 0; i < P::l; i++)
533+
for (int ī = 0; ī < P::l̅; ī++)
534+
trgsw[(i * P::l̅ + ī) + P::k * P::lₐ * P::l̅ₐ][P::k][0] +=
535+
h[i] * h̅[ī];
496536
}
497537

498538
template <class P>
@@ -525,24 +565,16 @@ template <class P>
525565
void halftrgswSymEncrypt(HalfTRGSW<P> &halftrgsw, const Polynomial<P> &p,
526566
const double α, const Key<P> &key)
527567
{
528-
constexpr std::array<typename P::T, P::l> h = hgen<P>();
529-
530568
for (TRLWE<P> &trlwe : halftrgsw) trlweSymEncryptZero<P>(trlwe, α, key);
531-
for (int i = 0; i < P::l; i++)
532-
for (int j = 0; j < P::n; j++)
533-
halftrgsw[i][P::k][j] += static_cast<typename P::T>(p[j]) * h[i];
569+
halftrgswhadd<P>(halftrgsw, p);
534570
}
535571

536572
template <class P>
537573
void halftrgswSymEncrypt(HalfTRGSW<P> &halftrgsw, const Polynomial<P> &p,
538574
const uint η, const Key<P> &key)
539575
{
540-
constexpr std::array<typename P::T, P::l> h = hgen<P>();
541-
542576
for (TRLWE<P> &trlwe : halftrgsw) trlweSymEncryptZero<P>(trlwe, η, key);
543-
for (int i = 0; i < P::l; i++)
544-
for (int j = 0; j < P::n; j++)
545-
halftrgsw[i][P::k][j] += static_cast<typename P::T>(p[j]) * h[i];
577+
halftrgswhadd<P>(halftrgsw, p);
546578
}
547579

548580
template <class P>

0 commit comments

Comments
 (0)