Skip to content

Commit cd40a31

Browse files
committed
Support changing Bg for nonce/non-nonce
1 parent e828e7a commit cd40a31

File tree

8 files changed

+69
-20
lines changed

8 files changed

+69
-20
lines changed

include/params/128bit.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@ struct lvl1param {
5050
static constexpr std::uint32_t lₐ = 3;
5151
static constexpr std::uint32_t l = 2;
5252
static constexpr std::uint32_t Bgbit = 6;
53+
static constexpr std::uint32_t Bgₐbit = 6;
5354
static constexpr std::uint32_t Bg = 1 << Bgbit;
55+
static constexpr std::uint32_t Bgₐ = 1 << Bgₐbit;
5456
static constexpr ErrorDistribution errordist =
5557
ErrorDistribution::ModularGaussian;
5658
static const inline double α = std::pow(2.0, -25); // fresh noise
@@ -72,7 +74,9 @@ struct lvl2param {
7274
static constexpr std::uint32_t lₐ = 4;
7375
static constexpr std::uint32_t l = 4;
7476
static constexpr std::uint32_t Bgbit = 9;
77+
static constexpr std::uint32_t Bgₐbit = 9;
7578
static constexpr std::uint32_t Bg = 1 << Bgbit;
79+
static constexpr std::uint32_t Bgₐ = 1 << Bgₐbit;
7680
static constexpr ErrorDistribution errordist =
7781
ErrorDistribution::ModularGaussian;
7882
static const inline double α = std::pow(2.0, -47); // fresh noise
@@ -93,7 +97,9 @@ struct lvl3param {
9397
static constexpr std::uint32_t lₐ = 4;
9498
static constexpr std::uint32_t l = 4;
9599
static constexpr std::uint32_t Bgbit = 9;
100+
static constexpr std::uint32_t Bgₐbit = 9;
96101
static constexpr std::uint32_t Bg = 1 << Bgbit;
102+
static constexpr std::uint32_t Bgₐ = 1 << Bgₐbit;
97103
static constexpr ErrorDistribution errordist =
98104
ErrorDistribution::ModularGaussian;
99105
static const inline double α = std::pow(2.0, -47); // fresh noise

include/params/CGGI16.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@ struct lvl1param {
4747
static constexpr std::uint32_t lₐ = 2;
4848
static constexpr std::uint32_t l = 2;
4949
static constexpr std::uint32_t Bgbit = 10;
50+
static constexpr std::uint32_t Bgₐbit = 10;
5051
static constexpr std::uint32_t Bg = 1 << Bgbit;
52+
static constexpr std::uint32_t Bgₐ = 1 << Bgₐbit;
5153
static constexpr ErrorDistribution errordist =
5254
ErrorDistribution::ModularGaussian;
5355
static const inline double α = 3.73e-9;
@@ -68,7 +70,9 @@ struct lvl2param {
6870
static constexpr std::uint32_t lₐ = 4;
6971
static constexpr std::uint32_t l = 4;
7072
static constexpr std::uint32_t Bgbit = 9;
73+
static constexpr std::uint32_t Bgₐbit = 9;
7174
static constexpr std::uint32_t Bg = 1 << Bgbit;
75+
static constexpr std::uint32_t Bgₐ = 1 << Bgₐbit;
7276
static constexpr ErrorDistribution errordist =
7377
ErrorDistribution::ModularGaussian;
7478
static const inline double α = std::pow(2.0, -44);
@@ -89,7 +93,9 @@ struct lvl3param {
8993
static constexpr std::uint32_t lₐ = 4;
9094
static constexpr std::uint32_t l = 4;
9195
static constexpr std::uint32_t Bgbit = 9;
96+
static constexpr std::uint32_t Bgₐbit = 9;
9297
static constexpr std::uint32_t Bg = 1 << Bgbit;
98+
static constexpr std::uint32_t Bgₐ = 1 << Bgₐbit;
9399
static constexpr ErrorDistribution errordist =
94100
ErrorDistribution::ModularGaussian;
95101
static const inline double α = std::pow(2.0, -47); // fresh noise

include/params/CGGI19.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,9 @@ struct lvl3param {
9090
static constexpr std::uint32_t k = 1;
9191
static constexpr std::uint32_t lₐ = 4;
9292
static constexpr std::uint32_t l = 4;
93+
static constexpr std::uint32_t Bgₐbit = 9;
9394
static constexpr std::uint32_t Bgbit = 9;
95+
static constexpr std::uint32_t Bgₐ = 1 << Bgₐbit;
9496
static constexpr std::uint32_t Bg = 1 << Bgbit;
9597
static constexpr ErrorDistribution errordist =
9698
ErrorDistribution::ModularGaussian;

include/params/compress.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ struct lvl1param {
5353
static constexpr std::uint32_t k = 2;
5454
static constexpr std::uint32_t lₐ = 2;
5555
static constexpr std::uint32_t l = 2;
56+
static constexpr std::uint32_t Bgₐbit = 8;
5657
static constexpr std::uint32_t Bgbit = 8;
58+
static constexpr std::uint32_t Bgₐ = 1 << Bgₐbit;
5759
static constexpr std::uint32_t Bg = 1 << Bgbit;
5860
static constexpr ErrorDistribution errordist =
5961
ErrorDistribution::CenteredBinomial;
@@ -78,7 +80,9 @@ struct lvl2param {
7880
static constexpr std::uint32_t k = 3;
7981
static constexpr std::uint32_t lₐ = 3;
8082
static constexpr std::uint32_t l = 3;
83+
static constexpr std::uint32_t Bgₐbit = 9;
8184
static constexpr std::uint32_t Bgbit = 9;
85+
static constexpr std::uint32_t Bgₐ = 1 << Bgₐbit;
8286
static constexpr std::uint32_t Bg = 1 << Bgbit;
8387
static constexpr ErrorDistribution errordist =
8488
ErrorDistribution::CenteredBinomial;
@@ -101,7 +105,9 @@ struct lvl3param {
101105
static constexpr std::uint32_t k = 1;
102106
static constexpr std::uint32_t lₐ = 4;
103107
static constexpr std::uint32_t l = 4;
108+
static constexpr std::uint32_t Bgₐbit = 9;
104109
static constexpr std::uint32_t Bgbit = 9;
110+
static constexpr std::uint32_t Bgₐ = 1 << Bgₐbit;
105111
static constexpr std::uint32_t Bg = 1 << Bgbit;
106112
static constexpr ErrorDistribution errordist =
107113
ErrorDistribution::ModularGaussian;

include/params/concrete.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,9 @@ struct lvl1param {
5555
static constexpr std::uint32_t k = 2;
5656
static constexpr std::uint32_t lₐ = 2;
5757
static constexpr std::uint32_t l = 2;
58+
static constexpr std::uint32_t Bgₐbit = 8;
5859
static constexpr std::uint32_t Bgbit = 8;
60+
static constexpr std::uint32_t Bgₐ = 1 << Bgₐbit;
5961
static constexpr std::uint32_t Bg = 1 << Bgbit;
6062
static constexpr ErrorDistribution errordist =
6163
ErrorDistribution::ModularGaussian;
@@ -78,7 +80,9 @@ struct lvl2param {
7880
static constexpr std::uint32_t k = 3;
7981
static constexpr std::uint32_t lₐ = 3;
8082
static constexpr std::uint32_t l = 3;
83+
static constexpr std::uint32_t Bgₐbit = 9;
8184
static constexpr std::uint32_t Bgbit = 9;
85+
static constexpr std::uint32_t Bgₐ = 1 << Bgₐbit;
8286
static constexpr std::uint32_t Bg = 1 << Bgbit;
8387
static constexpr ErrorDistribution errordist =
8488
ErrorDistribution::ModularGaussian;
@@ -98,7 +102,9 @@ struct lvl3param {
98102
static constexpr std::uint32_t k = 1;
99103
static constexpr std::uint32_t lₐ = 4;
100104
static constexpr std::uint32_t l = 4;
105+
static constexpr std::uint32_t Bgₐbit = 9;
101106
static constexpr std::uint32_t Bgbit = 9;
107+
static constexpr std::uint32_t Bgₐ = 1 << Bgₐbit;
102108
static constexpr std::uint32_t Bg = 1 << Bgbit;
103109
static constexpr ErrorDistribution errordist =
104110
ErrorDistribution::ModularGaussian;

include/params/ternary.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@ struct lvl1param {
5050
static constexpr std::uint32_t k = 1;
5151
static constexpr std::uint32_t lₐ = 3;
5252
static constexpr std::uint32_t l = 3;
53+
static constexpr std::uint32_t Bgₐbit = 6;
5354
static constexpr std::uint32_t Bgbit = 6;
55+
static constexpr std::uint32_t Bgₐ = 1 << Bgₐbit;
5456
static constexpr std::uint32_t Bg = 1 << Bgbit;
5557
static constexpr ErrorDistribution errordist =
5658
ErrorDistribution::ModularGaussian;
@@ -72,7 +74,9 @@ struct lvl2param {
7274
static constexpr std::uint32_t k = 1;
7375
static constexpr std::uint32_t lₐ = 4;
7476
static constexpr std::uint32_t l = 4;
77+
static constexpr std::uint32_t Bgₐbit = 9;
7578
static constexpr std::uint32_t Bgbit = 9;
79+
static constexpr std::uint32_t Bgₐ = 1 << Bgₐbit;
7680
static constexpr std::uint32_t Bg = 1 << Bgbit;
7781
static constexpr ErrorDistribution errordist =
7882
ErrorDistribution::ModularGaussian;
@@ -92,7 +96,9 @@ struct lvl3param {
9296
static constexpr std::uint32_t k = 1;
9397
static constexpr std::uint32_t lₐ = 4;
9498
static constexpr std::uint32_t l = 4;
99+
static constexpr std::uint32_t Bgₐbit = 9;
95100
static constexpr std::uint32_t Bgbit = 9;
101+
static constexpr std::uint32_t Bgₐ = 1 << Bgₐbit;
96102
static constexpr std::uint32_t Bg = 1 << Bgbit;
97103
static constexpr ErrorDistribution errordist =
98104
ErrorDistribution::ModularGaussian;

include/params/tfhe-rs.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,9 @@ struct lvl1param {
5656
static constexpr std::uint32_t k = 3;
5757
static constexpr std::uint32_t lₐ = 1;
5858
static constexpr std::uint32_t l = 1;
59+
static constexpr std::uint32_t Bgₐbit = 18;
5960
static constexpr std::uint32_t Bgbit = 18;
61+
static constexpr std::uint32_t Bgₐ = 1 << Bgₐbit;
6062
static constexpr std::uint32_t Bg = 1 << Bgbit;
6163
static constexpr ErrorDistribution errordist =
6264
ErrorDistribution::ModularGaussian;
@@ -79,7 +81,9 @@ struct lvl2param {
7981
static constexpr std::uint32_t k = 2;
8082
static constexpr std::uint32_t lₐ = 4;
8183
static constexpr std::uint32_t l = 4;
84+
static constexpr std::uint32_t Bgₐbit = 9;
8285
static constexpr std::uint32_t Bgbit = 9;
86+
static constexpr std::uint32_t Bgₐ = 1 << Bgₐbit;
8387
static constexpr std::uint32_t Bg = 1 << Bgbit;
8488
static constexpr ErrorDistribution errordist =
8589
ErrorDistribution::ModularGaussian;

include/trgsw.hpp

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,9 @@ constexpr typename P::T nonceoffsetgen()
8888
{
8989
typename P::T offset = 0;
9090
for (int i = 1; i <= P::lₐ; i++)
91-
offset += P::Bg / 2 *
91+
offset += P::Bgₐ / 2 *
9292
(1ULL << (std::numeric_limits<typename P::T>::digits -
93-
i * P::Bgbit));
93+
i * P::Bgₐbit));
9494
return offset;
9595
}
9696

@@ -100,17 +100,17 @@ inline void NonceDecomposition(DecomposedNoncePolynomial<P> &decpoly,
100100
{
101101
constexpr typename P::T offset = nonceoffsetgen<P>();
102102
constexpr typename P::T roundoffset =
103-
1ULL << (std::numeric_limits<typename P::T>::digits - P::lₐ * P::Bgbit -
103+
1ULL << (std::numeric_limits<typename P::T>::digits - P::lₐ * P::Bgₐbit -
104104
1);
105105
constexpr typename P::T mask =
106-
static_cast<typename P::T>((1ULL << P::Bgbit) - 1);
107-
constexpr typename P::T halfBg = (1ULL << (P::Bgbit - 1));
106+
static_cast<typename P::T>((1ULL << P::Bgₐbit) - 1);
107+
constexpr typename P::T halfBg = (1ULL << (P::Bgₐbit - 1));
108108

109109
for (int i = 0; i < P::n; i++) {
110110
for (int l = 0; l < P::lₐ; l++)
111111
decpoly[l][i] = (((poly[i] + offset + roundoffset) >>
112112
(std::numeric_limits<typename P::T>::digits -
113-
(l + 1) * P::Bgbit)) &
113+
(l + 1) * P::Bgₐbit)) &
114114
mask) -
115115
halfBg;
116116
}
@@ -422,27 +422,38 @@ TRGSWNTT<P> TRGSW2NTT(const TRGSW<P> &trgsw)
422422
}
423423

424424
template <class P>
425-
constexpr std::array<typename P::T, P::lₐ> hgen()
425+
constexpr std::array<typename P::T, P::l> hgen()
426426
{
427-
// If the parameter is selected by resoble way, this is reasonable assumption
428-
static_assert(
429-
P::l <= P::lₐ,
430-
"Since lₐ is more noise sensitive, lₐ should be larger than or equal to l");
431-
std::array<typename P::T, P::lₐ> h{};
427+
std::array<typename P::T, P::l> h{};
432428
if constexpr (hasq<P>)
433429
for (int i = 0; i < P::lₐ; i++)
434430
h[i] = (P::q + (1ULL << ((i + 1) * P::Bgbit - 1))) >>
435431
((i + 1) * P::Bgbit);
436432
else
437-
for (int i = 0; i < P::lₐ; i++)
433+
for (int i = 0; i < P::l; i++)
438434
h[i] = 1ULL << (std::numeric_limits<typename P::T>::digits -
439435
(i + 1) * P::Bgbit);
440436
return h;
441437
}
442438

439+
template <class P>
440+
constexpr std::array<typename P::T, P::lₐ> noncehgen()
441+
{
442+
std::array<typename P::T, P::lₐ> h{};
443+
if constexpr (hasq<P>)
444+
for (int i = 0; i < P::lₐ; i++)
445+
h[i] = (P::q + (1ULL << ((i + 1) * P::Bgₐbit - 1))) >>
446+
((i + 1) * P::Bgₐbit);
447+
else
448+
for (int i = 0; i < P::lₐ; i++)
449+
h[i] = 1ULL << (std::numeric_limits<typename P::T>::digits -
450+
(i + 1) * P::Bgₐbit);
451+
return h;
452+
}
453+
443454
template <class P>
444455
inline void halftrgswhadd(HalfTRGSW<P>& halftrgsw, const Polynomial<P> &p){
445-
constexpr std::array<typename P::T, P::lₐ> h = hgen<P>();
456+
constexpr std::array<typename P::T, P::l> h = hgen<P>();
446457
for (int i = 0; i < P::l; i++) {
447458
for (int j = 0; j < P::n; j++) {
448459
halftrgsw[i][P::k][j] +=
@@ -453,15 +464,16 @@ inline void halftrgswhadd(HalfTRGSW<P>& halftrgsw, const Polynomial<P> &p){
453464

454465
template <class P>
455466
inline void trgswhadd(TRGSW<P>& trgsw, const Polynomial<P> &p){
456-
constexpr std::array<typename P::T, P::lₐ> h = hgen<P>();
467+
constexpr std::array<typename P::T, P::lₐ> nonceh = noncehgen<P>();
457468
for (int i = 0; i < P::lₐ; i++) {
458469
for (int k = 0; k < P::k; k++) {
459470
for (int j = 0; j < P::n; j++) {
460471
trgsw[i + k * P::lₐ][k][j] +=
461-
static_cast<typename P::T>(p[j]) * h[i];
472+
static_cast<typename P::T>(p[j]) * nonceh[i];
462473
}
463474
}
464475
}
476+
constexpr std::array<typename P::T, P::l> h = hgen<P>();
465477
for (int i = 0; i < P::l; i++) {
466478
for (int j = 0; j < P::n; j++) {
467479
trgsw[i + P::k * P::lₐ][P::k][j] +=
@@ -472,11 +484,12 @@ inline void trgswhadd(TRGSW<P>& trgsw, const Polynomial<P> &p){
472484

473485
template <class P>
474486
inline void trgswhoneadd(TRGSW<P>& trgsw){
475-
constexpr std::array<typename P::T, P::lₐ> h = hgen<P>();
487+
constexpr std::array<typename P::T, P::lₐ> nonceh = noncehgen<P>();
476488
for (int i = 0; i < P::lₐ; i++)
477489
for (int k = 0; k < P::k; k++)
478-
trgsw[i + k * P::lₐ][k][0] += h[i];
490+
trgsw[i + k * P::lₐ][k][0] += nonceh[i];
479491

492+
constexpr std::array<typename P::T, P::l> h = hgen<P>();
480493
for (int i = 0; i < P::l; i++)
481494
trgsw[i + P::k * P::lₐ][P::k][0] += h[i];
482495
}
@@ -515,7 +528,7 @@ template <class P>
515528
HalfTRGSW<P> halftrgswSymEncrypt(const Polynomial<P> &p, const double α,
516529
const Key<P> &key)
517530
{
518-
constexpr std::array<typename P::T, P::lₐ> h = hgen<P>();
531+
constexpr std::array<typename P::T, P::l> h = hgen<P>();
519532

520533
HalfTRGSW<P> halftrgsw;
521534
for (TRLWE<P> &trlwe : halftrgsw) trlwe = trlweSymEncryptZero<P>(α, key);
@@ -529,7 +542,7 @@ template <class P>
529542
HalfTRGSW<P> halftrgswSymEncrypt(const Polynomial<P> &p, const uint η,
530543
const Key<P> &key)
531544
{
532-
constexpr std::array<typename P::T, P::lₐ> h = hgen<P>();
545+
constexpr std::array<typename P::T, P::l> h = hgen<P>();
533546

534547
HalfTRGSW<P> halftrgsw;
535548
for (TRLWE<P> &trlwe : halftrgsw) trlwe = trlweSymEncryptZero<P>(η, key);

0 commit comments

Comments
 (0)