diff --git a/ArkLib.lean b/ArkLib.lean index bcef19351..489dc589f 100644 --- a/ArkLib.lean +++ b/ArkLib.lean @@ -112,11 +112,13 @@ import ArkLib.ProofSystem.Binius.BinaryBasefold.General import ArkLib.ProofSystem.Binius.BinaryBasefold.Prelude import ArkLib.ProofSystem.Binius.BinaryBasefold.QueryPhase import ArkLib.ProofSystem.Binius.BinaryBasefold.ReductionLogic +import ArkLib.ProofSystem.Binius.BinaryBasefold.Soundness import ArkLib.ProofSystem.Binius.BinaryBasefold.Spec import ArkLib.ProofSystem.Binius.BinaryBasefold.Steps import ArkLib.ProofSystem.Binius.FRIBinius.CoreInteractionPhase import ArkLib.ProofSystem.Binius.FRIBinius.General import ArkLib.ProofSystem.Binius.FRIBinius.Prelude +import ArkLib.ProofSystem.Binius.RingSwitching.BBFSmallFieldIOPCS import ArkLib.ProofSystem.Binius.RingSwitching.BatchingPhase import ArkLib.ProofSystem.Binius.RingSwitching.General import ArkLib.ProofSystem.Binius.RingSwitching.Prelude @@ -160,6 +162,7 @@ import ArkLib.ToMathlib.Data.IndexedBinaryTree.Basic import ArkLib.ToMathlib.Data.IndexedBinaryTree.Equiv import ArkLib.ToMathlib.Data.IndexedBinaryTree.Lemmas import ArkLib.ToMathlib.Finset.Basic +import ArkLib.ToMathlib.MvPolynomial.Equiv import ArkLib.ToVCVio.DistEq import ArkLib.ToVCVio.Lemmas import ArkLib.ToVCVio.Oracle diff --git a/ArkLib/Data/FieldTheory/AdditiveNTT/AdditiveNTT.lean b/ArkLib/Data/FieldTheory/AdditiveNTT/AdditiveNTT.lean index 68a198cf6..90c73e6e9 100644 --- a/ArkLib/Data/FieldTheory/AdditiveNTT/AdditiveNTT.lean +++ b/ArkLib/Data/FieldTheory/AdditiveNTT/AdditiveNTT.lean @@ -1159,6 +1159,21 @@ noncomputable def iteratedQuotientMap [NeZero ℓ] (i : Fin r) {destIdx : Fin r} _ = (normalizedW 𝔽q β destIdx).eval u := by rw [h_comp_eq] exact ⟨y, h_mem⟩ +omit [DecidableEq 𝔽q] hF₂ in +lemma iteratedQuotientMap_congr_k + (i : Fin r) {destIdx : Fin r} {k₁ k₂ : ℕ} + (hk : k₁ = k₂) + (h_destIdx₁ : destIdx.val = i.val + k₁) + (h_destIdx₂ : destIdx.val = i.val + k₂) + (h_destIdx_le : destIdx.val ≤ ℓ) + (x : sDomain 𝔽q β h_ℓ_add_R_rate i) : + iteratedQuotientMap 𝔽q β h_ℓ_add_R_rate + (i := i) (k := k₁) (h_destIdx := h_destIdx₁) (h_destIdx_le := h_destIdx_le) x + = + iteratedQuotientMap 𝔽q β h_ℓ_add_R_rate + (i := i) (k := k₂) (h_destIdx := h_destIdx₂) (h_destIdx_le := h_destIdx_le) x := by + subst hk; rfl + omit [DecidableEq 𝔽q] [NeZero ℓ] hF₂ h_β₀_eq_1 in /-- The evaluation of qMap on an element from sDomain i belongs to sDomain (i+1). This is a key property that qMap maps between successive domains. -/ diff --git a/ArkLib/Data/Misc/Basic.lean b/ArkLib/Data/Misc/Basic.lean index ede93ef9d..57ff0167b 100644 --- a/ArkLib/Data/Misc/Basic.lean +++ b/ArkLib/Data/Misc/Basic.lean @@ -28,6 +28,8 @@ def subtypeSumComplEquiv {α : Type*} {p : α → Prop} [DecidablePred p] : lemma fun_eta_expansion {α β : Type*} (f : α → β) : f = (fun x => f x) := rfl +lemma fun_eta_expansion_apply {α β : Type*} (f : α → β) (x : α) : (f x) = (fun x => f x) x := rfl + /-- Casting a function equals the function that casts its argument. -/ lemma cast_fun_eq_fun_cast_arg.{u, v} {A B : Type u} {C : Type v} (h : A = B) (f : A → C) : cast (congrArg (· → C) h) f = fun x => f (cast h.symm x) := by diff --git a/ArkLib/Data/Probability/Instances.lean b/ArkLib/Data/Probability/Instances.lean index 841b89652..1b6223d6b 100644 --- a/ArkLib/Data/Probability/Instances.lean +++ b/ArkLib/Data/Probability/Instances.lean @@ -9,10 +9,12 @@ import ArkLib.Data.Probability.Notation import CompPoly.Data.Fin.BigOperators import CompPoly.Data.Nat.Bitwise import Mathlib.Algebra.MvPolynomial.SchwartzZippel - +import CompPoly.Data.MvPolynomial.Notation +import ArkLib.ToMathlib.MvPolynomial.Equiv +import VCVio.EvalDist.Instances.OptionT open ProbabilityTheory Filter open NNReal Finset Function -open scoped BigOperators ProbabilityTheory +open scoped BigOperators ProbabilityTheory Polynomial MvPolynomial open Real -- TODO(dtumad): Move most of the stuff in this file to VCV and generalize as possible @@ -242,13 +244,10 @@ theorem prob_split_uniform_sampling_of_equiv_prod {α γ δ : Type} -- Decidability for the predicates [DecidablePred P] [DecidablePred (fun xy : γ × δ => P (e.symm xy))] : - -- LHS: Probability over the combined space Pr_{ let r ← $ᵖ α }[ P r ] = - -- RHS: Probability over the sequential, split spaces Pr_{ let x ← $ᵖ γ; let y ← $ᵖ δ }[ P (e.symm (x, y)) ] := by - -- 1. Unroll the LHS (a single `let`) using `prStx_unfold_final` -- LHS = ∑' r, Pr[r] * (if P r then 1 else 0) rw [prob_tsum_form_singleton] @@ -259,7 +258,6 @@ theorem prob_split_uniform_sampling_of_equiv_prod {α γ δ : Type} conv_rhs => apply prob_tsum_form_split_first (D := $ᵖ γ) (D_rest := D_rest) simp_rw [D_rest] - simp only [PMF.uniformOfFintype_apply, mul_ite, mul_one, mul_zero] simp_rw [prob_tsum_form_singleton] -- ⊢ (∑' (x : α), ... = ∑' (x : γ), (↑(Fintype.card γ))⁻¹ * ∑' (x_1 : δ), ... @@ -281,7 +279,6 @@ theorem prob_split_uniform_sampling_of_equiv_prod {α γ δ : Type} )] -- ⊢ (∑ b : α, .. = ..) = (∑ b : γ × δ, ..) have hcard_of_equiv: (Fintype.card α) = (Fintype.card (γ × δ)) := Fintype.card_congr e - rw [Finset.sum_equiv (s := Finset.univ (α := α)) (t := Finset.univ (α := γ × δ)) (f := fun x => if P x then (↑(Fintype.card α))⁻¹ else 0) (g := fun x => (↑(Fintype.card γ))⁻¹ * (($ᵖ δ) x.2 * if P (e.symm x) then 1 else 0)) @@ -307,7 +304,6 @@ theorem prob_split_last_uniform_sampling_of_finFun {ϑ : ℕ} {F : Type} [Fintyp Pr_{ let r ← $ᵖ (Fin (ϑ + 1) → F) }[ P (r (Fin.last ϑ)) (fun i ↦ r i.castSucc) ] = Pr_{ let r_last ← $ᵖ F; let r_init ← $ᵖ (Fin ϑ → F) }[ P r_last r_init ] := by rw [prob_tsum_form_doubleton] - let e : (Fin (ϑ + 1) → F) ≃ F × (Fin ϑ → F) := equivFinFunSplitLast conv_lhs => rw [prob_split_uniform_sampling_of_equiv_prod (e := e)] @@ -330,7 +326,6 @@ theorem prob_marginalization_first_of_prod {α β : Type} [Fintype α] [Fintype [Nonempty α] [Nonempty β] (P : α → Prop) [DecidablePred P] : Pr_{let r ← $ᵖ (α × β) }[ P r.1 ] = Pr_{ let x ← $ᵖ α }[ P x ] := by rw [prob_split_uniform_sampling_of_prod] - let D_rest := fun (x : α) => (do let y ← $ᵖ β pure (P (x, y).1) @@ -374,6 +369,26 @@ theorem Pr_le_Pr_of_implies {α : Type} (D : PMF α) -- 5. Prove the factor `D r` is non-negative · exact zero_le (D r) -- Probabilities are always non-negative +alias prob_mono := Pr_le_Pr_of_implies + +/-- **Union bound**: Pr[A ∨ B] ≤ Pr[A] + Pr[B]. -/ +theorem Pr_or_le {α : Type} (D : PMF α) + (f g : α → Prop) [DecidablePred f] [DecidablePred g] [DecidablePred (fun r => f r ∨ g r)] : + Pr_{ let r ← D }[ f r ∨ g r ] ≤ Pr_{ let r ← D }[ f r ] + Pr_{ let r ← D }[ g r ] := by + rw [prob_tsum_form_singleton D (fun r => f r ∨ g r), + prob_tsum_form_singleton D f, prob_tsum_form_singleton D g] + trans ∑' r, (D r * (if f r then 1 else 0) + D r * (if g r then 1 else 0)) + · apply ENNReal.tsum_le_tsum + intro r + by_cases hf : f r + · by_cases hg : g r + · simp only [hf, hg, or_true, ↓reduceIte, mul_one]; exact le_add_of_nonneg_right (zero_le (D r)) + · simp only [hf, hg, or_true, true_or, ↓reduceIte, mul_one, mul_zero, add_zero]; exact le_refl (D r) + · by_cases hg : g r + · simp only [hf, hg, true_or, or_true, ↓reduceIte, mul_one, mul_zero, zero_add]; exact le_refl (D r) + · simp only [hf, hg, false_or, ↓reduceIte, mul_zero, zero_add]; exact le_refl 0 + · rw [ENNReal.tsum_add]; + theorem Pr_multi_let_equiv_single_let {α β : Type} (D₁ : PMF α) (D₂ : PMF β) -- Assuming D₂ is independent for simplicity (P : α → β → Prop) [∀ x, DecidablePred (P x)] : @@ -455,6 +470,98 @@ lemma Pr_congr {α : Type} {D : PMF α} {P Q : α → Prop} congr 2; funext x; congr 1; exact propext (h x) +section ProbabilitySplitting +variable {A : Type} [Fintype A] [Nonempty A] + +/-- Helper: Probability over functions can be split into iterated product. +For uniform sampling over `Fin n → A`, if the predicate factors as a conjunction +over each index, the probability equals the product of individual probabilities. + +This is the key lemma for showing that independent repetitions multiply their error rates. -/ +theorem prob_pow_of_forall_finFun + (n : ℕ) (P : A → Prop) [DecidablePred P] + [DecidablePred (fun (f : Fin n → A) => ∀ i, P (f i))] : + Pr_{ let f ← $ᵖ (Fin n → A) }[ ∀ i, P (f i) ] = + (Pr_{ let a ← $ᵖ A }[ P a ])^n := by + induction n with + | zero => + simp only [IsEmpty.forall_iff, PMF.monad_pure_eq_pure, PMF.monad_bind_eq_bind, PMF.bind_const, + PMF.pure_apply, ↓reduceIte, PMF.bind_apply, PMF.uniformOfFintype_apply, eq_iff_iff, true_iff, + mul_ite, mul_one, mul_zero, pow_zero] + | succ n ih => + -- Shorter equivalence proof + have h_eqv (f : Fin (n + 1) → A) : (∀ i, P (f i)) ↔ P (f (Fin.last n)) ∧ ∀ (i : Fin n), P (f i.castSucc) := by + constructor + · intro h; exact ⟨h _, fun i => h _⟩ + · rintro ⟨h_last, h_init⟩ ⟨i, hi⟩ + by_cases h : i < n + · exact h_init ⟨i, h⟩ + · have : i = n := by omega + simp only [this] + exact h_last + rw [Pr_congr (h := h_eqv)] + -- Chain the splitting and independence results + calc Pr_{ let f ← $ᵖ (Fin (n + 1) → A) }[ P (f (Fin.last n)) ∧ ∀ (i : Fin n), P (f i.castSucc) ] + _ = Pr_{ let a ← $ᵖ A; let f_init ← $ᵖ (Fin n → A) }[ P a ∧ ∀ (i : Fin n), P (f_init i) ] := by + have h := prob_split_last_uniform_sampling_of_finFun (ϑ := n) + (P := fun a (f_init : Fin n → A) => P a ∧ ∀ (i : Fin n), P (f_init i)) + simpa using h + _ = Pr_{ let a ← $ᵖ A }[ P a ] * Pr_{ let f_init ← $ᵖ (Fin n → A) }[ ∀ (i : Fin n), P (f_init i) ] := by + -- Convert sequential bind to single uniform over product + have h_prod : Pr_{ let a ← $ᵖ A; let f_init ← $ᵖ (Fin n → A) }[ P a ∧ ∀ (i : Fin n), P (f_init i) ] = + Pr_{ let p ← $ᵖ (A × (Fin n → A)) }[ P p.1 ∧ ∀ (i : Fin n), P (p.2 i) ] := by + rw [prob_split_uniform_sampling_of_prod] + rw [h_prod] + -- Use counting formula for product and components + rw [prob_uniform_eq_card_filter_div_card, prob_uniform_eq_card_filter_div_card, + prob_uniform_eq_card_filter_div_card] + simp only [Fintype.card_prod, ENNReal.div_eq_inv_mul] + -- Filter cardinality multiplies + have h_filter : + (Finset.filter (fun (p : A × (Fin n → A)) => P p.1 ∧ ∀ (i : Fin n), P (p.2 i)) Finset.univ).card = + (Finset.filter (fun (a : A) => P a) Finset.univ).card * + (Finset.filter (fun (f : Fin n → A) => ∀ (i : Fin n), P (f i)) Finset.univ).card := by + have : Finset.filter (fun (p : A × (Fin n → A)) => P p.1 ∧ ∀ (i : Fin n), P (p.2 i)) Finset.univ = + (Finset.filter (fun (a : A) => P a) Finset.univ) ×ˢ + (Finset.filter (fun (f : Fin n → A) => ∀ (i : Fin n), P (f i)) Finset.univ) := by + ext ⟨a, f⟩; simp + rw [this, Finset.card_product] + rw [h_filter] + simp only [Fintype.card_pi, Finset.prod_const, Finset.card_univ, Fintype.card_fin, + Nat.cast_mul, Nat.cast_pow, ENNReal.coe_mul, ENNReal.coe_natCast, ENNReal.coe_pow] + -- (↑(card A))^n = ↑((card A)^n) so lemma pattern (↑a * ↑b)⁻¹ matches + rw [← Nat.cast_pow] + rw [← ENNReal.mul_inv_rev_ENNReal (ha := Fintype.card_ne_zero) + (hb := ne_of_gt (Nat.pow_pos Fintype.card_pos))] + have h_eq : (Finset.filter (fun a => P a) Finset.univ) = Finset.filter P Finset.univ := + Finset.filter_congr (fun a _ => by rfl) + rw [← h_eq] + conv_lhs => + rw [←mul_assoc]; + rw [mul_assoc (c := ((Finset.filter (fun a => P a) Finset.univ).card : ENNReal))] + rw [mul_comm (b := ((Finset.filter (fun a => P a) Finset.univ).card : ENNReal))] + rw [←mul_assoc] + rw [mul_assoc (a := ((Fintype.card A):ENNReal)⁻¹ * (Finset.filter (fun a => P a) Finset.univ).card)] + _ = (Pr_{ let a ← $ᵖ A }[ P a ]) ^ (n + 1) := by + rw [ih, pow_succ', mul_comm] + +/-- Specialization: Probability bound for failing all proximity checks. +When each repetition independently bounds bad events by ε, running n repetitions +has cumulative bound ε^n (product rule for independent events). -/ +theorem prob_pow_bound_of_forall + (n : ℕ) (P : A → Prop) [DecidablePred P] + [DecidablePred (fun (f : Fin n → A) => ∀ i, P (f i))] + (ε : ENNReal) (h_bound : Pr_{ let a ← $ᵖ A}[P a] ≤ ε) : + Pr_{ let f ← $ᵖ (Fin n → A) }[ ∀ i, P (f i) ] ≤ ε^n := by + calc Pr_{ let f ← $ᵖ (Fin n → A) }[ ∀ i, P (f i) ] + = (Pr_{ let a ← $ᵖ A }[ P a ])^n := prob_pow_of_forall_finFun n P + _ ≤ ε^n := by + -- Use the fact that x ≤ y implies x^n ≤ y^n for ENNReal + apply pow_le_pow_left' + · exact h_bound + +end ProbabilitySplitting + /-- **Schwartz-Zippel Lemma** (Probability Form): For a non-zero multivariate polynomial `P` of total degree at most `d` over a finite field `L`, the probability that `P(r)` evaluates to 0 for a uniformly random `r` is at most `d / |L|`. -/ @@ -494,4 +601,126 @@ lemma prob_schwartz_zippel_mv_polynomial {R : Type} [CommRing R] [IsDomain R] [F rw [Nat.cast_pow] at sz_bound_ENNReal exact sz_bound_ENNReal +/-- **Schwartz-Zippel for univariate (Fin 1) polynomials with arbitrary degree bound**. +For a non-zero `P : MvPolynomial (Fin 1) R` with `P.totalDegree ≤ d`, the probability that +`P(r)` is 0 for uniform `r : Fin 1 → R` is at most `d / |R|`. -/ +lemma prob_schwartz_zippel_univariate_deg {R : Type} [CommRing R] [IsDomain R] [Fintype R] + (d : ℕ) (P : MvPolynomial (Fin 1) R) (h_nonzero : P ≠ 0) + (h_deg : P.totalDegree ≤ d) : + Pr_{ let r ←$ᵖ (Fin 1 → R) }[ MvPolynomial.eval r P = 0 ] ≤ + (d : ℝ≥0) / (Fintype.card R : ℝ≥0) := by + classical + rw [prob_uniform_eq_card_filter_div_card] + push_cast + have sz_bound := MvPolynomial.schwartz_zippel_totalDegree (R := R) (n := 1) + (p := P) (hp := h_nonzero) (S := Finset.univ) + simp only [Fintype.piFinset_univ, card_univ] at sz_bound + have sz_bound_le_d_div_card_R : ((#{f | (MvPolynomial.eval f) P = 0}) : ℚ≥0) + / ((Fintype.card R ^ 1)) ≤ (d : ℚ≥0) / ((#(Finset.univ : Finset R)) : ℚ≥0) := by + calc + _ ≤ (P.totalDegree : ℚ≥0) / ((#(Finset.univ : Finset R)) : ℚ≥0) := sz_bound + _ ≤ (d : ℚ≥0) / ((#(Finset.univ : Finset R)) : ℚ≥0) := by + simp only [card_univ] + apply div_le_of_le_mul₀ (hb := by simp only [zero_le]) (hc := by simp only [zero_le]) + rw [div_mul_cancel₀ (h := by simp only [ne_eq, Nat.cast_eq_zero, Fintype.card_ne_zero, + not_false_eq_true])] + exact Nat.cast_le.mpr h_deg + have sz_bound_le_d_div_card_R' : ((#{f | (MvPolynomial.eval f) P = 0}) : ℚ≥0) + / (Fintype.card R : ℚ≥0) ≤ (d : ℚ≥0) / (Fintype.card R : ℚ≥0) := by + rw [pow_one, card_univ] at sz_bound_le_d_div_card_R + exact sz_bound_le_d_div_card_R + have sz_bound_ENNReal : ((#{f | (MvPolynomial.eval f) P = 0}) : ENNReal) + / (Fintype.card R : ENNReal) ≤ (d : ENNReal) / (Fintype.card R : ENNReal) := by + simp_rw [ENNReal.coe_Nat_coe_NNRat] + conv_lhs => rw [ENNReal.coe_div_of_NNRat (hb := by + simp only [pow_one, ne_eq, Nat.cast_eq_zero, Fintype.card_ne_zero, not_false_eq_true])] + conv_rhs => rw [ENNReal.coe_div_of_NNRat (hb := by simp only [ne_eq, Nat.cast_eq_zero, + Fintype.card_ne_zero, not_false_eq_true])] + rw [ENNReal.coe_le_of_NNRat] + exact sz_bound_le_d_div_card_R' + simp only [Fintype.card_pi, prod_const, card_univ, Fintype.card_fin, pow_one, ge_iff_le] + exact sz_bound_ENNReal + +/-- **Schwartz-Zippel for degree-1 univariate polynomials**. +For two distinct degree-1 univariate polynomials over a commutative ring, the probability +that they agree at a random point is at most `1 / |R|`. -/ +lemma prob_poly_agreement_degree_one {R : Type} [CommRing R] [IsDomain R] [Fintype R] + (p q : R⦃≤ 1⦄[X]) + (h_ne : p ≠ q) : + Pr_{ let r ←$ᵖ R }[ p.val.eval r = q.val.eval r ] ≤ + (1 : ℝ≥0) / (Fintype.card R : ℝ≥0) := by + classical + -- 1. Setup the multivariate polynomial P = p - q + let P := (p.val - q.val).toMvPolynomial (σ := Fin 1) 0 + -- 2. Prove P is non-zero (immediate from p ≠ q) + have h_nz : P ≠ 0 := by + rw [Polynomial.toMvPolynomial_ne_zero_iff, sub_ne_zero] + exact fun h => h_ne (Subtype.ext h) + have h_p_deg : p.val.degree ≤ 1 := + Polynomial.mem_degreeLE (f := p.val) (n := 1).mp (by simp only [SetLike.coe_mem]) + have h_q_deg : q.val.degree ≤ 1 := + Polynomial.mem_degreeLE (f := q.val) (n := 1).mp (by simp only [SetLike.coe_mem]) + -- 3. Prove totalDegree P ≤ 1 + have h_deg : P.totalDegree ≤ 1 := by + apply (Polynomial.toMvPolynomial_totalDegree_le _ _).trans + apply (Polynomial.natDegree_sub_le _ _).trans + -- Use the fact that p, q are in R{≤1}[X] directly + simp only [max_le_iff] + constructor <;> apply Polynomial.natDegree_le_of_degree_le <;> + first | exact h_p_deg | exact h_q_deg + -- 4. Apply Schwartz-Zippel + calc Pr_{ let r ←$ᵖ R }[ p.val.eval r = q.val.eval r ] + _ = Pr_{ let r ←$ᵖ R }[ (p.val - q.val).eval r = 0 ] := by + apply Pr_congr; simp [sub_eq_zero] + _ = Pr_{ let r ←$ᵖ R }[ MvPolynomial.eval (fun _ ↦ r) P = 0 ] := by + apply Pr_congr; intro; simp [P, MvPolynomial.eval_toMvPolynomial] + _ = Pr_{ let f ←$ᵖ (Fin 1 → R) }[ MvPolynomial.eval f P = 0 ] := by + -- Move to function space (Fin 1 → R) + rw [← prob_uniform_singleton_finFun_eq] + congr; funext f + -- Collapse: eval f (toMv p) = eval (f 0) p + simp [P, MvPolynomial.eval_toMvPolynomial] + _ ≤ _ := by + have h := prob_schwartz_zippel_mv_polynomial P h_nz h_deg + simp only [Nat.cast_one] at h + exact h + +alias prob_schwartz_zippel_univariate_poly := prob_poly_agreement_degree_one + +/-- **Schwartz-Zippel for degree-2 univariate polynomials**. +For two distinct degree-2 univariate polynomials over a commutative ring, the probability +that they agree at a random point is at most `2 / |R|`. -/ +lemma prob_poly_agreement_degree_two {R : Type} [CommRing R] [IsDomain R] [Fintype R] + (p q : R⦃≤ 2⦄[X]) + (h_ne : p ≠ q) : + Pr_{ let r ←$ᵖ R }[ p.val.eval r = q.val.eval r ] ≤ + (2 : ℝ≥0) / (Fintype.card R : ℝ≥0) := by + classical + let P := (p.val - q.val).toMvPolynomial (σ := Fin 1) 0 + have h_nz : P ≠ 0 := by + rw [Polynomial.toMvPolynomial_ne_zero_iff, sub_ne_zero] + exact fun h => h_ne (Subtype.ext h) + have h_p_deg : p.val.degree ≤ 2 := + Polynomial.mem_degreeLE (f := p.val) (n := 2).mp (by simp only [SetLike.coe_mem]) + have h_q_deg : q.val.degree ≤ 2 := + Polynomial.mem_degreeLE (f := q.val) (n := 2).mp (by simp only [SetLike.coe_mem]) + have h_deg : P.totalDegree ≤ 2 := by + apply (Polynomial.toMvPolynomial_totalDegree_le _ _).trans + apply (Polynomial.natDegree_sub_le _ _).trans + simp only [max_le_iff] + constructor <;> apply Polynomial.natDegree_le_of_degree_le <;> + first | exact h_p_deg | exact h_q_deg + calc Pr_{ let r ←$ᵖ R }[ p.val.eval r = q.val.eval r ] + _ = Pr_{ let r ←$ᵖ R }[ (p.val - q.val).eval r = 0 ] := by apply Pr_congr; simp [sub_eq_zero] + _ = Pr_{ let r ←$ᵖ R }[ MvPolynomial.eval (fun _ ↦ r) P = 0 ] := by + apply Pr_congr; intro; simp [P, MvPolynomial.eval_toMvPolynomial] + _ = Pr_{ let f ←$ᵖ (Fin 1 → R) }[ MvPolynomial.eval f P = 0 ] := by + rw [← prob_uniform_singleton_finFun_eq] + congr; funext f + simp [P, MvPolynomial.eval_toMvPolynomial] + _ ≤ _ := by + -- The lemma returns (P.totalDegree / card R), which is ≤ (2 / card R) + have h := prob_schwartz_zippel_univariate_deg 2 P h_nz h_deg + exact h + end ProbabilityTools diff --git a/ArkLib/OracleReduction/Basic.lean b/ArkLib/OracleReduction/Basic.lean index 0f8310bcb..e4be13615 100644 --- a/ArkLib/OracleReduction/Basic.lean +++ b/ArkLib/OracleReduction/Basic.lean @@ -772,6 +772,15 @@ end IsSingleRound def FullTranscript.mk1 {pSpec : ProtocolSpec 1} (msg0 : pSpec.«Type» 0) : FullTranscript pSpec := fun | ⟨0, _⟩ => msg0 +@[simp] +theorem FullTranscript.mk1_eq_snoc {pSpec : ProtocolSpec 1} (msg0 : pSpec.«Type» 0) : + FullTranscript.mk1 msg0 = (default : pSpec.Transcript 0).concat msg0 := by + unfold FullTranscript.mk1 Transcript.concat + simp only [default, Fin.isValue] + funext i + have hi : i = 0 := by omega + subst hi; simp [Fin.snoc] + @[inline, reducible] def FullTranscript.mk2 {pSpec : ProtocolSpec 2} (msg0 : pSpec.«Type» 0) (msg1 : pSpec.«Type» 1) : FullTranscript pSpec := fun | ⟨0, _⟩ => msg0 | ⟨1, _⟩ => msg1 diff --git a/ArkLib/OracleReduction/Completeness.lean b/ArkLib/OracleReduction/Completeness.lean index a1a82bb1c..de65e2de9 100644 --- a/ArkLib/OracleReduction/Completeness.lean +++ b/ArkLib/OracleReduction/Completeness.lean @@ -5,6 +5,7 @@ Authors: Chung Thai Nguyen, Quang Dao -/ import ArkLib.OracleReduction.Security.Basic import ArkLib.ToVCVio.Simulation +import ArkLib.OracleReduction.Security.RoundByRound import ArkLib.ToVCVio.Lemmas /-! @@ -142,22 +143,17 @@ theorem unroll_n_message_reduction_perfectCompleteness unfold OracleReduction.perfectCompleteness simp only [Reduction.perfectCompleteness_eq_prob_one] simp only [probEvent_eq_one_iff] - simp only [Prod.forall] at * - apply forall_congr'; intro stmtIn apply forall_congr'; intro oStmtIn apply forall_congr'; intro witIn apply imp_congr_right; intro h_relIn - simp only [Reduction_run_def, Prover.run, Prover.runToRound] have h_init_probFailure_eq_0 : Pr[⊥ | init] = 0 := by rw [probFailure_eq_zero_iff]; exact hInit - conv_lhs => simp only rw [OptionT.probFailure_mk_bind_eq_zero_iff] - conv_lhs => simp only [h_init_probFailure_eq_0, true_and] enter [1, x, 2] @@ -196,7 +192,6 @@ theorem unroll_n_message_reduction_perfectCompleteness support_liftM, QueryImpl.mapQuery, OracleQuery.input_apply, OracleQuery.cont_apply, liftM_map] using hq )] - simp only [liftM_bind] simp only [ChallengeIdx, Challenge, liftM_pure, bind_pure_comp, liftM_OptionT_eq, Prod.mk.eta, bind_assoc, bind_map_left, OptionT.support_mk, Set.mem_setOf_eq, Prod.mk.injEq, @@ -207,7 +202,6 @@ theorem unroll_n_message_reduction_perfectCompleteness rw [OptionT.probFailure_mk_do_bind_bindT_eq_zero_iff] simp only [OptionT.probFailure_mk_do_bindT_eq_zero_iff] simp only [OracleReduction.toReduction] - have h_init_support_nonempty := support_nonempty_of_neverFails init hInit have elim_vacuous_quant : ∀ {α : Type} {S : Set α} {P : Prop}, (∀ x ∈ S, P) ↔ (S.Nonempty → P) := by @@ -230,7 +224,6 @@ theorem unroll_n_message_reduction_perfectCompleteness dsimp only [Functor.map, OptionT.instMonad] simp only [OptionT.mem_support_OptionT_bind_run_some_iff, Challenge, Function.comp_apply, Prod.exists] - apply and_congr · constructor · intro h tr lastPrvState h_mem_prvRun @@ -604,4 +597,575 @@ theorem unroll_2_message_reduction_perfectCompleteness end TwoMessageProtocol +/-! ## Round-by-Round Knowledge Soundness Unroll Lemmas + +This section provides unroll lemmas for `rbrKnowledgeSoundness` that mirror the structure +of the completeness unroll lemmas. These lemmas convert the probabilistic soundness bounds +into factored tsum forms that are easier to work with for probability reasoning. + +**Key differences from completeness:** +- Completeness: `probEvent = 1` → pure logic/support statements +- Soundness: `probEvent ≤ error` → tsum factorization → probability bounds + +**Main Results:** +- `unroll_rbrKnowledgeSoundness`: Generic lemma that factors the probEvent bound into a tsum + over initial states, enabling uniform bounds on the inner computation. +- Future: Specific versions for 1-message and 2-message protocols (similar to completeness) +-/ + +section RoundByRoundKnowledgeSoundness + +open NNReal ENNReal + +variable {ι : Type} {oSpec : OracleSpec ι} [oSpec.Fintype] + {StmtIn WitIn StmtOut WitOut : Type} {n : ℕ} {pSpec : ProtocolSpec n} + [∀ i, SampleableType (pSpec.Challenge i)] + {σ : Type} (init : ProbComp σ) (impl : QueryImpl oSpec (StateT σ ProbComp)) + +lemma tsum_mul_le_of_le_of_sum_le_one_nnreal {α : Type*} + {f g : α → ℝ≥0} {ε : ℝ≥0} + (hf_summable : Summable f) -- Required for NNReal tsum arithmetic + (hg : ∀ x, g x ≤ ε) + (hf : ∑' x, f x ≤ 1) : + ∑' x, f x * g x ≤ ε := by + -- 1. Establish that the upper bound series (f x * ε) is summable + have h_mul_summable : Summable (fun x ↦ f x * ε) := + hf_summable.mul_right ε + -- 2. Establish that the target series (f x * g x) is summable by comparison + have h_fg_summable : Summable (fun x ↦ f x * g x) := by + refine NNReal.summable_of_le (fun x ↦ ?_) h_mul_summable + exact mul_le_mul_of_nonneg_left (hg x) (zero_le (f x)) + -- 3. The calculation + calc ∑' x, f x * g x + _ ≤ ∑' x, f x * ε := by + apply Summable.tsum_le_tsum _ h_fg_summable h_mul_summable + intro x + exact mul_le_mul_of_nonneg_left (hg x) (zero_le _) + _ = (∑' x, f x) * ε := tsum_mul_right f ε + _ ≤ 1 * ε := mul_le_mul_of_nonneg_right hf (zero_le _) + _ = ε := one_mul ε + +lemma ENNReal.tsum_mul_le_of_le_of_sum_le_one {α : Type*} {f g : α → ℝ≥0∞} {ε : ℝ≥0∞} + (hg : ∀ x, g x ≤ ε) -- The conditional probability is bounded + (hf : ∑' x, f x ≤ 1) : -- The weights sum to at most 1 + ∑' x, f x * g x ≤ ε := by + calc ∑' x, f x * g x + _ ≤ ∑' x, f x * ε := + ENNReal.tsum_le_tsum (fun x ↦ mul_le_mul_left' (hg x) _) + _ = (∑' x, f x) * ε := ENNReal.tsum_mul_right + _ ≤ 1 * ε := mul_le_mul_right' hf ε + _ = ε := one_mul ε + +omit [oSpec.Fintype] in +/-- **Unroll lemma for round-by-round knowledge soundness (uniform bound form)** + +This is the preferred formulation for proving round-by-round knowledge soundness. +Instead of proving the tsum bound directly, we prove a **uniform bound for all states**: + +``` +∀ (s : σ), [doom_event | (simulateQ ...).run s] ≤ rbrKnowledgeError i +``` + +This implies `rbrKnowledgeSoundness` because: +- `∑' s, [= s | init] * [doom_event | ...run s] ≤ ∑' s, [= s | init] * ε` +- `= ε * ∑' s, [= s | init]` +- `≤ ε * 1 = ε` (since `∑' s, [= s | init] ≤ 1` for any probability distribution) + +This form is convenient because: +1. The initial state `s` is fixed, simplifying the probability reasoning +2. The bound holds uniformly regardless of `init`, making proofs more modular +3. It aligns with how we typically apply tools like Schwartz-Zippel +-/ +theorem unroll_rbrKnowledgeSoundness + (verifier : Verifier oSpec StmtIn StmtOut pSpec) + (relIn : Set (StmtIn × WitIn)) (relOut : Set (StmtOut × WitOut)) + (rbrKnowledgeError : pSpec.ChallengeIdx → ℝ≥0) + (WitMid : Fin (n + 1) → Type) + (extractor : Extractor.RoundByRound oSpec StmtIn WitIn WitOut pSpec WitMid) + (kSF : verifier.KnowledgeStateFunction init impl relIn relOut extractor) + (h_single_bound : ∀ stmtIn : StmtIn, + ∀ witIn : WitIn, + ∀ prover : Prover oSpec StmtIn WitIn StmtOut WitOut pSpec, + ∀ i : pSpec.ChallengeIdx, + ∀ s : σ, + (Pr[fun ⟨⟨transcript, challenge, _proveQueryLog⟩, _initState⟩ => + ∃ witMid, + ¬ kSF i.1.castSucc stmtIn transcript + (extractor.extractMid i.1 stmtIn (transcript.concat challenge) witMid) ∧ + kSF i.1.succ stmtIn (transcript.concat challenge) witMid + | (simulateQ (impl.addLift challengeQueryImpl : QueryImpl _ (StateT σ ProbComp)) + (do + let ⟨⟨transcript, _⟩, proveQueryLog⟩ + ← prover.runWithLogToRound i.1.castSucc stmtIn witIn + let challenge ← liftComp (pSpec.getChallenge i) _ + return (transcript, challenge, proveQueryLog))).run s] ≤ + rbrKnowledgeError i)) : + (verifier.rbrKnowledgeSoundness init impl relIn relOut rbrKnowledgeError) := by + -- Provide the witnesses from hypotheses + use WitMid, extractor, kSF + intro stmtIn witIn prover i + rw [probEvent_bind_eq_tsum] + apply ENNReal.tsum_mul_le_of_le_of_sum_le_one (α := σ) (f := fun s => Pr[= s | init]) + · intro s + simp only [StateT.run'] + rw [probEvent_map] + let res := h_single_bound stmtIn witIn prover i s + exact res + · apply tsum_probOutput_le_one + +end RoundByRoundKnowledgeSoundness + +/-! ## Probability Event Simplification Lemmas for Soundness Proofs + +This section provides lemmas for simplifying `probEvent` expressions when the predicate +ignores certain parts of the output (like query logs or final states). These are essential +for reducing complex soundness goals to cleaner forms suitable for Schwartz-Zippel-style bounds. + +### Key Patterns Addressed + +1. **State marginalization**: When predicate ignores the final state from `StateT` +2. **Query log elimination**: When predicate ignores the query log from `runWithLogToRound` +3. **Combined patterns**: Full simplification for the common soundness proof shape + +### Usage + +Apply these lemmas (or use them as `simp` lemmas) to transform goals of the form: +```lean +[fun ⟨⟨transcript, challenge, _log⟩, _state⟩ => P transcript challenge | + (simulateQ impl (do ... runWithLogToRound ... getChallenge ...)).run s] +``` +into cleaner forms: +```lean +[fun ⟨transcript, challenge⟩ => P transcript challenge | + simulateQ impl (do ... runToRound ... getChallenge ...)] +``` +-/ + +section ProbEventSimplification + +open NNReal ENNReal + +variable {ι : Type} {oSpec : OracleSpec ι} [oSpec.Fintype] + {StmtIn WitIn StmtOut WitOut : Type} {n : ℕ} {pSpec : ProtocolSpec n} + [∀ i, SampleableType (pSpec.Challenge i)] + {σ : Type} + +/-! ### Lemma 1: State Marginalization + +When the predicate ignores the final state, we can use `run'` instead of `run`. -/ + +/-- When the predicate ignores the final state from a stateful computation, + the probability event can be computed using `run'` (which discards state). -/ +theorem probEvent_StateT_run_ignore_state {α : Type} + (comp : StateT σ ProbComp α) (s : σ) + (P : α → Prop) [DecidablePred P] [DecidablePred (fun x : α × σ => P x.1)] : + Pr[fun x : α × σ => P x.1 | comp.run s] = Pr[P | comp.run' s] := by + simp only [StateT.run'_eq, probEvent_map] + congr 1 + +omit [oSpec.Fintype] in +/-- Version for `simulateQ` with stateful implementation. -/ +theorem probEvent_simulateQ_run_ignore_state {α : Type} + (impl : QueryImpl oSpec (StateT σ ProbComp)) + (oa : OracleComp oSpec α) (s : σ) + (P : α → Prop) [DecidablePred P] [DecidablePred (fun x : α × σ => P x.1)] : + Pr[fun x : α × σ => P x.1 | (simulateQ impl oa).run s] = + Pr[P | (simulateQ impl oa).run' s] := by + simp only [StateT.run'_eq, probEvent_map] + congr 1 + +/-! ### Lemma 2: Query Log Elimination + +When the predicate ignores the query log from `runWithLogToRound`, we can +eliminate the logging layer entirely using `runToRound`. -/ + +omit [oSpec.Fintype] [(i : pSpec.ChallengeIdx) → SampleableType (pSpec.Challenge i)] in +/-- When the predicate ignores the query log, `runWithLogToRound` can be replaced + with `runToRound`. This is the fundamental query log elimination lemma. -/ +theorem probEvent_runWithLogToRound_ignore_log + [(oSpec + [pSpec.Challenge]ₒ).Fintype] + [(oSpec + [pSpec.Challenge]ₒ).Inhabited] + (prover : Prover oSpec StmtIn WitIn StmtOut WitOut pSpec) + (i : Fin (n + 1)) (stmt : StmtIn) (wit : WitIn) + (P : pSpec.Transcript i × prover.PrvState i → Prop) + [DecidablePred P] + [DecidablePred (fun x : (pSpec.Transcript i × prover.PrvState i) × + QueryLog (oSpec + [pSpec.Challenge]ₒ) => P x.1)] : + Pr[fun x => P x.1 | prover.runWithLogToRound i stmt wit] = + Pr[P | prover.runToRound i stmt wit] := by + rw [← Prover.runWithLogToRound_discard_log_eq_runToRound, probEvent_map] + congr 1 + +/-! ### Lemma 3: Combined Transcript-Challenge Pattern + +These lemmas handle the common pattern in soundness proofs where we compute +`(transcript, challenge, queryLog)` but only care about `(transcript, challenge)`. -/ + +/-- Projection function that extracts `(transcript, challenge)` from the full tuple + `((transcript, challenge, queryLog), state)`. -/ +@[reducible] +def projTranscriptChallenge {T C L S : Type} : ((T × C × L) × S) → T × C := + fun ⟨⟨t, c, _⟩, _⟩ => (t, c) + +/-- Projection function that extracts `(transcript, challenge)` from the inner tuple + `(transcript, challenge, queryLog)`. -/ +@[reducible] +def projTranscriptChallengeInner {T C L : Type} : (T × C × L) → T × C := + fun ⟨t, c, _⟩ => (t, c) + +/-- When computing `(transcript, challenge, queryLog)` inside a stateful simulation, + but the predicate only uses `(transcript, challenge)`, we can eliminate both + the query log and the state tracking. + This transforms: + ``` + Pr[fun ⟨⟨tr, chal, _log⟩, _state⟩ => P tr chal | (simulateQ impl computation).run s] + ``` + into a cleaner form suitable for probability analysis. -/ +theorem probEvent_proj_transcript_challenge + {T C L : Type} + (comp : StateT σ ProbComp (T × C × L)) + (s : σ) (P : T × C → Prop) + [DecidablePred P] + [DecidablePred (P ∘ projTranscriptChallenge (T := T) (C := C) (L := L) (S := σ))] : + Pr[P ∘ projTranscriptChallenge | comp.run s] = + Pr[P ∘ projTranscriptChallengeInner | comp.run' s] := by + simp only [StateT.run'_eq, probEvent_map, Function.comp_def, projTranscriptChallenge, + projTranscriptChallengeInner] + +/-! ### Lemma 4: Master Log Unrolling for Soundness Goals + +The ultimate lemmas that handle the full pattern appearing in `unroll_rbrKnowledgeSoundness`, +eliminating both the query log and state when the predicate doesn't use them. -/ + +omit [oSpec.Fintype] in +/-- **Master log unrolling lemma for soundness bounds.** + +This transforms the complex goal shape from `unroll_rbrKnowledgeSoundness`: +```lean +Pr[fun ⟨⟨transcript, challenge, _log⟩, _state⟩ => P transcript challenge | + (simulateQ (impl ++ₛₒ challengeQueryImpl) + (do + let ⟨⟨transcript, _⟩, proveQueryLog⟩ ← runWithLogToRound ... + let challenge ← getChallenge.liftComp ... + pure (transcript, challenge, proveQueryLog))).run s] +``` + +into the cleaner form without logging: +```lean +Pr[fun ⟨transcript, challenge⟩ => P transcript challenge | + (simulateQ (impl ++ₛₒ challengeQueryImpl) + (do + let ⟨transcript, _⟩ ← runToRound ... + let challenge ← getChallenge.liftComp ... + pure (transcript, challenge))).run' s] +``` + +This cleaner form is suitable for applying `probEvent_bind_eq_tsum` to factor +out the challenge for Schwartz-Zippel-style probability bounds. +-/ +theorem probEvent_soundness_goal_unroll_log + [∀ i, Fintype (pSpec.Challenge i)] [∀ i, Inhabited (pSpec.Challenge i)] + [(oSpec + [pSpec.Challenge]ₒ).Fintype] + (impl : QueryImpl oSpec (StateT σ ProbComp)) + (prover : Prover oSpec StmtIn WitIn StmtOut WitOut pSpec) + (i : pSpec.ChallengeIdx) (stmt : StmtIn) (wit : WitIn) (s : σ) + (P : pSpec.Transcript i.1.castSucc × pSpec.Challenge i → Prop) + [DecidablePred P] + [DecidablePred (fun x : (pSpec.Transcript i.1.castSucc × pSpec.Challenge i × + QueryLog (oSpec + [pSpec.Challenge]ₒ)) × σ => P (projTranscriptChallenge x))] : + Pr[fun x => P (projTranscriptChallenge x) | + (simulateQ (impl.addLift challengeQueryImpl : QueryImpl _ (StateT σ ProbComp)) + (do + let ⟨⟨transcript, _⟩, proveQueryLog⟩ ← prover.runWithLogToRound i.1.castSucc stmt wit + let challenge ← liftComp (pSpec.getChallenge i) (oSpec + [pSpec.Challenge]ₒ) + return (transcript, challenge, proveQueryLog))).run s] = + Pr[fun x => P x | + (simulateQ (impl.addLift challengeQueryImpl : QueryImpl _ (StateT σ ProbComp)) + (do + let ⟨transcript, _⟩ ← prover.runToRound i.1.castSucc stmt wit + let challenge ← liftComp (pSpec.getChallenge i) (oSpec + [pSpec.Challenge]ₒ) + return (transcript, challenge))).run' s] := by + simp only at * + have h_eq : (fun x => P (projTranscriptChallenge (T := pSpec.Transcript i.1.castSucc) + (C := pSpec.Challenge i) (L := QueryLog (oSpec + [pSpec.Challenge]ₒ)) (S := σ) x)) = + P ∘ projTranscriptChallenge (T := pSpec.Transcript i.1.castSucc) + (C := pSpec.Challenge i) (L := QueryLog (oSpec + [pSpec.Challenge]ₒ)) (S := σ) := by + ext x + simp only [Function.comp_apply] + rw [h_eq] + rw [← probEvent_map (f := projTranscriptChallenge (T := pSpec.Transcript i.1.castSucc) + (C := pSpec.Challenge i) (L := QueryLog (oSpec + [pSpec.Challenge]ₒ)) (S := σ))] + congr 1 + simp only [StateT.run'_eq] + simp only [← Prover.runWithLogToRound_discard_log_eq_runToRound] + simp only [simulateQ_bind, liftComp_query, bind_pure_comp, StateT.run_bind, Function.comp_apply, + simulateQ_map, + simulateQ_query, StateT.run_map, map_bind, Functor.map_map] + rw [bind_map_left] + +omit [oSpec.Fintype] in +/-- Variant of `probEvent_soundness_goal_unroll_log` with explicit predicate matching + the exact shape in `unroll_rbrKnowledgeSoundness`. -/ +theorem probEvent_soundness_goal_unroll_log' + [∀ i, Fintype (pSpec.Challenge i)] [∀ i, Inhabited (pSpec.Challenge i)] + [(oSpec + [pSpec.Challenge]ₒ).Fintype] + (impl : QueryImpl oSpec (StateT σ ProbComp)) + (prover : Prover oSpec StmtIn WitIn StmtOut WitOut pSpec) + (i : pSpec.ChallengeIdx) (stmt : StmtIn) (wit : WitIn) (s : σ) + (P : pSpec.Transcript i.1.castSucc → pSpec.Challenge i → Prop) + [DecidablePred (fun x : pSpec.Transcript i.1.castSucc × pSpec.Challenge i => P x.1 x.2)] + [DecidablePred (fun x : (pSpec.Transcript i.1.castSucc × pSpec.Challenge i × + QueryLog (oSpec + [pSpec.Challenge]ₒ)) × σ => P x.1.1 x.1.2.1)] : + Pr[fun x : (pSpec.Transcript i.1.castSucc × pSpec.Challenge i × + QueryLog (oSpec + [pSpec.Challenge]ₒ)) × σ => P x.1.1 x.1.2.1 | + (simulateQ (impl.addLift challengeQueryImpl : QueryImpl _ (StateT σ ProbComp)) + (do + let ⟨⟨transcript, _⟩, proveQueryLog⟩ ← prover.runWithLogToRound i.1.castSucc stmt wit + let challenge ← liftComp (pSpec.getChallenge i) (oSpec + [pSpec.Challenge]ₒ) + return (transcript, challenge, proveQueryLog))).run s] = + Pr[fun x : pSpec.Transcript i.1.castSucc × pSpec.Challenge i => P x.1 x.2 | + (simulateQ (impl.addLift challengeQueryImpl : QueryImpl _ (StateT σ ProbComp)) + (do + let ⟨transcript, _⟩ ← prover.runToRound i.1.castSucc stmt wit + let challenge ← liftComp (pSpec.getChallenge i) (oSpec + [pSpec.Challenge]ₒ) + return (transcript, challenge))).run' s] := by + have h := probEvent_soundness_goal_unroll_log (ι := ι) (oSpec := oSpec) + (StmtIn := StmtIn) (WitIn := WitIn) (StmtOut := StmtOut) (WitOut := WitOut) + (n := n) (pSpec := pSpec) (σ := σ) (impl := impl) (prover := prover) + (i := i) (stmt := stmt) (wit := wit) (s := s) (P := fun x => P x.1 x.2) + exact h + +end ProbEventSimplification + +section SoundnessUnrolling + +open OracleSpec OracleComp ProtocolSpec ProbComp + +variable {ι : Type} {oSpec : OracleSpec ι} + {StmtIn WitIn StmtOut WitOut : Type} + {n : ℕ} {pSpec : ProtocolSpec n} + [∀ i, SampleableType (pSpec.Challenge i)] + [∀ i, Fintype (pSpec.Challenge i)] [∀ i, Inhabited (pSpec.Challenge i)] + [∀ i, OracleInterface (pSpec.Message i)] + {σ : Type} + +/-- **Unroll Soundness Computation: 1 Round (P → V)** + +Unrolls `runToRound 1` when dir 0 = P_to_V (one prover message at index 0). For pSpecBatching +the challenge is at index 1; use `soundness_unroll_runToRound_2_pSpec_2` to unroll through it. + +**Usage:** `rw [soundness_unroll_runToRound_1_P_to_V_pSpec_2]` -/ +theorem soundness_unroll_runToRound_1_P_to_V_pSpec_2 + {pSpec : ProtocolSpec 2} + (prover : Prover oSpec StmtIn WitIn StmtOut WitOut pSpec) + (stmtIn : StmtIn) (witIn : WitIn) + (hDir0 : pSpec.dir 0 = .P_to_V) : + prover.runToRound 1 stmtIn witIn = + do + let msg0_state1 ← prover.sendMessage ⟨0, hDir0⟩ (prover.input (stmtIn, witIn)) + let transcript := ProtocolSpec.FullTranscript.mk1 msg0_state1.1 + return (transcript, msg0_state1.2) := by + simp only [Prover.runToRound] + have h_one_eq : (1 : Fin 3) = (1 : Fin 2).castSucc := rfl + rw! (castMode := .all) [h_one_eq, Fin.induction_init] + conv_lhs => + rw [Fin.induction_one'] + simp only [Fin.castSucc_zero] + rw [Prover.processRound_P_to_V (h := hDir0)] + simp only + dsimp only [ChallengeIdx, Fin.isValue, Fin.castSucc_zero, Fin.succ_zero_eq_one, Challenge, + Nat.reduceAdd, Fin.reduceLast] + simp only [pure_bind] + congr 1 + unfold FullTranscript.mk1 + funext i + unfold Transcript.concat + congr 1; congr 1 + funext x + fin_cases x + rfl + +/-- **Unroll Soundness Computation: 1 Round (V → P)** + +Variant when the first message (index 0) is verifier-to-prover: unrolls `runToRound 1` into +explicit `getChallenge` and `receiveChallenge` calls. Useful for ProtocolSpec 2 where dir 0 = V_to_P. + +**Usage:** `rw [soundness_unroll_runToRound_1_V_to_P_pSpec_2]` -/ +theorem soundness_unroll_runToRound_1_V_to_P_pSpec_2 + {pSpec : ProtocolSpec 2} + (prover : Prover oSpec StmtIn WitIn StmtOut WitOut pSpec) + (stmtIn : StmtIn) (witIn : WitIn) + (hDir0 : pSpec.dir 0 = .V_to_P) : + prover.runToRound 1 stmtIn witIn = + do + let challenge ← pSpec.getChallenge ⟨0, hDir0⟩ + let receiveChallengeFn ← prover.receiveChallenge ⟨0, hDir0⟩ (prover.input (stmtIn, witIn)) + let state1 := receiveChallengeFn challenge + let transcript := ProtocolSpec.FullTranscript.mk1 challenge + return (transcript, state1) := by + simp only [Prover.runToRound] + have h_one_eq : (1 : Fin 3) = (1 : Fin 2).castSucc := rfl + rw! (castMode := .all) [h_one_eq, Fin.induction_init] + conv_lhs => + rw [Fin.induction_one'] + simp only [Fin.castSucc_zero] + rw [Prover.processRound_V_to_P (h := hDir0)] + simp only + dsimp only [ChallengeIdx, Fin.isValue, Fin.castSucc_zero, Fin.succ_zero_eq_one, Challenge, + Nat.reduceAdd, Fin.reduceLast] + simp only [pure_bind] + congr 1 + unfold FullTranscript.mk1 + funext i + unfold Transcript.concat + congr 1; + funext receiveChallengeFn + congr 1; congr 1; + funext x + fin_cases x + rfl + +theorem soundness_unroll_runToRound_0_pSpec_1_V_to_P + {pSpec : ProtocolSpec 1} + (prover : Prover oSpec StmtIn WitIn StmtOut WitOut pSpec) + (stmtIn : StmtIn) (witIn : WitIn) : + prover.runToRound 0 stmtIn witIn = + pure (default, prover.input (stmtIn, witIn)) := by + simp only [Prover.runToRound] + rfl + +/-- **Unroll Soundness Computation: 2 Rounds (P → V, V → P)** + +Unrolls the computation leading up to the second challenge (Index 2). +Useful for 5-move protocols or 2-round reductions. +-/ +theorem soundness_unroll_runToRound_2_pSpec_2 + {pSpec : ProtocolSpec 2} -- Restrict to n=2 context or generally n >= 2 + (prover : Prover oSpec StmtIn WitIn StmtOut WitOut pSpec) + (stmtIn : StmtIn) (witIn : WitIn) + (hDir0 : pSpec.dir 0 = .P_to_V) (hDir1 : pSpec.dir 1 = .V_to_P) : + prover.runToRound 2 stmtIn witIn = + do + let ⟨msg0, state1⟩ ← prover.sendMessage ⟨0, hDir0⟩ (prover.input (stmtIn, witIn)) + let r1 ← pSpec.getChallenge ⟨1, hDir1⟩ + let receiveChallengeFn ← prover.receiveChallenge ⟨1, hDir1⟩ state1 + let state2 := receiveChallengeFn r1 + let transcript := ProtocolSpec.FullTranscript.mk2 msg0 r1 + return (transcript, state2) := by + simp [Prover.runToRound, Fin.induction_two', Prover.processRound_P_to_V (h := hDir0), + Prover.processRound_V_to_P (h := hDir1), ProtocolSpec.FullTranscript.mk2_eq_snoc_snoc] + +end SoundnessUnrolling + +section ProbEventToPrNotation + +open ProbabilityTheory +open scoped ProbabilityTheory + +/-- **Convert probEvent notation to Pr notation for PMF** + +Converts `[P | pmf]` (where `pmf : PMF α`) to `Pr_{ let x ← pmf }[P x]`. + +This bridges VCVio's `probEvent` notation with ArkLib's `Pr_` notation, +enabling the use of probability tools like Schwartz-Zippel. + +**Note**: `[P | pmf]` where `pmf : PMF α` is interpreted as `pmf.toOuterMeasure {x | P x}`. +-/ +theorem probEvent_PMF_eq_Pr {α : Type} (pmf : PMF α) (P : α → Prop) [DecidablePred P] : + (pmf.toOuterMeasure {x | P x}) = Pr_{ let x ← pmf }[P x] := by + -- Both sides compute the probability that P holds + -- LHS: pmf.toOuterMeasure {x | P x} = ∑' x, if P x then pmf x else 0 + -- RHS: Pr_{ let x ← pmf }[P x] = (do let x ← pmf; return P x).val True + -- = ∑' x, pmf x * (if P x then 1 else 0) + simp only [PMF.toOuterMeasure_apply] + rw [prob_tsum_form_singleton] + congr 1 + funext x + simp only [Set.indicator_apply, Set.mem_setOf_eq, mul_ite, mul_one, mul_zero] + +/-- **Convert probOutput on OracleComp to PMF value** + +If `evalDist oa = OptionT.lift pmf` for some `pmf : PMF α`, then `[= x | oa] = pmf x`. + +This is useful when an `OracleComp` evaluates to a pure `PMF` (no failure probability). +-/ +theorem probOutput_eq_PMF_apply + {ι : Type} {spec : OracleSpec ι} [spec.Fintype] [spec.Inhabited] + {α : Type} (oa : OracleComp spec α) (pmf : PMF α) (x : α) + (h : evalDist oa = OptionT.lift pmf) : + Pr[= x | oa] = pmf x := by + have h' : evalDist oa = liftM pmf := by + exact (show evalDist oa = liftM pmf from by + simpa [OptionT.liftM_def] using h) + exact (evalDist_eq_liftM_iff (mx := oa) (p := pmf)).1 h' x + +open Classical in +/-- **Convert probOutput on uniform OracleComp to Pr_ notation** + +If an `OracleComp` evaluates to uniform sampling from a finite type `L`, +then `[= x | oa]` equals the uniform probability `1/|L|` for any `x : L`. + +This can be converted to `Pr_` notation: `[= x | oa] = Pr_{ let y ← $ᵖ L }[y = x]`. +-/ +theorem probOutput_uniform_eq_Pr + {ι : Type} {spec : OracleSpec ι} [spec.Fintype] [spec.Inhabited] + {L : Type} [Fintype L] [Nonempty L] [DecidableEq L] + (oa : OracleComp spec L) (x : L) + (h : evalDist oa = OptionT.lift (PMF.uniformOfFintype L)) : + Pr[= x | oa] = Pr_{ let y ← $ᵖ L }[y = x] := by + classical + rw [probOutput_eq_PMF_apply oa (PMF.uniformOfFintype L) x h] + simp + +/-- **Convert probOutput on uniform OracleComp to Pr_ notation (using $ᵗ notation)** + +If `oa = $ᵗ L` (uniform sampling from a finite type `L`), +then `[= x | oa]` equals the uniform probability `1/|L|` for any `x : L`. + +This can be converted to `Pr_` notation: `[= x | $ᵗ L] = Pr_{ let y ← $ᵖ L }[y = x]`. + +This version uses the existing `evalDist_uniformOfFintype` lemma to derive the hypothesis. + +**Note**: The `[Inhabited L]` requirement is necessary because `evalDist_uniformOfFintype` requires it. +For field types `L`, this is automatically satisfied since `Field L` implies `Inhabited L` (via `Zero`). +-/ +theorem probOutput_uniformOfFintype_eq_Pr + {L : Type} [Fintype L] [Nonempty L] [DecidableEq L] [SampleableType L] [Inhabited L] + (x : L) : + Pr[= x | $ᵗ L] = Pr_{ let y ← $ᵖ L }[y = x] := by + refine probOutput_uniform_eq_Pr ($ᵗ L) x ?_ + simpa [OptionT.liftM_def] using (evalDist_uniformSample (α := L)) + +open Classical in +/-- **Convert sum of uniform probabilities back to Pr_ notation** + +If we have a sum over `x` where each term is `Pr_{ let y ← $ᵖ L }[y = x]` when `P x` holds, +this equals `Pr_{ let y ← $ᵖ L }[P y]`. + +This is the inverse of expanding `Pr_` notation into a sum. +-/ +theorem tsum_uniform_Pr_eq_Pr + {L : Type} [Fintype L] [Nonempty L] [DecidableEq L] + (P : L → Prop) [DecidablePred P] : + (∑' x : L, if P x then Pr_{ let y ← $ᵖ L }[y = x] else 0) = Pr_{ let y ← $ᵖ L }[P y] := by + classical + simp + +/-- **Convert probEvent on StateT.run' to tsum form** + +Converts `[P | (comp : StateT σ ProbComp α).run' s]` to a sum form using `probEvent_eq_tsum_ite`. + +**Note**: `ProbComp = OracleComp unifSpec`. The `probEvent` notation measures `Option.some '' {x | P x}` +in `PMF (Option α)`, which is equivalent to summing `[= x | comp.run' s]` over `α` where `P x` holds. + +This is useful for further manipulation, e.g., applying probability bounds. Note that we cannot +directly convert to `Pr_` notation because `evalDist` returns `PMF (Option α)`, not `PMF α`. +-/ +theorem probEvent_StateT_run'_eq_tsum + {σ α : Type} (comp : StateT σ ProbComp α) (s : σ) (P : α → Prop) [DecidablePred P] : + Pr[P | comp.run' s] = ∑' x : α, if P x then Pr[= x | comp.run' s] else 0 := by + simpa using (probEvent_eq_tsum_ite (mx := comp.run' s) (p := P)) + +end ProbEventToPrNotation + end OracleReduction diff --git a/ArkLib/OracleReduction/Security/RoundByRound.lean b/ArkLib/OracleReduction/Security/RoundByRound.lean index 434689a29..4ce243856 100644 --- a/ArkLib/OracleReduction/Security/RoundByRound.lean +++ b/ArkLib/OracleReduction/Security/RoundByRound.lean @@ -417,6 +417,22 @@ theorem rbrKnowledgeSoundnessOneShot_implies_rbrKnowledgeSoundness -- TODO: Complete this proof sorry +/-- If a verifier is RBR knowledge sound with error ε₁, and ε₂ is pointwise equal to ε₁, + then it is RBR knowledge sound with error ε₂. Use this to state soundness with a + "flat" or Fin-like error definition that you prove equal to the composed error. -/ +theorem rbrKnowledgeSoundness_of_eq_error + {relIn : Set (StmtIn × WitIn)} {relOut : Set (StmtOut × WitOut)} + {verifier : Verifier oSpec StmtIn StmtOut pSpec} + {ε₁ ε₂ : pSpec.ChallengeIdx → ℝ≥0} + (h_ε : ∀ i, ε₂ i = ε₁ i) + (h : verifier.rbrKnowledgeSoundness init impl relIn relOut ε₁) : + verifier.rbrKnowledgeSoundness init impl relIn relOut ε₂ := by + unfold rbrKnowledgeSoundness at h ⊢ + obtain ⟨WitMid, extractor, kSF, h_bound⟩ := h + refine ⟨WitMid, extractor, kSF, fun stmtIn witIn prover i => ?_⟩ + rw [h_ε i] + exact h_bound stmtIn witIn prover i + end RoundByRound end Verifier @@ -467,6 +483,19 @@ def rbrKnowledgeSoundness (rbrKnowledgeError : pSpec.ChallengeIdx → ℝ≥0) : Prop := verifier.toVerifier.rbrKnowledgeSoundness init impl relIn relOut rbrKnowledgeError +/-- If an oracle verifier is RBR knowledge sound with error ε₁ and ε₂ is pointwise equal to ε₁, + then it is RBR knowledge sound with error ε₂. Use this to state soundness with a + "flat" or Fin-like error definition that you prove equal to the composed error. -/ +theorem rbrKnowledgeSoundness_of_eq_error + {relIn : Set ((StmtIn × ∀ i, OStmtIn i) × WitIn)} + {relOut : Set ((StmtOut × ∀ i, OStmtOut i) × WitOut)} + {verifier : OracleVerifier oSpec StmtIn OStmtIn StmtOut OStmtOut pSpec} + {ε₁ ε₂ : pSpec.ChallengeIdx → ℝ≥0} + (h_ε : ∀ i, ε₂ i = ε₁ i) + (h : verifier.rbrKnowledgeSoundness init impl relIn relOut ε₁) : + verifier.rbrKnowledgeSoundness init impl relIn relOut ε₂ := + Verifier.rbrKnowledgeSoundness_of_eq_error (init := init) (impl := impl) (h_ε := h_ε) (h := h) + end OracleVerifier end OracleProtocol diff --git a/ArkLib/ProofSystem/Binius/BinaryBasefold/Basic.lean b/ArkLib/ProofSystem/Binius/BinaryBasefold/Basic.lean index 0c1279936..bf0825eaf 100644 --- a/ArkLib/ProofSystem/Binius/BinaryBasefold/Basic.lean +++ b/ArkLib/ProofSystem/Binius/BinaryBasefold/Basic.lean @@ -742,11 +742,33 @@ lemma projectToMidSumcheckPoly_at_last_eval (x := x) congr 1 -- this auto rw using h_m and h_t +/-- At `Fin.last ℓ`, the projected sumcheck polynomial is exactly the constant polynomial +equal to the product of the evaluations. This does NOT require an infinite field. -/ +lemma projectToMidSumcheckPoly_at_last_eq + (t : MultilinearPoly L ℓ) + (m : MultilinearPoly L ℓ) + (challenges : Fin ℓ → L) : + (projectToMidSumcheckPoly (L := L) (ℓ := ℓ) (t := t) (m := m) + (i := Fin.last ℓ) (challenges := challenges)).val = + MvPolynomial.C (m.val.eval challenges * t.val.eval challenges) := by + -- The domain Fin (ℓ - ℓ) is empty, so both sides are constant polynomials + -- We prove equality by showing they have the same constant coefficient + have h_dim : ℓ - ↑(Fin.last ℓ) = 0 := Nat.sub_self ℓ + -- Since Fin (ℓ - ℓ) is empty (isomorphic to Fin 0), use isEmpty instance + haveI : IsEmpty (Fin (ℓ - ↑(Fin.last ℓ))) := by + rw [h_dim] + infer_instance + rw [MvPolynomial.eq_C_of_isEmpty + (projectToMidSumcheckPoly (L := L) (ℓ := ℓ) (t := t) (m := m) + (i := Fin.last ℓ) (challenges := challenges)).val] + simp only [Fin.val_last, ← constantCoeff_eq] + rw [←projectToMidSumcheckPoly_at_last_eval (x := 0)] + simp only [Fin.val_last, MvPolynomial.eval_zero] + end SumcheckOperations variable {r : ℕ} [NeZero r] variable {L : Type} [Field L] [Fintype L] [DecidableEq L] [CharP L 2] - -- [SampleableType L] => not used variable (𝔽q : Type) [Field 𝔽q] [Fintype 𝔽q] [DecidableEq 𝔽q] [h_Fq_char_prime : Fact (Nat.Prime (ringChar 𝔽q))] [hF₂ : Fact (Fintype.card 𝔽q = 2)] variable [Algebra 𝔽q L] @@ -804,6 +826,17 @@ def OracleStatement (ϑ : ℕ) [NeZero ϑ] (i : Fin (ℓ + 1)) : let sDomainIdx := oraclePositionToDomainIndex ℓ ϑ j exact (sDomain 𝔽q β h_ℓ_add_R_rate) ⟨sDomainIdx, by omega⟩ → L +/-- First oracle witness consistency: the witness polynomial t, when projected to level 0 and + evaluated on the initial domain S^(0), must be close within unique decoding radius to f^(0) -/ +def firstOracleWitnessConsistencyProp (t : MultilinearPoly L ℓ) + (f₀ : sDomain 𝔽q β h_ℓ_add_R_rate 0 → L) : Prop := + let P₀: L[X]_(2 ^ ℓ) := polynomialFromNovelCoeffsF₂ 𝔽q β ℓ (by omega) + (fun ω => t.val.eval (bitsOfIndex ω)) + -- The constraint: P_0 evaluated on S^(0) is close within unique decoding radius to f^(0) + pair_UDRClose 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := 0) (h_i := by + simp only [Fin.coe_ofNat_eq_mod, zero_mod, _root_.zero_le]) (f := f₀) + (g := polyToOracleFunc 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (domainIdx := 0) (P := P₀)) + omit [CharP L 2] [DecidableEq 𝔽q] hF₂ h_β₀_eq_1 [NeZero 𝓡] hdiv in /-- **Oracle Access Congruence**: Proves equality of oracle evaluations `oStmtIn j x = oStmtIn j' x'` -/ @@ -838,8 +871,9 @@ structure Witness (i : Fin (ℓ + 1)) where H : L⦃≤ 2⦄[X Fin (ℓ - i)] -- Hᵢ f: (sDomain 𝔽q β h_ℓ_add_R_rate) ⟨i, by omega⟩ → L -- fᵢ -/-- The extractor that recovers the multilinear polynomial t from f^(i) -/ -noncomputable def extractMLP (i : Fin ℓ) (f : (sDomain 𝔽q β h_ℓ_add_R_rate) ⟨i, by omega⟩ → L) : +/-- The extractor that recovers the multilinear polynomial t from f^(i). +In the current protocol flow, call sites decode only the first oracle (`i = 0`). -/ +def extractMLP (i : Fin ℓ) (f : (sDomain 𝔽q β h_ℓ_add_R_rate) ⟨i, by omega⟩ → L) : Option (L⦃≤ 1⦄[X Fin (ℓ - i)]) := by set domain_size := Fintype.card (sDomain 𝔽q β h_ℓ_add_R_rate ⟨i, by omega⟩) set d := Code.distFromCode (u := f) @@ -873,8 +907,12 @@ noncomputable def extractMLP (i : Fin ℓ) (f : (sDomain 𝔽q β h_ℓ_add_R_ra match berlekamp_welch_result with | none => exact none -- Decoder failed | some P => - -- 5. Check if degree < 2^ℓ (unique decoding condition) - if hp_deg_lt: P.natDegree ≥ 2^(ℓ - i.val) then + -- 5. **post-decoding check** : Check if P's degree < 2^ℓ and `f` is UDR-Close to + -- the encoding of `P` + let isUDRClose := pair_UDRClose 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := ⟨i, by omega⟩) + (h_i := by dsimp only; omega) (f := f) (g := polyToOracleFunc 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (domainIdx := ⟨i, by omega⟩) (P := P)) + if hP_valid: P.natDegree ≥ 2^(ℓ - i.val) ∨ (¬isUDRClose) then exact none -- Outside unique decoding radius else -- 6. Convert P(X) from monomial basis to novel polynomial basis @@ -886,7 +924,8 @@ noncomputable def extractMLP (i : Fin ℓ) (f : (sDomain 𝔽q β h_ℓ_add_R_ra · simp only [hi, tsub_self, pow_zero, cast_one] by_cases hp_p_eq_0: P = 0 · simp only [hp_p_eq_0, degree_zero]; omega - · simp only [hi, tsub_self, pow_zero, ge_iff_le, not_le, lt_one_iff] at hp_deg_lt + · simp only [hi, tsub_self, pow_zero, ge_iff_le, not_or, not_le, lt_one_iff, + not_not] at hP_valid have h_deg_p: P.degree = 0 := by omega simp only [h_deg_p] omega @@ -896,35 +935,77 @@ noncomputable def extractMLP (i : Fin ℓ) (f : (sDomain 𝔽q β h_ℓ_add_R_ra simp only [degree_zero, cast_pow, cast_ofNat, gt_iff_lt] -- ⊢ ⊥ < 2 ^ (ℓ - ↑i) have h_deg_ne_bot : 2 ^ (ℓ - ↑i) ≠ ⊥ := by - exact not_isBot_iff_ne_bot.mp fun a ↦ hp_deg_lt (a P.natDegree) + exact not_isBot_iff_ne_bot.mp fun a ↦ hP_valid (Or.inl (a P.natDegree)) exact compareOfLessAndEq_eq_lt.mp rfl · have h := Polynomial.natDegree_lt_iff_degree_lt (p:=P) (n:=2 ^ (ℓ - ↑i)) (hp:=by exact hp_p_eq_0) rw [←h]; omega - let P_bounded : L⦃<2^(ℓ - i.val)⦄[X] := ⟨P, h_deg_bound⟩ -- Get monomial coefficients of P(X) let monomial_coeffs : Fin (2^(ℓ - i.val)) → L := fun i => P.coeff i.val -- Convert to novel polynomial basis coefficients using change of basis -- The changeOfBasisMatrix A has A[j,i] = coeff of X^i in novel basis vector X_j -- So we need A⁻¹ to convert monomial coeffs → novel coeffs - let novel_coeffs : Option (Fin (2^(ℓ - i.val)) → L) := - let h_ℓ_le_r : ℓ ≤ r := by - -- ℓ + 𝓡 < r implies ℓ < r, hence ℓ ≤ r - have : ℓ < r := by omega - exact Nat.le_of_lt this - some (AdditiveNTT.monomialToNovelCoeffs 𝔽q β (ℓ - i.val) (by omega) monomial_coeffs) - match novel_coeffs with - | none => exact none - | some t_coeffs => - -- Interpret novel coeffs as Lagrange cosefficients on Boolean hypercube - -- and reconstruct the multilinear polynomial using MLE - let hypercube_evals : (Fin (ℓ - i.val) → Fin 2) → L := fun w => - -- Map Boolean hypercube point w to its linear index - let w_index : Fin (2^(ℓ - i.val)) := Nat.binaryFinMapToNat - (n:=ℓ - i.val) (m:=w) (h_binary:=by intro j; simp only [Nat.cast_id]; omega) - t_coeffs w_index - let t_multilinear_mv := MvPolynomial.MLE hypercube_evals - exact some ⟨t_multilinear_mv, MLE_mem_restrictDegree hypercube_evals⟩ + -- NOTE: We intentionally use the base-basis map `monomialToNovelCoeffs` here + -- (not `getINovelCoeffs`): downstream specs at `i = 0` are phrased with + -- `polynomialFromNovelCoeffsF₂` / `bitsOfIndex`, i.e. the base novel basis. + let t_coeffs : Fin (2^(ℓ - i.val)) → L := + AdditiveNTT.monomialToNovelCoeffs 𝔽q β (ℓ - i.val) (by omega) monomial_coeffs + -- Interpret novel coeffs as Lagrange cosefficients on Boolean hypercube + -- and reconstruct the multilinear polynomial using MLE + let hypercube_evals : (Fin (ℓ - i.val) → Fin 2) → L := fun w => + -- Map Boolean hypercube point w to its linear index + let w_index : Fin (2^(ℓ - i.val)) := Nat.binaryFinMapToNat + (n:=ℓ - i.val) (m:=w) (h_binary:=by intro j; simp only [Nat.cast_id]; omega) + t_coeffs w_index + let t_multilinear_mv := MvPolynomial.MLE hypercube_evals + exact some ⟨t_multilinear_mv, MLE_mem_restrictDegree hypercube_evals⟩ + +/-- For index 0, `extractMLP 0 f = some tpoly` iff `f` is pair-UDR-close to the oracle function +of the multilinear polynomial `tpoly` (i.e. the polynomial-as-oracle from novel coeffs of tpoly). +Forward: decoder succeeds only when within UDR. Backward: within UDR the decoded codeword +is `polyToOracleFunc (polynomialFromNovelCoeffsF₂ tpoly)`. -/ +lemma extractMLP_eq_some_iff_pair_UDRClose (f : (sDomain 𝔽q β h_ℓ_add_R_rate) ⟨0, by omega⟩ → L) + (tpoly : MultilinearPoly L ℓ) : + (extractMLP 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) 0 f = some tpoly) ↔ + pair_UDRClose 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := 0) + (h_i := by simp only [Fin.coe_ofNat_eq_mod, zero_mod, _root_.zero_le]) + (f := f) + (g := polyToOracleFunc 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (domainIdx := 0) + (P := polynomialFromNovelCoeffsF₂ 𝔽q β ℓ (by omega) + (fun ω => tpoly.val.eval (bitsOfIndex ω)))) := by + sorry + +/-- If a block starting at index `0` is compliant in the sense of `isCompliant`, then the +Berlekamp–Welch decoder `extractMLP` at index `0` succeeds on the source oracle. + +Mathematically: `isCompliant` gives fiberwise-closeness of the source oracle to the +appropriate code, which implies UDR-closeness, and hence decoder success. -/ +lemma extractMLP_some_of_isCompliant_at_zero + {destIdx : Fin r} {steps : ℕ} [NeZero steps] + (zero_Idx : Fin r) (h_zero_Idx : zero_Idx.val = 0) + (h_destIdx : destIdx = 0 + steps) + (h_destIdx_le : destIdx ≤ ℓ) + (f_i : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) zero_Idx) + (f_next : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) destIdx) + (challenges : Fin steps → L) + (h_compl : + isCompliant 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := zero_Idx) (steps := steps) + (destIdx := destIdx) (h_destIdx := by omega) (h_destIdx_le := h_destIdx_le) + (f_i := f_i) (f_i_plus_steps := f_next) (challenges := challenges)) : + ∃ tpoly : MultilinearPoly L ℓ, + extractMLP 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) 0 + (fun x => f_i (cast (by + simp only [Fin.coe_ofNat_eq_mod, zero_mod, Fin.mk_zero']; + have h_eq := sDomain_eq_of_eq 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := 0) + (j := zero_Idx) (h := by apply Fin.eq_of_val_eq; simp only [Fin.coe_ofNat_eq_mod, + zero_mod, h_zero_Idx]) + rw [h_eq]) x)) = some tpoly := by + classical + -- From compliance we get fiberwise-closeness of `f_i` to the appropriate codeword, + -- which implies UDR-closeness, and therefore decoder success via + -- `extractMLP_eq_some_iff_pair_UDRClose`. + sorry def dummyLastWitness : Witness (L := L) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (Fin.last ℓ) := { @@ -934,7 +1015,7 @@ def dummyLastWitness : } /-- The initial statement for the commitment phase contains the evaluation claim s = t(r) -/ -structure InitialStatement where +structure MLPEvalStatement (L : Type) (ℓ : ℕ) where -- Original evaluation claim: s = t(r) t_eval_point : Fin ℓ → L -- r = (r_0, ..., r_{ℓ-1}) => shared input original_claim : L -- s = t(r) => the original claim to verify @@ -1053,7 +1134,6 @@ def getLastOracle {oracleFrontierIdx : Fin (ℓ + 1)} {destIdx : Fin r} have h_lt : getLastOracleDomainIndex ℓ ϑ oracleFrontierIdx < r := by omega have h_eq : destIdx = ⟨getLastOracleDomainIndex ℓ ϑ oracleFrontierIdx, h_lt⟩ := Fin.eq_of_val_eq (by omega) - -- subst h_eq fun y => res (cast (by rw [h_eq]) y) section SecurityRelations @@ -1080,27 +1160,18 @@ def getNextOracle (i : Fin (ℓ + 1)) (oStmt : ∀ j, (OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ i) j) (j : Fin (toOutCodewordsCount ℓ ϑ i)) (hj : j.val + 1 < toOutCodewordsCount ℓ ϑ i) {destDomainIdx : Fin r} (h_destDomainIdx : destDomainIdx = j.val * ϑ + ϑ) : - OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) destDomainIdx := by - -- ⟨j.val * ϑ + ϑ, by - -- apply Nat.lt_succ_of_le; - -- let h_k_next_le_i := oracle_block_k_next_le_i (ℓ := ℓ) (ϑ := ϑ) (i := i) (j := j) (hj := hj) - -- calc _ ≤ i.val := h_k_next_le_i - -- _ ≤ ℓ := Fin.is_le i - -- ⟩ := by - let res := oStmt ⟨j.val + 1, hj⟩ - have h: j.val * ϑ + ϑ = (j.val + 1) * ϑ := by - rw [Nat.add_mul, one_mul] - dsimp only [OracleStatement] at res - have h_lt : (↑j + 1) * ϑ < r := by omega - have h_eq : destDomainIdx = ⟨(↑j + 1) * ϑ, h_lt⟩ := Fin.eq_of_val_eq (by simp only; omega) - subst h_eq - -- rw! [h] - exact res + OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) destDomainIdx := + let res := oStmt ⟨j.val + 1, hj⟩ + have h : j.val * ϑ + ϑ = (j.val + 1) * ϑ := by rw [Nat.add_mul, one_mul] + have h_lt : (↑j + 1) * ϑ < r := by omega + have h_eq : destDomainIdx = ⟨(↑j + 1) * ϑ, h_lt⟩ := + Fin.eq_of_val_eq (by simp only; omega) + fun y => res (cast (by rw [h_eq]) y) /-- Folding consistency for round i (where i is the oracleIdx) -/ def oracleFoldingConsistencyProp (i : Fin (ℓ + 1)) (challenges : Fin i → L) - (oStmt : ∀ j, (OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ i) j) - (includeFinalFiberwiseClose : Bool) : Prop := + (oStmt : ∀ j, (OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ i) j) : Prop := + -- (includeFinalFiberwiseClose : Bool) : Prop := -- TODO: check index of this (∀ (j : Fin (toOutCodewordsCount ℓ ϑ i)) (hj : j.val + 1 < toOutCodewordsCount ℓ ϑ i), -- let k is j.val * ϑ @@ -1122,22 +1193,21 @@ def oracleFoldingConsistencyProp (i : Fin (ℓ + 1)) (challenges : Fin i → L) (challenges := getFoldingChallenges (r := r) (𝓡 := 𝓡) i challenges (k := j.val * ϑ) (h := h_k_next_le_i)) ) - ∧ - (if includeFinalFiberwiseClose then - -- the last oracle is fiberwise-close to its code - let curDomainIdx : Fin r := ⟨getLastOracleDomainIndex ℓ ϑ i, by omega⟩ - let destDomainIdx : Fin r := ⟨getLastOracleDomainIndex ℓ ϑ i + ϑ, by - have h_le := oracle_index_add_steps_le_ℓ ℓ ϑ (i := i) (j := getLastOraclePositionIndex ℓ ϑ i) - dsimp only [oraclePositionToDomainIndex] - omega - ⟩ - fiberwiseClose 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) - (i := curDomainIdx) (steps := ϑ) - (destIdx := destDomainIdx) (by rfl) - (by dsimp only [destDomainIdx]; simp only [oracle_index_add_steps_le_ℓ]) - (f := getLastOracle (h_destIdx := by rfl) (oracleFrontierIdx := i) 𝔽q β - (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (oStmt := oStmt)) - else True) + -- ∧ + -- (if includeFinalFiberwiseClose then + -- -- the last oracle is fiberwise-close to its code + -- let curDomainIdx : Fin r := ⟨getLastOracleDomainIndex ℓ ϑ i, by omega⟩ + -- let destDomainIdx : Fin r := ⟨getLastOracleDomainIndex ℓ ϑ i + ϑ, by + -- have h_le := oracle_index_add_steps_le_ℓ ℓ ϑ (i := i) (j := getLastOraclePositionIndex ℓ ϑ i) + -- dsimp only [oraclePositionToDomainIndex] + -- omega + -- ⟩ + -- fiberwiseClose 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + -- (i := curDomainIdx) (steps := ϑ) + -- (destIdx := destDomainIdx) (by rfl) (by dsimp only [destDomainIdx]; simp only [oracle_index_add_steps_le_ℓ]) + -- (f := getLastOracle (h_destIdx := by rfl) + -- (oracleFrontierIdx := i) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (oStmt := oStmt)) + -- else True) def BBF_eq_multiplier (r : Fin ℓ → L) : MultilinearPoly L ℓ := ⟨MvPolynomial.eqPolynomial r, by simp only [eqPolynomial_mem_restrictDegree]⟩ @@ -1154,9 +1224,9 @@ def getMidCodewords {i : Fin (ℓ + 1)} (t : L⦃≤ 1⦄[X Fin ℓ]) -- origina let f₀ : (sDomain 𝔽q β h_ℓ_add_R_rate 0) → L := fun x => P₀.val.eval x.val let fᵢ := iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := 0) (steps := i) (destIdx := ⟨i, by omega⟩) - (h_destIdx := by simp only [Fin.coe_ofNat_eq_mod, zero_mod, zero_add]) - (h_destIdx_le := by simp only; omega) - (f := f₀) (r_challenges := challenges) + (h_destIdx := by simp only [Fin.coe_ofNat_eq_mod, zero_mod, zero_add]) (h_destIdx_le := by simp only; omega) + (f := f₀) + (r_challenges := challenges) fun x => fᵢ x -- TODO: double check this? @@ -1278,17 +1348,12 @@ def witnessStructuralInvariant {i : Fin (ℓ + 1)} (stmt : Statement (L := L) Co def sumcheckConsistencyProp {k : ℕ} (sumcheckTarget : L) (H : L⦃≤ 2⦄[X Fin (k)]) : Prop := sumcheckTarget = ∑ x ∈ (univ.map 𝓑) ^ᶠ (k), H.val.eval x -/-- First oracle witness consistency: the witness polynomial t, when projected to level 0 and - evaluated on the initial domain S^(0), must be close within unique decoding radius to f^(0) -/ -def firstOracleWitnessConsistencyProp (t : MultilinearPoly L ℓ) - (f₀ : sDomain 𝔽q β h_ℓ_add_R_rate 0 → L) : Prop := - let P₀: L[X]_(2 ^ ℓ) := polynomialFromNovelCoeffsF₂ 𝔽q β ℓ (by omega) - (fun ω => t.val.eval (bitsOfIndex ω)) - -- The constraint: P_0 evaluated on S^(0) is close within unique decoding radius to f^(0) - pair_UDRClose 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := 0) (h_i := by - simp only [Fin.coe_ofNat_eq_mod, zero_mod, _root_.zero_le]) (f := f₀) - (g := polyToOracleFunc 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) - (domainIdx := 0) (P := P₀)) +lemma firstOracleWitnessConsistencyProp_unique (t₁ t₂ : MultilinearPoly L ℓ) + (f₀ : sDomain 𝔽q β h_ℓ_add_R_rate 0 → L) + (h₁ : firstOracleWitnessConsistencyProp 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) t₁ f₀) + (h₂ : firstOracleWitnessConsistencyProp 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) t₂ f₀) : + t₁ = t₂ := by + sorry /-- The bad folding event of `fᵢ` exists RIGHT AFTER the V's challenge of sumcheck round `i+ϑ-1`, this is the last point that `fᵢ` is the last oracle being sent so far and both @@ -1302,7 +1367,7 @@ noncomputable def foldingBadEventAtBlock have h_ϑ: ϑ > 0 := by exact pos_of_neZero ϑ -- TODO: check this let curOracleDomainIdx : Fin r := ⟨oraclePositionToDomainIndex (positionIdx := j), by omega⟩ - if hj: curOracleDomainIdx + ϑ ≤ oracleIdx.val then + if hj: curOracleDomainIdx + ϑ ≤ stmtIdx.val then let f_k := oStmt j let destIdx : Fin r := ⟨oraclePositionToDomainIndex (positionIdx := j) + ϑ, by have h_le := oracle_index_add_steps_le_ℓ ℓ ϑ (i := oracleIdx.val) (j := j) @@ -1310,156 +1375,220 @@ noncomputable def foldingBadEventAtBlock omega ⟩ Binius.BinaryBasefold.foldingBadEvent 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) - (i := curOracleDomainIdx) (steps := ϑ) (destIdx := destIdx) (by rfl) - (by dsimp only [destIdx]; simp only [oracle_index_add_steps_le_ℓ]) (f_i := f_k) - (r_challenges := getFoldingChallenges (r := r) (𝓡 := 𝓡) stmtIdx challenges (k := j.val * ϑ) - (h := by - have h_le := oracle_index_add_steps_le_ℓ ℓ ϑ (i := oracleIdx.val) (j := j) - simp only [curOracleDomainIdx] at hj - have h := OracleFrontierIndex.val_le_i (i := stmtIdx) (oracleIdx := oracleIdx) - exact Nat.le_trans hj h - )) + (i := curOracleDomainIdx) (steps := ϑ) (destIdx := destIdx) (by rfl) (by dsimp only [destIdx]; simp only [oracle_index_add_steps_le_ℓ]) (f_i := f_k) (r_challenges := + getFoldingChallenges (r := r) (𝓡 := 𝓡) stmtIdx challenges (k := j.val * ϑ) (h := by + simp only [curOracleDomainIdx] at hj + exact hj + )) else False +/-- For non-latest oracle positions (where j*ϑ + ϑ ≤ i.val), the bad event with +extended challenges (Fin.snoc chal r_new) at stmtIdx = i.succ equals the bad event +with original challenges (chal) at stmtIdx = i.castSucc. + +This is because: +1. Both have the same oracleIdx.val (= i.val), so the oracle statement is identical. +2. The guard is satisfied in both cases (j*ϑ + ϑ ≤ i.val ≤ i.val and ≤ i.val+1). +3. The getFoldingChallenges accesses indices < j*ϑ + ϑ ≤ i.val, where Fin.snoc + agrees with the original challenges. -/ +lemma foldingBadEventAtBlock_snoc_castSucc_eq (i : Fin ℓ) + (oStmt : ∀ j, OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (ϑ := ϑ) (i := i.castSucc) j) + (challenges : Fin i.castSucc → L) (r_new : L) + (j : Fin (toOutCodewordsCount ℓ ϑ i.castSucc)) + (hj_le : j.val * ϑ + ϑ ≤ i.castSucc.val) : + foldingBadEventAtBlock 𝔽q β (stmtIdx := i.succ) + (oracleIdx := OracleFrontierIndex.mkFromStmtIdxCastSuccOfSucc i) + (oStmt := oStmt) + (challenges := Fin.snoc challenges r_new) j = + foldingBadEventAtBlock 𝔽q β (stmtIdx := i.castSucc) + (oracleIdx := OracleFrontierIndex.mkFromStmtIdx i.castSucc) + (oStmt := oStmt) + (challenges := challenges) j := by + unfold foldingBadEventAtBlock + simp only [OracleFrontierIndex.val_mkFromStmtIdxCastSuccOfSucc, + Fin.val_castSucc, OracleFrontierIndex.val_mkFromStmtIdx, + Fin.val_succ] + -- Both guards are satisfied since j*ϑ + ϑ ≤ i.val ≤ i.val + 1 + have h_guard_succ : oraclePositionToDomainIndex (positionIdx := j) + ϑ ≤ i.val + 1 := by + simp only [Fin.val_castSucc] at ⊢ hj_le + omega + have h_guard_cast : oraclePositionToDomainIndex (positionIdx := j) + ϑ ≤ i.val := by + simp only [Fin.val_castSucc] at ⊢ hj_le + omega + simp only [h_guard_succ, h_guard_cast, ↓reduceDIte] + -- Now show the foldingBadEvent calls are equal by showing getFoldingChallenges agree + congr 1 + unfold getFoldingChallenges + ext cId + simp only [Fin.snoc] + split + · rfl + · exfalso + rename_i h_lt + simp only [not_lt] at h_lt + simp only at h_guard_cast + omega + attribute [irreducible] foldingBadEventAtBlock open Classical in -def badEventExistsProp +def blockBadEventExistsProp (stmtIdx : Fin (ℓ + 1)) (oracleIdx : OracleFrontierIndex stmtIdx) (oStmt : ∀ j, (OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ϑ := ϑ) (i := oracleIdx.val) j)) (challenges : Fin stmtIdx → L) : Prop := ∃ j, foldingBadEventAtBlock 𝔽q β (stmtIdx := stmtIdx) (oracleIdx := oracleIdx) (oStmt := oStmt) (challenges := challenges) j -def oracleWitnessConsistency +def incrementalBadEventExistsProp (stmtIdx : Fin (ℓ + 1)) (oracleIdx : OracleFrontierIndex stmtIdx) - (stmt : Statement (L := L) (Context := Context) stmtIdx) - (wit : Witness (L := L) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) stmtIdx) - (oStmt : ∀ j, (OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) - ϑ (i := oracleIdx.val) j)) : Prop := - let witnessStructuralInvariant: Prop := witnessStructuralInvariant (i:=stmtIdx) 𝔽q β (mp := mp) - (h_ℓ_add_R_rate := h_ℓ_add_R_rate) stmt wit - let firstOracleConsistency: Prop := firstOracleWitnessConsistencyProp 𝔽q β - wit.t (getFirstOracle 𝔽q β oStmt) - let oracleFoldingConsistency: Prop := oracleFoldingConsistencyProp 𝔽q β (i := oracleIdx.val) - (challenges := Fin.take (m := oracleIdx.val) (v := stmt.challenges) - (h := by simp only [Fin.val_fin_le, OracleFrontierIndex.val_le_i])) - (oStmt := oStmt) (includeFinalFiberwiseClose := true) - witnessStructuralInvariant ∧ firstOracleConsistency ∧ - oracleFoldingConsistency - + (oStmt : ∀ j, (OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ϑ := ϑ) + (i := oracleIdx.val) j)) (challenges : Fin stmtIdx → L) : Prop := + ∃ j : Fin (toOutCodewordsCount ℓ ϑ oracleIdx.val), + -- Number of challenges available for block j + let curOracleDomainIdx : Fin r := ⟨oraclePositionToDomainIndex (positionIdx := j), by omega⟩ + let k : ℕ := min ϑ (stmtIdx.val - curOracleDomainIdx.val) + have h1 := oracle_index_add_steps_le_ℓ ℓ ϑ (i := oracleIdx.val) (j := j) + have h2 : ℓ + 𝓡 < r := h_ℓ_add_R_rate + have _ : 𝓡 > 0 := pos_of_neZero 𝓡 + let midIdx : Fin r := ⟨curOracleDomainIdx.val + k, by omega⟩ + let destIdx : Fin r := ⟨curOracleDomainIdx.val + ϑ, by + dsimp only [oraclePositionToDomainIndex, curOracleDomainIdx]; omega⟩ + Binius.BinaryBasefold.incrementalFoldingBadEvent 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (block_start_idx := curOracleDomainIdx) (k := k) + (h_k_le := Nat.min_le_left ϑ (stmtIdx.val - curOracleDomainIdx.val)) + (midIdx := midIdx) (destIdx := destIdx) (h_midIdx := rfl) (h_destIdx := rfl) + (h_destIdx_le := oracle_index_add_steps_le_ℓ ℓ ϑ (i := oracleIdx.val) (j := j)) + (f_block_start := oStmt j) + (r_challenges := fun cId => challenges ⟨curOracleDomainIdx.val + cId.val, by + -- Proof that curOracleDomainIdx + cId < stmtIdx.val + have h_k_le_stmt : k ≤ stmtIdx.val - curOracleDomainIdx.val := + Nat.min_le_right ϑ (stmtIdx.val - curOracleDomainIdx.val) + have h_cId_lt_k : cId.val < k := cId.isLt + omega + ⟩) + +/-- At the terminal frontier (`stmtIdx = oracleIdx = Fin.last ℓ`), the global bad-event +predicate and incremental bad-event predicate coincide. -/ +lemma badEventExistsProp_iff_incrementalBadEventExistsProp_last + (oStmt : ∀ j, OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ (Fin.last ℓ) j) + (challenges : Fin (Fin.last ℓ) → L) : + blockBadEventExistsProp 𝔽q β + (stmtIdx := Fin.last ℓ) (oracleIdx := OracleFrontierIndex.mkFromStmtIdx (Fin.last ℓ)) + (oStmt := oStmt) (challenges := challenges) ↔ + incrementalBadEventExistsProp 𝔽q β + (stmtIdx := Fin.last ℓ) (oracleIdx := OracleFrontierIndex.mkFromStmtIdx (Fin.last ℓ)) + (oStmt := oStmt) (challenges := challenges) := by + constructor + · intro h_bad + rcases h_bad with ⟨j, h_j_bad⟩ + refine ⟨j, ?_⟩ + sorry + · intro h_inc_bad + rcases h_inc_bad with ⟨j, h_j_inc_bad⟩ + refine ⟨j, ?_⟩ + sorry + +def badSumcheckEventProp (r_i' : L) (h_i h_star : L⦃≤ 2⦄[X]) := + h_i ≠ h_star ∧ h_i.val.eval r_i' = h_star.val.eval r_i' section SingleStepRelationPreservationLemmas -/-- Oracle embedding for commit step: maps existing oracles via Sum.inl, -and the new oracle to Sum.inr ⟨0, rfl⟩ -/ -def commitStepEmbed (i : Fin ℓ) (hCR : isCommitmentRound ℓ ϑ i) : - Fin (toOutCodewordsCount ℓ ϑ i.succ) ↪ - (Fin (toOutCodewordsCount ℓ ϑ i.castSucc) ⊕ Fin 1) := ⟨fun j => by - classical - if hj : j.val < toOutCodewordsCount ℓ ϑ i.castSucc then - exact Sum.inl ⟨j.val, by omega⟩ - else - exact Sum.inr ⟨0, by omega⟩ -, by - intro a b h_ab_eq - simp only at h_ab_eq - split_ifs at h_ab_eq with h_ab_eq_l h_ab_eq_r - · simp only [Sum.inl.injEq, Fin.mk.injEq] at h_ab_eq; apply Fin.eq_of_val_eq; exact h_ab_eq - · have ha_lt : a < toOutCodewordsCount ℓ ϑ i.succ := by omega - have hb_lt : b < toOutCodewordsCount ℓ ϑ i.succ := by omega - conv_rhs at ha_lt => rw [toOutCodewordsCount_succ_eq ℓ ϑ i] - conv_rhs at hb_lt => rw [toOutCodewordsCount_succ_eq ℓ ϑ i] - simp only [hCR, ↓reduceIte] at ha_lt hb_lt - have h_a : a = toOutCodewordsCount ℓ ϑ i.castSucc := by omega - have h_b : b = toOutCodewordsCount ℓ ϑ i.castSucc := by omega - omega -⟩ - -/-- Oracle statement type equality for commit step -/ -def commitStepHEq (i : Fin ℓ) (hCR : isCommitmentRound ℓ ϑ i) : - ∀ (oracleIdx : Fin (toOutCodewordsCount ℓ ϑ i.succ)), - OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ i.succ oracleIdx = - match commitStepEmbed (r := r) (𝓡 := 𝓡) i hCR oracleIdx with - | Sum.inl j => OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ i.castSucc j - | Sum.inr _ => sDomain 𝔽q β h_ℓ_add_R_rate ⟨i.succ, by omega⟩ → L := fun oracleIdx => by - unfold OracleStatement commitStepEmbed - simp only [Function.Embedding.coeFn_mk] - by_cases hlt : oracleIdx.val < toOutCodewordsCount ℓ ϑ i.castSucc - · simp only [hlt, ↓reduceDIte] - · simp only [hlt, ↓reduceDIte,] - have hOracleIdx_lt : oracleIdx.val < toOutCodewordsCount ℓ ϑ i.succ := by omega - simp only [toOutCodewordsCount_succ_eq ℓ ϑ i, hCR, ↓reduceIte] at hOracleIdx_lt - have hOracleIdx : oracleIdx = toOutCodewordsCount ℓ ϑ i.castSucc := by omega - simp_rw [hOracleIdx] - have h := toOutCodewordsCount_mul_ϑ_eq_i_succ ℓ ϑ (i := i) (hCR := hCR) - rw! [h]; rfl - section FoldStepPreservationLemmas variable {Context : Type} {mp : SumcheckMultiplierParam L ℓ Context} end FoldStepPreservationLemmas -lemma oracleWitnessConsistency_relay_preserved - (i : Fin ℓ) (hNCR : ¬ isCommitmentRound ℓ ϑ i) - (stmt : Statement (L := L) Context i.succ) - (wit : Witness (L := L) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i.succ) +/-- blockBadEventExistsProp is preserved under relay step oracle remapping. + Key insight: hNCR means no new oracle block is completed, so bad events are the same. -/ +lemma incrementalBadEventExistsProp_relay_preserved (i : Fin ℓ) (hNCR : ¬ isCommitmentRound ℓ ϑ i) + (oStmt : ∀ j, OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ i.castSucc j) + (challenges : Fin i.succ → L) : + incrementalBadEventExistsProp 𝔽q β i.succ (OracleFrontierIndex.mkFromStmtIdxCastSuccOfSucc i) + oStmt challenges ↔ + incrementalBadEventExistsProp 𝔽q β i.succ (OracleFrontierIndex.mkFromStmtIdx i.succ) + (mapOStmtOutRelayStep 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i hNCR oStmt) challenges := by + sorry + +/-- oracleFoldingConsistencyProp is preserved under relay step oracle remapping. -/ +lemma oracleFoldingConsistencyProp_relay_preserved (i : Fin ℓ) (hNCR : ¬ isCommitmentRound ℓ ϑ i) + (challenges : Fin i.succ.val → L) (oStmt : ∀ j, OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ i.castSucc j) : - oracleWitnessConsistency 𝔽q β (mp := mp) i.succ - (oracleIdx := OracleFrontierIndex.mkFromStmtIdxCastSuccOfSucc i) stmt wit oStmt = - oracleWitnessConsistency 𝔽q β (mp := mp) i.succ - (oracleIdx := OracleFrontierIndex.mkFromStmtIdx i.succ) stmt wit - (mapOStmtOutRelayStep 𝔽q β i hNCR oStmt) := by - unfold oracleWitnessConsistency - -- All four components (witnessStructuralInvariant, sumCheckConsistency, - -- firstOracleConsistency, oracleFoldingConsistency) are preserved during relay - have h_oracle_size_eq: toOutCodewordsCount ℓ ϑ i.castSucc = toOutCodewordsCount ℓ ϑ i.succ := by - simp only [toOutCodewordsCount_succ_eq ℓ ϑ i, hNCR, ↓reduceIte] - congr 1 + oracleFoldingConsistencyProp 𝔽q β (i := i.castSucc) (Fin.init challenges) oStmt ↔ + oracleFoldingConsistencyProp 𝔽q β (i := i.succ) challenges + (mapOStmtOutRelayStep 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i hNCR oStmt) := by sorry - -- -- firstOracleConsistency: getFirstOracle is preserved - -- · unfold getFirstOracle - -- simp only [mapOStmtOutRelayStep, h_oracle_size_eq] - -- -- oracleFoldingConsistency: preserved by similar reasoning to - -- -- nonDoomedFoldingProp_relay_preserved - -- · unfold oracleFoldingConsistencyProp - -- apply propext - -- constructor <;> intro h j hj - -- · -- Forward direction - -- have h_j_mapped : j.val < toOutCodewordsCount ℓ ϑ i.castSucc := by - -- omega - -- let j_orig : Fin (toOutCodewordsCount ℓ ϑ i.castSucc) := ⟨j.val, h_j_mapped⟩ - -- have hj_orig : j_orig.val + 1 < toOutCodewordsCount ℓ ϑ i.castSucc := by - -- simp [j_orig]; omega - -- have h_spec := h j_orig hj_orig - -- unfold mapOStmtOutRelayStep getNextOracle - -- simp only [h_oracle_size_eq] - -- convert h_spec using 2 - -- · unfold getFoldingChallenges; ext cId - -- have h_take_init : Fin.take (m := i.succ) (h := by omega) stmt.challenges - -- = Fin.init stmt.challenges := by - -- ext k; simp [Fin.take, Fin.init] - -- have h_take_init_alt : Fin.take (m := i.succ) (h := by omega) stmt.challenges = - -- Fin.init stmt.challenges := by - -- ext k; simp [Fin.take, Fin.init] - -- rw [h_take_init] - -- simp [Fin.init, Fin.val_castSucc, Fin.castSucc_mk, Fin.val_succ] - -- · rfl - -- · rfl - -- · -- Backward direction - -- let j_new : Fin (toOutCodewordsCount ℓ ϑ i.succ) := ⟨j.val, by omega⟩ - -- have hj_new : j_new.val + 1 < toOutCodewordsCount ℓ ϑ i.succ := by simp [j_new]; omega - -- have h_spec := h j_new hj_new - -- unfold mapOStmtOutRelayStep getNextOracle at h_spec - -- simp only [h_oracle_size_eq] at h_spec - -- convert h_spec using 2 - -- · unfold getFoldingChallenges; ext cId - -- have h_take_init : Fin.take (m := i.succ) (h := by omega) stmt.challenges = - -- Fin.init stmt.challenges := by - -- ext k; simp [Fin.take, Fin.init] - -- rw [h_take_init] - -- simp [Fin.init, Fin.val_castSucc, Fin.castSucc_mk, Fin.val_succ] - -- · rfl - -- · rfl + +section CommitStepPreservationLemmas +/-! +## Commit Step Preservation Lemmas (Backward Direction) + +These lemmas show that properties at round 1 (after oracle commit message) imply +properties at round 0 (before oracle commit message). + +Key structure: +- Round 1: `oracleIdx = mkFromStmtIdx i.succ`, `oStmt = snoc_oracle oStmtIn newOracle` +- Round 0: `oracleIdx = mkFromStmtIdxCastSuccOfSucc i`, `oStmt = oStmtIn` + +The backward direction works because: +1. For bad events: The newly committed oracle can't have a bad event yet (needs ϑ more + challenges for its folding to be analyzed). So any bad event at round 1 must be for + an older oracle block that's also active at round 0. +2. For consistency: Round 1 checks more oracle blocks (including the new one). If all + blocks are consistent at round 1, then the subset checked at round 0 is consistent. + And `snoc_oracle` returns `oStmtIn j` for j < old_count, so the oracles match. +-/ + +/-- Bad event preservation for commit step (backward direction). + +If a bad event exists at round 1 (with synchronized oracle index and extended oracle +statement), then a bad event exists at round 0 (with lagging oracle index and original +oracle statement). + +Key insight: At round 1, the newly committed oracle at position `old_count` cannot have +a bad event because `foldingBadEventAtBlock` requires `curOracleDomainIdx + ϑ ≤ oracleIdx.val`, +but for the new oracle: `old_count * ϑ = i.val + 1` (commitment round property), so +`old_count * ϑ + ϑ = i.val + 1 + ϑ > i.val + 1 = oracleIdx.val`, making the condition false. +Therefore any bad event at round 1 must be for an older block, which is also active at round 0. -/ +lemma incrementalBadEventExistsProp_commit_step_backward (i : Fin ℓ) (hCR : isCommitmentRound ℓ ϑ i) + (oStmtIn : ∀ j, OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ i.castSucc j) + (newOracle : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (domainIdx := ⟨i.val + 1, by omega⟩)) + (challenges : Fin i.succ → L) : + incrementalBadEventExistsProp 𝔽q β i.succ (OracleFrontierIndex.mkFromStmtIdx i.succ) + (snoc_oracle 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (h_destIdx := rfl) + oStmtIn newOracle) challenges → + incrementalBadEventExistsProp 𝔽q β i.succ (OracleFrontierIndex.mkFromStmtIdxCastSuccOfSucc i) + oStmtIn challenges := by + sorry + +/-- Oracle witness consistency preservation for commit step (backward direction). + +If oracle-witness consistency holds at round 1 (with synchronized oracle index and +extended oracle statement including the new oracle), then it holds at round 0 (with +lagging oracle index and original oracle statement). + +Key insight: Round 1 checks consistency for oracle blocks 0..new_count-1, while round 0 +checks blocks 0..old_count-1 (where new_count = old_count + 1 for commitment rounds). +Since `snoc_oracle oStmtIn newOracle j = oStmtIn j` for j < old_count, consistency +for the subset at round 0 follows from consistency at round 1. + +Components: +1. `witnessStructuralInvariant`: Only depends on `stmtIdx` (same at both rounds) +2. `firstOracleWitnessConsistencyProp`: `getFirstOracle (snoc_oracle ...) = getFirstOracle oStmtIn` +3. `oracleFoldingConsistencyProp`: Fewer blocks at round 0, all using same oracle functions -/ +lemma oracleFoldingConsistencyProp_commit_step_backward (i : Fin ℓ) (hCR : isCommitmentRound ℓ ϑ i) + (challenges : Fin i.succ.val → L) + (oStmtIn : ∀ j, OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ i.castSucc j) + (newOracle : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (domainIdx := ⟨i.val + 1, by omega⟩)) : + oracleFoldingConsistencyProp 𝔽q β (i := i.succ) challenges + (snoc_oracle 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (h_destIdx := rfl) + oStmtIn newOracle) → + oracleFoldingConsistencyProp 𝔽q β (i := i.castSucc) (Fin.init challenges) oStmtIn := by + sorry + +end CommitStepPreservationLemmas end SingleStepRelationPreservationLemmas /-- Before V's challenge of the `i-th` foldStep, we ignore the bad-folding-event @@ -1473,12 +1602,16 @@ def masterKStateProp (stmtIdx : Fin (ℓ + 1)) (oStmt : ∀ j, (OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ (i := oracleIdx.val) j)) (localChecks : Prop := True) : Prop := - let oracleWitnessConsistency: Prop := oracleWitnessConsistency 𝔽q β (mp := mp) - stmtIdx oracleIdx stmt wit oStmt - let badEventExists := badEventExistsProp 𝔽q β stmtIdx oracleIdx + let structural := witnessStructuralInvariant 𝔽q β (mp := mp) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) stmt wit + let initial := firstOracleWitnessConsistencyProp 𝔽q β wit.t (getFirstOracle 𝔽q β oStmt) + let oracleFoldingConsistency: Prop := oracleFoldingConsistencyProp 𝔽q β (i := oracleIdx.val) + (challenges := Fin.take (m := oracleIdx.val) (v := stmt.challenges) + (h := by simp only [Fin.val_fin_le, OracleFrontierIndex.val_le_i])) + (oStmt := oStmt) + let badEventExists := incrementalBadEventExistsProp 𝔽q β stmtIdx oracleIdx (challenges := stmt.challenges) (oStmt := oStmt) - let core := badEventExists ∨ oracleWitnessConsistency - localChecks ∧ core + let good := localChecks ∧ structural ∧ initial ∧ oracleFoldingConsistency + badEventExists ∨ good def roundRelationProp (i : Fin (ℓ + 1)) (input : (Statement (L := L) Context i × @@ -1508,18 +1641,47 @@ def foldStepRelOutProp (i : Fin ℓ) stmt wit oStmt (localChecks := sumCheckConsistency) +def finalSumcheckStepOracleConsistencyProp {h_le : ϑ ≤ ℓ} + (stmtOut : FinalSumcheckStatementOut (L := L) (ℓ := ℓ)) + (oStmtOut : ∀ j, OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ + (Fin.last ℓ) j) : Prop := + let j := getLastOraclePositionIndex (ℓ := ℓ) (ϑ := ϑ) (Fin.last ℓ) -- actually `j = ℓ / ϑ - 1` + let k := j.val * ϑ -- k = getLastOracleDomainIndex (Fin.last ℓ) + have h_k: k = ℓ - ϑ := by + dsimp only [k, j] + rw [getLastOraclePositionIndex_last] + rw [Nat.sub_mul, Nat.one_mul] + rw [Nat.div_mul_cancel (hdiv.out)] + let f_k := oStmtOut j + let challenges : Fin ϑ → L := fun cId => stmtOut.challenges ⟨k + cId, by + simp only [Fin.val_last, k, j] + rw [getLastOraclePositionIndex_last, Nat.sub_mul, Nat.one_mul, Nat.div_mul_cancel (hdiv.out)] + rw [Nat.sub_add_eq_sub_sub_rev (h1:=by omega) (h2:=by omega)]; omega + ⟩ + -- **NOTE**: we must have this final oracle compliance check between the + -- last explicit oracle and the virtual oracle (fun x => c) at the final sumcheck step + -- because the virtual oracle is not availabe to be in commit steps of the interaction rounds + let finalOracleFoldingConsistency: Prop := by + -- folding consistency between two adjacent oracles `j` & `j + ϑ` + exact isCompliant 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := ⟨k, by omega⟩) (steps := ϑ) (destIdx := ⟨k + ϑ, by omega⟩) (by rfl) (by simp only; omega) (f_i := f_k) + (f_i_plus_steps := fun x => stmtOut.final_constant) (challenges := challenges) + -- If oracleFoldingConsistency is true, then we can extract the original + -- well-formed poly `t` and derive witnesses that satisfy the relations at any state + oracleFoldingConsistencyProp 𝔽q β (i := Fin.last ℓ) + (challenges := stmtOut.challenges) (oStmt := oStmtOut) + ∧ finalOracleFoldingConsistency + /-- This is a special case of nonDoomedFoldingProp for `i = ℓ`, where we support the consistency between the last oracle `ℓ - ϑ` and the final constant `c`. This definition has form similar to masterKState where there is no localChecks. -/ -def finalFoldingStateProp {h_le : ϑ ≤ ℓ} +def finalSumcheckStepFoldingStateProp {h_le : ϑ ≤ ℓ} (input : (FinalSumcheckStatementOut (L := L) (ℓ := ℓ) × (∀ j, OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ (Fin.last ℓ) j))) : Prop := - let stmt := input.1 - let oStmt := input.2 - -- let f_ℓ: (sDomain 𝔽q β h_ℓ_add_R_rate) ⟨ℓ, by omega⟩ → L := fun x => stmt.final_constant + let stmtOut := input.1 + let oStmtOut := input.2 let j := getLastOraclePositionIndex (ℓ := ℓ) (ϑ := ϑ) (Fin.last ℓ) -- actually `j = ℓ / ϑ - 1` let k := j.val * ϑ -- k = getLastOracleDomainIndex (Fin.last ℓ) have h_k: k = ℓ - ϑ := by @@ -1527,37 +1689,21 @@ def finalFoldingStateProp {h_le : ϑ ≤ ℓ} rw [getLastOraclePositionIndex_last] rw [Nat.sub_mul, Nat.one_mul] rw [Nat.div_mul_cancel (hdiv.out)] - let f_k := oStmt j - let challenges : Fin ϑ → L := fun cId => stmt.challenges ⟨k + cId, by + let f_k := oStmtOut j + let challenges : Fin ϑ → L := fun cId => stmtOut.challenges ⟨k + cId, by simp only [Fin.val_last, k, j] rw [getLastOraclePositionIndex_last, Nat.sub_mul, Nat.one_mul, Nat.div_mul_cancel (hdiv.out)] rw [Nat.sub_add_eq_sub_sub_rev (h1:=by omega) (h2:=by omega)]; omega ⟩ have h_k_add_ϑ: k + ϑ = ℓ := by rw [h_k]; apply Nat.sub_add_cancel; omega - let finalOracleFoldingConsistency: Prop := by - -- folding consistency between two adjacent oracles `j` & `j + ϑ` - exact isCompliant 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := ⟨k, by omega⟩) (steps := ϑ) - (destIdx := ⟨k + ϑ, by omega⟩) (by rfl) (by simp only; omega) (f_i := f_k) - (f_i_plus_steps := fun x => stmt.final_constant) (challenges := challenges) - -- If oracleFoldingConsistency is true, then we can extract the original - -- well-formed poly `t` and derive witnesses that satisfy the relations at any state let oracleFoldingConsistency: Prop := - (oracleFoldingConsistencyProp 𝔽q β (i := Fin.last ℓ) - (challenges := stmt.challenges) (oStmt := oStmt) - (includeFinalFiberwiseClose := false)) - -- Note: we ignore the fiberwise-closeness of last oracle since it's - -- available in finalOracleFoldingConsistency - ∧ finalOracleFoldingConsistency - let finalFoldingBadEvent : Prop := - Binius.BinaryBasefold.foldingBadEvent 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) - (i := ⟨k, by omega⟩) (steps := ϑ) (destIdx := ⟨k + ϑ, by omega⟩) - (by rfl) (by simp only; omega) (f_i := f_k) - (r_challenges := challenges) + finalSumcheckStepOracleConsistencyProp 𝔽q β (h_le := h_le) (stmtOut := stmtOut) + (oStmtOut := oStmtOut) -- All bad folding events are fully formed across the sum-check rounds, - -- no new bad event at the final sumcheck step - let foldingBadEventExists : Prop := (badEventExistsProp 𝔽q β (stmtIdx := Fin.last ℓ) + -- no new bad event needed at the final sumcheck step + let foldingBadEventExists : Prop := (blockBadEventExistsProp 𝔽q β (stmtIdx := Fin.last ℓ) (oracleIdx := OracleFrontierIndex.mkFromStmtIdx (Fin.last ℓ)) - (oStmt := oStmt) (challenges := stmt.challenges)) ∨ finalFoldingBadEvent + (oStmt := oStmtOut) (challenges := stmtOut.challenges)) oracleFoldingConsistency ∨ foldingBadEventExists /-- **Relaxed fold step output relation for RBR Knowledge Soundness**. @@ -1597,7 +1743,7 @@ def finalSumcheckRelOutProp (Unit))) : Prop := -- Final oracle consistency and bad events - finalFoldingStateProp 𝔽q β + finalSumcheckStepFoldingStateProp 𝔽q β (h_le := by apply Nat.le_of_dvd (by exact Nat.pos_of_neZero ℓ) (hdiv.out)) (input := input.1) @@ -1718,13 +1864,13 @@ def strictFoldStepRelOutProp (i : Fin ℓ) /-- **Strict final folding state property** (for Completeness). -This is the strict version of `finalFoldingStateProp` that: +This is the strict version of `finalSumcheckStepFoldingStateProp` that: - Removes all bad event tracking - Uses exact code membership and equality instead of proximity-based checks - Ensures deterministic preservation with probability 1 Used only for Perfect Completeness proofs. -/ -def strictFinalFoldingStateProp (t : MultilinearPoly L ℓ) {h_le : ϑ ≤ ℓ} +def strictfinalSumcheckStepFoldingStateProp (t : MultilinearPoly L ℓ) {h_le : ϑ ≤ ℓ} (input : (FinalSumcheckStatementOut (L := L) (ℓ := ℓ) × (∀ j, OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ (Fin.last ℓ) j))) : Prop := @@ -1804,7 +1950,7 @@ def strictFinalSumcheckRelOutProp (∀ j, OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ (Fin.last ℓ) j)) × (Unit))) : Prop := -- Final oracle consistency with exact equality - ∃ (t : MultilinearPoly L ℓ), strictFinalFoldingStateProp 𝔽q β (t := t) + ∃ (t : MultilinearPoly L ℓ), strictfinalSumcheckStepFoldingStateProp 𝔽q β (t := t) (h_le := by apply Nat.le_of_dvd (by exact Nat.pos_of_neZero ℓ) (hdiv.out)) (input := input.1) diff --git a/ArkLib/ProofSystem/Binius/BinaryBasefold/Code.lean b/ArkLib/ProofSystem/Binius/BinaryBasefold/Code.lean index 330f03447..e678455e2 100644 --- a/ArkLib/ProofSystem/Binius/BinaryBasefold/Code.lean +++ b/ArkLib/ProofSystem/Binius/BinaryBasefold/Code.lean @@ -116,8 +116,12 @@ lemma BBF_CodeDistance_eq (i : Fin r) (h_i : i ≤ ℓ) : /-- Disagreement set Δ : The set of points where two functions disagree. For functions f^(i) and g^(i), this is {y ∈ S^(i) | f^(i)(y) ≠ g^(i)(y)}. -/ def disagreementSet (i : Fin r) + {destIdx : Fin r} (h_destIdx : destIdx = i.val) (f g : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i) : - Finset ((sDomain 𝔽q β h_ℓ_add_R_rate) i) := {y | f y ≠ g y} + Finset ((sDomain 𝔽q β h_ℓ_add_R_rate) destIdx) := + have h_destIdx_eq_i : destIdx = i := Fin.ext h_destIdx + {(y : (sDomain 𝔽q β h_ℓ_add_R_rate) destIdx) | + f (cast (by subst h_destIdx_eq_i; rfl) y) ≠ g (cast (by subst h_destIdx_eq_i; rfl) y)} /-- Fiber-wise disagreement set Δ^(i) : The set of points y ∈ S^(i+ϑ) for which functions f^(i) and g^(i) are not identical when restricted to the entire fiber @@ -131,6 +135,46 @@ def fiberwiseDisagreementSet (i : Fin r) {destIdx : Fin r} (steps : ℕ) {y | ∃ x, (iteratedQuotientMap 𝔽q β h_ℓ_add_R_rate (i := i) (k := steps) h_destIdx h_destIdx_le x) = y ∧ f x ≠ g x} +omit [CharP L 2] [DecidableEq 𝔽q] hF₂ [NeZero 𝓡] in +lemma fiberwiseDisagreementSet_congr_sourceDomain_index (sourceIdx₁ sourceIdx₂ : Fin r) {destIdx : Fin r} (steps : ℕ) + (h_sourceIdx_eq : sourceIdx₁ = sourceIdx₂) + (h_destIdx : destIdx = sourceIdx₁.val + steps) (h_destIdx_le : destIdx ≤ ℓ) + (f g : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) sourceIdx₁) : + -- have h_sourceIdx_eq : sourceIdx₁ = sourceIdx₂ := Fin.ext h_sourceIdx_eq_sourceIdx₂ + let Δ_fiber₁ := fiberwiseDisagreementSet 𝔽q β sourceIdx₁ steps h_destIdx h_destIdx_le f g + let Δ_fiber₂ := fiberwiseDisagreementSet 𝔽q β sourceIdx₂ steps (by omega) h_destIdx_le (fun x => f (cast (by subst h_sourceIdx_eq; rfl) x)) (fun x => g (cast (by subst h_sourceIdx_eq; rfl) x)) + Δ_fiber₁ = Δ_fiber₂ := by + subst h_sourceIdx_eq + rfl + +omit [CharP L 2] [DecidableEq 𝔽q] hF₂ [NeZero 𝓡] in +/-- When `steps = 0`, the fiberwise disagreement set (projecting to `S^{i+0} = S^i`) +equals the ordinary pointwise disagreement set. +Both sides are stated with `destIdx := i` so they share the same `Finset` type. -/ +@[simp] +lemma fiberwiseDisagreementSet_steps_zero_eq_disagreementSet + (i destIdx : Fin r) (h_destIdx : destIdx = i.val + 0) (h_destIdx_le : destIdx ≤ ℓ) + (f g : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i) : + fiberwiseDisagreementSet 𝔽q β i (steps := 0) (destIdx := destIdx) (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) f g = + disagreementSet 𝔽q β (i := i) (destIdx := destIdx) (h_destIdx := h_destIdx) f g := by + -- iteratedQuotientMap at k = 0 evaluates intermediateNormVpoly (k=0) = X, i.e. the identity + -- have iqm_id : ∀ x : (sDomain 𝔽q β h_ℓ_add_R_rate) i, + -- iteratedQuotientMap 𝔽q β h_ℓ_add_R_rate i (k := 0) (destIdx := destIdx) (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) x = x := by + -- intro x + -- apply Subtype.ext + -- simp only [iteratedQuotientMap, intermediateNormVpoly, Fin.foldl_zero, Polynomial.eval_X, + -- Subtype.coe_eta] + -- ext y + -- simp only [fiberwiseDisagreementSet, disagreementSet, Finset.mem_filter, Finset.mem_univ, + -- true_and, ne_eq] + -- constructor + -- · rintro ⟨x, hxy, hfg⟩ + -- rw [iqm_id x] at hxy + -- exact hxy ▸ hfg + -- · intro hfg + -- exact ⟨y, iqm_id y, hfg⟩ + sorry + def pair_fiberwiseDistance (i : Fin r) {destIdx : Fin r} (steps : ℕ) (h_destIdx : destIdx = i.val + steps) (h_destIdx_le : destIdx ≤ ℓ) (f g : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i) : ℕ := @@ -139,7 +183,7 @@ def pair_fiberwiseDistance (i : Fin r) {destIdx : Fin r} (steps : ℕ) /-- Fiber-wise distance d^(i) : The minimum size of the fiber-wise disagreement set between f^(i) and any codeword in C^(i). -/ def fiberwiseDistance (i : Fin r) {destIdx : Fin r} (steps : ℕ) - (h_destIdx : destIdx = i + steps) (h_destIdx_le : destIdx ≤ ℓ) + (h_destIdx : destIdx = i.val + steps) (h_destIdx_le : destIdx ≤ ℓ) (f : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i) : ℕ := -- The minimum size of the fiber-wise disagreement set between f^(i) and any codeword in C^(i) -- d^(i)(f^(i), C^(i)) := min_{g^(i) ∈ C^(i)} |Δ^(i)(f^(i), g^(i))| @@ -152,23 +196,23 @@ def fiberwiseDistance (i : Fin r) {destIdx : Fin r} (steps : ℕ) /-- Fiberwise closeness : f^(i) is fiberwise close to C^(i) if 2 * d^(i)(f^(i), C^(i)) < d_{i+steps} -/ -def fiberwiseClose (i : Fin r) {destIdx : Fin r} (steps : ℕ) [NeZero steps] - (h_destIdx : destIdx = i.val + steps) (h_destIdx_le : destIdx ≤ ℓ) - (f : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i) : Prop := +def fiberwiseClose (i : Fin r) {destIdx : Fin r} (steps : ℕ) + (h_destIdx : destIdx = i.val + steps) (h_destIdx_le : destIdx ≤ ℓ) + (f : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i) : Prop := 2 * (fiberwiseDistance 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := i) (steps := steps) (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) (f := f)) < (BBF_CodeDistance 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := destIdx): ℕ∞) -def pair_fiberwiseClose (i : Fin r) {destIdx : Fin r} (steps : ℕ) [NeZero steps] - (h_destIdx : destIdx = i.val + steps) (h_destIdx_le : destIdx ≤ ℓ) - (f g : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i) : Prop := +def pair_fiberwiseClose (i : Fin r) {destIdx : Fin r} (steps : ℕ) + (h_destIdx : destIdx = i.val + steps) (h_destIdx_le : destIdx ≤ ℓ) + (f g : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i) : Prop := 2 * pair_fiberwiseDistance 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := i) steps h_destIdx h_destIdx_le (f := f) (g := g) < (BBF_CodeDistance 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := destIdx): ℕ∞) /-- Hamming UDR-closeness : f is close to C in Hamming distance if `2 * d(f, C) < d_i` -/ def UDRClose (i : Fin r) (h_i : i ≤ ℓ) - (f : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i) : Prop := + (f : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i) : Prop := 2 * Δ₀(f, (BBF_Code 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i)) < BBF_CodeDistance 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := i) @@ -176,6 +220,39 @@ def pair_UDRClose (i : Fin r) (h_i : i ≤ ℓ) (f g : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i) : Prop := 2 * Δ₀(f, g) < BBF_CodeDistance 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := i) +omit [CharP L 2] [DecidableEq 𝔽q] hF₂ [NeZero 𝓡] in +/-- When `steps = 0`, `pair_fiberwiseDistance` equals the Hamming distance. -/ +@[simp] +lemma pair_fiberwiseDistance_steps_zero_eq_hammingDist + (i : Fin r) (h_i_le : i ≤ ℓ) + (f g : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i) : + pair_fiberwiseDistance 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := i) (destIdx := i) (steps := 0) (h_destIdx := by omega) (h_destIdx_le := by omega) (f := f) (g := g) = hammingDist f g := by + rw [pair_fiberwiseDistance, fiberwiseDisagreementSet_steps_zero_eq_disagreementSet 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := i) (destIdx := i) (h_destIdx := by omega) (h_destIdx_le := by omega) f g] + simp only [disagreementSet, cast_eq, ne_eq, card_filter, ite_not, hammingDist] + +omit [CharP L 2] [DecidableEq 𝔽q] hF₂ [NeZero 𝓡] in +/-- When `steps = 0`, fiberwise closeness coincides with UDR closeness. -/ +@[simp] +lemma fiberwiseClose_steps_zero_iff_UDRClose + (i destIdx : Fin r) (h_destIdx : destIdx = i.val + 0) (h_destIdx_le : destIdx ≤ ℓ) + (f : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i) : + fiberwiseClose 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i (steps := 0) (destIdx := destIdx) (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) f ↔ + UDRClose 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i (h_i := by omega) f := by + sorry + +omit [CharP L 2] [DecidableEq 𝔽q] hF₂ [NeZero 𝓡] in +lemma fiberwiseClose_congr_sourceDomain_index (sourceIdx₁ sourceIdx₂ : Fin r) {destIdx : Fin r} (steps : ℕ) + (h_sourceIdx_eq : sourceIdx₁ = sourceIdx₂) + (h_destIdx : destIdx = sourceIdx₁.val + steps) (h_destIdx_le : destIdx ≤ ℓ) + (f : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) sourceIdx₁) : + -- have h_sourceIdx_eq : sourceIdx₁ = sourceIdx₂ := Fin.ext h_sourceIdx_eq_sourceIdx₂ + let Δ_fiber₁ := fiberwiseClose 𝔽q β sourceIdx₁ steps h_destIdx h_destIdx_le f + let Δ_fiber₂ := fiberwiseClose 𝔽q β sourceIdx₂ steps (by omega) h_destIdx_le (fun x => f (cast (by subst h_sourceIdx_eq; rfl) x)) + Δ_fiber₁ = Δ_fiber₂ := by + subst h_sourceIdx_eq + rfl + section ConstantFunctions omit [CharP L 2] [DecidableEq 𝔽q] hF₂ h_β₀_eq_1 [NeZero ℓ] [NeZero 𝓡] in @@ -252,7 +329,6 @@ lemma UDRClose_iff_within_UDR_radius (i : Fin r) (h_i : i ≤ ℓ) omega /-- Unique closest codeword in the unique decoding radius of a function f -/ -@[reducible, simp] def UDRCodeword (i : Fin r) (h_i : i ≤ ℓ) (f : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i) (h_within_radius : UDRClose 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i h_i f) : @@ -992,10 +1068,9 @@ def fold_error_containment (i : Fin r) {destIdx : Fin r} (steps : ℕ) (i := i) (steps := steps) h_destIdx h_destIdx_le (f := f) (g := f_bar) let folded_f := iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (steps := steps) (i := i) h_destIdx h_destIdx_le (f := f) (r_challenges := r_challenges) - let folded_f_bar := iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (steps := steps) - (i := i) h_destIdx h_destIdx_le (f := f_bar) (r_challenges := r_challenges) - let folded_Δ_set := disagreementSet 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) - (i := destIdx) (f := folded_f) (g := folded_f_bar) + let folded_f_bar := iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (steps := steps) (i := i) + h_destIdx h_destIdx_le (f := f_bar) (r_challenges := r_challenges) + let folded_Δ_set := disagreementSet 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := destIdx) (destIdx := destIdx) (h_destIdx := rfl) (f := folded_f) (g := folded_f_bar) folded_Δ_set ⊆ fiberwise_Δ_set /-! **Lemma 4.18.** For each `i ∈ {0, steps, ..., ℓ-steps}`, if `f⁽ⁱ⁾` is `UDR-close`, then, for @@ -1118,7 +1193,8 @@ def foldingBadEvent (i : Fin r) {destIdx : Fin r} (steps : ℕ) (f := f_bar_i) (r_challenges := r_challenges) -- The Bad Condition: FiberDisagreements ⊈ FoldedDisagreements ¬ (fiberwiseDisagreementSet 𝔽q β i steps h_destIdx h_destIdx_le (f := f_i) (g := f_bar_i) ⊆ - disagreementSet 𝔽q β (i := destIdx) (f := folded_f_i) (g := folded_f_bar_i)) + disagreementSet 𝔽q β (i := destIdx) (destIdx := destIdx) (h_destIdx := rfl) (f := folded_f_i) (g := folded_f_bar_i)) + else -- Case 2 : The oracle `f_i` is fiber-wise "far" from the code. -- Folding a "far" function should result in another "far" function. @@ -1126,4 +1202,132 @@ def foldingBadEvent (i : Fin r) {destIdx : Fin r} (steps : ℕ) UDRClose 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := destIdx) (h_i := by omega) (f := folded_f_i) +open Classical in +/-- **Definition 4.19.2** (Incremental Bad Events). +For block start index `block_start_idx`, block size `ϑ`, and **partial step count** +`k ≤ ϑ`, with `destIdx = block_start_idx + ϑ` (the block destination), `E(block_start_idx, k)` is defined as follows: + +- If `k = 0`: Returns `False` (no challenges consumed yet). +- Case 1 (fiberwise close at block level): + `Δ⁽ⁱ⁾(f, f̄) ⊄ Δ⁽ⁱ⁺ᵏ⁾(fold_k(f), fold_k(f̄))` + where both sides are projected to `S^{i+ϑ}` (the block destination). +- Case 2 (fiberwise far at block level): + `d⁽ⁱ⁺ᵏ⁾(fold_k(f), C⁽ⁱ⁺ᵏ⁾) < d_{i+ϑ}/2` + where `d⁽ⁱ⁺ᵏ⁾` is the fiberwise distance projected to `S^{i+ϑ}`. + +When `k = ϑ`, this coincides with `foldingBadEvent`. -/ +def incrementalFoldingBadEvent + (block_start_idx : Fin r) {midIdx destIdx : Fin r} (k : ℕ) + (h_k_le : k ≤ ϑ) (h_midIdx : midIdx = block_start_idx + k) + (h_destIdx : destIdx = block_start_idx + ϑ) (h_destIdx_le : destIdx ≤ ℓ) + (f_block_start : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) block_start_idx) + (r_challenges : Fin k → L) : Prop := + have h_ik_le : block_start_idx.val + k ≤ ℓ := by omega + + -- midIdx + (ϑ - k) = block_start_idx + ϑ = destIdx + have h_midIdx_to_block : destIdx = midIdx + (ϑ - k) := by omega + + let folded_f_block_start := iterated_fold 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := block_start_idx) (steps := k) (destIdx := midIdx) + (h_destIdx := h_midIdx) (h_destIdx_le := by omega) + (f := f_block_start) (r_challenges := r_challenges) + + if h_is_close : (fiberwiseClose 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := block_start_idx) (steps := ϑ) + h_destIdx h_destIdx_le (f := f_block_start)) then + -- Case 1 : fiberwise close (block-level classification) + let f_bar_block_start := UDRCodeword 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + block_start_idx (f := f_block_start) + (h_within_radius := UDRClose_of_fiberwiseClose 𝔽q β + block_start_idx ϑ h_destIdx h_destIdx_le + f_block_start h_is_close) + + let folded_f_bar_block_start := iterated_fold 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := block_start_idx) (steps := k) + (h_destIdx := h_midIdx) (h_destIdx_le := by omega) + (f := f_bar_block_start) (r_challenges := r_challenges) + + -- Bad: Δ⁽ⁱ⁾(f, f̄) ⊄ Δ⁽ⁱ⁺ᵏ⁾(fold_k(f), fold_k(f̄)) + -- Both projected to S^{i+ϑ} = S^destIdx + ¬ (fiberwiseDisagreementSet 𝔽q β + block_start_idx ϑ h_destIdx h_destIdx_le + f_block_start f_bar_block_start + ⊆ + fiberwiseDisagreementSet 𝔽q β + midIdx (ϑ - k) h_midIdx_to_block h_destIdx_le + folded_f_block_start folded_f_bar_block_start) + else + -- Case 2 : fiberwise far (block-level classification) + -- Bad: d⁽ⁱ⁺ᵏ⁾(fold_k(f), C⁽ⁱ⁺ᵏ⁾) < d_{i+ϑ}/2 + -- projected to S^{i+ϑ}, threshold = d_{destIdx} + fiberwiseClose 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := midIdx) (steps := ϑ - k) + (h_destIdx := h_midIdx_to_block) + (h_destIdx_le := h_destIdx_le) + (f := folded_f_block_start) + +omit [CharP L 2] in +/-- When all folding steps have been applied (`k = ϑ`), the incremental bad event +coincides with the full `foldingBadEvent`. -/ +@[simp] +lemma incrementalFoldingBadEvent_of_k_eq_0_is_false + (block_start_idx : Fin r) (k : ℕ) (h_k : k = 0) {midIdx destIdx : Fin r} + (h_midIdx : midIdx.val = block_start_idx.val) (h_destIdx : destIdx = block_start_idx + ϑ) + (h_destIdx_le : destIdx ≤ ℓ) + (f_block_start : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) block_start_idx) + (r_challenges : Fin k → L) : + ¬(incrementalFoldingBadEvent 𝔽q β (block_start_idx := block_start_idx) (ϑ := ϑ) (k := k) (h_k_le := by omega) + (midIdx := midIdx) (destIdx := destIdx) + (h_midIdx := by omega) (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) + f_block_start r_challenges):= by + subst h_k + unfold incrementalFoldingBadEvent + simp only [tsub_zero] + by_cases h_close : fiberwiseClose 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := block_start_idx) (steps := ϑ) + h_destIdx h_destIdx_le (f := f_block_start) + · simp only [h_close, ↓reduceDIte] + rw [iterated_fold_zero_steps 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (destIdx := midIdx) (h_destIdx := by omega) (h_destIdx_le := by omega)] + rw [iterated_fold_zero_steps 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (destIdx := midIdx) (h_destIdx := by omega) (h_destIdx_le := by omega)] + rw [fiberwiseDisagreementSet_congr_sourceDomain_index 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (sourceIdx₁ := block_start_idx) (sourceIdx₂ := midIdx) (h_sourceIdx_eq := by omega)] + simp only [subset_refl, not_true_eq_false, not_false_eq_true] + · simp only [h_close, ↓reduceDIte] + rw [fiberwiseClose_congr_sourceDomain_index 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (sourceIdx₁ := block_start_idx) (sourceIdx₂ := midIdx) (h_sourceIdx_eq := by omega)] at h_close + rw [iterated_fold_zero_steps 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (destIdx := midIdx) (h_destIdx := by omega) (h_destIdx_le := by omega) (f := f_block_start) (r_challenges := r_challenges)] + exact h_close + +omit [CharP L 2] in +/-- When all folding steps have been applied (`k = ϑ`), the incremental bad event +coincides with the full `foldingBadEvent`. -/ +lemma incrementalFoldingBadEvent_eq_foldingBadEvent_of_k_eq_ϑ + (block_start_idx : Fin r) {midIdx destIdx : Fin r} + (h_midIdx : midIdx.val = destIdx.val) (h_destIdx : destIdx = block_start_idx + ϑ) + (h_destIdx_le : destIdx ≤ ℓ) + (f_block_start : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) block_start_idx) + (r_challenges : Fin ϑ → L) : + incrementalFoldingBadEvent 𝔽q β block_start_idx ϑ (h_k_le := le_refl ϑ) + (midIdx := midIdx) (destIdx := destIdx) + (h_midIdx := by omega) (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) + f_block_start r_challenges ↔ + foldingBadEvent 𝔽q β block_start_idx ϑ h_destIdx h_destIdx_le + f_block_start r_challenges := by + unfold incrementalFoldingBadEvent foldingBadEvent + simp only [show ϑ ≠ 0 from NeZero.ne ϑ, ↓reduceDIte, Nat.sub_self] + have h_midIdx_eq_destIdx : midIdx = destIdx := by omega + subst h_midIdx_eq_destIdx + by_cases h_close : fiberwiseClose 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := block_start_idx) (steps := ϑ) + h_destIdx h_destIdx_le (f := f_block_start) + · simp only [h_close, ↓reduceDIte] + rw [fiberwiseDisagreementSet_steps_zero_eq_disagreementSet] + · simp only [h_close, ↓reduceDIte] + rw [fiberwiseClose_steps_zero_iff_UDRClose] + end SoundnessTools diff --git a/ArkLib/ProofSystem/Binius/BinaryBasefold/CoreInteractionPhase.lean b/ArkLib/ProofSystem/Binius/BinaryBasefold/CoreInteractionPhase.lean index 425da1518..d8d869c3c 100644 --- a/ArkLib/ProofSystem/Binius/BinaryBasefold/CoreInteractionPhase.lean +++ b/ArkLib/ProofSystem/Binius/BinaryBasefold/CoreInteractionPhase.lean @@ -6,6 +6,8 @@ Authors: Chung Thai Nguyen, Quang Dao import ArkLib.ProofSystem.Binius.BinaryBasefold.Steps import ArkLib.OracleReduction.Cast +import ArkLib.OracleReduction.Composition.Sequential.General +import ArkLib.OracleReduction.ProtocolSpec.SeqCompose -- Note: should filter errors when doing compilation @@ -62,44 +64,51 @@ variable {Context : Type} {mp : SumcheckMultiplierParam L ℓ Context} -- Sumche /-! ### Helper Lemmas for Fin Equality and Type Congruence -/ +/-! Fin equality for 0 * ϑ = 0 -/ omit [NeZero ℓ] [NeZero ϑ] hdiv in -/-- Fin equality for 0 * ϑ = 0 -/ lemma fin_zero_mul_eq (h : 0 * ϑ < ℓ + 1) : (⟨0 * ϑ, h⟩ : Fin (ℓ + 1)) = 0 := by ext; simp only [zero_mul, Fin.coe_ofNat_eq_mod, Nat.zero_mod] +/-! Statement equality from Fin equality -/ omit [Field L] [Fintype L] [DecidableEq L] [CharP L 2] [SampleableType L] [NeZero ℓ] in -/-- Statement equality from Fin equality -/ lemma Statement.of_fin_eq {i j : Fin (ℓ + 1)} (h : i = j) : Statement (L := L) (ℓ := ℓ) Context i = Statement (L := L) (ℓ := ℓ) Context j := by subst h; rfl +/-! OracleStatement index type equality from Fin equality -/ omit [NeZero ℓ] [NeZero ϑ] hdiv in -/-- OracleStatement index type equality from Fin equality -/ lemma OracleStatement.idx_eq {i j : Fin (ℓ + 1)} (h : i = j) : Fin (toOutCodewordsCount ℓ ϑ i) = Fin (toOutCodewordsCount ℓ ϑ j) := by subst h; rfl +/-! OracleStatement function HEq from Fin equality -/ omit [CharP L 2] [SampleableType L] [DecidableEq 𝔽q] hF₂ h_β₀_eq_1 [NeZero 𝓡] hdiv in -/-- OracleStatement function HEq from Fin equality -/ lemma OracleStatement.heq_of_fin_eq {i j : Fin (ℓ + 1)} (h : i = j) : HEq (fun k => OracleStatement 𝔽q β ϑ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i k) (fun k => OracleStatement 𝔽q β ϑ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) j k) := by subst h; rfl +/-! Witness equality from Fin equality -/ omit [CharP L 2] [SampleableType L] [DecidableEq 𝔽q] hF₂ h_β₀_eq_1 [NeZero ℓ] [NeZero 𝓡] in -/-- Witness equality from Fin equality -/ lemma Witness.of_fin_eq {i j : Fin (ℓ + 1)} (h : i = j) : Witness (L := L) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ := ℓ) i = Witness (L := L) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ := ℓ) j := by subst h; rfl +/-! Relation equality from Fin equality -/ omit [CharP L 2] [SampleableType L] [DecidableEq 𝔽q] h_β₀_eq_1 in -/-- Relation equality from Fin equality -/ lemma strictRoundRelation.of_fin_eq {i j : Fin (ℓ + 1)} (h : i = j) : strictRoundRelation (mp := mp) 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) i ≍ strictRoundRelation (mp := mp) 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) j := by subst h; rfl +/-! Round relation equality from Fin equality -/ +omit [CharP L 2] [SampleableType L] in +lemma roundRelation.of_fin_eq {i j : Fin (ℓ + 1)} (h : i = j) : + roundRelation (mp := mp) 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) i ≍ + roundRelation (mp := mp) 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) j := by + subst h; rfl + section FoldRelayRound -- foldRound + relay @[reducible] @@ -137,7 +146,7 @@ def foldRelayOracleReduction (i : Fin ℓ) variable {σ : Type} {init : ProbComp σ} {impl : QueryImpl []ₒ (StateT σ ProbComp)} -/-- Perfect completeness of the non-commitment round reduction follows by append composition +/-! Perfect completeness of the non-commitment round reduction follows by append composition of the fold-round and the transfer-round reductions. -/ theorem foldRelayOracleReduction_perfectCompleteness (hInit : NeverFail init) (i : Fin ℓ) (hNCR : ¬ isCommitmentRound ℓ ϑ i) @@ -158,40 +167,60 @@ theorem foldRelayOracleReduction_perfectCompleteness (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) (mp := mp) (init := init) (impl := impl) hInit i · -- Perfect completeness of relayOracleReduction exact relayOracleReduction_perfectCompleteness (L := L) 𝔽q β (ϑ := ϑ) - (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) (mp := mp) (init := init) (impl := impl) hInit i hNCR - -/-- RBR Knowledge Soundness of the non-commitment round verifier via append composition - of fold-round and transfer-round RBR KS. -/ + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) (mp := mp) + (init := init) (impl := impl) hInit i hNCR + +/-! Flat form of RBR knowledge error for fold+relay: case split on challenge index + instead of Sum.elim. Equal to the append-composed error (see foldRelayKnowledgeError_eq). -/ +def foldRelayKnowledgeError (i : Fin ℓ) + (j : (pSpecFoldRelay (L := L)).ChallengeIdx) : ℝ≥0 := + match ChallengeIdx.sumEquiv.symm j with + | Sum.inl j₁ => foldKnowledgeError 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i j₁ + | Sum.inr j₂ => relayKnowledgeError j₂ + +lemma foldRelayKnowledgeError_eq (i : Fin ℓ) + (j : (pSpecFoldRelay (L := L)).ChallengeIdx) : + foldRelayKnowledgeError 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i j = + Sum.elim (foldKnowledgeError 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i) + relayKnowledgeError (ChallengeIdx.sumEquiv.symm j) := by + unfold foldRelayKnowledgeError + cases ChallengeIdx.sumEquiv.symm j with + | inl _ => rfl + | inr _ => rfl + +/-! RBR KS for Fold+Relay block: append then convert to flat error. -/ theorem foldRelayOracleVerifier_rbrKnowledgeSoundness (i : Fin ℓ) (hNCR : ¬ isCommitmentRound ℓ ϑ i) : - (foldRelayOracleVerifier 𝔽q β (mp := mp) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) i hNCR).rbrKnowledgeSoundness - init impl + (foldRelayOracleVerifier 𝔽q β (mp := mp) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (𝓑:=𝓑) i hNCR).rbrKnowledgeSoundness (init := init) (impl := impl) (relIn := roundRelation 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) i.castSucc (mp := mp)) (relOut := roundRelation 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) i.succ (mp := mp)) - (rbrKnowledgeError := fun m => foldKnowledgeError 𝔽q β (ϑ:=ϑ) - (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i ⟨m, by - match m with - | ⟨0, h0⟩ => nomatch h0 - | ⟨1, h1⟩ => rfl - ⟩) := by - unfold foldRelayOracleVerifier pSpecFoldRelay - suffices h : OracleVerifier.rbrKnowledgeSoundness init impl (roundRelation 𝔽q β i.castSucc) - (roundRelation 𝔽q β i.succ) - ((foldOracleVerifier 𝔽q β i).append (relayOracleVerifier 𝔽q β i hNCR)) - (Sum.elim (foldKnowledgeError 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i) - relayKnowledgeError ∘ ChallengeIdx.sumEquiv.symm) by - convert h using 1 - funext m - simp only [Function.comp, ChallengeIdx.sumEquiv, Equiv.symm] - dsimp - split - · congr 1; ext; simp - · omega - exact OracleVerifier.append_rbrKnowledgeSoundness _ _ - (foldOracleVerifier_rbrKnowledgeSoundness 𝔽q β i) - (relayOracleVerifier_rbrKnowledgeSoundness 𝔽q β i hNCR) + (rbrKnowledgeError := foldRelayKnowledgeError 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i) := by + have hAppend := OracleVerifier.append_rbrKnowledgeSoundness + (init := init) (impl := impl) + (rel₁ := roundRelation (mp := mp) 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) i.castSucc) + (rel₂ := foldStepRelOut (mp := mp) 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) i) + (rel₃ := roundRelation (mp := mp) 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) i.succ) + (V₁ := foldOracleVerifier 𝔽q β (mp := mp) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) i) + (V₂ := relayOracleVerifier 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i hNCR) + (rbrKnowledgeError₁ := foldKnowledgeError 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i) + (rbrKnowledgeError₂ := relayKnowledgeError) + (h₁ := foldOracleVerifier_rbrKnowledgeSoundness (L := L) 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) (mp := mp) (init := init) (impl := impl) i) + (h₂ := relayOracleVerifier_rbrKnowledgeSoundness (L := L) 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) (mp := mp) (init := init) (impl := impl) i hNCR) + exact OracleVerifier.rbrKnowledgeSoundness_of_eq_error + (h_ε := fun j => foldRelayKnowledgeError_eq 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i j) (h := hAppend) end FoldRelayRound -- foldRound + relay @@ -230,7 +259,7 @@ def foldCommitOracleReduction (i : Fin ℓ) variable {σ : Type} {init : ProbComp σ} {impl : QueryImpl []ₒ (StateT σ ProbComp)} -/-- Perfect completeness for Fold+Commitment block by append composition. -/ +/-! Perfect completeness for Fold+Commitment block by append composition. -/ theorem foldCommitOracleReduction_perfectCompleteness (hInit : NeverFail init) (i : Fin ℓ) (hCR : isCommitmentRound ℓ ϑ i) [(i : pSpecFold.ChallengeIdx) → Fintype ((pSpecFold (L := L)).Challenge i)] @@ -254,43 +283,59 @@ theorem foldCommitOracleReduction_perfectCompleteness (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) (mp := mp) (init := init) (impl := impl) hInit i · -- Perfect completeness of commitOracleReduction exact commitOracleReduction_perfectCompleteness (L := L) 𝔽q β (ϑ := ϑ) - (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) (mp := mp) (init := init) (impl := impl) hInit i hCR - -/-- RBR KS for Fold+Commitment block by append composition. -/ + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) (mp := mp) + (init := init) (impl := impl) hInit i hCR + +/-! Flat form of RBR knowledge error for fold+commit: case split on challenge index + instead of Sum.elim. Equal to the append-composed error (see foldCommitKnowledgeError_eq). -/ +def foldCommitKnowledgeError (i : Fin ℓ) + (j : (pSpecFoldCommit 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i).ChallengeIdx) : ℝ≥0 := + match ChallengeIdx.sumEquiv.symm j with + | Sum.inl j₁ => foldKnowledgeError 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i j₁ + | Sum.inr j₂ => commitKnowledgeError 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) j₂ + +lemma foldCommitKnowledgeError_eq (i : Fin ℓ) + (j : (pSpecFoldCommit 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i).ChallengeIdx) : + foldCommitKnowledgeError 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i j = + Sum.elim (foldKnowledgeError 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i) + (commitKnowledgeError 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) + (ChallengeIdx.sumEquiv.symm j) := by + unfold foldCommitKnowledgeError + cases ChallengeIdx.sumEquiv.symm j with + | inl _ => rfl + | inr _ => rfl + +/-! RBR KS for Fold+Commitment block: append then convert to flat error. -/ theorem foldCommitOracleVerifier_rbrKnowledgeSoundness (i : Fin ℓ) (hCR : isCommitmentRound ℓ ϑ i) : - (foldCommitOracleVerifier 𝔽q β (mp := mp) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) i hCR).rbrKnowledgeSoundness - init impl + (foldCommitOracleVerifier 𝔽q β (mp := mp) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (𝓑:=𝓑) i hCR).rbrKnowledgeSoundness (init := init) (impl := impl) (relIn := roundRelation (mp := mp) 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) i.castSucc ) (relOut := roundRelation (mp := mp) 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) i.succ ) - (rbrKnowledgeError := fun _ => foldKnowledgeError 𝔽q β (ϑ:=ϑ) - (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i ⟨1, by rfl⟩ - ) := by - unfold foldCommitOracleVerifier pSpecFoldCommit - have herr : (fun _ => foldKnowledgeError 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) - i ⟨1, by rfl⟩) = - (Sum.elim (foldKnowledgeError 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i) - (commitKnowledgeError 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) ∘ - (ChallengeIdx.sumEquiv (pSpec₁ := pSpecFold (L := L)) - (pSpec₂ := pSpecCommit 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i)).symm) := by - funext m - simp only [Function.comp, ChallengeIdx.sumEquiv, Equiv.symm] - dsimp - split - · simp [foldKnowledgeError] - · next hlt => - exfalso - have hv := m.1.isLt - have hp := m.2 - simp only [ProtocolSpec.append, Fin.vappend_eq_append, Fin.append, Fin.addCases, - Direction.not_P_to_V_eq_V_to_P] at hp - split at hp <;> simp_all <;> omega - rw [herr] - exact OracleVerifier.append_rbrKnowledgeSoundness _ _ - (foldOracleVerifier_rbrKnowledgeSoundness 𝔽q β i) - (commitOracleVerifier_rbrKnowledgeSoundness 𝔽q β i hCR) + (rbrKnowledgeError := foldCommitKnowledgeError 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i) := by + have hAppend := OracleVerifier.append_rbrKnowledgeSoundness + (init := init) (impl := impl) + (rel₁ := roundRelation (mp := mp) 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) i.castSucc) + (rel₂ := foldStepRelOut (mp := mp) 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) i) + (rel₃ := roundRelation (mp := mp) 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) i.succ) + (V₁ := foldOracleVerifier 𝔽q β (mp := mp) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) i) + (V₂ := commitOracleVerifier 𝔽q β (mp := mp) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) i hCR) + (rbrKnowledgeError₁ := foldKnowledgeError 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i) + (rbrKnowledgeError₂ := commitKnowledgeError 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) + (h₁ := foldOracleVerifier_rbrKnowledgeSoundness (L := L) 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) (mp := mp) (init := init) (impl := impl) i) + (h₂ := commitOracleVerifier_rbrKnowledgeSoundness (L := L) 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) (mp := mp) (init := init) (impl := impl) i hCR) + exact OracleVerifier.rbrKnowledgeSoundness_of_eq_error + (h_ε := fun j => foldCommitKnowledgeError_eq (i := i) (j := j)) (h := hAppend) end FoldCommitRound @@ -316,7 +361,8 @@ def nonLastSingleBlockOracleVerifier (bIdx : Fin (ℓ / ϑ - 1)) := (V := fun i => by have hNCR : ¬ isCommitmentRound ℓ ϑ ⟨bIdx * ϑ + i, bIdx_mul_ϑ_add_i_fin_ℓ_pred_lt_ℓ bIdx i⟩ := isNeCommitmentRound (r:=r) (ℓ:=ℓ) (𝓡:=𝓡) (ϑ:=ϑ) bIdx (x:=i.val) (hx:=by omega) - exact foldRelayOracleVerifier (L:=L) 𝔽q β (mp := mp) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) + exact foldRelayOracleVerifier (L:=L) 𝔽q β (mp := mp) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) ⟨bIdx * ϑ + i, bIdx_mul_ϑ_add_i_fin_ℓ_pred_lt_ℓ bIdx i⟩ hNCR ) let h1 : ↑bIdx * ϑ + (ϑ - 1) < ℓ := by @@ -328,10 +374,9 @@ def nonLastSingleBlockOracleVerifier (bIdx : Fin (ℓ / ϑ - 1)) := change ↑bIdx * ϑ + fv.val < ℓ + 0 apply bIdx_mul_ϑ_add_i_lt_ℓ_succ let h1_succ : ↑bIdx * ϑ + (ϑ - 1) < ℓ + 1 := by omega - - let lastOracleVerifier := foldCommitOracleVerifier 𝔽q β (mp := mp) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) + let lastOracleVerifier := foldCommitOracleVerifier 𝔽q β (mp := mp) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) (i := ⟨bIdx * ϑ + (ϑ - 1), h1⟩) (hCR:=isCommitmentRoundOfNonLastBlock (𝓡:=𝓡) (r:=r) bIdx) - let nonLastSingleBlockOracleVerifier := OracleVerifier.append (oSpec:=[]ₒ) (Stmt₁:=Statement (L := L) (ℓ := ℓ) Context ⟨bIdx * ϑ, by @@ -344,7 +389,8 @@ def nonLastSingleBlockOracleVerifier (bIdx : Fin (ℓ / ϑ - 1)) := (Stmt₃:=Statement (L := L) Context ⟨(bIdx + 1) * ϑ, bIdx_succ_mul_ϑ_lt_ℓ_succ bIdx⟩) (OStmt₁:=OracleStatement 𝔽q β ϑ ⟨bIdx * ϑ, Nat.lt_of_add_right_lt h1_succ⟩) (OStmt₂:=OracleStatement 𝔽q β ϑ ⟨bIdx * ϑ + (ϑ - 1), h1_succ⟩) - (OStmt₃:=OracleStatement 𝔽q β ϑ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (⟨(bIdx + 1) * ϑ, bIdx_succ_mul_ϑ_lt_ℓ_succ bIdx⟩ : Fin (ℓ+1))) + (OStmt₃:=OracleStatement 𝔽q β ϑ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (⟨(bIdx + 1) * ϑ, bIdx_succ_mul_ϑ_lt_ℓ_succ bIdx⟩ : Fin (ℓ+1))) (pSpec₁:=pSpecFoldRelaySequence (L:=L) (n:=ϑ - 1)) (pSpec₂:=pSpecFoldCommit 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ⟨bIdx * ϑ + (ϑ - 1), h1⟩) (V₁:= firstFoldRelayRoundsOracleVerifier.castOutSimple (h_stmt := by rfl) (h_ostmt := by rfl)) @@ -354,44 +400,50 @@ def nonLastSingleBlockOracleVerifier (bIdx : Fin (ℓ / ϑ - 1)) := (StmtOut₁ := Statement Context (⟨↑bIdx * ϑ + (ϑ - 1), h1⟩ : Fin ℓ).succ) (StmtOut₂ := Statement (L := L) Context ⟨(bIdx + 1) * ϑ, bIdx_succ_mul_ϑ_lt_ℓ_succ bIdx⟩) (OStmtIn₁ := (OracleStatement 𝔽q β ϑ (⟨↑bIdx * ϑ + (ϑ - 1), h1⟩ : Fin ℓ).castSucc)) - (OStmtIn₂ := OracleStatement 𝔽q β ϑ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ⟨bIdx * ϑ + (ϑ - 1), h1_succ⟩) + (OStmtIn₂ := OracleStatement 𝔽q β ϑ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + ⟨bIdx * ϑ + (ϑ - 1), h1_succ⟩) (OStmtOut₁ := OracleStatement 𝔽q β ϑ (⟨↑bIdx * ϑ + (ϑ - 1), h1⟩ : Fin ℓ).succ) - (OStmtOut₂ := OracleStatement 𝔽q β ϑ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (⟨(bIdx + 1) * ϑ, bIdx_succ_mul_ϑ_lt_ℓ_succ bIdx⟩ : Fin (ℓ+1))) + (OStmtOut₂ := OracleStatement 𝔽q β ϑ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + ⟨(bIdx + 1) * ϑ, bIdx_succ_mul_ϑ_lt_ℓ_succ bIdx⟩) (pSpec := pSpecFoldCommit 𝔽q β ⟨↑bIdx * ϑ + (ϑ - 1), h1⟩) (h_stmtIn := by apply Statement.of_fin_eq - simp [Fin.castSucc, Fin.eta]) + simp? [Fin.castSucc, Fin.eta]) (h_stmtOut := by apply Statement.of_fin_eq - ext; simp [Fin.val_succ] - rw [Nat.add_assoc, Nat.sub_add_cancel (by exact NeZero.one_le), Nat.add_mul, Nat.one_mul]) + ext; simp? [Fin.val_succ] + rw [Nat.add_assoc, Nat.sub_add_cancel (by exact NeZero.one_le), + Nat.add_mul, Nat.one_mul]) (h_idxIn := by apply OracleStatement.idx_eq - simp [Fin.castSucc, Fin.eta]) + simp? [Fin.castSucc, Fin.eta]) (h_idxOut := by apply OracleStatement.idx_eq - ext; simp [Fin.val_succ] - rw [Nat.add_assoc, Nat.sub_add_cancel (by exact NeZero.one_le), Nat.add_mul, Nat.one_mul]) + ext; simp? [Fin.val_succ] + rw [Nat.add_assoc, Nat.sub_add_cancel (by exact NeZero.one_le), + Nat.add_mul, Nat.one_mul]) (h_ostmtIn := by apply OracleStatement.heq_of_fin_eq - simp [Fin.castSucc, Fin.eta]) + simp? [Fin.castSucc, Fin.eta]) (h_ostmtOut := by apply OracleStatement.heq_of_fin_eq - ext; simp [Fin.val_succ] - rw [Nat.add_assoc, Nat.sub_add_cancel (by exact NeZero.one_le), Nat.add_mul, Nat.one_mul]) + ext; simp only [Fin.succ_mk] + rw [Nat.add_assoc, Nat.sub_add_cancel (by exact NeZero.one_le), + Nat.add_mul, Nat.one_mul]) (h_Oₛᵢ := by apply instOracleStatementBinaryBasefold_heq_of_fin_eq ext; simp only [Fin.castSucc, Fin.castAdd_mk]) ) - nonLastSingleBlockOracleVerifier def nonLastBlocksOracleVerifier : OracleVerifier []ₒ (StmtIn := Statement (L := L) (ℓ := ℓ) Context ⟨0 * ϑ, by omega⟩) (OStmtIn := OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ ⟨0 * ϑ, by omega⟩) - (StmtOut := Statement (L := L) (ℓ := ℓ) Context ⟨(ℓ / ϑ - 1) * ϑ, by apply lastBlockIdx_mul_ϑ_add_x_lt_ℓ_succ (x:=0) (hx:=by omega)⟩) - (OStmtOut := OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ ⟨(ℓ / ϑ - 1) * ϑ, by apply lastBlockIdx_mul_ϑ_add_x_lt_ℓ_succ (x:=0) (hx:=by omega)⟩) + (StmtOut := Statement (L := L) (ℓ := ℓ) Context ⟨(ℓ / ϑ - 1) * ϑ, by + apply lastBlockIdx_mul_ϑ_add_x_lt_ℓ_succ (x:=0) (hx:=by omega)⟩) + (OStmtOut := OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ ⟨(ℓ / ϑ - 1) * ϑ, by + apply lastBlockIdx_mul_ϑ_add_x_lt_ℓ_succ (x:=0) (hx:=by omega)⟩) (pSpec := pSpecNonLastBlocks 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) := let stmt : Fin (ℓ / ϑ - 1 + 1) → Type := fun i => Statement (L := L) (ℓ := ℓ) Context ⟨i * ϑ, blockIdx_mul_ϑ_lt_ℓ_succ i⟩ @@ -426,8 +478,8 @@ def lastBlockOracleVerifier := (V := fun i => by have hNCR : ¬ isCommitmentRound ℓ ϑ ⟨bIdx * ϑ + i, lastBlockIdx_mul_ϑ_add_fin_lt_ℓ i⟩ := lastBlockIdx_isNeCommitmentRound i - exact foldRelayOracleVerifier (L:=L) 𝔽q β (mp := mp) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) - ⟨bIdx * ϑ + i, lastBlockIdx_mul_ϑ_add_fin_lt_ℓ i⟩ hNCR + exact foldRelayOracleVerifier (L:=L) 𝔽q β (mp := mp) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (𝓑:=𝓑) ⟨bIdx * ϑ + i, lastBlockIdx_mul_ϑ_add_fin_lt_ℓ i⟩ hNCR ) exact OracleVerifier.castInOut (V := cur) (StmtIn₂ := Statement (L := L) (ℓ := ℓ) Context ⟨bIdx * ϑ, by @@ -442,11 +494,11 @@ def lastBlockOracleVerifier := (pSpec := pSpecLastBlock (L:=L) (ϑ:=ϑ)) (h_stmtIn := by apply Statement.of_fin_eq - ext; simp) + ext; simp?) (h_stmtOut := by apply Statement.of_fin_eq ext - simp [Fin.val_last] + simp? [Fin.val_last] have : bIdx * ϑ + ϑ = ℓ := by have h_div : ϑ ∣ ℓ := hdiv.out have h_mod : ℓ % ϑ = 0 := Nat.mod_eq_zero_of_dvd h_div @@ -456,11 +508,11 @@ def lastBlockOracleVerifier := simp only [this]) (h_idxIn := by apply OracleStatement.idx_eq - ext; simp) + ext; simp?) (h_idxOut := by apply OracleStatement.idx_eq ext - simp [Fin.val_last] + simp? [Fin.val_last] have : bIdx * ϑ + ϑ = ℓ := by have h_div : ϑ ∣ ℓ := hdiv.out have h_mod : ℓ % ϑ = 0 := Nat.mod_eq_zero_of_dvd h_div @@ -470,11 +522,11 @@ def lastBlockOracleVerifier := simp only [this]) (h_ostmtIn := by apply OracleStatement.heq_of_fin_eq - ext; simp) + ext; simp?) (h_ostmtOut := by apply OracleStatement.heq_of_fin_eq ext - simp [Fin.val_last] + simp only [Fin.val_last] have : bIdx * ϑ + ϑ = ℓ := by have h_div : ϑ ∣ ℓ := hdiv.out have h_mod : ℓ % ϑ = 0 := Nat.mod_eq_zero_of_dvd h_div @@ -485,15 +537,14 @@ def lastBlockOracleVerifier := (h_Oₛᵢ := by apply instOracleStatementBinaryBasefold_heq_of_fin_eq ext; simp only [Fin.coe_ofNat_eq_mod, Nat.zero_mod, add_zero]) - V @[reducible] def sumcheckFoldOracleVerifier := - let nonLastBlocksOracleVerifier := nonLastBlocksOracleVerifier (L:=L) 𝔽q β (mp := mp) (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) - - let lastOracleVerifier := lastBlockOracleVerifier 𝔽q β (mp := mp) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) - + let nonLastBlocksOracleVerifier := nonLastBlocksOracleVerifier (L := L) + 𝔽q β (mp := mp) (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) + let lastOracleVerifier := lastBlockOracleVerifier 𝔽q β (mp := mp) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) let sumcheckFoldOV : OracleVerifier []ₒ (StmtIn := Statement (L := L) (ℓ := ℓ) Context 0) (OStmtIn := OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ 0) @@ -520,7 +571,6 @@ def sumcheckFoldOracleVerifier := (h_Oₛᵢ := by apply instOracleStatementBinaryBasefold_heq_of_fin_eq ext; simp only [zero_mul, Fin.coe_ofNat_eq_mod, Nat.zero_mod]) - sumcheckFoldOV end composedOracleVerifiers @@ -529,7 +579,8 @@ section composedOracleRedutions def nonLastSingleBlockOracleReduction (bIdx : Fin (ℓ / ϑ - 1)) := let stmt : Fin (ϑ - 1 + 1) → Type := - fun i => Statement (L := L) (ℓ := ℓ) Context ⟨bIdx * ϑ + i, bIdx_mul_ϑ_add_i_cast_lt_ℓ_succ bIdx i⟩ + fun i => Statement (L := L) (ℓ := ℓ) Context + ⟨bIdx * ϑ + i, bIdx_mul_ϑ_add_i_cast_lt_ℓ_succ bIdx i⟩ let oStmt := fun i: Fin (ϑ - 1 + 1) => OracleStatement 𝔽q β ϑ ⟨bIdx * ϑ + i, bIdx_mul_ϑ_add_i_cast_lt_ℓ_succ bIdx i⟩ let wit := fun i: Fin (ϑ - 1 + 1) => @@ -547,7 +598,6 @@ def nonLastSingleBlockOracleReduction (bIdx : Fin (ℓ / ϑ - 1)) := exact foldRelayOracleReduction (L:=L) 𝔽q β (mp := mp) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) (i:=⟨bIdx * ϑ + i, bIdx_mul_ϑ_add_i_fin_ℓ_pred_lt_ℓ bIdx i⟩) hNCR ) - let h1 : ↑bIdx * ϑ + (ϑ - 1) < ℓ := by let fv: Fin ϑ := ⟨ϑ - 1, by have h := NeZero.one_le (n:=ϑ) @@ -557,11 +607,9 @@ def nonLastSingleBlockOracleReduction (bIdx : Fin (ℓ / ϑ - 1)) := change ↑bIdx * ϑ + fv.val < ℓ + 0 apply bIdx_mul_ϑ_add_i_lt_ℓ_succ let h1_succ : ↑bIdx * ϑ + (ϑ - 1) < ℓ + 1 := by omega - let lastOracleReduction := foldCommitOracleReduction 𝔽q β (mp := mp) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) (i := ⟨bIdx * ϑ + (ϑ - 1), h1⟩) (hCR:=isCommitmentRoundOfNonLastBlock (𝓡:=𝓡) (r:=r) bIdx) - let nonLastSingleBlockOracleReduction := OracleReduction.append (oSpec:=[]ₒ) (Stmt₁:=Statement (L := L) (ℓ := ℓ) Context ⟨bIdx * ϑ, by @@ -592,7 +640,8 @@ def nonLastSingleBlockOracleReduction (bIdx : Fin (ℓ / ϑ - 1)) := (OStmt₃:=OracleStatement 𝔽q β ϑ ⟨(bIdx + 1) * ϑ, bIdx_succ_mul_ϑ_lt_ℓ_succ bIdx⟩) (pSpec₁:=pSpecFoldRelaySequence (L:=L) (n:=ϑ - 1)) (pSpec₂:=pSpecFoldCommit 𝔽q β ⟨bIdx * ϑ + (ϑ - 1), h1⟩) - (R₁:=firstFoldRelayRoundsOracleReduction.castOutSimple (h_stmt := by rfl) (h_ostmt := by rfl) (h_wit := by rfl) + (R₁:=firstFoldRelayRoundsOracleReduction.castOutSimple (h_stmt := by rfl) (h_ostmt := by rfl) + (h_wit := by rfl) ) (R₂:= OracleReduction.castInOut (R := lastOracleReduction) (StmtIn₁ := (Statement Context (⟨↑bIdx * ϑ + (ϑ - 1), h1⟩ : Fin ℓ).castSucc)) @@ -600,43 +649,44 @@ def nonLastSingleBlockOracleReduction (bIdx : Fin (ℓ / ϑ - 1)) := (StmtOut₁ := Statement Context (⟨↑bIdx * ϑ + (ϑ - 1), h1⟩ : Fin ℓ).succ) (StmtOut₂ := Statement (L := L) Context ⟨(bIdx + 1) * ϑ, bIdx_succ_mul_ϑ_lt_ℓ_succ bIdx⟩) (OStmtIn₁ := (OracleStatement 𝔽q β ϑ (⟨↑bIdx * ϑ + (ϑ - 1), h1⟩ : Fin ℓ).castSucc)) - (OStmtIn₂ := OracleStatement 𝔽q β ϑ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ⟨bIdx * ϑ + (ϑ - 1), h1_succ⟩) + (OStmtIn₂ := OracleStatement 𝔽q β ϑ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + ⟨bIdx * ϑ + (ϑ - 1), h1_succ⟩) (OStmtOut₁ := OracleStatement 𝔽q β ϑ (⟨↑bIdx * ϑ + (ϑ - 1), h1⟩ : Fin ℓ).succ) - (OStmtOut₂ := OracleStatement 𝔽q β ϑ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (⟨(bIdx + 1) * ϑ, bIdx_succ_mul_ϑ_lt_ℓ_succ bIdx⟩ : Fin (ℓ+1))) + (OStmtOut₂ := OracleStatement 𝔽q β ϑ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (⟨(bIdx + 1) * ϑ, bIdx_succ_mul_ϑ_lt_ℓ_succ bIdx⟩ : Fin (ℓ+1))) (pSpec := pSpecFoldCommit 𝔽q β ⟨↑bIdx * ϑ + (ϑ - 1), h1⟩) (h_stmtIn := by apply Statement.of_fin_eq - simp [Fin.castSucc, Fin.eta]) + simp only [Fin.castSucc, Fin.castAdd_mk]) (h_stmtOut := by apply Statement.of_fin_eq - ext; simp [Fin.val_succ] + ext; simp only [Fin.succ_mk] rw [Nat.add_assoc, Nat.sub_add_cancel (by exact NeZero.one_le), Nat.add_mul, Nat.one_mul]) (h_idxIn := by apply OracleStatement.idx_eq - simp [Fin.castSucc, Fin.eta]) + simp only [Fin.castSucc, Fin.castAdd_mk]) (h_idxOut := by apply OracleStatement.idx_eq - ext; simp [Fin.val_succ] + ext; simp only [Fin.succ_mk] rw [Nat.add_assoc, Nat.sub_add_cancel (by exact NeZero.one_le), Nat.add_mul, Nat.one_mul]) (h_ostmtIn := by apply OracleStatement.heq_of_fin_eq - simp [Fin.castSucc, Fin.eta]) + simp only [Fin.castSucc, Fin.castAdd_mk]) (h_ostmtOut := by apply OracleStatement.heq_of_fin_eq - ext; simp [Fin.val_succ] + ext; simp only [Fin.succ_mk] rw [Nat.add_assoc, Nat.sub_add_cancel (by exact NeZero.one_le), Nat.add_mul, Nat.one_mul]) (h_witIn := by apply Witness.of_fin_eq - simp [Fin.castSucc, Fin.eta]) + simp only [Fin.castSucc, Fin.castAdd_mk]) (h_witOut := by apply Witness.of_fin_eq - ext; simp [Fin.val_succ] + ext; simp only [Fin.succ_mk] rw [Nat.add_assoc, Nat.sub_add_cancel (by exact NeZero.one_le), Nat.add_mul, Nat.one_mul]) (h_Oₛᵢ := by apply instOracleStatementBinaryBasefold_heq_of_fin_eq - ext; simp only [Fin.castSucc, Fin.castAdd_mk, Nat.add_left_cancel_iff]) + ext; simp only [Fin.castSucc, Fin.castAdd_mk]) ) - nonLastSingleBlockOracleReduction def nonLastBlocksOracleReduction : @@ -644,10 +694,12 @@ def nonLastBlocksOracleReduction : (StmtIn := Statement (L := L) (ℓ := ℓ) Context ⟨0 * ϑ, by omega⟩) (OStmtIn := OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ ⟨0 * ϑ, by omega⟩) (WitIn := Witness (L := L) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ:=ℓ) ⟨0 * ϑ, by omega⟩) - - (StmtOut := Statement (L := L) (ℓ:=ℓ) Context ⟨(ℓ / ϑ - 1) * ϑ, by apply lastBlockIdx_mul_ϑ_add_x_lt_ℓ_succ (x:=0) (hx:=by omega)⟩) - (OStmtOut := OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ ⟨(ℓ / ϑ - 1) *ϑ, by apply lastBlockIdx_mul_ϑ_add_x_lt_ℓ_succ (x:=0) (hx:=by omega)⟩) - (WitOut := Witness (L := L) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ:=ℓ) ⟨(ℓ / ϑ - 1) * ϑ, by apply lastBlockIdx_mul_ϑ_add_x_lt_ℓ_succ (x:=0) (hx:=by omega)⟩) + (StmtOut := Statement (L := L) (ℓ:=ℓ) Context ⟨(ℓ / ϑ - 1) * ϑ, by + apply lastBlockIdx_mul_ϑ_add_x_lt_ℓ_succ (x:=0) (hx:=by omega)⟩) + (OStmtOut := OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ ⟨(ℓ / ϑ - 1) *ϑ, by + apply lastBlockIdx_mul_ϑ_add_x_lt_ℓ_succ (x:=0) (hx:=by omega)⟩) + (WitOut := Witness (L := L) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ:=ℓ) ⟨(ℓ / ϑ - 1) * ϑ, by + apply lastBlockIdx_mul_ϑ_add_x_lt_ℓ_succ (x:=0) (hx:=by omega)⟩) (pSpec := pSpecNonLastBlocks 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) := let stmt : Fin (ℓ / ϑ - 1 + 1) → Type := fun i => Statement (L := L) (ℓ := ℓ) Context ⟨i * ϑ, blockIdx_mul_ϑ_lt_ℓ_succ i⟩ @@ -665,7 +717,7 @@ def nonLastBlocksOracleReduction : res def lastBlockOracleReduction := - have h_le: ϑ ≤ ℓ := by apply Nat.le_of_dvd (by exact Nat.pos_of_neZero ℓ); exact hdiv.out + have h_le : ϑ ≤ ℓ := by apply Nat.le_of_dvd (by exact Nat.pos_of_neZero ℓ); exact hdiv.out let bIdx := ℓ / ϑ - 1 let stmt : Fin (ϑ + 1) → Type := fun i => Statement (L := L) (ℓ := ℓ) Context ⟨bIdx * ϑ + i, by apply lastBlockIdx_mul_ϑ_add_x_lt_ℓ_succ (hx:=by omega)⟩ @@ -707,11 +759,11 @@ def lastBlockOracleReduction := (pSpec := pSpecLastBlock (L:=L) (ϑ:=ϑ)) (h_stmtIn := by apply Statement.of_fin_eq - ext; simp) + ext; simp only [Fin.coe_ofNat_eq_mod, Nat.zero_mod, add_zero]) (h_stmtOut := by apply Statement.of_fin_eq ext - simp [Fin.val_last] + simp only [Fin.val_last] have : bIdx * ϑ + ϑ = ℓ := by have h_div : ϑ ∣ ℓ := hdiv.out have h_mod : ℓ % ϑ = 0 := Nat.mod_eq_zero_of_dvd h_div @@ -719,14 +771,14 @@ def lastBlockOracleReduction := dsimp only [bIdx]; rw [Nat.sub_mul, h_mul, Nat.one_mul] omega - simp [this]) + simp only [this]) (h_idxIn := by apply OracleStatement.idx_eq - ext; simp) + ext; simp only [Fin.coe_ofNat_eq_mod, Nat.zero_mod, add_zero]) (h_idxOut := by apply OracleStatement.idx_eq ext - simp [Fin.val_last] + simp only [Fin.val_last] have : bIdx * ϑ + ϑ = ℓ := by have h_div : ϑ ∣ ℓ := hdiv.out have h_mod : ℓ % ϑ = 0 := Nat.mod_eq_zero_of_dvd h_div @@ -734,40 +786,39 @@ def lastBlockOracleReduction := dsimp only [bIdx] rw [Nat.sub_mul, h_mul, Nat.one_mul] omega - simp [this]) + simp only [this]) (h_ostmtIn := by apply OracleStatement.heq_of_fin_eq - ext; simp) + ext; simp only [Fin.coe_ofNat_eq_mod, Nat.zero_mod, add_zero]) (h_ostmtOut := by apply OracleStatement.heq_of_fin_eq ext - simp [Fin.val_last] + simp only [Fin.val_last] have : bIdx * ϑ + ϑ = ℓ := by have h_div : ϑ ∣ ℓ := hdiv.out have h_mod : ℓ % ϑ = 0 := Nat.mod_eq_zero_of_dvd h_div have h_mul : ℓ / ϑ * ϑ = ℓ := Nat.div_mul_cancel (Nat.dvd_of_mod_eq_zero h_mod) - dsimp [bIdx] + dsimp only [bIdx] rw [Nat.sub_mul, h_mul, Nat.one_mul] omega - simp [this]) + simp only [this]) (h_witIn := by apply Witness.of_fin_eq - ext; simp) + ext; simp only [Fin.coe_ofNat_eq_mod, Nat.zero_mod, add_zero]) (h_witOut := by apply Witness.of_fin_eq ext - simp [Fin.val_last] + simp only [Fin.val_last] have : bIdx * ϑ + ϑ = ℓ := by have h_div : ϑ ∣ ℓ := hdiv.out have h_mod : ℓ % ϑ = 0 := Nat.mod_eq_zero_of_dvd h_div have h_mul : ℓ / ϑ * ϑ = ℓ := Nat.div_mul_cancel (Nat.dvd_of_mod_eq_zero h_mod) rw [Nat.sub_mul, h_mul, Nat.one_mul] omega - simp [this]) + simp only [this]) (h_Oₛᵢ := by apply instOracleStatementBinaryBasefold_heq_of_fin_eq ext; simp only [Fin.coe_ofNat_eq_mod, Nat.zero_mod, add_zero]) - V def sumcheckFoldOracleReduction : OracleReduction []ₒ @@ -785,10 +836,10 @@ def sumcheckFoldOracleReduction : OracleReduction []ₒ let wit := fun i: Fin (ℓ / ϑ - 1 + 1) => Witness (L := L) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ:=ℓ) ⟨i * ϑ, blockIdx_mul_ϑ_lt_ℓ_succ i⟩ - let nonLastSingleBlockOracleReduction := nonLastBlocksOracleReduction (L:=L) 𝔽q β (mp := mp) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) (ϑ := ϑ) - - let lastOracleReduction := lastBlockOracleReduction 𝔽q β (mp := mp) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) - + let nonLastSingleBlockOracleReduction := nonLastBlocksOracleReduction (L:=L) 𝔽q β (mp := mp) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) (ϑ := ϑ) + let lastOracleReduction := lastBlockOracleReduction 𝔽q β (mp := mp) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) (OracleReduction.append (oSpec:=[]ₒ) (pSpec₁ := pSpecNonLastBlocks 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) (pSpec₂ := pSpecLastBlock (L:=L) (ϑ:=ϑ)) @@ -822,7 +873,7 @@ section SecurityProps variable {σ : Type} {init : ProbComp σ} {impl : QueryImpl []ₒ (StateT σ ProbComp)} -/-- Perfect completeness for a single non-last block -/ +/-! Perfect completeness for a single non-last block -/ lemma nonLastSingleBlockOracleReduction_perfectCompleteness (hInit : NeverFail init) (bIdx : Fin (ℓ / ϑ - 1)) : OracleReduction.perfectCompleteness (init := init) (impl := impl) @@ -853,7 +904,6 @@ lemma nonLastSingleBlockOracleReduction_perfectCompleteness apply bIdx_mul_ϑ_add_i_lt_ℓ_succ (m:=1) ⟩) (impl := impl) (init := init) - · -- Perfect completeness of the fold+relay sequence part (`seqCompose`), output-cast is rfl apply OracleReduction.castInOut_perfectCompleteness (h_stmtIn := by @@ -880,15 +930,12 @@ lemma nonLastSingleBlockOracleReduction_perfectCompleteness simp only [Fin.coe_ofNat_eq_mod, Nat.zero_mod, add_zero]) (h_relOut := by rfl) (impl := impl) (init := init) - -- ⊢ OracleReduction.perfectCompleteness init impl (strictRoundRelation 𝔽q β ⟨↑bIdx * ϑ + ↑0, ⋯⟩) - -- (strictRoundRelation 𝔽q β ⟨↑bIdx * ϑ + (ϑ - 1), ⋯⟩) - -- (OracleReduction.seqCompose (fun i ↦ Statement Context ⟨↑bIdx * ϑ + ↑i, ⋯⟩) - -- (fun i ↦ OracleStatement 𝔽q β ϑ ⟨↑bIdx * ϑ + ↑i, ⋯⟩) (fun i ↦ Witness 𝔽q β ⟨↑bIdx * ϑ + ↑i, ⋯⟩) fun i ↦ - -- foldRelayOracleReduction 𝔽q β ⟨↑bIdx * ϑ + ↑i, ⋯⟩ ⋯) let stmt : Fin (ϑ - 1 + 1) → Type := - fun i => Statement (L := L) (ℓ := ℓ) Context ⟨bIdx * ϑ + i, bIdx_mul_ϑ_add_i_cast_lt_ℓ_succ bIdx i⟩ + fun i => Statement (L := L) (ℓ := ℓ) Context + ⟨bIdx * ϑ + i, bIdx_mul_ϑ_add_i_cast_lt_ℓ_succ bIdx i⟩ let oStmt := fun i: Fin (ϑ - 1 + 1) => - OracleStatement 𝔽q β ϑ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ⟨bIdx * ϑ + i, bIdx_mul_ϑ_add_i_cast_lt_ℓ_succ bIdx i⟩ + OracleStatement 𝔽q β ϑ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + ⟨bIdx * ϑ + i, bIdx_mul_ϑ_add_i_cast_lt_ℓ_succ bIdx i⟩ let wit := fun i: Fin (ϑ - 1 + 1) => Witness (L := L) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ:=ℓ) ⟨bIdx * ϑ + i, bIdx_mul_ϑ_add_i_cast_lt_ℓ_succ bIdx i⟩ @@ -899,7 +946,8 @@ lemma nonLastSingleBlockOracleReduction_perfectCompleteness (OStmt := oStmt) (Wit := wit) (R := fun i => by - have hNCR : ¬ isCommitmentRound ℓ ϑ ⟨bIdx * ϑ + i, bIdx_mul_ϑ_add_i_fin_ℓ_pred_lt_ℓ bIdx i⟩ + have hNCR : ¬ isCommitmentRound ℓ ϑ + ⟨bIdx * ϑ + i, bIdx_mul_ϑ_add_i_fin_ℓ_pred_lt_ℓ bIdx i⟩ := isNeCommitmentRound (r:=r) (ℓ:=ℓ) (𝓡:=𝓡) (ϑ:=ϑ) bIdx (x:=i.val) (hx:=by omega) exact foldRelayOracleReduction (L:=L) 𝔽q β (mp := mp) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) (i:=⟨bIdx * ϑ + i, bIdx_mul_ϑ_add_i_fin_ℓ_pred_lt_ℓ bIdx i⟩) hNCR @@ -954,7 +1002,6 @@ lemma nonLastSingleBlockOracleReduction_perfectCompleteness apply strictRoundRelation.of_fin_eq; simp only [Fin.succ_mk, Fin.mk.injEq, Nat.add_mul]; omega) (impl := impl) (init := init) - let h1 : ↑bIdx * ϑ + (ϑ - 1) < ℓ := by let fv: Fin ϑ := ⟨ϑ - 1, by have h := NeZero.one_le (n:=ϑ) @@ -964,13 +1011,12 @@ lemma nonLastSingleBlockOracleReduction_perfectCompleteness change ↑bIdx * ϑ + fv.val < ℓ + 0 apply bIdx_mul_ϑ_add_i_lt_ℓ_succ let h1_succ : ↑bIdx * ϑ + (ϑ - 1) < ℓ + 1 := by omega - exact foldCommitOracleReduction_perfectCompleteness 𝔽q β (mp := mp) (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) (init := init) (impl := impl) (hCR := isCommitmentRoundOfNonLastBlock (𝓡:=𝓡) (r:=r) bIdx) (i := ⟨bIdx * ϑ + (ϑ - 1), h1⟩) (hInit := hInit) -/-- Perfect completeness for the last block -/ +/-! Perfect completeness for the last block -/ lemma lastBlockOracleReduction_perfectCompleteness (hInit : NeverFail init) : OracleReduction.perfectCompleteness (init := init) (impl := impl) (relIn := strictRoundRelation (mp := mp) 𝔽q β (ϑ:=ϑ) @@ -1026,7 +1072,6 @@ lemma lastBlockOracleReduction_perfectCompleteness (hInit : NeverFail init) : apply Fin.eq_of_val_eq; simp only [Fin.val_last, Nat.sub_mul]; rw [Nat.div_mul_cancel (by exact hdiv.out), Nat.one_mul]; omega) (impl := impl) (init := init) - let bIdx := ℓ / ϑ - 1 let stmt : Fin (ϑ + 1) → Type := fun i => Statement (L := L) (ℓ := ℓ) Context ⟨bIdx * ϑ + i, by apply lastBlockIdx_mul_ϑ_add_x_lt_ℓ_succ (hx:=by omega)⟩ @@ -1034,7 +1079,6 @@ lemma lastBlockOracleReduction_perfectCompleteness (hInit : NeverFail init) : ⟨bIdx * ϑ + i, by apply lastBlockIdx_mul_ϑ_add_x_lt_ℓ_succ (hx:=by omega)⟩ let wit := fun i: Fin (ϑ + 1) => Witness (L := L) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ:=ℓ) ⟨bIdx * ϑ + i, by apply lastBlockIdx_mul_ϑ_add_x_lt_ℓ_succ (hx:=by omega)⟩ - let foldRelayRoundsPerfectCompleteness := OracleReduction.seqCompose_perfectCompleteness (oSpec := []ₒ) (m := ϑ) (Stmt := stmt) @@ -1060,7 +1104,7 @@ lemma lastBlockOracleReduction_perfectCompleteness (hInit : NeverFail init) : (hInit := hInit) (i := ⟨bIdx * ϑ + i, lastBlockIdx_mul_ϑ_add_fin_lt_ℓ i⟩) (hNCR := hNCR) exact res -/-- Perfect completeness for the core interaction oracle reduction -/ +/-! Perfect completeness for the core interaction oracle reduction -/ theorem sumcheckFoldOracleReduction_perfectCompleteness (hInit : NeverFail init) : OracleReduction.perfectCompleteness (pSpec := pSpecSumcheckFold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) @@ -1073,7 +1117,6 @@ theorem sumcheckFoldOracleReduction_perfectCompleteness (hInit : NeverFail init) (init := init) (impl := impl) := by unfold sumcheckFoldOracleReduction pSpecSumcheckFold - let stmt : Fin (ℓ / ϑ - 1 + 1) → Type := fun i => Statement (L := L) (ℓ := ℓ) Context ⟨i * ϑ, blockIdx_mul_ϑ_lt_ℓ_succ i⟩ let oStmt := fun i: Fin (ℓ / ϑ - 1 + 1) => @@ -1081,14 +1124,14 @@ theorem sumcheckFoldOracleReduction_perfectCompleteness (hInit : NeverFail init) let wit := fun i: Fin (ℓ / ϑ - 1 + 1) => Witness (L := L) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ:=ℓ) ⟨i * ϑ, blockIdx_mul_ϑ_lt_ℓ_succ i⟩ - apply OracleReduction.castInOut_perfectCompleteness (pSpec := pSpecSumcheckFold 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) (StmtIn₁ := stmt 0) (StmtIn₂ := Statement (L := L) (ℓ := ℓ) Context 0) (ιₛᵢ₁ := Fin (toOutCodewordsCount ℓ ϑ ⟨0 * ϑ, by omega⟩)) (ιₛᵢ₂ := Fin (toOutCodewordsCount ℓ ϑ 0)) - (OStmtIn₁ := fun i => OracleStatement 𝔽q β ϑ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ⟨0 * ϑ, by omega⟩ i) + (OStmtIn₁ := fun i => OracleStatement 𝔽q β ϑ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + ⟨0 * ϑ, by omega⟩ i) (OStmtIn₂ := fun i => OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ 0 i) (WitIn₁ := wit 0) (WitIn₂ := Witness (L := L) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ:=ℓ) 0) @@ -1100,10 +1143,14 @@ theorem sumcheckFoldOracleReduction_perfectCompleteness (hInit : NeverFail init) (OStmtOut₂ := fun i => OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ (Fin.last ℓ) i) (WitOut₁ := Witness (L := L) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ:=ℓ) (Fin.last ℓ)) (WitOut₂ := Witness (L := L) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ:=ℓ) (Fin.last ℓ)) - (relIn₁ := strictRoundRelation (mp := mp) 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) ⟨0 * ϑ, by omega⟩) - (relIn₂ := strictRoundRelation (mp := mp) 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) 0) - (relOut₁ := strictRoundRelation (mp := mp) 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) (Fin.last ℓ)) - (relOut₂ := strictRoundRelation (mp := mp) 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) (Fin.last ℓ)) + (relIn₁ := strictRoundRelation (mp := mp) 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (𝓑:=𝓑) ⟨0 * ϑ, by omega⟩) + (relIn₂ := strictRoundRelation (mp := mp) 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (𝓑:=𝓑) 0) + (relOut₁ := strictRoundRelation (mp := mp) 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (𝓑:=𝓑) (Fin.last ℓ)) + (relOut₂ := strictRoundRelation (mp := mp) 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (𝓑:=𝓑) (Fin.last ℓ)) (h_stmtIn := by apply Statement.of_fin_eq apply fin_zero_mul_eq) @@ -1130,15 +1177,19 @@ theorem sumcheckFoldOracleReduction_perfectCompleteness (hInit : NeverFail init) (impl := impl) (init := init) apply OracleReduction.append_perfectCompleteness - (rel₁ := strictRoundRelation (mp := mp) 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) ⟨0 * ϑ, by omega⟩) + (rel₁ := strictRoundRelation (mp := mp) 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (𝓑:=𝓑) ⟨0 * ϑ, by omega⟩) (rel₂ := strictRoundRelation (mp := mp) 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) ⟨(ℓ / ϑ - 1) * ϑ, by apply lastBlockIdx_mul_ϑ_add_x_lt_ℓ_succ (x:=0) (hx:=by omega)⟩) - (rel₃ := strictRoundRelation (mp := mp) 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) (Fin.last ℓ)) + (rel₃ := strictRoundRelation (mp := mp) 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (𝓑:=𝓑) (Fin.last ℓ)) (pSpec₁ := pSpecNonLastBlocks 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) (pSpec₂ := pSpecLastBlock (L:=L) (ϑ:=ϑ)) - (R₁ := nonLastBlocksOracleReduction 𝔽q β (ϑ:=ϑ) (mp := mp) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑)) - (R₂ := lastBlockOracleReduction 𝔽q β (ϑ:=ϑ) (mp := mp) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑)) + (R₁ := nonLastBlocksOracleReduction 𝔽q β (ϑ:=ϑ) (mp := mp) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (𝓑:=𝓑)) + (R₂ := lastBlockOracleReduction 𝔽q β (ϑ:=ϑ) (mp := mp) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (𝓑:=𝓑)) (impl := impl) (init := init) · @@ -1148,9 +1199,11 @@ theorem sumcheckFoldOracleReduction_perfectCompleteness (hInit : NeverFail init) (Stmt := fun i : Fin (ℓ / ϑ - 1 + 1) => Statement (L := L) (ℓ := ℓ) Context ⟨i * ϑ, blockIdx_mul_ϑ_lt_ℓ_succ i⟩) (OStmt := fun i : Fin (ℓ / ϑ - 1 + 1) => - OracleStatement 𝔽q β ϑ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ⟨i * ϑ, blockIdx_mul_ϑ_lt_ℓ_succ i⟩) + OracleStatement 𝔽q β ϑ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + ⟨i * ϑ, blockIdx_mul_ϑ_lt_ℓ_succ i⟩) (Wit := fun i : Fin (ℓ / ϑ - 1 + 1) => - Witness (L := L) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ:=ℓ) ⟨i * ϑ, blockIdx_mul_ϑ_lt_ℓ_succ i⟩) + Witness (L := L) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ:=ℓ) + ⟨i * ϑ, blockIdx_mul_ϑ_lt_ℓ_succ i⟩) (rel := fun i => strictRoundRelation (mp := mp) 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) ⟨i * ϑ, blockIdx_mul_ϑ_lt_ℓ_succ i⟩) (pSpec := fun (bIdx: Fin (ℓ / ϑ - 1)) => pSpecFullNonLastBlock 𝔽q β (ϑ:=ϑ) bIdx) @@ -1166,17 +1219,351 @@ theorem sumcheckFoldOracleReduction_perfectCompleteness (hInit : NeverFail init) exact lastBlockOracleReduction_perfectCompleteness 𝔽q β (ϑ:=ϑ) (mp := mp) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) (init := init) (impl := impl) hInit -def NBlockMessages := 2 * (ϑ - 1) + 3 - -def sumcheckFoldKnowledgeError := fun j : (pSpecSumcheckFold 𝔽q β (ϑ:=ϑ) - (h_ℓ_add_R_rate := h_ℓ_add_R_rate)).ChallengeIdx => - if hj: (j.val % NBlockMessages (ϑ:=ϑ)) % 2 = 1 then - foldKnowledgeError 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) - ⟨j / NBlockMessages (ϑ:=ϑ) * ϑ + ((j % NBlockMessages (ϑ:=ϑ)) / 2 + 1), by - sorry⟩ ⟨1, rfl⟩ - else 0 -- this case never happens - -/-- Round-by-round knowledge soundness for the sumcheck fold oracle verifier -/ +/-! RBR knowledge error for last block: seqCompose of foldRelay over ϑ rounds. -/ +def lastBlockRbrKnowledgeError (k : (pSpecLastBlock (L := L) (ϑ := ϑ)).ChallengeIdx) : ℝ≥0 := + let ij := seqComposeChallengeIdxToSigma k + foldRelayKnowledgeError 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + ⟨(ℓ / ϑ - 1) * ϑ + ij.1, lastBlockIdx_mul_ϑ_add_fin_lt_ℓ ij.1⟩ ij.2 + +/-! RBR KS for last block verifier (seqCompose of foldRelay then castInOut). -/ +theorem lastBlockOracleVerifier_rbrKnowledgeSoundness : + OracleVerifier.rbrKnowledgeSoundness init impl + (roundRelation (mp := mp) 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (𝓑:=𝓑) ⟨(ℓ / ϑ - 1) * ϑ, by + apply lastBlockIdx_mul_ϑ_add_x_lt_ℓ_succ (x:=0) (hx:=by omega)⟩) + (roundRelation (mp := mp) 𝔽q β (ϑ:=ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) (Fin.last ℓ)) + (lastBlockOracleVerifier 𝔽q β (mp := mp) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑)) + (rbrKnowledgeError := lastBlockRbrKnowledgeError (L := L) 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) := by + have h_ϑ_le_ℓ : ϑ ≤ ℓ := Nat.le_of_dvd (by exact Nat.pos_of_neZero ℓ) (by exact hdiv.out) + apply OracleVerifier.castInOut_rbrKnowledgeSoundness + (relIn₁ := roundRelation (mp := mp) 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (𝓑:=𝓑) ⟨(ℓ / ϑ - 1) * ϑ + (0 : Fin (ϑ + 1)), by + apply lastBlockIdx_mul_ϑ_add_x_lt_ℓ_succ (x:=0) (hx:=by omega)⟩) + (relOut₁ := roundRelation (mp := mp) 𝔽q β (ϑ:=ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) + ⟨(ℓ / ϑ - 1) * ϑ + (Fin.last ϑ : Fin (ϑ + 1)), by + apply lastBlockIdx_mul_ϑ_add_x_lt_ℓ_succ (x:=ϑ) (hx:=by omega)⟩) + (h_stmtIn := by + apply Statement.of_fin_eq + apply Fin.eq_of_val_eq + simp only [Nat.sub_mul, one_mul, Fin.coe_ofNat_eq_mod, Nat.zero_mod, add_zero]) + (h_stmtOut := by + apply Statement.of_fin_eq + apply Fin.eq_of_val_eq + simp only [Fin.val_last, Nat.sub_mul] + rw [Nat.div_mul_cancel (by exact hdiv.out), Nat.one_mul] + omega) + (h_idxIn := by + apply OracleStatement.idx_eq + apply Fin.eq_of_val_eq + simp only [Nat.sub_mul, one_mul, Fin.coe_ofNat_eq_mod, Nat.zero_mod, add_zero]) + (h_idxOut := by + apply OracleStatement.idx_eq + apply Fin.eq_of_val_eq + simp only [Fin.val_last, Nat.sub_mul] + rw [Nat.div_mul_cancel (by exact hdiv.out), Nat.one_mul] + omega) + (h_ostmtIn := by + apply OracleStatement.heq_of_fin_eq + apply Fin.eq_of_val_eq + simp only [Nat.sub_mul, one_mul, Fin.coe_ofNat_eq_mod, Nat.zero_mod, add_zero]) + (h_ostmtOut := by + apply OracleStatement.heq_of_fin_eq + apply Fin.eq_of_val_eq + simp only [Fin.val_last, Nat.sub_mul] + rw [Nat.div_mul_cancel (by exact hdiv.out), Nat.one_mul] + omega) + (h_witIn := by rfl) + (h_witOut := by + apply Witness.of_fin_eq + apply Fin.eq_of_val_eq + simp only [Fin.val_last, Nat.sub_mul] + rw [Nat.div_mul_cancel (by exact hdiv.out), Nat.one_mul] + omega) + (h_Oₛᵢ := by + apply instOracleStatementBinaryBasefold_heq_of_fin_eq + apply Fin.eq_of_val_eq + simp only [Nat.sub_mul, one_mul, Fin.coe_ofNat_eq_mod, Nat.zero_mod, add_zero]) + (h_relIn := by + apply roundRelation.of_fin_eq + apply Fin.eq_of_val_eq + simp only [Nat.sub_mul, one_mul, Fin.coe_ofNat_eq_mod, Nat.zero_mod, add_zero]) + (h_relOut := by + apply roundRelation.of_fin_eq + apply Fin.eq_of_val_eq + simp only [Fin.val_last, Nat.sub_mul] + rw [Nat.div_mul_cancel (by exact hdiv.out), Nat.one_mul] + omega) + (ε := lastBlockRbrKnowledgeError (L := L) 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) + let bIdx := ℓ / ϑ - 1 + let stmt : Fin (ϑ + 1) → Type := fun i => Statement (L := L) (ℓ:=ℓ) Context + ⟨bIdx * ϑ + i, by apply lastBlockIdx_mul_ϑ_add_x_lt_ℓ_succ (hx:=by omega)⟩ + let oStmt := fun i: Fin (ϑ + 1) => OracleStatement 𝔽q β ϑ + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + ⟨bIdx * ϑ + i, by apply lastBlockIdx_mul_ϑ_add_x_lt_ℓ_succ (hx:=by omega)⟩ + let foldRelayRoundsRbrKnowledgeSoundness := OracleVerifier.seqCompose_rbrKnowledgeSoundness + (oSpec := []ₒ) (m := ϑ) + (Stmt := stmt) + (OStmt := oStmt) + (pSpec := fun i => pSpecFoldRelay (L:=L)) + (V := fun i => by + have hNCR : ¬ isCommitmentRound ℓ ϑ ⟨bIdx * ϑ + i, lastBlockIdx_mul_ϑ_add_fin_lt_ℓ i⟩ := + lastBlockIdx_isNeCommitmentRound i + exact foldRelayOracleVerifier (L:=L) 𝔽q β (mp := mp) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) + ⟨bIdx * ϑ + i, lastBlockIdx_mul_ϑ_add_fin_lt_ℓ i⟩ hNCR + ) + (rel := fun i ↦ + roundRelation (mp := mp) 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) + (⟨↑bIdx * ϑ + ↑i, by + apply lastBlockIdx_mul_ϑ_add_x_lt_ℓ_succ (hx:=by omega)⟩ : Fin (ℓ + 1))) + (rbrKnowledgeError := fun i => + foldRelayKnowledgeError 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + ⟨bIdx * ϑ + i, lastBlockIdx_mul_ϑ_add_fin_lt_ℓ i⟩) + (init := init) (impl := impl) + have hCur : + OracleVerifier.rbrKnowledgeSoundness init impl + (roundRelation (mp := mp) 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) + ⟨bIdx * ϑ + (0 : Fin (ϑ + 1)), by + apply lastBlockIdx_mul_ϑ_add_x_lt_ℓ_succ (x:=0) (hx:=by omega)⟩) + (roundRelation (mp := mp) 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) + ⟨bIdx * ϑ + (Fin.last ϑ : Fin (ϑ + 1)), by + apply lastBlockIdx_mul_ϑ_add_x_lt_ℓ_succ (x:=ϑ) (hx:=by omega)⟩) + (OracleVerifier.seqCompose stmt oStmt (fun i => by + have hNCR : ¬ isCommitmentRound ℓ ϑ ⟨bIdx * ϑ + i, lastBlockIdx_mul_ϑ_add_fin_lt_ℓ i⟩ := + lastBlockIdx_isNeCommitmentRound i + exact foldRelayOracleVerifier (L:=L) 𝔽q β (mp := mp) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) + ⟨bIdx * ϑ + i, lastBlockIdx_mul_ϑ_add_fin_lt_ℓ i⟩ hNCR)) + (fun combinedIdx => + let ij := seqComposeChallengeIdxToSigma combinedIdx + foldRelayKnowledgeError 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + ⟨bIdx * ϑ + ij.1, lastBlockIdx_mul_ϑ_add_fin_lt_ℓ ij.1⟩ ij.2) := by + apply foldRelayRoundsRbrKnowledgeSoundness + intro (i : Fin ϑ) + have hNCR : ¬ isCommitmentRound ℓ ϑ ⟨bIdx * ϑ + i, lastBlockIdx_mul_ϑ_add_fin_lt_ℓ i⟩ := + lastBlockIdx_isNeCommitmentRound i + exact foldRelayOracleVerifier_rbrKnowledgeSoundness (L := L) 𝔽q β (ϑ := ϑ) (mp := mp) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) (init := init) (impl := impl) + (i := ⟨bIdx * ϑ + i, lastBlockIdx_mul_ϑ_add_fin_lt_ℓ i⟩) hNCR + exact OracleVerifier.rbrKnowledgeSoundness_of_eq_error (h := hCur) (h_ε := by + intro k + simp? [lastBlockRbrKnowledgeError, bIdx, stmt, oStmt]) + +/-! The commitment-round index inside a non-last block. -/ +def nonLastSingleBlockCommitIdx (bIdx : Fin (ℓ / ϑ - 1)) : Fin ℓ := + ⟨bIdx * ϑ + (ϑ - 1), by + let fv : Fin ϑ := ⟨ϑ - 1, by + have h := NeZero.one_le (n := ϑ) + exact Nat.sub_one_lt_of_lt h + ⟩ + change bIdx.val * ϑ + fv.val < ℓ + 0 + apply bIdx_mul_ϑ_add_i_lt_ℓ_succ + ⟩ + +/-! RBR knowledge error for the fold-relay prefix inside one non-last block. -/ +def nonLastSingleBlockFoldRelayRbrKnowledgeError (bIdx : Fin (ℓ / ϑ - 1)) + (k : (pSpecFoldRelaySequence (L := L) (n := ϑ - 1)).ChallengeIdx) : ℝ≥0 := + let ij := seqComposeChallengeIdxToSigma k + foldRelayKnowledgeError 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + ⟨bIdx * ϑ + ij.1, bIdx_mul_ϑ_add_i_fin_ℓ_pred_lt_ℓ bIdx ij.1⟩ ij.2 + +/-! RBR knowledge error for one non-last block (fold-relay prefix + fold-commit suffix). -/ +def nonLastSingleBlockRbrKnowledgeError (bIdx : Fin (ℓ / ϑ - 1)) + (k : (pSpecFullNonLastBlock 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) bIdx).ChallengeIdx) : ℝ≥0 := + Sum.elim + (nonLastSingleBlockFoldRelayRbrKnowledgeError (L := L) 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) bIdx) + (foldCommitKnowledgeError 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := nonLastSingleBlockCommitIdx (ℓ := ℓ) (ϑ := ϑ) bIdx)) + (ChallengeIdx.sumEquiv.symm k) + +/-! RBR KS for one non-last block verifier. -/ +theorem nonLastSingleBlockOracleVerifier_rbrKnowledgeSoundness + (bIdx : Fin (ℓ / ϑ - 1)) : + (nonLastSingleBlockOracleVerifier 𝔽q β (mp := mp) (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) bIdx).rbrKnowledgeSoundness init impl + (relIn := roundRelation (mp := mp) 𝔽q β (ϑ:=ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) + ⟨bIdx * ϑ, by + apply Nat.lt_trans (m:=ℓ) (h₁:=by + change bIdx.val * ϑ + (⟨0, by exact Nat.pos_of_neZero ϑ⟩ : Fin ϑ).val < ℓ + 0 + apply bIdx_mul_ϑ_add_i_lt_ℓ_succ + ) (by omega) + ⟩) + (relOut := roundRelation (mp := mp) 𝔽q β (ϑ:=ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) ⟨(bIdx + 1) * ϑ, bIdx_succ_mul_ϑ_lt_ℓ_succ bIdx⟩) + (rbrKnowledgeError := nonLastSingleBlockRbrKnowledgeError (L := L) 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) bIdx) := by + unfold nonLastSingleBlockOracleVerifier nonLastSingleBlockRbrKnowledgeError + apply OracleVerifier.append_rbrKnowledgeSoundness + (init := init) (impl := impl) + (rel₁ := roundRelation (mp := mp) 𝔽q β (ϑ:=ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) + ⟨bIdx * ϑ, by + apply Nat.lt_trans (m:=ℓ) (h₁:=by + change bIdx.val * ϑ + (⟨0, by exact Nat.pos_of_neZero ϑ⟩ : Fin ϑ).val < ℓ + 0 + apply bIdx_mul_ϑ_add_i_lt_ℓ_succ + ) (by omega) + ⟩) + (rel₂ := roundRelation (mp := mp) 𝔽q β (ϑ:=ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) + ⟨bIdx * ϑ + (ϑ - 1), by + let fv: Fin ϑ := ⟨ϑ - 1, by + have h := NeZero.one_le (n:=ϑ) + exact Nat.sub_one_lt_of_lt h + ⟩ + change ↑bIdx * ϑ + fv.val < ℓ + 1 + apply bIdx_mul_ϑ_add_i_lt_ℓ_succ (m:=1) + ⟩) + (rel₃ := roundRelation (mp := mp) 𝔽q β (ϑ:=ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) ⟨(bIdx + 1) * ϑ, bIdx_succ_mul_ϑ_lt_ℓ_succ bIdx⟩) + (rbrKnowledgeError₁ := nonLastSingleBlockFoldRelayRbrKnowledgeError (L := L) 𝔽q β + (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) bIdx) + (rbrKnowledgeError₂ := foldCommitKnowledgeError 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := nonLastSingleBlockCommitIdx (ℓ := ℓ) (ϑ := ϑ) bIdx)) + · let stmt : Fin (ϑ - 1 + 1) → Type := + fun i => Statement (L := L) (ℓ := ℓ) Context + ⟨bIdx * ϑ + i, bIdx_mul_ϑ_add_i_cast_lt_ℓ_succ bIdx i⟩ + let oStmt := fun i: Fin (ϑ - 1 + 1) => + OracleStatement 𝔽q β ϑ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + ⟨bIdx * ϑ + i, bIdx_mul_ϑ_add_i_cast_lt_ℓ_succ bIdx i⟩ + let hSeq := OracleVerifier.seqCompose_rbrKnowledgeSoundness + (oSpec := []ₒ) (m := ϑ - 1) + (Stmt := stmt) + (OStmt := oStmt) + (pSpec := fun _ => pSpecFoldRelay (L:=L)) + (V := fun i => by + have hNCR : ¬ isCommitmentRound ℓ ϑ ⟨bIdx * ϑ + i, bIdx_mul_ϑ_add_i_fin_ℓ_pred_lt_ℓ bIdx i⟩ + := isNeCommitmentRound (r:=r) (ℓ:=ℓ) (𝓡:=𝓡) (ϑ:=ϑ) bIdx (x:=i.val) (hx:=by omega) + exact foldRelayOracleVerifier (L:=L) 𝔽q β (mp := mp) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (𝓑:=𝓑) ⟨bIdx * ϑ + i, bIdx_mul_ϑ_add_i_fin_ℓ_pred_lt_ℓ bIdx i⟩ hNCR + ) + (rel := fun i ↦ + roundRelation (mp := mp) 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) + (⟨↑bIdx * ϑ + ↑i, bIdx_mul_ϑ_add_i_cast_lt_ℓ_succ bIdx i⟩ : Fin (ℓ + 1))) + (rbrKnowledgeError := fun i => + foldRelayKnowledgeError 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + ⟨bIdx * ϑ + i, bIdx_mul_ϑ_add_i_fin_ℓ_pred_lt_ℓ bIdx i⟩) + (init := init) (impl := impl) + have hSeq' := hSeq (by + intro i + have hNCR : ¬ isCommitmentRound ℓ ϑ ⟨bIdx * ϑ + i, bIdx_mul_ϑ_add_i_fin_ℓ_pred_lt_ℓ bIdx i⟩ := + isNeCommitmentRound (r:=r) (ℓ:=ℓ) (𝓡:=𝓡) (ϑ:=ϑ) bIdx (x:=i.val) (hx:=by omega) + exact foldRelayOracleVerifier_rbrKnowledgeSoundness (L := L) 𝔽q β (ϑ := ϑ) (mp := mp) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) (init := init) (impl := impl) + (i := ⟨bIdx * ϑ + i, bIdx_mul_ϑ_add_i_fin_ℓ_pred_lt_ℓ bIdx i⟩) hNCR + ) + exact OracleVerifier.rbrKnowledgeSoundness_of_eq_error (h := by + simpa using hSeq') (h_ε := by + intro k + unfold nonLastSingleBlockFoldRelayRbrKnowledgeError + rfl) + · have h_ϑ_gt_zero : ϑ > 0 := Nat.pos_of_neZero ϑ + have h_idxOut_eq : + (nonLastSingleBlockCommitIdx (ℓ := ℓ) (ϑ := ϑ) bIdx).succ = + (⟨(bIdx + 1) * ϑ, bIdx_succ_mul_ϑ_lt_ℓ_succ bIdx⟩ : Fin (ℓ + 1)) := by + ext + change (bIdx.val * ϑ + (ϑ - 1)) + 1 = (bIdx.val + 1) * ϑ + rw [Nat.add_assoc, Nat.sub_add_cancel (NeZero.one_le (n := ϑ))] + rw [Nat.add_mul, Nat.one_mul] + apply OracleVerifier.castInOut_rbrKnowledgeSoundness + (h_stmtIn := by + apply Statement.of_fin_eq + simp only [Fin.castSucc_mk]) + (h_stmtOut := by + apply Statement.of_fin_eq + exact h_idxOut_eq) + (h_idxIn := by + apply OracleStatement.idx_eq + simp only [Fin.castSucc_mk]) + (h_idxOut := by + apply OracleStatement.idx_eq + exact h_idxOut_eq) + (h_ostmtIn := by + apply OracleStatement.heq_of_fin_eq + simp only [Fin.castSucc_mk]) + (h_ostmtOut := by + apply OracleStatement.heq_of_fin_eq + exact h_idxOut_eq) + (h_witIn := by + rfl) + (h_witOut := by + apply Witness.of_fin_eq + exact h_idxOut_eq) + (h_Oₛᵢ := by + apply instOracleStatementBinaryBasefold_heq_of_fin_eq + simp only [Fin.castSucc_mk]) + (h_relIn := by + apply roundRelation.of_fin_eq + simp only [Fin.castSucc_mk]) + (h_relOut := by + apply roundRelation.of_fin_eq + exact h_idxOut_eq) + (ε := foldCommitKnowledgeError 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := nonLastSingleBlockCommitIdx (ℓ := ℓ) (ϑ := ϑ) bIdx)) + exact foldCommitOracleVerifier_rbrKnowledgeSoundness (L := L) 𝔽q β (ϑ := ϑ) (mp := mp) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) (init := init) (impl := impl) + (i := nonLastSingleBlockCommitIdx (ℓ := ℓ) (ϑ := ϑ) bIdx) + (hCR := isCommitmentRoundOfNonLastBlock (r:=r) (𝓡:=𝓡) bIdx) + +/-! RBR knowledge error for non-last blocks: seqCompose over non-last blocks. -/ +def nonLastBlocksRbrKnowledgeError + (k : (pSpecNonLastBlocks 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate)).ChallengeIdx) : + ℝ≥0 := + let ij := seqComposeChallengeIdxToSigma k + nonLastSingleBlockRbrKnowledgeError (L := L) 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ij.1 ij.2 + +/-! RBR KS for non-last blocks verifier (seqCompose of nonLastSingleBlock). -/ +theorem nonLastBlocksOracleVerifier_rbrKnowledgeSoundness : + (nonLastBlocksOracleVerifier 𝔽q β (mp := mp) (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑)).rbrKnowledgeSoundness init impl + (relIn := roundRelation (mp := mp) 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (𝓑:=𝓑) ⟨0 * ϑ, by omega⟩) + (relOut := roundRelation (mp := mp) 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) + ⟨(ℓ / ϑ - 1) * ϑ, by apply lastBlockIdx_mul_ϑ_add_x_lt_ℓ_succ (x:=0) (hx:=by omega)⟩) + (rbrKnowledgeError := nonLastBlocksRbrKnowledgeError (L := L) 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) := by + unfold nonLastBlocksOracleVerifier nonLastBlocksRbrKnowledgeError + simp only + refine OracleVerifier.seqCompose_rbrKnowledgeSoundness + (oSpec := []ₒ) + (Stmt := fun i => Statement (L := L) (ℓ := ℓ) Context ⟨i * ϑ, blockIdx_mul_ϑ_lt_ℓ_succ i⟩) + (OStmt := fun i => + OracleStatement 𝔽q β ϑ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ⟨i * ϑ, blockIdx_mul_ϑ_lt_ℓ_succ i⟩) + (pSpec := fun (bIdx : Fin (ℓ / ϑ - 1)) => pSpecFullNonLastBlock 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) bIdx) + (V := fun bIdx => nonLastSingleBlockOracleVerifier (L := L) 𝔽q β (mp := mp) (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) bIdx) + (rel := fun i => roundRelation (mp := mp) 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) ⟨i * ϑ, blockIdx_mul_ϑ_lt_ℓ_succ i⟩) + (rbrKnowledgeError := fun bIdx => + nonLastSingleBlockRbrKnowledgeError (L := L) 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) bIdx) + (init := init) (impl := impl) ?_ + intro bIdx + exact nonLastSingleBlockOracleVerifier_rbrKnowledgeSoundness (L := L) 𝔽q β (ϑ := ϑ) + (mp := mp) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) (init := init) (impl := impl) bIdx + +/-! RBR knowledge error for sumcheck-fold: append of non-last blocks and last block. -/ +def sumcheckFoldKnowledgeError (j : (pSpecSumcheckFold 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate)).ChallengeIdx) : ℝ≥0 := + Sum.elim + (f := nonLastBlocksRbrKnowledgeError (L := L) 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) + (g := lastBlockRbrKnowledgeError (L := L) 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) + (ChallengeIdx.sumEquiv.symm j) + +/-! Round-by-round knowledge soundness for the sumcheck fold oracle verifier. + Proof: append (nonLastBlocks, lastBlock) has RBR KS by append_rbrKnowledgeSoundness; + then castInOut preserves it; finally rbrKnowledgeSoundness_of_eq_error gives the flat + sumcheckFoldKnowledgeError. The error equality (flat = Sum.elim form) remains. -/ theorem sumcheckFoldOracleVerifier_rbrKnowledgeSoundness : (sumcheckFoldOracleVerifier 𝔽q β (mp := mp) (𝓑 := 𝓑)).rbrKnowledgeSoundness init impl (pSpec := pSpecSumcheckFold 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) @@ -1184,9 +1571,60 @@ theorem sumcheckFoldOracleVerifier_rbrKnowledgeSoundness : (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) 0 ) (relOut := roundRelation (mp := mp) 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) (Fin.last ℓ) ) - (rbrKnowledgeError := sumcheckFoldKnowledgeError 𝔽q β (ϑ:=ϑ)) := by + (rbrKnowledgeError := sumcheckFoldKnowledgeError (L := L) 𝔽q β (ϑ:=ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) := by unfold sumcheckFoldOracleVerifier pSpecSumcheckFold - sorry + have hAppend := OracleVerifier.append_rbrKnowledgeSoundness + (init := init) (impl := impl) + (rel₁ := roundRelation (mp := mp) 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) + ⟨0 * ϑ, by omega⟩) + (rel₂ := roundRelation (mp := mp) 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (𝓑:=𝓑) ⟨(ℓ / ϑ - 1) * ϑ, by apply lastBlockIdx_mul_ϑ_add_x_lt_ℓ_succ (x:=0) (hx:=by omega)⟩) + (rel₃ := roundRelation (mp := mp) 𝔽q β (ϑ:=ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) (Fin.last ℓ)) + (V₁ := nonLastBlocksOracleVerifier 𝔽q β (mp := mp) (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑)) + (V₂ := lastBlockOracleVerifier 𝔽q β (mp := mp) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑)) + (rbrKnowledgeError₁ := nonLastBlocksRbrKnowledgeError (L := L) 𝔽q β (ϑ:=ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) + (rbrKnowledgeError₂ := lastBlockRbrKnowledgeError (L := L) 𝔽q β (ϑ:=ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) + (h₁ := by + simpa using (nonLastBlocksOracleVerifier_rbrKnowledgeSoundness (L := L) 𝔽q β + (ϑ := ϑ) (mp := mp) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) + (init := init) (impl := impl))) + (h₂ := by + simpa using (lastBlockOracleVerifier_rbrKnowledgeSoundness (L := L) 𝔽q β + (ϑ := ϑ) (mp := mp) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) + (init := init) (impl := impl))) + apply OracleVerifier.castInOut_rbrKnowledgeSoundness + (h_stmtIn := by + apply Statement.of_fin_eq + apply fin_zero_mul_eq) + (h_stmtOut := by rfl) + (h_idxIn := by + apply OracleStatement.idx_eq + apply fin_zero_mul_eq) + (h_idxOut := by rfl) + (h_ostmtIn := by + apply OracleStatement.heq_of_fin_eq + apply fin_zero_mul_eq) + (h_ostmtOut := by rfl) + (h_witIn := by + apply Witness.of_fin_eq + apply fin_zero_mul_eq) + (h_witOut := by rfl) + (h_Oₛᵢ := by + apply instOracleStatementBinaryBasefold_heq_of_fin_eq + ext + simp only [zero_mul, Fin.coe_ofNat_eq_mod, Nat.zero_mod]) + (h_relIn := by + apply roundRelation.of_fin_eq + apply fin_zero_mul_eq) + (h_relOut := by rfl) + (ε := sumcheckFoldKnowledgeError (L := L) 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) + (hRbrKs := hAppend) end SecurityProps @@ -1195,7 +1633,7 @@ end ComponentReductions section CoreInteractionPhaseReduction -/-- The final oracle verifier that composes sumcheckFold with finalSumcheckStep -/ +/-! The final oracle verifier that composes sumcheckFold with finalSumcheckStep -/ @[reducible] def coreInteractionOracleVerifier := OracleVerifier.append (oSpec:=[]ₒ) @@ -1207,10 +1645,11 @@ def coreInteractionOracleVerifier := (OStmt₃ := OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ (Fin.last ℓ)) (pSpec₁ := pSpecSumcheckFold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) (pSpec₂ := pSpecFinalSumcheckStep (L:=L)) - (V₁ := sumcheckFoldOracleVerifier 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) (mp := BBF_SumcheckMultiplierParam)) + (V₁ := sumcheckFoldOracleVerifier 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) + (mp := BBF_SumcheckMultiplierParam)) (V₂ := finalSumcheckVerifier 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑)) -/-- The final oracle reduction that composes sumcheckFold with finalSumcheckStep -/ +/-! The final oracle reduction that composes sumcheckFold with finalSumcheckStep -/ @[reducible] def coreInteractionOracleReduction := OracleReduction.append (oSpec:=[]ₒ) @@ -1225,12 +1664,13 @@ def coreInteractionOracleReduction := (OStmt₃ := OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ (Fin.last ℓ)) (pSpec₁ := pSpecSumcheckFold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) (pSpec₂ := pSpecFinalSumcheckStep (L:=L)) - (R₁ := sumcheckFoldOracleReduction 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) (mp := BBF_SumcheckMultiplierParam)) + (R₁ := sumcheckFoldOracleReduction 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) + (mp := BBF_SumcheckMultiplierParam)) (R₂ := finalSumcheckOracleReduction 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑)) variable {σ : Type} {init : ProbComp σ} {impl : QueryImpl []ₒ (StateT σ ProbComp)} -/-- Perfect completeness for the core interaction oracle reduction -/ +/-! Perfect completeness for the core interaction oracle reduction -/ theorem coreInteractionOracleReduction_perfectCompleteness (hInit : NeverFail init) [(j : pSpecFold.ChallengeIdx) → Fintype ((pSpecFold (L := L)).Challenge j)] [(j : pSpecFold.ChallengeIdx) → Inhabited ((pSpecFold (L := L)).Challenge j)] @@ -1266,7 +1706,7 @@ def coreInteractionOracleRbrKnowledgeError (j : (pSpecCoreInteraction 𝔽q β ( (g := fun i => finalSumcheckKnowledgeError (L := L) i) (ChallengeIdx.sumEquiv.symm j) -/-- Round-by-round knowledge soundness for the core interaction oracle verifier -/ +/-! Round-by-round knowledge soundness for the core interaction oracle verifier -/ theorem coreInteractionOracleVerifier_rbrKnowledgeSoundness : (coreInteractionOracleVerifier 𝔽q β (𝓑 := 𝓑)).rbrKnowledgeSoundness init impl (pSpec := pSpecCoreInteraction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) diff --git a/ArkLib/ProofSystem/Binius/BinaryBasefold/Prelude.lean b/ArkLib/ProofSystem/Binius/BinaryBasefold/Prelude.lean index 7809be102..9a21d83d5 100644 --- a/ArkLib/ProofSystem/Binius/BinaryBasefold/Prelude.lean +++ b/ArkLib/ProofSystem/Binius/BinaryBasefold/Prelude.lean @@ -410,7 +410,12 @@ lemma coe_fin_pow_two_eq_bitsOfIndex {n : ℕ} (k : Fin (2 ^ n)) : · simp only [cast_one] · simp only [cast_zero] -lemma multilinear_eval_eq_sum_bool_hypercube (challenges : Fin ℓ → L) (t : ↥L⦃≤ 1⦄[X Fin ℓ]) : +/-- **Multilinear extension over the Boolean hypercube**: +as the sum of its values on all Boolean vertices `bitsOfIndex x`, weighted by +`multilinearWeight challenges x`, the standard multilinear “eq” polynomial. +i.e., `t(challenges) = ∑ x ∈ {0, 1}, eq(challenges, x) * t(x)`. +-/ +theorem multilinear_eval_eq_sum_bool_hypercube (challenges : Fin ℓ → L) (t : ↥L⦃≤ 1⦄[X Fin ℓ]) : t.val.eval challenges = ∑ (x : Fin (2^ℓ)), (multilinearWeight (r := challenges) (i := x)) * (t.val.eval (bitsOfIndex x) : L) := by sorry @@ -589,6 +594,45 @@ lemma qMap_total_fiber_congr_steps ⟨x.val, by subst h_steps_eq; exact x.is_lt⟩ := by subst h_steps_eq; rfl +omit [CharP L 2] [DecidableEq 𝔽q] hF₂ h_β₀_eq_1 [NeZero ℓ] in +lemma qMap_total_fiber_congr_source + {sourceIdx₁ sourceIdx₂ : Fin r} (steps : ℕ) {destIdx : Fin r} + (h_sourceIdx_eq : sourceIdx₁ = sourceIdx₂) + (h_destIdx : destIdx = sourceIdx₁.val + steps) + (h_destIdx_le : destIdx ≤ ℓ) + (y : sDomain 𝔽q β h_ℓ_add_R_rate (i := destIdx)) : + qMap_total_fiber 𝔽q β (i := sourceIdx₁) (steps := steps) (h_destIdx := h_destIdx) + (h_destIdx_le := h_destIdx_le) (y := y) = + cast (by subst h_sourceIdx_eq; rfl) (qMap_total_fiber 𝔽q β (i := sourceIdx₂) + (steps := steps) (h_destIdx := by omega) (h_destIdx_le := h_destIdx_le) (y := y)) := by + subst h_sourceIdx_eq; rfl + +omit [CharP L 2] [DecidableEq 𝔽q] hF₂ h_β₀_eq_1 [NeZero ℓ] in +lemma qMap_total_fiber_congr_source_apply + {sourceIdx₁ sourceIdx₂ : Fin r} (steps : ℕ) {destIdx : Fin r} + (h_sourceIdx_eq : sourceIdx₁ = sourceIdx₂) + (h_destIdx : destIdx = sourceIdx₁.val + steps) + (h_destIdx_le : destIdx ≤ ℓ) + (y : sDomain 𝔽q β h_ℓ_add_R_rate (i := destIdx)) (x : Fin (2 ^ steps)) : + qMap_total_fiber 𝔽q β (i := sourceIdx₁) (steps := steps) (h_destIdx := h_destIdx) + (h_destIdx_le := h_destIdx_le) (y := y) x = + cast (by subst h_sourceIdx_eq; rfl) (qMap_total_fiber 𝔽q β (i := sourceIdx₂) + (steps := steps) (h_destIdx := by omega) (h_destIdx_le := h_destIdx_le) (y := y) x) := by + subst h_sourceIdx_eq; rfl + +omit [CharP L 2] [DecidableEq 𝔽q] hF₂ h_β₀_eq_1 [NeZero ℓ] in +lemma qMap_total_fiber_congr_dest + {sourceIdx : Fin r} (steps : ℕ) {destIdx₁ destIdx₂ : Fin r} + (h_destIdx_congr : destIdx₁ = destIdx₂) + (h_destIdx : destIdx₁ = sourceIdx.val + steps) + (h_destIdx_le : destIdx₁ ≤ ℓ) + (y : sDomain 𝔽q β h_ℓ_add_R_rate (i := destIdx₁)) : + qMap_total_fiber 𝔽q β (i := sourceIdx) (steps := steps) (destIdx := destIdx₁) + (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) (y := y) = + qMap_total_fiber 𝔽q β (i := sourceIdx) + (steps := steps) (destIdx := destIdx₂) (h_destIdx := by omega) (h_destIdx_le := by omega) (y := cast (by subst h_destIdx_congr; rfl) y) := by + subst h_destIdx_congr; rfl + /- TODO : state that the fiber of y is the set of all 2 ^ steps points in the larger domain S⁽ⁱ⁾ that get mapped to y by the series of quotient maps q⁽ⁱ⁾, ..., q⁽ⁱ⁺steps⁻¹⁾. -/ @@ -851,6 +895,54 @@ lemma qMap_total_fiber_basis_sum_repr (i : Fin r) {destIdx : Fin r} (steps : ℕ apply Subtype.ext -- convert to equality in Subtype embedding rw [hx_val_sum] +omit [CharP L 2] [DecidableEq 𝔽q] hF₂ h_β₀_eq_1 [NeZero ℓ] in +theorem qMap_total_fiber_injective (i : Fin r) {destIdx : Fin r} (steps : ℕ) + (h_destIdx : destIdx = i.val + steps) (h_destIdx_le : destIdx ≤ ℓ) + (y : sDomain 𝔽q β h_ℓ_add_R_rate (i := destIdx)) : + Function.Injective (qMap_total_fiber 𝔽q β (i := i) (steps := steps) + h_destIdx h_destIdx_le (y := y)) := by + intro k₁ k₂ h_eq + let basis_x := sDomain_basis 𝔽q β h_ℓ_add_R_rate i (Sdomain_bound (by omega)) + set fiberMap := qMap_total_fiber 𝔽q β (i := i) (steps := steps) + h_destIdx h_destIdx_le (y := y) + have h_coeffs_eq : basis_x.repr (fiberMap k₁) = basis_x.repr (fiberMap k₂) := by + rw [h_eq] + have h_bits_eq : ∀ j : Fin steps, + Nat.getBit (k := j) (n := k₁.val) = Nat.getBit (k := j) (n := k₂.val) := by + intro j + have h_coeff_j_eq : basis_x.repr (fiberMap k₁) ⟨j, by omega⟩ + = basis_x.repr (fiberMap k₂) ⟨j, by omega⟩ := by rw [h_coeffs_eq] + rw [qMap_total_fiber_repr_coeff 𝔽q β (i := i) (steps := steps) + h_destIdx h_destIdx_le (y := y) (j := ⟨j, by omega⟩)] + at h_coeff_j_eq + rw [qMap_total_fiber_repr_coeff 𝔽q β (i := i) (steps := steps) + h_destIdx h_destIdx_le (y := y) (k := k₂) (j := ⟨j, by omega⟩)] + at h_coeff_j_eq + simp only [fiber_coeff, Fin.is_lt, ↓reduceDIte] at h_coeff_j_eq + by_cases hbitj_k₁ : Nat.getBit (k := j) (n := k₁.val) = 0 + · simp only [hbitj_k₁, ↓reduceIte, left_eq_ite_iff, zero_ne_one, imp_false, + Decidable.not_not] at ⊢ h_coeff_j_eq + simp only [h_coeff_j_eq] + · simp only [hbitj_k₁, ↓reduceIte, right_eq_ite_iff, one_ne_zero, + imp_false] at ⊢ h_coeff_j_eq + have b1 : Nat.getBit (k := j) (n := k₁.val) = 1 := by + have h := Nat.getBit_eq_zero_or_one (k := j) (n := k₁.val) + simp only [hbitj_k₁, false_or] at h + exact h + have b2 : Nat.getBit (k := j) (n := k₂.val) = 1 := by + have h := Nat.getBit_eq_zero_or_one (k := j) (n := k₂.val) + simp only [h_coeff_j_eq, false_or] at h + exact h + simp only [b1, b2] + apply Fin.eq_of_val_eq + apply eq_iff_eq_all_getBits.mpr + intro k + by_cases h_k : k < steps + · simp only [h_bits_eq ⟨k, by omega⟩] + · conv_lhs => rw [Nat.getBit_of_lt_two_pow] + conv_rhs => rw [Nat.getBit_of_lt_two_pow] + simp only [h_k, ↓reduceIte] + omit [CharP L 2] [DecidableEq 𝔽q] hF₂ h_β₀_eq_1 [NeZero ℓ] in theorem card_qMap_total_fiber (i : Fin r) {destIdx : Fin r} (steps : ℕ) (h_destIdx : destIdx = i.val + steps) (h_destIdx_le : destIdx ≤ ℓ) @@ -858,62 +950,9 @@ theorem card_qMap_total_fiber (i : Fin r) {destIdx : Fin r} (steps : ℕ) Fintype.card (Set.image (qMap_total_fiber 𝔽q β (i := i) (steps := steps) h_destIdx h_destIdx_le (y := y)) Set.univ) = 2 ^ steps := by - -- The cardinality of the image of a function equals the cardinality of its domain - -- if it is injective. rw [Set.card_image_of_injective Set.univ] - -- The domain is `Fin (2 ^ steps)`, which has cardinality `2 ^ steps`. - · -- ⊢ Fintype.card ↑Set.univ = 2 ^ steps - simp only [Fintype.card_setUniv, Fintype.card_fin] - · -- prove that `qMap_total_fiber` is an injective function. - intro k₁ k₂ h_eq - -- Assume two indices `k₁` and `k₂` produce the same point `x`. - let basis_x := sDomain_basis 𝔽q β h_ℓ_add_R_rate i (Sdomain_bound (by omega)) - -- If the points are equal, their basis representations must be equal. - set fiberMap := qMap_total_fiber 𝔽q β (i := i) (steps := steps) - h_destIdx h_destIdx_le (y := y) - have h_coeffs_eq : basis_x.repr (fiberMap k₁) = basis_x.repr (fiberMap k₂) := by - rw [h_eq] - -- The first `steps` coefficients are determined by the bits of `k₁` and `k₂`. - -- If the coefficients are equal, the bits must be equal. - have h_bits_eq : ∀ j : Fin steps, - Nat.getBit (k := j) (n := k₁.val) = Nat.getBit (k := j) (n := k₂.val) := by - intro j - have h_coeff_j_eq : basis_x.repr (fiberMap k₁) ⟨j, by omega⟩ - = basis_x.repr (fiberMap k₂) ⟨j, by omega⟩ := by rw [h_coeffs_eq] - rw [qMap_total_fiber_repr_coeff 𝔽q β (i := i) (steps := steps) - h_destIdx h_destIdx_le (y := y) (j := ⟨j, by omega⟩)] - at h_coeff_j_eq - rw [qMap_total_fiber_repr_coeff 𝔽q β (i := i) (steps := steps) - h_destIdx h_destIdx_le (y := y) (k := k₂) (j := ⟨j, by omega⟩)] - at h_coeff_j_eq - simp only [fiber_coeff, Fin.is_lt, ↓reduceDIte] at h_coeff_j_eq - by_cases hbitj_k₁ : Nat.getBit (k := j) (n := k₁.val) = 0 - · simp only [hbitj_k₁, ↓reduceIte, left_eq_ite_iff, zero_ne_one, imp_false, - Decidable.not_not] at ⊢ h_coeff_j_eq - simp only [h_coeff_j_eq] - · simp only [hbitj_k₁, ↓reduceIte, right_eq_ite_iff, one_ne_zero, - imp_false] at ⊢ h_coeff_j_eq - have b1 : Nat.getBit (k := j) (n := k₁.val) = 1 := by - have h := Nat.getBit_eq_zero_or_one (k := j) (n := k₁.val) - simp only [hbitj_k₁, false_or] at h - exact h - have b2 : Nat.getBit (k := j) (n := k₂.val) = 1 := by - have h := Nat.getBit_eq_zero_or_one (k := j) (n := k₂.val) - simp only [h_coeff_j_eq, false_or] at h - exact h - simp only [b1, b2] - -- Extract the j-th coefficient from h_coeffs_eq and show it implies the bits are equal. - -- If all the bits of two numbers are equal, the numbers themselves are equal. - apply Fin.eq_of_val_eq - -- ⊢ ∀ {n : ℕ} {i j : Fin n}, ↑i = ↑j → i = j - apply eq_iff_eq_all_getBits.mpr - intro k - by_cases h_k : k < steps - · simp only [h_bits_eq ⟨k, by omega⟩] - · -- The bits at positions ≥ steps must be deterministic - conv_lhs => rw [Nat.getBit_of_lt_two_pow] - conv_rhs => rw [Nat.getBit_of_lt_two_pow] - simp only [h_k, ↓reduceIte] + · simp only [Fintype.card_setUniv, Fintype.card_fin] + · exact qMap_total_fiber_injective 𝔽q β i steps h_destIdx h_destIdx_le y omit [CharP L 2] in /-- The images of `qMap_total_fiber` over distinct quotient points `y₁ ≠ y₂` are @@ -1527,6 +1566,27 @@ lemma iterated_fold_transitivity lhs = rhs := by sorry -- admitted for brevity, relies on a lemma like `Fin.dfoldl_add` +omit [CharP L 2] [DecidableEq 𝔽q] hF₂ h_β₀_eq_1 [NeZero ℓ] in +/-- **First-step decomposition**: `iterated_fold(i, steps+1, f, r₀ :: r_rest)` equals +`iterated_fold(i+1, steps, fold(f, r₀), r_rest)`. +Dual to `iterated_fold_last` which decomposes from the last step. -/ +lemma iterated_fold_first (i : Fin r) {midIdx destIdx : Fin r} (steps : ℕ) + (h_midIdx : midIdx.val = i.val + 1) (h_destIdx : destIdx.val = i.val + (steps + 1)) + (h_destIdx_le : destIdx ≤ ℓ) + (f : sDomain 𝔽q β h_ℓ_add_R_rate (i := i) → L) + (r_challenges : Fin (steps + 1) → L) : + iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := i) (steps := steps + 1) (h_destIdx := h_destIdx) + (h_destIdx_le := h_destIdx_le) (f := f) (r_challenges := r_challenges) = + iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := midIdx) (steps := steps) (h_destIdx := by omega) + (h_destIdx_le := h_destIdx_le) + (f := fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := i) + (destIdx := midIdx) (h_destIdx := h_midIdx) + (h_destIdx_le := by omega) f (r_challenges 0)) + (r_challenges := fun j => r_challenges j.succ) := by + sorry + /-- **Definition 4.6** : the single-step vector-matrix-vector multiplication form of `fold` -/ def fold_single_matrix_mul_form (i : Fin r) {destIdx : Fin r} (h_destIdx : destIdx = i.val + 1) (h_destIdx_le : destIdx ≤ ℓ) @@ -2263,11 +2323,12 @@ lemma constantIntermediateEvaluationPoly_eval_eq_const simp only [Polynomial.eval_C, mul_one] omit [CharP L 2] in -/-- When folding from level 0 all the way to level ℓ, the resulting function is constant. -/ -lemma iterated_fold_to_level_ℓ_is_constant +/-- When folding from level 0 all the way to level ℓ, the resulting function is constant +with value `t(challenges)`. -/ +lemma iterated_fold_to_level_ℓ_eval (t : MultilinearPoly L ℓ) (destIdx : Fin r) (h_destIdx : destIdx.val = ℓ) (challenges : Fin ℓ → L) : - let P₀: L[X]_(2 ^ ℓ) := polynomialFromNovelCoeffsF₂ 𝔽q β ℓ (by omega) + let P₀ : L[X]_(2 ^ ℓ) := polynomialFromNovelCoeffsF₂ 𝔽q β ℓ (by omega) (fun ω => t.val.eval (bitsOfIndex ω)) let f₀ := polyToOracleFunc 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (domainIdx := 0) (P := P₀) let f_ℓ := iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := 0) (steps := ℓ) @@ -2275,28 +2336,96 @@ lemma iterated_fold_to_level_ℓ_is_constant (h_destIdx := by simp only [Fin.coe_ofNat_eq_mod, Nat.zero_mod, zero_add]; omega) (h_destIdx_le := by omega) f₀ challenges - ∀ x y, f_ℓ x = f_ℓ y := by - intro P₀ f₀ f_ℓ x y + f_ℓ = fun _ => t.val.eval challenges := by + intro P₀ f₀ f_ℓ + funext x let coeffs := fun (ω : Fin (2 ^ ℓ)) => t.val.eval (bitsOfIndex ω) have h_f_ℓ_eq_poly := iterated_fold_advances_evaluation_poly 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := 0) (steps := ℓ) (destIdx := destIdx) (h_destIdx := by simp only [Fin.coe_ofNat_eq_mod, Nat.zero_mod, zero_add]; omega) (h_destIdx_le := by omega) (coeffs := coeffs) (r_challenges := challenges) -- h_f_ℓ_eq_poly says: f_ℓ = polyToOracleFunc P_ℓ where - -- P_ℓ = intermediateEvaluationPoly with new_coeffs - -- Step 2: When destIdx = ℓ, we have ℓ - ℓ = 0, so new_coeffs : Fin (2^0) = Fin 1 → L - -- This means P_ℓ is a constant polynomial (only one coefficient) + -- P_ℓ = intermediateEvaluationPoly with new_coeffs dsimp only [f_ℓ, f₀, P₀, polynomialFromNovelCoeffsF₂] + -- Rewrite f_ℓ in terms of the intermediate polynomial at level ℓ. -- unfold polyToOracleFunc - rw [←intermediate_poly_P_base 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (h_ℓ := by omega)] - simp only at h_f_ℓ_eq_poly; - rw [h_f_ℓ_eq_poly] - -- f_ℓ x = polyToOracleFunc P_ℓ x = P_ℓ.eval x.val - -- Since P_ℓ is constant, P_ℓ.eval x.val = P_ℓ.eval (𝓑 0) for all x - -- We need to show that intermediateEvaluationPoly with Fin 1 coefficients is constant + rw [←intermediate_poly_P_base 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (h_ℓ := by omega) (coeffs := coeffs)] + -- Now f_ℓ x = (polyToOracleFunc P_ℓ) x = P_ℓ.eval x.val, and P_ℓ is constant. + -- Evaluate both sides at x: + have h_eq := congr_fun (h := h_f_ℓ_eq_poly) (a := x) + conv_lhs => rw [h_eq] + + -- Use the lemma that the intermediate polynomial at level ℓ is the constant t(challenges). dsimp only [polyToOracleFunc] - rw [constantIntermediateEvaluationPoly_eval_eq_const] - omega + conv_rhs => rw [multilinear_eval_eq_sum_bool_hypercube] + let new_coeffs : Fin (2 ^ (ℓ - destIdx.val)) → L := fun j => + ∑ m : Fin (2 ^ ℓ), + multilinearWeight (r := challenges) (i := m) * coeffs ⟨j.val * 2 ^ ℓ + m.val, by + have h_j : j.val = 0 := by + have hj_lt := j.isLt + simp only [h_destIdx, tsub_self, pow_zero, Nat.lt_one_iff] at hj_lt + exact hj_lt + rw [h_j, zero_mul, zero_add] + exact m.isLt⟩ + change Polynomial.eval (↑x) + (intermediateEvaluationPoly 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := destIdx) (h_i := by omega) new_coeffs) + = ∑ x, multilinearWeight challenges x * (MvPolynomial.eval (bitsOfIndex x)) ↑t + + have h_const_eval : + Polynomial.eval (↑x) + (intermediateEvaluationPoly 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := destIdx) (h_i := by omega) new_coeffs) + = + Polynomial.eval (0 : L) + (intermediateEvaluationPoly 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := destIdx) (h_i := by omega) new_coeffs) := by + exact constantIntermediateEvaluationPoly_eval_eq_const 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (destIdx := destIdx) (coeffs := new_coeffs) + (h_destIdx := h_destIdx) (x := ↑x) (y := 0) + + rw [h_const_eval] + dsimp only [new_coeffs, intermediateEvaluationPoly] + rw [Finset.sum_eq_single (a := ⟨0, by + exact Nat.two_pow_pos (ℓ - destIdx.val)⟩) (h₀ := fun j _ hj_ne => by + have h_j_lt := j.isLt + simp only [h_destIdx, tsub_self, pow_zero, Nat.lt_one_iff, Fin.val_eq_zero_iff] at h_j_lt + simp only [Fin.mk_zero', ne_eq] at hj_ne + exfalso + exact hj_ne h_j_lt + ) (h₁ := fun h => by + simp only [Fin.mk_zero', Finset.mem_univ, not_true_eq_false] at h)] + rw [intermediateNovelBasisX_zero_eq_one 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := destIdx) (h_i := by omega)] + simp only [Polynomial.eval_C, mul_one] + apply Finset.sum_congr rfl + intro m hm + have h_idx_eq : (⟨0 * 2 ^ ℓ + m.val, by + have h_j : (0 : Fin (2 ^ (ℓ - destIdx.val))).val = 0 := by + simp only [Fin.val_zero] + rw [zero_mul, zero_add]; exact m.isLt⟩ : Fin (2 ^ ℓ)) = m := by + apply Fin.ext + simp only [zero_mul, zero_add] + rw [h_idx_eq] + +omit [CharP L 2] in +/-- When folding from level 0 all the way to level ℓ, the resulting function is constant. -/ +lemma iterated_fold_to_level_ℓ_is_constant + (t : MultilinearPoly L ℓ) (destIdx : Fin r) (h_destIdx : destIdx.val = ℓ) + (challenges : Fin ℓ → L) : + let P₀ : L[X]_(2 ^ ℓ) := polynomialFromNovelCoeffsF₂ 𝔽q β ℓ (by omega) + (fun ω => t.val.eval (bitsOfIndex ω)) + let f₀ := polyToOracleFunc 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (domainIdx := 0) (P := P₀) + let f_ℓ := iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := 0) (steps := ℓ) + (destIdx := destIdx) + (h_destIdx := by simp only [Fin.coe_ofNat_eq_mod, Nat.zero_mod, zero_add]; omega) + (h_destIdx_le := by omega) + f₀ challenges + ∀ x y, f_ℓ x = f_ℓ y := by + intro P₀ f₀ f_ℓ x y + dsimp only [f_ℓ] + rw [iterated_fold_to_level_ℓ_eval 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (h_destIdx := by omega)] end FoldTheory diff --git a/ArkLib/ProofSystem/Binius/BinaryBasefold/QueryPhase.lean b/ArkLib/ProofSystem/Binius/BinaryBasefold/QueryPhase.lean index 65e2ae78e..5bbd41a7d 100644 --- a/ArkLib/ProofSystem/Binius/BinaryBasefold/QueryPhase.lean +++ b/ArkLib/ProofSystem/Binius/BinaryBasefold/QueryPhase.lean @@ -4,14 +4,12 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: Chung Thai Nguyen, Quang Dao -/ import ArkLib.ProofSystem.Binius.BinaryBasefold.Spec +import ArkLib.ProofSystem.Binius.BinaryBasefold.Soundness import ArkLib.ProofSystem.Binius.BinaryBasefold.ReductionLogic import ArkLib.OracleReduction.Completeness import ArkLib.OracleReduction.Basic import ArkLib.Data.Misc.Basic -set_option maxHeartbeats 400000 -- Increase if needed -set_option profiler true - namespace Binius.BinaryBasefold.QueryPhase /-! @@ -46,374 +44,10 @@ variable {h_ℓ_add_R_rate : ℓ + 𝓡 < r} -- ℓ ∈ {1, ..., r-1} variable {𝓑 : Fin 2 ↪ L} variable [hdiv : Fact (ϑ ∣ ℓ)] -section IndexBounds - -omit [NeZero ℓ] in -@[simp] -lemma k_mul_ϑ_lt_ℓ {k : Fin (ℓ / ϑ)} : - ↑k * ϑ < ℓ := by - have h_mul_eq : (ℓ/ϑ) * ϑ = ℓ := Nat.div_mul_cancel hdiv.out - calc ↑k * ϑ < (ℓ / ϑ) * ϑ := Nat.mul_lt_mul_of_pos_right k.isLt (NeZero.pos ϑ) - _ = ℓ := h_mul_eq - -omit [NeZero ℓ] [NeZero ϑ] in -@[simp] -lemma k_succ_mul_ϑ_le_ℓ {k : Fin (ℓ / ϑ)} : (k.val + 1) * ϑ ≤ ℓ := by - have h_mul_eq : (ℓ/ϑ) * ϑ = ℓ := Nat.div_mul_cancel hdiv.out - calc (k.val + 1) * ϑ ≤ (ℓ / ϑ) * ϑ := Nat.mul_le_mul_right (k := ϑ) (h := by omega) - _ = ℓ := h_mul_eq - -omit [NeZero ℓ] [NeZero ϑ] in -@[simp] -lemma k_succ_mul_ϑ_le_ℓ_₂ {k : Fin (ℓ / ϑ)} : (k.val * ϑ + ϑ) ≤ ℓ := by - conv_lhs => enter [2]; rw [←Nat.one_mul ϑ] - rw [←Nat.add_mul]; - exact k_succ_mul_ϑ_le_ℓ; - -omit [NeZero r] [NeZero ℓ] [NeZero 𝓡] in -@[simp] -lemma lt_r_of_le_ℓ {h_ℓ_add_R_rate : ℓ + 𝓡 < r} {x : ℕ} (h : x ≤ ℓ) - : x < r := by omega - -omit [NeZero r] [NeZero ℓ] [NeZero 𝓡] in -@[simp] -lemma lt_r_of_lt_ℓ {h_ℓ_add_R_rate : ℓ + 𝓡 < r} {x : ℕ} (h : x < ℓ) - : x < r := by omega - -end IndexBounds -open scoped NNReal - -/-! -## Common Proximity Check Helpers - -These functions extract the shared logic between `queryOracleVerifier` -and `queryKnowledgeStateFunction` for proximity testing, allowing code reuse -and ensuring both implementations follow the same logic. --/ - -/-- Extract suffix starting at position `destIdx` from a full challenge. -/ -def extractSuffixFromChallenge (v : sDomain 𝔽q β h_ℓ_add_R_rate ⟨0, by omega⟩) - (destIdx : Fin r) (h_destIdx_le : destIdx ≤ ℓ) : - sDomain 𝔽q β h_ℓ_add_R_rate destIdx := - iteratedQuotientMap 𝔽q β h_ℓ_add_R_rate (i := ⟨0, by omega⟩) (k := destIdx.val) - (h_destIdx := by simp only [zero_add]) (h_destIdx_le := h_destIdx_le) (x := v) - -omit [CharP L 2] [SampleableType L] [DecidableEq 𝔽q] hF₂ [NeZero 𝓡] in -/-- **Congruence Lemma for Challenge Suffixes**: -Allows proving equality between two suffix extractions when the destination indices -are proven equal (`destIdx = destIdx'`), handling the necessary type casting. -/ -lemma extractSuffixFromChallenge_congr_destIdx - (v : sDomain 𝔽q β h_ℓ_add_R_rate ⟨0, by omega⟩) - {destIdx destIdx' : Fin r} - (h_idx_eq : destIdx = destIdx') - (h_le : destIdx ≤ ℓ) - (h_le' : destIdx' ≤ ℓ) : - extractSuffixFromChallenge 𝔽q β v destIdx h_le = - cast (by rw [h_idx_eq]) (extractSuffixFromChallenge 𝔽q β v destIdx' h_le') := by - subst h_idx_eq; rfl - -omit [CharP L 2] [SampleableType L] [DecidableEq 𝔽q] h_β₀_eq_1 in -/-- **First Oracle Equals Polynomial Oracle Function**: -When `strictOracleFoldingConsistencyProp` holds, the first oracle (`getFirstOracle`) equals -the polynomial oracle function `f₀` derived from the multilinear polynomial `t`. -This follows from the consistency property for `j = 0`, where `iterated_fold` with 0 steps -is the identity function. -/ -lemma polyToOracleFunc_eq_getFirstOracle - (t : MultilinearPoly L ℓ) - (i : Fin (ℓ + 1)) - (challenges : Fin i → L) - (oStmt : - ∀ j, OracleStatement 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i j) - (h_consistency : - strictOracleFoldingConsistencyProp 𝔽q β (t := t) (i := i) - (challenges := challenges) (oStmt := oStmt)) : - let P₀ : L[X]_(2 ^ ℓ) := - polynomialFromNovelCoeffsF₂ 𝔽q β ℓ (by omega) (fun ω => t.val.eval (bitsOfIndex ω)) - let f₀ := polyToOracleFunc 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (domainIdx := 0) (P := P₀) - f₀ = getFirstOracle 𝔽q β oStmt := by - intro P₀ f₀ - -- Use strictOracleFoldingConsistencyProp for j = 0 - have h_pos : 0 < toOutCodewordsCount ℓ ϑ i := by - exact (instNeZeroNatToOutCodewordsCount ℓ ϑ i).pos - have h_first_oracle := h_consistency ⟨0, by omega⟩ - dsimp only [strictOracleFoldingConsistencyProp] at h_first_oracle - dsimp only [f₀, P₀, getFirstOracle] at h_first_oracle ⊢ - rw [h_first_oracle] - funext y - conv_rhs => - rw [iterated_fold_congr_steps_index 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (steps' := 0) - (h_destIdx := by simp only [Nat.zero_mod, zero_mul, Fin.coe_ofNat_eq_mod, add_zero]) - (h_destIdx_le := by simp only [zero_mul, zero_le]) - (h_steps_eq_steps' := by simp only [zero_mul])] - rw [iterated_fold_zero_steps 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := 0) - (h_destIdx := by simp only [Nat.zero_mod, zero_mul, Fin.coe_ofNat_eq_mod])] - conv_rhs => simp only [cast_cast, cast_eq]; simp only [←fun_eta_expansion] - -/-- Decompose challenge v at position i into (fiberIndex, suffix). - This is the inverse of `Nat.joinBits` in some sense. - Uses loose indexing with `Fin r`. -/ -def decomposeChallenge (v : sDomain 𝔽q β h_ℓ_add_R_rate ⟨0, by omega⟩) - (i : Fin r) {destIdx : Fin r} (steps : ℕ) - (h_destIdx : destIdx = i.val + steps) - (h_destIdx_le : destIdx ≤ ℓ) : - Fin (2^steps) × sDomain 𝔽q β h_ℓ_add_R_rate destIdx := - (extractMiddleFinMask 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) - (v := v) (i := i) (steps := steps), - extractSuffixFromChallenge 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) - (v := v) (destIdx := destIdx) (h_destIdx_le := h_destIdx_le)) - -/-- This proposition declaratively captures the iterative logic of the verifier. For each repetition -and each folding step, it asserts that the folded value of the function from level `i` must equal -the value of the function from the oracle of the next level `i+ϑ`. - Uses loose indexing with Fin r. -/ -def proximityChecksSpec (γ_challenges : - Fin γ_repetitions → sDomain 𝔽q β h_ℓ_add_R_rate ⟨0, by omega⟩) - (oStmt : ∀ j, OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ (Fin.last ℓ) j) - (fold_challenges : Fin ℓ → L) (final_constant : L) : Prop := - ∀ rep : Fin γ_repetitions, - let v := γ_challenges rep - -- For all folding levels k = 0, ..., ℓ/ϑ - 1, we track c_cur through the iterations - ∀ k_val : Fin (ℓ / ϑ), - let i := k_val.val * ϑ - have h_k: k_val ≤ (ℓ/ϑ - 1) := by omega - have h_i_add_ϑ_le_ℓ : i + ϑ ≤ ℓ := by - calc i + ϑ = k_val * ϑ + ϑ := by omega - _ ≤ (ℓ/ϑ - 1) * ϑ + ϑ := by - apply Nat.add_le_add_right; apply Nat.mul_le_mul_right; omega - _ = ℓ/ϑ * ϑ := by - rw [Nat.sub_mul, one_mul, Nat.sub_add_cancel]; - conv_lhs => rw [←one_mul ϑ] - apply Nat.mul_le_mul_right; omega - _ ≤ ℓ := by apply Nat.div_mul_le_self; - let k_th_oracleIdx: Fin (toOutCodewordsCount ℓ ϑ (Fin.last ℓ)) := - ⟨k_val, by simp only [toOutCodewordsCount, Fin.val_last, - lt_self_iff_false, ↓reduceIte, add_zero, Fin.is_lt];⟩ - have h: k_th_oracleIdx.val * ϑ = i := by rw [show k_th_oracleIdx.val = k_val.val by rfl] - have h_i_lt_ℓ: i < ℓ := by - calc i ≤ ℓ - ϑ := by omega - _ < ℓ := by - apply Nat.sub_lt (by exact Nat.pos_of_neZero ℓ) (by exact Nat.pos_of_neZero ϑ) - -- Create the suffix `(v_{i+ϑ}, ..., v_{ℓ+R-1})` as an element of `S^(i+ϑ)` - let destIdx : Fin r := ⟨i + ϑ, by omega⟩ - let next_suffix_of_v : sDomain 𝔽q β h_ℓ_add_R_rate destIdx := extractSuffixFromChallenge 𝔽q β - (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (v:=v) (destIdx:=destIdx) (h_destIdx_le:=by omega) - - let next_suffix_of_v_fin : Fin (2 ^ (ℓ + 𝓡 - (i + ϑ))) := - ⟨sDomainToFin 𝔽q β h_ℓ_add_R_rate ⟨i + ϑ, by omega⟩ (by - apply Nat.lt_add_of_pos_right_of_le; simp only; omega) next_suffix_of_v, - by simp only [Fin.val_mk, Fin.is_lt]⟩ - - -- Create the fiber evaluation mapping by querying oracle f^(i) at all fiber points - let f_i_on_fiber : Fin (2^ϑ) → L := fun u => - let x: Fin (2 ^ (ℓ + 𝓡 - i)) := by - let fiber_point_num_repr := Nat.joinBits (low := u) (high := next_suffix_of_v_fin) - simp only at fiber_point_num_repr - have h: 2 ^ (ℓ + 𝓡 - (i + ϑ) + ϑ) = 2 ^ (ℓ + 𝓡 - i) := by - simp only [Nat.ofNat_pos, ne_eq, OfNat.ofNat_ne_one, not_false_eq_true, - pow_right_inj₀] - omega - rw [h] at fiber_point_num_repr - exact fiber_point_num_repr - let x_point := finToSDomain 𝔽q β h_ℓ_add_R_rate ⟨i, by omega⟩ (by - apply Nat.lt_add_of_pos_right_of_le; simp only; omega) x - oStmt k_th_oracleIdx x_point - - -- Compute the next value using localized fold matrix form - let cur_challenge_batch : Fin ϑ → L := fun j => fold_challenges ⟨i + j.val, by omega⟩ - - let c_next : L := - single_point_localized_fold_matrix_form 𝔽q β - (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := ⟨i, by omega⟩) (steps := ϑ) - (destIdx := destIdx) (h_destIdx := by dsimp only [destIdx]) (h_destIdx_le := by omega) - (r_challenges := cur_challenge_batch) (y := next_suffix_of_v) - (fiber_eval_mapping := f_i_on_fiber) - - -- NOTE: at i, we do the consistency check FOR THE NEXT LEVEL (`i + ϑ`): - -- `c_next ?= f^(i + ϑ)(v_{i + ϑ}, ..., v_{ℓ+R-1})`, the final check is also covered - let consistency_check : Prop := - let oracle_point_idx := extractMiddleFinMask 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) - (v:=v) (i:=⟨i, by omega⟩) (steps:=ϑ) - let f_i_next_val := - if hk: k_val < ℓ / ϑ - 1 then - let x_next : sDomain 𝔽q β h_ℓ_add_R_rate ⟨i + ϑ, by omega⟩ := next_suffix_of_v - let ⟨x_next', hx_next'⟩ := x_next - oStmt ⟨k_val + 1, by rw [toOutCodewordsCount_last ℓ ϑ]; omega⟩ - (⟨x_next', by simpa [Nat.add_mul] using hx_next'⟩) - else final_constant - c_next = f_i_next_val - consistency_check - -/-- RBR knowledge error for the query phase. -Proximity testing error rate: `(1/2 + 1/(2 * 2^𝓡))^γ` -/ -def queryRbrKnowledgeError := fun _ : (pSpecQuery 𝔽q β γ_repetitions - (h_ℓ_add_R_rate := h_ℓ_add_R_rate)).ChallengeIdx => - ((1/2 : ℝ≥0) + (1 : ℝ≥0) / (2 * 2^𝓡))^γ_repetitions - -/-- Oracle query helper: query a committed codeword at a given domain point. - Restricted to codeword indices where the oracle range is L. -/ -def queryCodeword (j : Fin (toOutCodewordsCount ℓ ϑ (Fin.last ℓ))) - (point : (sDomain 𝔽q β h_ℓ_add_R_rate) ⟨oraclePositionToDomainIndex ℓ ϑ j, by omega⟩) : - OptionT (OracleComp ([]ₒ + - ([OracleStatement 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ( Fin.last ℓ)]ₒ + - [(pSpecQuery 𝔽q β γ_repetitions (h_ℓ_add_R_rate := h_ℓ_add_R_rate)).Message]ₒ))) L := - query (spec := [OracleStatement 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (Fin.last ℓ)]ₒ) - ⟨⟨j, by omega⟩, point⟩ +open scoped NNReal ProbabilityTheory section FinalQueryRoundIOR -/-! -### IOR Implementation for the Final Query Round --/ -def getChallengeSuffix (k : Fin (ℓ / ϑ)) (v : sDomain 𝔽q β h_ℓ_add_R_rate ⟨0, by omega⟩) : - let i := k.val * ϑ - have h_i_add_ϑ_le_ℓ : i + ϑ ≤ ℓ := k_succ_mul_ϑ_le_ℓ_₂ (k := k) - let destIdx : Fin r := ⟨i + ϑ, by omega⟩ - sDomain 𝔽q β h_ℓ_add_R_rate destIdx := - have h_i_add_ϑ_le_ℓ := k_succ_mul_ϑ_le_ℓ_₂ (k := k) - extractSuffixFromChallenge 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) - (v:=v) (destIdx := ⟨k.val * ϑ + ϑ, by omega⟩) (h_destIdx_le:=by omega) - -def challengeSuffixToFin (k : Fin (ℓ / ϑ)) - (suffix : sDomain 𝔽q β h_ℓ_add_R_rate ⟨k.val * ϑ + ϑ, by - have := k_succ_mul_ϑ_le_ℓ_₂ (k := k) - omega⟩) : - Fin (2 ^ (ℓ + 𝓡 - (k.val * ϑ + ϑ))) := - let i := k.val * ϑ - have h_i_add_ϑ_le_ℓ : i + ϑ ≤ ℓ := k_succ_mul_ϑ_le_ℓ_₂ (k := k) - let destIdx : Fin r := ⟨i + ϑ, by omega⟩ - sDomainToFin 𝔽q β h_ℓ_add_R_rate (i := ⟨k.val * ϑ + ϑ, by omega⟩) (h_i := by - simp only [k_succ_mul_ϑ_le_ℓ_₂, Nat.lt_add_of_pos_right_of_le]) suffix - -/-- Return the point `f^(i)(u_0, ..., u_{ϑ-1}, v_{i+ϑ}, ..., v_{ℓ+R-1})` -for a fiber index `u ∈ B_ϑ`. -/ -noncomputable def getFiberPoint - (k : Fin (ℓ / ϑ)) (v : sDomain 𝔽q β h_ℓ_add_R_rate ⟨0, by omega⟩) (u : Fin (2 ^ ϑ)) : - (sDomain 𝔽q β h_ℓ_add_R_rate) (i := ⟨oraclePositionToDomainIndex ℓ ϑ (i := Fin.last ℓ) - (positionIdx := ⟨k, by simp only [toOutCodewordsCount_last, Fin.is_lt]⟩), - lt_r_of_lt_ℓ (x := k.val * ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (h := k_mul_ϑ_lt_ℓ)⟩) := - let i := k.val * ϑ - have h_i_add_ϑ_le_ℓ : i + ϑ ≤ ℓ := k_succ_mul_ϑ_le_ℓ_₂ (k := k) - - -- TODO: should we make next_suffix_of_v_fin a separate def? - let destIdx : Fin r := ⟨i + ϑ, by omega⟩ - - let next_suffix_of_v_fin : Fin (2 ^ (ℓ + 𝓡 - (i + ϑ))) := - challengeSuffixToFin (k := k) (suffix := getChallengeSuffix (k := k) (v := v)) - - let fiber_point_num_repr := Nat.joinBits (low := u) (high := next_suffix_of_v_fin) - have h : 2 ^ (ℓ + 𝓡 - (i + ϑ) + ϑ) = 2 ^ (ℓ + 𝓡 - i) := by - simp only [Nat.ofNat_pos, ne_eq, OfNat.ofNat_ne_one, not_false_eq_true, pow_right_inj₀] - omega - let x : Fin (2 ^ (ℓ + 𝓡 - i)) := ⟨fiber_point_num_repr.val, by omega⟩ - let k_th_oracleIdx : Fin (toOutCodewordsCount ℓ ϑ (Fin.last ℓ)) := - ⟨k, by simp only [toOutCodewordsCount, Fin.val_last, lt_self_iff_false, ↓reduceIte, add_zero, - Fin.is_lt]⟩ - finToSDomain 𝔽q β h_ℓ_add_R_rate (i:=⟨i, by omega⟩) - (h_i:=by apply Nat.lt_add_of_pos_right_of_le; simp only; omega) (idx:=x) - -/-! -### Helper Functions for Verifier Logic - -These functions break down the verifier's proximity checking logic into composable blocks, -making it easier to prove properties about each component separately. --/ - -/-- Query all fiber points for a given folding step. - Returns a list of evaluations `f^(i)(u_0, ..., u_{ϑ-1}, v_{i+ϑ}, ..., v_{ℓ+R-1})` - for all `u ∈ B_ϑ`. - Note: `oStmtIn` is accessed via oracle queries in the OracleComp context. -/ -noncomputable def queryFiberPoints - (k : Fin (ℓ / ϑ)) - (v : sDomain 𝔽q β h_ℓ_add_R_rate ⟨0, by omega⟩) : - OptionT - (OracleComp - ([]ₒ + ([OracleStatement 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (Fin.last ℓ)]ₒ + - [(pSpecQuery 𝔽q β γ_repetitions (h_ℓ_add_R_rate := h_ℓ_add_R_rate)).Message]ₒ))) - (Vector L (2^ϑ)) := do - let k_th_oracleIdx : Fin (toOutCodewordsCount ℓ ϑ (Fin.last ℓ)) := - ⟨k, by simp only [toOutCodewordsCount, Fin.val_last, lt_self_iff_false, ↓reduceIte, add_zero, - Fin.is_lt]⟩ - -- 2. Map over the Vector monadically - let results : Vector L (2^ϑ) ← - (⟨Array.finRange (2^ϑ), by simp only [Array.size_finRange]⟩ - : Vector (Fin (2^ϑ)) (2^ϑ)).mapM (fun (u : Fin (2^ϑ)) => do - queryCodeword 𝔽q β (γ_repetitions := γ_repetitions) (ϑ := ϑ) - (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (j := k_th_oracleIdx) - (point := getFiberPoint 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (k := k) (v := v) (u := u))) - pure results - -/-- Check a single folding step: query fiber points, verify consistency, and compute next value. - Returns `(c_next, all_checks_passed)` where `c_next` is the computed folded value - and `all_checks_passed` indicates if all consistency checks passed. - Note: `oStmtIn` is accessed via oracle queries in the OracleComp context. -/ -noncomputable def checkSingleFoldingStep - (k_val : Fin (ℓ / ϑ)) (c_cur : L) (v : sDomain 𝔽q β h_ℓ_add_R_rate ⟨0, by omega⟩) - (stmt : FinalSumcheckStatementOut (L := L) (ℓ := ℓ)) : - OptionT (OracleComp ([]ₒ + ([OracleStatement 𝔽q β (ϑ:=ϑ) - (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (Fin.last ℓ)]ₒ + [(pSpecQuery 𝔽q β - γ_repetitions (h_ℓ_add_R_rate := h_ℓ_add_R_rate)).Message]ₒ))) L := do - let i := k_val.val * ϑ - have h_k: k_val ≤ (ℓ/ϑ - 1) := by omega - have h_i_add_ϑ_le_ℓ : i + ϑ ≤ ℓ := by - calc i + ϑ = k_val * ϑ + ϑ := by omega - _ ≤ (ℓ/ϑ - 1) * ϑ + ϑ := by - apply Nat.add_le_add_right; apply Nat.mul_le_mul_right; omega - _ = ℓ/ϑ * ϑ := by - rw [Nat.sub_mul, one_mul, Nat.sub_add_cancel]; - conv_lhs => rw [←one_mul ϑ] - apply Nat.mul_le_mul_right; omega - _ ≤ ℓ := by apply Nat.div_mul_le_self; - have h_i_lt_ℓ : i < ℓ := by - calc i ≤ ℓ - ϑ := by omega - _ < ℓ := by - apply Nat.sub_lt (by exact Nat.pos_of_neZero ℓ) (by exact Nat.pos_of_neZero ϑ) - let f_i_on_fiber ← - queryFiberPoints 𝔽q β (γ_repetitions := γ_repetitions) (ϑ := ϑ) - (h_ℓ_add_R_rate := h_ℓ_add_R_rate) k_val v - -- Check consistency if i > 0 - if h_i_pos : i > 0 then - let oracle_point_idx := extractMiddleFinMask 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) - (v:=v) (i:=⟨i, by omega⟩) (steps:=ϑ) - let f_i_val := f_i_on_fiber.get oracle_point_idx - guard (c_cur = f_i_val) - let destIdx : Fin r := ⟨i + ϑ, by omega⟩ - let next_suffix_of_v : sDomain 𝔽q β h_ℓ_add_R_rate destIdx := - getChallengeSuffix (k := k_val) (v := v) - let cur_challenge_batch : Fin ϑ → L := fun j => - stmt.challenges ⟨i + j.val, by simp only [Fin.val_last]; omega⟩ - let c_next : L := - single_point_localized_fold_matrix_form 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) - (i := ⟨i, by omega⟩) (steps := ϑ) (destIdx := destIdx) - (h_destIdx := by dsimp only [destIdx]) (h_destIdx_le := by omega) - (r_challenges := cur_challenge_batch) (y := next_suffix_of_v) - (fiber_eval_mapping := f_i_on_fiber.get) - return c_next - -/-- Check a single repetition: iterate through all folding steps and verify final consistency. - Returns `true` if all checks pass, `false` otherwise. - Note: `oStmtIn` is accessed via oracle queries in the OracleComp context. - - Uses `mut` + `for` loop for true early termination (stops immediately on first failure). - For proofs, we'll need to reason about the loop invariant that `c_cur` maintains the - correct accumulated value through iterations. -/ -noncomputable def checkSingleRepetition - (v : sDomain 𝔽q β h_ℓ_add_R_rate ⟨0, by omega⟩) - (stmt : FinalSumcheckStatementOut (L := L) (ℓ := ℓ)) (final_constant : L) : - OptionT (OracleComp ([]ₒ + ([OracleStatement 𝔽q β (ϑ:=ϑ) - (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (Fin.last ℓ)]ₒ + [(pSpecQuery 𝔽q β - γ_repetitions (h_ℓ_add_R_rate := h_ℓ_add_R_rate)).Message]ₒ))) Unit := do - let mut c_cur : L := 0 -- Will be initialized in first iteration - - -- Iterate through the `ℓ/ϑ` adjacent pairs of oracles & validate local folding consistency - -- Early termination: stops immediately on first failure via OptionT failure - for k_val in List.finRange (ℓ / ϑ) do - let c_next ← checkSingleFoldingStep 𝔽q β (ϑ:=ϑ) - (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (γ_repetitions := γ_repetitions) - ⟨k_val, by omega⟩ c_cur v stmt - c_cur := c_next - -- Final check: c_ℓ ?= final_constant - guard (c_cur = final_constant) - /-! ### Oracle-Aware Reduction Logic for Query Phase @@ -429,7 +63,7 @@ This encapsulates the pure logic of the query phase: - `honestProverTranscript`: The honest transcript just receives the challenges - `proverOut`: The honest prover always outputs `(true, ())` -/ noncomputable def queryPhaseLogicStep : - CoreInteraction.OracleAwareReductionLogicStep + OracleAwareReductionLogicStep -- oSpec is the base/shared oracle (empty for query phase - no random oracles) -- The structure internally uses oSpec + ([OracleIn]ₒ + [pSpec.Message]ₒ) (oSpec := []ₒ) @@ -440,11 +74,9 @@ noncomputable def queryPhaseLogicStep : (StmtOut := Bool) (WitOut := Unit) (pSpec := pSpecQuery 𝔽q β γ_repetitions (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) where - -- Relations completeness_relIn := strictFinalSumcheckRelOut 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) completeness_relOut := acceptRejectOracleRel - -- Verifier (Oracle-Aware): verifierCheck queries oracles and returns StmtOut -- Iterates through all γ_repetitions and checks each one verifierCheck := fun stmtIn transcript => do @@ -457,18 +89,14 @@ noncomputable def queryPhaseLogicStep : (h_ℓ_add_R_rate := h_ℓ_add_R_rate) v stmtIn stmtIn.final_constant return true -- StmtOut = Bool for QueryPhase - -- Pure output computation (deterministic) verifierOut := fun _stmtIn _transcript => true - -- Oracle embedding (no output oracles for query phase) embed := ⟨Empty.elim, fun a _ => Empty.elim a⟩ hEq := fun i => Empty.elim i - -- Honest prover transcript: just receives the challenges honestProverTranscript := fun stmtIn _witIn _oStmtIn challenges => FullTranscript.mk1 (challenges ⟨0, by rfl⟩) - -- Prover output: always outputs (true, ()) proverOut := fun _stmtIn _witIn _oStmtIn _transcript => ((true, fun i => Empty.elim i), ()) @@ -496,17 +124,13 @@ noncomputable def queryOracleProver : (pSpec := pSpecQuery 𝔽q β γ_repetitions (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) where -- Prover state: tracks (stmtIn, oStmtIn, witIn) and optionally the challenges PrvState := queryPhaseProverState 𝔽q β (ϑ:=ϑ) γ_repetitions (h_ℓ_add_R_rate := h_ℓ_add_R_rate) - input := fun ⟨⟨stmtIn, oStmtIn⟩, witIn⟩ => (stmtIn, oStmtIn, witIn) - sendMessage | ⟨0, h⟩ => nomatch h - receiveChallenge | ⟨0, _⟩ => fun ⟨stmtIn, oStmtIn, witIn⟩ => do -- V sends all γ challenges v₁, ..., v_γ pure (fun challenges => (stmtIn, oStmtIn, witIn, challenges)) - output := fun ⟨stmtIn, oStmtIn, witIn, challenges⟩ => do -- Build the transcript using the logic step's honestProverTranscript let transcript := FullTranscript.mk1 (pSpec := @@ -530,7 +154,6 @@ noncomputable def queryOracleVerifier : (StmtOut := Bool) (OStmtOut := fun _ : Empty => Unit) (pSpec := pSpecQuery 𝔽q β γ_repetitions (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) where - verify := fun stmtIn challenges => do let transcript := FullTranscript.mk1 (pSpec := pSpecQuery 𝔽q β γ_repetitions (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) (challenges ⟨0, by rfl⟩) @@ -538,7 +161,6 @@ noncomputable def queryOracleVerifier : (h_ℓ_add_R_rate := h_ℓ_add_R_rate) let _ ← (logic.verifierCheck stmtIn transcript) pure (logic.verifierOut stmtIn transcript) - -- Use embed and hEq from the logic step embed := (queryPhaseLogicStep 𝔽q β (ϑ:=ϑ) γ_repetitions (h_ℓ_add_R_rate := h_ℓ_add_R_rate)).embed @@ -599,7 +221,6 @@ lemma mem_support_queryFiberPoints let k_th_oracleIdx: Fin (toOutCodewordsCount ℓ ϑ (Fin.last ℓ)) := ⟨oraclePositionIdx, by simp only [toOutCodewordsCount, Fin.val_last, lt_self_iff_false, ↓reduceIte, add_zero, Fin.is_lt];⟩ - ∀ (fiberIndex : Fin (2 ^ ϑ)), f_i_on_fiber.get fiberIndex = (oStmtIn k_th_oracleIdx (getFiberPoint 𝔽q β oraclePositionIdx v fiberIndex)) := by @@ -610,7 +231,7 @@ lemma mem_support_queryFiberPoints set so := OracleInterface.simOracle2 []ₒ oStmtIn transcript.messages with h_so -- rw [simulateQ_liftComp] at h_fiber_mem unfold queryFiberPoints at h_fiber_mem - simp only [Message, MessageIdx, bind_pure, liftComp_id] at h_fiber_mem + simp only [bind_pure] at h_fiber_mem unfold queryCodeword at h_fiber_mem -- Simplify the simulation through liftComp/liftM -- simp_rw [← simulateQ_liftComp] at h_fiber_mem @@ -622,11 +243,13 @@ lemma mem_support_queryFiberPoints conv_rhs at h_fiber_mem => erw [simulateQ_liftComp] simp only [MessageIdx, Message, Fin.getElem_fin, Vector.getElem_mk, OptionT.run_monadLift, - simulateQ_map, simulateQ_query, OracleQuery.input_query, OracleQuery.cont_query, id_map, + simulateQ_map, OracleQuery.input_query, OracleQuery.cont_query, id_map, OptionT.mem_support_iff, toPFunctor_emptySpec, OptionT.support_run_eq, support_map, Set.mem_image, Option.some.injEq, exists_eq_right] + erw [simulateQ_query] erw [simulateQ_simOracle2_lift_liftComp_query_T1] - simp only [Fin.getElem_fin, Vector.getElem_mk, support_pure, Set.mem_singleton_iff] at h_fiber_mem + simp only [monadLift_self, LawfulApplicative.map_pure, support_pure, + Set.mem_singleton_iff] at h_fiber_mem simp only intro fiberIndex have h_res := h_fiber_mem fiberIndex @@ -635,7 +258,7 @@ lemma mem_support_queryFiberPoints simp only [Array.getElem_finRange, Fin.cast_mk, Fin.eta] /-! Simulated `queryFiberPoints` has zero failure probability. -/ -omit [CharP L 2] [SampleableType L] hF₂ in +omit [CharP L 2] [SampleableType L] [DecidableEq 𝔽q] hF₂ in lemma probFailure_simulateQ_queryFiberPoints_eq_zero (so : QueryImpl ([]ₒ + ([OracleStatement 𝔽q β (ϑ := ϑ) @@ -668,10 +291,7 @@ lemma probFailure_simulateQ_queryFiberPoints_eq_zero simp only [HasEvalPMF.probFailure_eq_zero, zero_add] rw [probOutput_eq_zero_iff] erw [simulateQ_map] - simp only [ofPFunctor_toPFunctor, List.get_eq_getElem, monadLift_self, - OracleQuery.input_query, OracleQuery.snd_query, id_eq, MessageIdx, Message, - OptionT.support_run, support_map, Set.mem_image, reduceCtorEq, and_false, exists_const, - not_false_eq_true])) + simp)) | some a => simp only [OptionT.mk] erw [simulateQ_pure, probFailure_pure] @@ -688,7 +308,8 @@ lemma query_phase_consistency_guard_safe (stmtIn : FinalSumcheckStatementOut (L := L) (ℓ := ℓ)) (oStmtIn : ∀ j, OracleStatement 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (Fin.last ℓ) j) (witIn : Unit) - (h_relIn : strictFinalSumcheckRelOut 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ((stmtIn, oStmtIn), witIn)) + (h_relIn : strictFinalSumcheckRelOut 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + ((stmtIn, oStmtIn), witIn)) -- Hypothesis: c_k is the correct iterated fold value up to this point (h_c_k_correct : let := k_mul_ϑ_lt_ℓ (k := k) @@ -696,56 +317,67 @@ lemma query_phase_consistency_guard_safe c_k = iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := 0) (steps := k.val * ϑ) (destIdx := ⟨k.val * ϑ, by omega⟩) (h_destIdx_le := by simp only; omega) (f := getFirstOracle 𝔽q β oStmtIn) - (r_challenges := getFoldingChallenges (𝓡 := 𝓡) (r := r) (Fin.last ℓ) stmtIn.challenges 0 (by simp only [zero_add, Fin.val_last]; omega)) - (y := extractSuffixFromChallenge 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (v := v) (destIdx := ⟨k.val * ϑ, by omega⟩) (h_destIdx_le := by simp only; omega)) - (h_destIdx := by simp only [Fin.coe_ofNat_eq_mod, Nat.zero_mod, zero_add];) - ) + (r_challenges := getFoldingChallenges (𝓡 := 𝓡) (r := r) (Fin.last ℓ) + stmtIn.challenges 0 (by simp only [zero_add, Fin.val_last]; omega)) + (y := extractSuffixFromChallenge 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (v := v) + (destIdx := ⟨k.val * ϑ, by omega⟩) (h_destIdx_le := by simp only; omega)) + (h_destIdx := by simp only [Fin.coe_ofNat_eq_mod, Nat.zero_mod, zero_add])) -- Hypothesis: We are at a step > 0 where a check actually happens (h_k_pos : k.val * ϑ > 0) (challenges : (pSpecQuery 𝔽q β γ_repetitions (h_ℓ_add_R_rate := h_ℓ_add_R_rate)).Challenges) -- Hypothesis: The fiber evaluations come from the simulated oracle query (h_fiber_mem : - let step := queryPhaseLogicStep 𝔽q β (ϑ := ϑ) γ_repetitions (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + let step := queryPhaseLogicStep 𝔽q β (ϑ := ϑ) γ_repetitions + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) let transcript := step.honestProverTranscript stmtIn witIn oStmtIn challenges - let so := OracleInterface.simOracle2 []ₒ oStmtIn transcript.messages + let so := OracleInterface.simOracle2.{0, 0, 0, 0, 0} []ₒ oStmtIn transcript.messages some (f_i_on_fiber) ∈ - (simulateQ so ((queryFiberPoints 𝔽q β (γ_repetitions := γ_repetitions) (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) k v))).support) : + (simulateQ.{0, 0, 0} so + ((queryFiberPoints 𝔽q β (γ_repetitions := γ_repetitions) (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) k v))).support) : let := k_mul_ϑ_lt_ℓ (k := k) - c_k = f_i_on_fiber.get (extractMiddleFinMask 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (v := v) (i := ⟨k.val * ϑ, by omega⟩) (steps := ϑ)) := by - + c_k = f_i_on_fiber.get (extractMiddleFinMask 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (v := v) (i := ⟨k.val * ϑ, by omega⟩) (steps := ϑ)) := by have h_fiber_val := mem_support_queryFiberPoints 𝔽q β γ_repetitions (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (oraclePositionIdx := k) v f_i_on_fiber stmtIn - oStmtIn () challenges h_fiber_mem + oStmtIn witIn challenges (h_fiber_mem := h_fiber_mem) simp only at h_fiber_val - rw [h_c_k_correct] simp only have h₁ : k.val * ϑ < ℓ := k_mul_ϑ_lt_ℓ (k := k) set destIdx : Fin r := ⟨k.val * ϑ, by omega⟩ with h_destIdx_eq conv_rhs => rw [h_fiber_val] - dsimp only [strictFinalSumcheckRelOut, strictFinalSumcheckRelOutProp, - strictFinalFoldingStateProp] at h_relIn + strictfinalSumcheckStepFoldingStateProp] at h_relIn simp only [Fin.val_last, exists_and_right, Subtype.exists] at h_relIn rcases h_relIn with ⟨exists_t_MLP, _⟩ rcases exists_t_MLP with ⟨t, h_t_mem_support, h_strictOracleFoldingConsistency⟩ dsimp only [strictOracleFoldingConsistencyProp] at h_strictOracleFoldingConsistency - -- Now extract the oStmtIn equality at position k - have h_oStmtIn_k_eq := h_strictOracleFoldingConsistency ⟨k.val, by simp only [toOutCodewordsCount_last, - Fin.is_lt]⟩ - + have h_oStmtIn_k_eq := h_strictOracleFoldingConsistency ⟨k.val, + by simp only [toOutCodewordsCount_last, Fin.is_lt]⟩ conv_rhs => rw [h_oStmtIn_k_eq] simp only - - have h_point_eq : extractSuffixFromChallenge 𝔽q β v ⟨↑k * ϑ, by omega⟩ (by simp only; omega) = getFiberPoint 𝔽q β k v (extractMiddleFinMask 𝔽q β v ⟨↑k * ϑ, by omega⟩ ϑ) := by + have h_point_eq : extractSuffixFromChallenge 𝔽q β v ⟨↑k * ϑ, by omega⟩ (by simp only; omega) = + getFiberPoint 𝔽q β k v (extractMiddleFinMask 𝔽q β v ⟨↑k * ϑ, by omega⟩ ϑ) := by + -- The key insight: getFiberPoint reconstructs a point in S^i by: + -- 1. Taking the suffix at i+ϑ + -- 2. Joining it with the fiber index u (the middle ϑ bits) + -- 3. Converting back to sDomain + -- When u = extractMiddleFinMask v i ϑ, this reconstructs exactly the suffix at i + -- Unfold definitions + dsimp only [getFiberPoint, getChallengeSuffix, challengeSuffixToFin, extractSuffixFromChallenge] + -- Both sides use iteratedQuotientMap, so we need to show they're applied to the same element + -- This requires showing that finToSDomain (joinBits u suffix_fin) = iteratedQuotientMap v + -- where u = extractMiddleFinMask and suffix_fin comes from the suffix at i+ϑ + -- This requires deep reasoning about the relationship between + -- joinBits, sDomainToFin, finToSDomain, and iteratedQuotientMap sorry - rw [h_point_eq] - - rw [polyToOracleFunc_eq_getFirstOracle 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (t := ⟨t, h_t_mem_support⟩) (i := Fin.last ℓ) - (challenges := stmtIn.challenges) (oStmt := oStmtIn) - (h_consistency := h_strictOracleFoldingConsistency)] + rw [polyToOracleFunc_eq_getFirstOracle 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (t := ⟨t, h_t_mem_support⟩) (i := Fin.last ℓ) + (challenges := stmtIn.challenges) (oStmt := oStmtIn) + (h_consistency := h_strictOracleFoldingConsistency)] /-- Lemma 2 (Preservation): @@ -769,41 +401,46 @@ lemma query_phase_step_preserves_fold (c_k : L) (s' : L) -- The next state (c_next) (stmtIn : FinalSumcheckStatementOut (L := L) (ℓ := ℓ)) (oStmtIn : ∀ j, OracleStatement 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (Fin.last ℓ) j) - (h_relIn : strictFinalSumcheckRelOut 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ((stmtIn, oStmtIn), ())) + (h_relIn : strictFinalSumcheckRelOut 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + ((stmtIn, oStmtIn), ())) (h_c_k_correct_of_k_pos : let := k_mul_ϑ_lt_ℓ (k := k) let := k_succ_mul_ϑ_le_ℓ (k := k) - if hk : k.val > 0 then + if _ : k.val > 0 then c_k = iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := 0) (steps := k.val * ϑ) (destIdx := ⟨k.val * ϑ, by omega⟩) (h_destIdx_le := by simp only; omega) (f := getFirstOracle 𝔽q β oStmtIn) - (r_challenges := getFoldingChallenges (𝓡 := 𝓡) (r := r) (Fin.last ℓ) stmtIn.challenges 0 (by simp only [zero_add, Fin.val_last]; omega)) - (y := extractSuffixFromChallenge 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (v := v) (destIdx := ⟨k.val * ϑ, by omega⟩) (h_destIdx_le := by simp only; omega)) + (r_challenges := getFoldingChallenges (𝓡 := 𝓡) (r := r) (Fin.last ℓ) stmtIn.challenges + 0 (by simp only [zero_add, Fin.val_last]; omega)) + (y := extractSuffixFromChallenge 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (v := v) + (destIdx := ⟨k.val * ϑ, by omega⟩) (h_destIdx_le := by simp only; omega)) (h_destIdx := by simp only [Fin.coe_ofNat_eq_mod, Nat.zero_mod, zero_add]) - else True - ) + else True) -- Hypothesis: s' is a valid output of the simulated step function (challenges : (pSpecQuery 𝔽q β γ_repetitions (h_ℓ_add_R_rate := h_ℓ_add_R_rate)).Challenges) (h_s'_mem : let step := queryPhaseLogicStep 𝔽q β (ϑ := ϑ) γ_repetitions (h_ℓ_add_R_rate := h_ℓ_add_R_rate) let witIn : Unit := () let transcript := step.honestProverTranscript stmtIn witIn oStmtIn challenges - let so := OracleInterface.simOracle2 []ₒ oStmtIn transcript.messages + let so := OracleInterface.simOracle2.{0, 0, 0, 0, 0} []ₒ oStmtIn transcript.messages s' ∈ (OptionT.mk - (simulateQ so ((checkSingleFoldingStep 𝔽q β (γ_repetitions := γ_repetitions) (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) k c_k v stmtIn))).support)) : + (simulateQ.{0, 0, 0} so + ((checkSingleFoldingStep 𝔽q β (γ_repetitions := γ_repetitions) (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) k c_k v stmtIn))).support)) : let := k_succ_mul_ϑ_le_ℓ (k := k) s' = iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := 0) (steps := (k.val + 1) * ϑ) (destIdx := ⟨(k.val + 1) * ϑ, by omega⟩) (h_destIdx_le := by simp only; omega) (f := getFirstOracle 𝔽q β oStmtIn) - (r_challenges := getFoldingChallenges (𝓡 := 𝓡) (r := r) (Fin.last ℓ) stmtIn.challenges 0 (by simp only [zero_add, Fin.val_last]; omega)) - (y := extractSuffixFromChallenge 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (v := v) (destIdx := ⟨(k.val + 1) * ϑ, by omega⟩) (h_destIdx_le := by simp only; omega)) (h_destIdx := by simp only [Fin.coe_ofNat_eq_mod, Nat.zero_mod, zero_add];) := by - + (r_challenges := getFoldingChallenges (𝓡 := 𝓡) (r := r) (Fin.last ℓ) stmtIn.challenges 0 + (by simp only [zero_add, Fin.val_last]; omega)) + (y := extractSuffixFromChallenge 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (v := v) + (destIdx := ⟨(k.val + 1) * ϑ, by omega⟩) (h_destIdx_le := by simp only; omega)) + (h_destIdx := by simp only [Fin.coe_ofNat_eq_mod, Nat.zero_mod, zero_add];) := by let step := queryPhaseLogicStep 𝔽q β (ϑ := ϑ) γ_repetitions (h_ℓ_add_R_rate := h_ℓ_add_R_rate) let witIn : Unit := () let transcript := step.honestProverTranscript stmtIn witIn oStmtIn challenges let so := OracleInterface.simOracle2 []ₒ oStmtIn transcript.messages - -- This is basically due to definition of s' -- First, convert h_s'_mem to equality form dsimp only [checkSingleFoldingStep] at h_s'_mem @@ -815,7 +452,6 @@ lemma query_phase_step_preserves_fold have h_ϑ_le_ℓ : ϑ ≤ ℓ := Nat.le_of_dvd (by exact Nat.pos_of_neZero ℓ) (by exact hdiv.out) let destIdx : Fin r := ⟨(k.val + 1) * ϑ, by omega⟩ let midIdx : Fin r := ⟨k.val * ϑ, by omega⟩ - by_cases h_k_pos : k.val > 0 · -- Case k > 0: The guard is present. -- **Simplify the monadic structure** @@ -829,102 +465,70 @@ lemma query_phase_step_preserves_fold erw [simulateQ_bind, support_bind] at h_s'_mem simp only [Set.mem_iUnion, exists_prop] at h_s'_mem rcases h_s'_mem with ⟨fiber_vec_Opt, h_fiber_vec_Opt_mem_support, h_s'_mem_support_guard⟩ - let k_fin_list : Fin (List.finRange (ℓ / ϑ)).length := ⟨k.val, by simp only [List.length_finRange, Fin.is_lt]⟩ - have h_k_fin_list_eq : k = ((List.finRange (ℓ / ϑ)).get k_fin_list) := by apply Fin.eq_of_val_eq; simp only [List.get_eq_getElem, List.getElem_finRange, Fin.eta, Fin.val_cast]; rfl - have h_probFailure_queryFiberPoints_eq_zero := by - apply probFailure_simulateQ_queryFiberPoints_eq_zero (γ_repetitions := γ_repetitions) (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝔽q := 𝔽q) (β := β) + apply probFailure_simulateQ_queryFiberPoints_eq_zero (γ_repetitions := γ_repetitions) + (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝔽q := 𝔽q) (β := β) (so := so) (k := k_fin_list) (v := v) - have h_probOutput_none_queryFiberPoints_eq_zero := OptionT.probOutput_none_run_eq_zero_of_probFailure_eq_zero (hfail := h_probFailure_queryFiberPoints_eq_zero) - have h_fiber_vec_Opt_mem_support_eq := exists_eq_some_of_mem_support_of_probOutput_none_eq_zero (x := fiber_vec_Opt) (hx := h_fiber_vec_Opt_mem_support) (hnone := by - rw [h_k_fin_list_eq]; exact h_probOutput_none_queryFiberPoints_eq_zero) + simpa [so, transcript, h_k_fin_list_eq] using h_probOutput_none_queryFiberPoints_eq_zero) rcases h_fiber_vec_Opt_mem_support_eq with ⟨fiber_vec, h_fiber_vec_Opt_mem_support_eq⟩ rw [h_fiber_vec_Opt_mem_support_eq] at h_s'_mem_support_guard h_fiber_vec_Opt_mem_support - -- h_s'_eq : s' = the evaluation at y of the folded function from fiber_vec -- simp only [OptionT.simulateQ_map] at h_s'_mem_support_guard - have h_fiber_val := mem_support_queryFiberPoints 𝔽q β (γ_repetitions := γ_repetitions) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (oraclePositionIdx := k) v fiber_vec stmtIn - oStmtIn () challenges h_fiber_vec_Opt_mem_support - + oStmtIn () challenges (by simpa using h_fiber_vec_Opt_mem_support) erw [simulateQ_bind, support_bind] at h_s'_mem_support_guard simp only [Function.comp_apply, Set.mem_iUnion, exists_prop] at h_s'_mem_support_guard - have h₁ : k.val * ϑ < ℓ := k_mul_ϑ_lt_ℓ (k := k) - -- 1. Simplify failure probability to just the guard condition -- simp only [h_i_pos, ↓reduceIte, OptionT.simulateQ_map] - have h_guard_pass : c_k = fiber_vec.get (extractMiddleFinMask 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (v := v) (i := ⟨k.val * ϑ, by omega⟩) (steps := ϑ)) := by - -- 1. Construct the correct index type for the lemma - -- 3. Unfold Rel to get the equality - -- unfold Rel checkSingleRepetition_foldRel at h_rel_k_c - -- have h_k_castSucc_ne_0 : ¬(k.castSucc = 0) := by - -- by_contra h_eq - -- have h_val_eq := Fin.val_eq_of_eq h_eq - -- simp only [Fin.val_castSucc, Fin.coe_ofNat_eq_mod, List.length_finRange, - -- Nat.zero_mod] at h_val_eq - -- have h_k_ne_0 : k.val ≠ 0 := by omega -- from h_i_pos.1 - -- -- h_val_eq : ↑k = 0 - -- -- h_k_ne_0 : ↑k ≠ 0 - -- exact h_k_ne_0 h_val_eq - -- rw [dif_neg h_k_castSucc_ne_0] at h_rel_k_c - -- simp only [Fin.val_castSucc] at h_rel_k_c - -- simp only [Fin.isValue, List.get_eq_getElem, List.getElem_finRange, Fin.eta, - -- Fin.coe_cast] - + have h_guard_pass : c_k = fiber_vec.get (extractMiddleFinMask 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (v := v) (i := ⟨k.val * ϑ, by omega⟩) (steps := ϑ)) := by have h_mul_gt_0 : k.val * ϑ > 0 := by simp only [gt_iff_lt, CanonicallyOrderedAdd.mul_pos] omega - - have h_k_eq_fin_cast : k = Fin.cast (by simp only [List.length_finRange, Fin.is_lt]) k_fin_list := by + have h_k_eq_fin_cast : k = Fin.cast (by simp only [List.length_finRange]) k_fin_list := by apply Fin.eq_of_val_eq; simp only [Fin.val_cast]; rfl - -- 4. Apply the lemma - have res := query_phase_consistency_guard_safe 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (k := k) (v := v) (c_k := c_k) (f_i_on_fiber := fiber_vec) (stmtIn := stmtIn) (oStmtIn := oStmtIn) (witIn := witIn) (h_relIn := h_relIn) (h_c_k_correct := by - simp only [gt_iff_lt, h_k_pos, ↓reduceDIte] at h_c_k_correct_of_k_pos + have res := query_phase_consistency_guard_safe 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (k := k) (v := v) (c_k := c_k) (f_i_on_fiber := fiber_vec) (stmtIn := stmtIn) + (oStmtIn := oStmtIn) (witIn := witIn) (h_relIn := h_relIn) (h_c_k_correct := by + simp only at h_c_k_correct_of_k_pos + simp only [gt_iff_lt, h_k_pos] at h_c_k_correct_of_k_pos exact h_c_k_correct_of_k_pos - ) (h_k_pos := h_mul_gt_0) (γ_repetitions := γ_repetitions) (challenges := challenges) (h_fiber_mem := by - simp only [witIn] - exact h_fiber_vec_Opt_mem_support - ) + ) (h_k_pos := h_mul_gt_0) (γ_repetitions := γ_repetitions) (challenges := challenges) + (h_fiber_mem := by simp only [witIn]; exact h_fiber_vec_Opt_mem_support) exact res - simp only [h_guard_pass, ↓reduceIte] at h_s'_mem_support_guard erw [simulateQ_pure] at h_s'_mem_support_guard simp only [support_pure, Set.mem_singleton_iff, exists_eq_left, OptionT.simulateQ_pure, OptionT.support_OptionT_pure_run, Option.some.injEq] at h_s'_mem_support_guard - -- Step 1: Use symmetry of h_s'_eq rw [h_s'_mem_support_guard] dsimp only [getChallengeSuffix] -- extractSuffixFromChallenge arise here - have h_destIdx_eq : destIdx.val = k.val * ϑ + ϑ := by dsimp only [destIdx]; rw [Nat.add_mul, Nat.one_mul] - -- iterated_fold 𝔽q β 0 ((↑k + 1) * ϑ) ⋯ ⋯ (getFirstOracle 𝔽q β oStmtIn) - -- (getFoldingChallenges (Fin.last ℓ) stmtIn.challenges 0 ⋯) (extractSuffixFromChallenge 𝔽q β v ⟨(↑k + 1) * ϑ, ⋯⟩ ⋯) - + -- (getFoldingChallenges (Fin.last ℓ) stmtIn.challenges 0 ⋯) (extractSuffixFromChallenge + -- 𝔽q β v ⟨(↑k + 1) * ϑ, ⋯⟩ ⋯) set challenges_full := getFoldingChallenges (𝓡 := 𝓡) (r := r) (ϑ := (k.val + 1) * ϑ) (i := Fin.last ℓ) stmtIn.challenges (k := 0) (h := by simp only [zero_add, Fin.val_last, k_succ_mul_ϑ_le_ℓ]) with h_challenges_full_defs - set challenges_mid := getFoldingChallenges (𝓡 := 𝓡) (r := r) (ϑ := k.val * ϑ) (i := Fin.last ℓ) stmtIn.challenges (k := 0) (h := by simp only [zero_add, Fin.val_last]; omega) with h_challenges_mid_defs - - set challenges_last : Fin ϑ → L := (fun j ↦ stmtIn.challenges ⟨↑k * ϑ + ↑j, by simp only [Fin.val_last]; omega⟩) with h_challenges_last_defs - + set challenges_last : Fin ϑ → L := (fun j ↦ stmtIn.challenges ⟨↑k * ϑ + ↑j, by + simp only [Fin.val_last]; omega⟩) with h_challenges_last_defs set y_left := extractSuffixFromChallenge 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (v := v) (destIdx := ⟨k.val * ϑ + ϑ, by omega⟩) (h_destIdx_le := by omega) with hy_left_defs set y_right := extractSuffixFromChallenge 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (v := v) @@ -934,18 +538,20 @@ lemma query_phase_step_preserves_fold let k_oracle_idx : Fin (toOutCodewordsCount ℓ ϑ (Fin.last ℓ)) := ⟨k, by simp only [toOutCodewordsCount_last, Fin.is_lt]⟩ -- Prove that oraclePositionToDomainIndex matches midIdx - have h_domain_idx_eq : (oraclePositionToDomainIndex ℓ ϑ (i := Fin.last ℓ) (positionIdx := k_oracle_idx)).val = midIdx.val := by + have h_domain_idx_eq : (oraclePositionToDomainIndex ℓ ϑ (i := Fin.last ℓ) + (positionIdx := k_oracle_idx)).val = midIdx.val := by dsimp only [oraclePositionToDomainIndex, midIdx] - have h_sDomain_midIdx_eq : sDomain 𝔽q β h_ℓ_add_R_rate midIdx = sDomain 𝔽q β h_ℓ_add_R_rate ⟨(oraclePositionToDomainIndex ℓ ϑ (i := Fin.last ℓ) (positionIdx := k_oracle_idx)).val, by omega⟩ := by + have h_sDomain_midIdx_eq : sDomain 𝔽q β h_ℓ_add_R_rate midIdx = sDomain 𝔽q β h_ℓ_add_R_rate + ⟨(oraclePositionToDomainIndex ℓ ϑ (i := Fin.last ℓ) + (positionIdx := k_oracle_idx)).val, by omega⟩ := by apply sDomain_eq_of_eq; apply Fin.eq_of_val_eq; rw [h_domain_idx_eq] let f_mid : ↥(sDomain 𝔽q β h_ℓ_add_R_rate midIdx) → L := fun x => oStmtIn k_oracle_idx (cast (by rw [h_sDomain_midIdx_eq]) x) - set fiber_vec_actual_def := fiberEvaluations 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) - (i := midIdx) (steps := ϑ) (destIdx := ⟨k * ϑ + ϑ, by omega⟩) (h_destIdx := by simp only [Nat.add_right_cancel_iff]; rfl) + (i := midIdx) (steps := ϑ) (destIdx := ⟨k * ϑ + ϑ, by omega⟩) (h_destIdx := by + simp only [Nat.add_right_cancel_iff]; rfl) (h_destIdx_le := by omega) (f := f_mid) (y := y_left) with h_fiber_vec_actual_def - have h_fiber_vec_get : fiber_vec.get = fiber_vec_actual_def := by dsimp only [fiber_vec_actual_def]; unfold fiberEvaluations funext x @@ -956,48 +562,39 @@ lemma query_phase_step_preserves_fold dsimp only [getFirstOracle] dsimp only [f_mid] apply OracleStatement.oracle_eval_congr 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) - (oStmtIn := oStmtIn) (j' := k_oracle_idx) (j := ⟨k, by simp only [toOutCodewordsCount_last, Fin.is_lt]⟩) (h_j := by rfl) - -- ⊢ finToSDomain 𝔽q β h_ℓ_add_R_rate ⟨↑k * ϑ, ⋯⟩ ⋯ - -- ⟨↑(Nat.joinBits x (challengeSuffixToFin 𝔽q β k (extractSuffixFromChallenge 𝔽q β v ⟨↑k * ϑ + ϑ, ⋯⟩ ⋯))), ⋯⟩ = - -- cast ⋯ (cast ⋯ (qMap_total_fiber 𝔽q β midIdx ϑ h_destIdx_eq h₁ y x)) - sorry + (oStmtIn := oStmtIn) (j' := k_oracle_idx) (j := ⟨k, by + simp only [toOutCodewordsCount_last, Fin.is_lt]⟩) (h_j := by rfl) + rfl rw [h_fiber_vec_get]; dsimp only [fiber_vec_actual_def] - - -- single_point_localized_fold_matrix_form 𝔽q β ⟨↑k * ϑ, ⋯⟩ ϑ ⋯ ⋯ (fun j ↦ stmtIn.challenges ⟨↑k * ϑ + ↑j, ⋯⟩) - -- (extractSuffixFromChallenge 𝔽q β v ⟨↑k * ϑ + ϑ, ⋯⟩ ⋯) (fiberEvaluations 𝔽q β midIdx ϑ h_destIdx_eq h₁ f_mid y) = - -- iterated_fold 𝔽q β 0 ((↑k + 1) * ϑ) ⋯ ⋯ (getFirstOracle 𝔽q β oStmtIn) - -- (getFoldingChallenges (Fin.last ℓ) stmtIn.challenges 0 ⋯) y - have h_eq := single_point_localized_fold_matrix_form_eq_iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := midIdx) (steps := ϑ) (destIdx := ⟨k * ϑ + ϑ, by omega⟩) - (h_destIdx := by simp only [Nat.add_right_cancel_iff]; rfl) (h_destIdx_le := by omega) (f := f_mid) (r_challenges := fun j => stmtIn.challenges ⟨k.val * ϑ + j.val, by simp only [Fin.val_last]; omega⟩) (y := y_left) + have h_eq := single_point_localized_fold_matrix_form_eq_iterated_fold 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := midIdx) (steps := ϑ) + (destIdx := ⟨k * ϑ + ϑ, by omega⟩) (h_destIdx := by simp only [Nat.add_right_cancel_iff]; rfl) + (h_destIdx_le := by omega) (f := f_mid) (y := y_left) (r_challenges := + fun j => stmtIn.challenges ⟨k.val * ϑ + j.val, by simp only [Fin.val_last]; omega⟩) conv_lhs => rw [h_eq] - dsimp only [f_mid] -- Now rw the oStmtIn k_oracle_idx into the iterated_fold of f⁽⁰⁾ form -- Extract t and strictOracleFoldingConsistencyProp from h_relIn dsimp only [strictFinalSumcheckRelOut, strictFinalSumcheckRelOutProp, - strictFinalFoldingStateProp] at h_relIn + strictfinalSumcheckStepFoldingStateProp] at h_relIn simp only [Fin.val_last, exists_and_right, Subtype.exists] at h_relIn rcases h_relIn with ⟨exists_t_MLP, _⟩ rcases exists_t_MLP with ⟨t, h_t_mem_support, h_strictOracleFoldingConsistency⟩ dsimp only [strictOracleFoldingConsistencyProp] at h_strictOracleFoldingConsistency - -- Get the equality for k_oracle_idx: oStmtIn k_oracle_idx = iterated_fold from 0 to k.val * ϑ have h_f_mid_eq_iterated_fold := h_strictOracleFoldingConsistency k_oracle_idx conv_lhs => rw [h_f_mid_eq_iterated_fold] - let P₀: L[X]_(2 ^ ℓ) := polynomialFromNovelCoeffsF₂ 𝔽q β ℓ (by omega) (fun ω => t.eval (bitsOfIndex ω)) let f₀ := polyToOracleFunc 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (domainIdx := 0) (P := P₀) conv_lhs => dsimp only [midIdx] conv_lhs => simp only [cast_eq, Fin.val_last]; rw [←fun_eta_expansion] - conv_lhs => rw [iterated_fold_transitivity 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (h_destIdx := by simp only [Fin.coe_ofNat_eq_mod, Nat.zero_mod, zero_add, Nat.add_right_cancel_iff, mul_eq_mul_right_iff]; left; rfl )] dsimp only [k_oracle_idx] - -- Step 1: Align steps (k * ϑ + ϑ = (k + 1) * ϑ) have h_steps_eq : k.val * ϑ + ϑ = (k.val + 1) * ϑ := by rw [Nat.add_mul, Nat.one_mul] conv_lhs => @@ -1008,7 +605,6 @@ lemma query_phase_step_preserves_fold (h_steps_eq_steps' := h_steps_eq) (f := f₀) (r_challenges := Fin.append challenges_mid challenges_last) (y := y_left)] - -- Step 2: Align destIdx (⟨k * ϑ + ϑ, ...⟩ = ⟨(k + 1) * ϑ, ...⟩) conv_lhs => rw [iterated_fold_congr_dest_index 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := 0) @@ -1018,16 +614,16 @@ lemma query_phase_step_preserves_fold simp only [Fin.coe_ofNat_eq_mod, Nat.zero_mod, zero_add]; omega) (h_destIdx_le := by omega) (h_destIdx_eq_destIdx' := by apply Fin.eq_of_val_eq; omega) (f := f₀)] - -- Step 3: Align function (f₀ = getFirstOracle) have h_f₀_eq_getFirstOracle : f₀ = getFirstOracle 𝔽q β oStmtIn := by - exact polyToOracleFunc_eq_getFirstOracle 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (t := ⟨t, h_t_mem_support⟩) (i := Fin.last ℓ) + exact polyToOracleFunc_eq_getFirstOracle 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (t := ⟨t, h_t_mem_support⟩) (i := Fin.last ℓ) (challenges := stmtIn.challenges) (oStmt := oStmtIn) (h_consistency := h_strictOracleFoldingConsistency) conv_lhs => rw [h_f₀_eq_getFirstOracle] - -- Step 4: Align challenges - have h_challenges_eq : (fun (cIdx : Fin ((↑k + 1) * ϑ)) => Fin.append challenges_mid challenges_last ⟨cIdx.val, by omega⟩) = challenges_full := by + have h_challenges_eq : (fun (cIdx : Fin ((↑k + 1) * ϑ)) => Fin.append challenges_mid + challenges_last ⟨cIdx.val, by omega⟩) = challenges_full := by funext j dsimp only [Fin.append, Fin.addCases, challenges_full, challenges_mid, challenges_last] -- dsimp only [chalLeft, chalRight] @@ -1040,12 +636,14 @@ lemma query_phase_step_preserves_fold eq_rec_constant] congr 1; simp only [Fin.val_last, zero_add, Fin.mk.injEq]; omega conv_lhs => rw [h_challenges_eq] - have h_sDomain_eq : sDomain 𝔽q β h_ℓ_add_R_rate ⟨k.val * ϑ + ϑ, by omega⟩ = sDomain 𝔽q β h_ℓ_add_R_rate ⟨(↑k + 1) * ϑ, by omega⟩ := by + have h_sDomain_eq : sDomain 𝔽q β h_ℓ_add_R_rate ⟨k.val * ϑ + ϑ, by omega⟩ + = sDomain 𝔽q β h_ℓ_add_R_rate ⟨(↑k + 1) * ϑ, by omega⟩ := by apply sDomain_eq_of_eq; apply Fin.eq_of_val_eq; simp only; omega -- Step 5: Align points have h_y_eq : cast (by rw [h_sDomain_eq]) y_left = y_right := by dsimp only [y_left, y_right] - rw [←extractSuffixFromChallenge_congr_destIdx 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (h_idx_eq := by apply Fin.eq_of_val_eq; omega)] + rw [←extractSuffixFromChallenge_congr_destIdx 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (h_idx_eq := by apply Fin.eq_of_val_eq; omega)] conv_lhs => rw [h_y_eq] · -- Case k = 0: No guard. --------------------------------------------------------------------- @@ -1060,83 +658,70 @@ lemma query_phase_step_preserves_fold omega simp only [h_k_eq_0, zero_mul, zero_add] at h_s'_mem ⊢ simp only [MessageIdx, Message, gt_iff_lt, lt_self_iff_false, ↓reduceDIte, Fin.mk_zero', - Fin.val_last, bind_pure_comp, map_pure, OptionT.simulateQ_map, ReduceClaim.support_mk, + Fin.val_last, bind_pure_comp, ReduceClaim.support_mk, Set.mem_setOf_eq] at h_s'_mem - erw [support_bind] at h_s'_mem - simp only [Function.comp_apply, Set.mem_iUnion, exists_prop] at h_s'_mem - + erw [simulateQ_bind, support_bind] at h_s'_mem + simp only [Set.mem_iUnion, exists_prop] at h_s'_mem rcases h_s'_mem with ⟨fiber_vec_Opt, h_fiber_vec_Opt_mem_support, h_s'_mem_support_guard⟩ - let k_fin_list : Fin (List.finRange (ℓ / ϑ)).length := ⟨k.val, by simp only [List.length_finRange, Fin.is_lt]⟩ - have h_k_fin_list_eq : k = ((List.finRange (ℓ / ϑ)).get k_fin_list) := by apply Fin.eq_of_val_eq; simp only [List.get_eq_getElem, List.getElem_finRange, Fin.eta, Fin.val_cast]; rfl - have h_probFailure_queryFiberPoints_eq_zero := by - apply probFailure_simulateQ_queryFiberPoints_eq_zero (γ_repetitions := γ_repetitions) (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝔽q := 𝔽q) (β := β) + apply probFailure_simulateQ_queryFiberPoints_eq_zero (γ_repetitions := γ_repetitions) + (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝔽q := 𝔽q) (β := β) (so := so) (k := k_fin_list) (v := v) - have h_probOutput_none_queryFiberPoints_eq_zero := OptionT.probOutput_none_run_eq_zero_of_probFailure_eq_zero (hfail := h_probFailure_queryFiberPoints_eq_zero) - - have h_exists_some_fiber_vec_of_fiber_vec_Opt := exists_eq_some_of_mem_support_of_probOutput_none_eq_zero + have h_exists_some_fiber_vec_of_fiber_vec_Opt := + exists_eq_some_of_mem_support_of_probOutput_none_eq_zero (x := fiber_vec_Opt) (hx := h_fiber_vec_Opt_mem_support) (hnone := by - rw [h_k_fin_list_eq]; exact h_probOutput_none_queryFiberPoints_eq_zero) + simpa [so, transcript, h_k_fin_list_eq] using h_probOutput_none_queryFiberPoints_eq_zero) rcases h_exists_some_fiber_vec_of_fiber_vec_Opt with ⟨fiber_vec, h_fiber_vec_Opt_eq_some⟩ rw [h_fiber_vec_Opt_eq_some] at h_s'_mem_support_guard h_fiber_vec_Opt_mem_support - -- **Simplify the monadic structure** - simp only [OptionT.support_OptionT_pure_run, Set.mem_singleton_iff, - Option.some.injEq] at h_s'_mem_support_guard - + simp only [LawfulApplicative.map_pure] at h_s'_mem_support_guard + erw [simulateQ_pure] at h_s'_mem_support_guard + simp only [support_pure, Set.mem_singleton_iff, Option.some.injEq] at h_s'_mem_support_guard -- h_s'_mem_support_guard : s' = single_point_localized_fold_matrix_form - have h_fiber_val := mem_support_queryFiberPoints 𝔽q β (γ_repetitions := γ_repetitions) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (oraclePositionIdx := k) v fiber_vec stmtIn - oStmtIn () challenges h_fiber_vec_Opt_mem_support - + oStmtIn () challenges (by simpa using h_fiber_vec_Opt_mem_support) -- Step 1: Use symmetry of h_s'_eq rw [h_s'_mem_support_guard] - -- ⊢ single_point_localized_fold_matrix_form ... = iterated_fold ... - have h_destIdx_eq : destIdx.val = ϑ := by dsimp only [destIdx]; rw [h_k_eq_0, zero_add, one_mul] - -- iterated_fold 𝔽q β 0 ((↑k + 1) * ϑ) ⋯ ⋯ (getFirstOracle 𝔽q β oStmtIn) - -- (getFoldingChallenges (Fin.last ℓ) stmtIn.challenges 0 ⋯) (extractSuffixFromChallenge 𝔽q β v ⟨(↑k + 1) * ϑ, ⋯⟩ ⋯) - - let challenges_full := getFoldingChallenges (𝓡 := 𝓡) (r := r) (ϑ := (k.val + 1) * ϑ) (i := Fin.last ℓ) stmtIn.challenges + -- (getFoldingChallenges (Fin.last ℓ) stmtIn.challenges 0 ⋯) + -- (extractSuffixFromChallenge 𝔽q β v ⟨(↑k + 1) * ϑ, ⋯⟩ ⋯) + let challenges_full := getFoldingChallenges (𝓡 := 𝓡) (r := r) (ϑ := (k.val + 1) * ϑ) + (i := Fin.last ℓ) stmtIn.challenges (k := 0) (h := by simp only [zero_add, Fin.val_last]; omega) set y := extractSuffixFromChallenge 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (v := v) (destIdx := ⟨(k.val + 1) * ϑ, by omega⟩) (h_destIdx_le := by omega) with hy_def - -- Step 2: Transform the RHS let rhs_to_mat_mul_form := iterated_fold_eq_matrix_form 𝔽q β (i := 0) (steps := (k.val + 1) * ϑ) (destIdx := destIdx) (h_destIdx := by simp only [Fin.coe_ofNat_eq_mod, Nat.zero_mod, zero_add]; rfl) - (h_destIdx_le := by omega) (f := getFirstOracle 𝔽q β oStmtIn) (r_challenges := challenges_full) + (h_destIdx_le := by omega) (f := getFirstOracle 𝔽q β oStmtIn) + (r_challenges := challenges_full) conv_rhs => rw [rhs_to_mat_mul_form] dsimp only [localized_fold_matrix_form] - -- Step 3: Unfold localized form conv_rhs => unfold localized_fold_matrix_form - -- 1. Simplify the index arithmetic for k=0 -- (k+1)*ϑ becomes ϑ -- simp? [Fin.mk_zero', Fin.val_last] -- 2. Unfold your helper definition -- This reveals that LHS suffix is exactly the RHS suffix dsimp only [getChallengeSuffix] - set fiber_vec_actual_def := fiberEvaluations 𝔽q β (i := 0) (steps := ϑ) (destIdx := destIdx) (h_destIdx := by simp only [Fin.coe_ofNat_eq_mod, Nat.zero_mod, zero_add]; omega) (h_destIdx_le := by omega) (f := getFirstOracle 𝔽q β oStmtIn) (y := y) with hright_def - have h_fiber_vec_get : fiber_vec.get = fiber_vec_actual_def := by dsimp only [fiber_vec_actual_def]; unfold fiberEvaluations funext x @@ -1146,19 +731,39 @@ lemma query_phase_step_preserves_fold conv_rhs => dsimp only [getFirstOracle] simp only [Fin.mk_zero'] + -- symm apply OracleStatement.oracle_eval_congr 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) - (oStmtIn := oStmtIn) (j := ⟨k, by - simp only [toOutCodewordsCount_last, Fin.is_lt]⟩) (j' := 0) (h_j := by - simp only [h_k_eq_0, Fin.mk_zero'];) - sorry - + (oStmtIn := oStmtIn) (j' := 0) (j := ⟨k, by + simp only [toOutCodewordsCount_last, Fin.is_lt]⟩) (h_j := by + apply Fin.eq_of_val_eq + simpa using h_k_eq_0) + have h_destIdx_eq : (⟨k.val * ϑ + ϑ, by omega⟩ : Fin r) = ⟨(k.val + 1) * ϑ, by omega⟩ := by + apply Fin.eq_of_val_eq + simp only [Nat.add_mul, one_mul] + simp only [id_eq, Fin.coe_ofNat_eq_mod, cast_cast] + have h_i_eq : (⟨k.val * ϑ, by omega⟩ : Fin r) = 0 := by + apply Fin.eq_of_val_eq + simp [h_mul_eq_0] + have hsrc_fun := qMap_total_fiber_congr_source_apply 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (steps := ϑ) (destIdx := destIdx) (sourceIdx₁ := (⟨k.val * ϑ, by omega⟩ : Fin r)) + (sourceIdx₂ := 0) (h_sourceIdx_eq := h_i_eq) + (h_destIdx := by dsimp only [destIdx]; rw [Nat.add_mul, Nat.one_mul]) + (h_destIdx_le := by omega) (y := y) (x := x) + rw [←hsrc_fun] + have hdest_congr := qMap_total_fiber_congr_dest 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (sourceIdx := ⟨k * ϑ, by omega⟩) (steps := ϑ) (destIdx₁ := ⟨k.val * ϑ + ϑ, by omega⟩) + (destIdx₂ := destIdx) (h_destIdx_congr := by omega) (h_destIdx := by dsimp only) + (h_destIdx_le := by omega) + rw [hdest_congr] + congr 1 + rw [←extractSuffixFromChallenge_congr_destIdx 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (v := v) (destIdx := destIdx) (destIdx' := ⟨k.val * ϑ + ϑ, by omega⟩) + (h_idx_eq := by omega) (h_le := by omega) (h_le' := by omega)] rw [h_fiber_vec_get] -- Step 4: Apply the congruence lemma of single_point_localized_fold_matrix_form - -- 1. Establish that the step counts are equal have h_steps_eq : ϑ = (↑k + 1) * ϑ := by simp only [h_k_eq_0, zero_add, one_mul] - -- 2. Apply the Step Congruence Lemma to the RHS -- We rewrite the RHS to use 'ϑ' instead of '(k+1)*ϑ' conv_rhs => rw [single_point_localized_fold_matrix_form_congr_steps_index 𝔽q β @@ -1173,11 +778,13 @@ lemma query_phase_step_preserves_fold have h_sDomain_eq : (sDomain 𝔽q β h_ℓ_add_R_rate ⟨↑k * ϑ + ϑ, by omega⟩) = (sDomain 𝔽q β h_ℓ_add_R_rate ⟨(↑k + 1) * ϑ, by omega⟩) := by apply sDomain_eq_of_eq; simp only [Fin.mk.injEq]; omega - conv_lhs => - rw [single_point_localized_fold_matrix_form_congr_dest_index 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (destIdx' := destIdx) (h_destIdx_eq_destIdx' := by + rw [single_point_localized_fold_matrix_form_congr_dest_index 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (destIdx' := destIdx) (h_destIdx_eq_destIdx' := by dsimp only [destIdx]; simp only [Nat.add_mul, Nat.one_mul])] - have h_y_eq : y = cast (by rw [h_sDomain_eq]) (extractSuffixFromChallenge 𝔽q β (v := v) (destIdx := ⟨k.val * ϑ + ϑ, by omega⟩) (h_destIdx_le := by simp only [k_succ_mul_ϑ_le_ℓ_₂])) := by + have h_y_eq : y = cast (by rw [h_sDomain_eq]) (extractSuffixFromChallenge 𝔽q β (v := v) + (destIdx := ⟨k.val * ϑ + ϑ, by omega⟩) + (h_destIdx_le := by simp only [k_succ_mul_ϑ_le_ℓ_₂])) := by rw [hy_def] rw [extractSuffixFromChallenge_congr_destIdx] simp only [Nat.add_mul, Nat.one_mul] @@ -1186,80 +793,67 @@ lemma query_phase_step_preserves_fold rw [qMap_total_fiber_congr_steps 𝔽q β (i := 0) (steps := ϑ) (steps' := (↑k + 1) * ϑ) (h_steps_eq := h_steps_eq) (y := y)] -/-- Lemma 3 (Completeness): +/-! Lemma 3 (Completeness): Proves that the fully folded value (result of `iterated_fold` at `ℓ`) equals the `final_constant` expected by the statement. -/ +omit [SampleableType L] [DecidableEq 𝔽q] in lemma query_phase_final_fold_eq_constant (v : sDomain 𝔽q β h_ℓ_add_R_rate 0) (c : L) (stmtIn : FinalSumcheckStatementOut (L := L) (ℓ := ℓ)) (oStmtIn : ∀ j, OracleStatement 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (Fin.last ℓ) j) (witIn : Unit) - (h_relIn : strictFinalSumcheckRelOut 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ((stmtIn, oStmtIn), witIn)) + (h_relIn : strictFinalSumcheckRelOut 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + ((stmtIn, oStmtIn), witIn)) -- Hypothesis: x is the result of folding all the way to ℓ (h_c_correct : have h_mul_eq : (ℓ / ϑ) * ϑ = ℓ := Nat.div_mul_cancel hdiv.out c = iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := 0) (steps := (ℓ / ϑ) * ϑ) (destIdx := ⟨(ℓ / ϑ) * ϑ, by omega⟩) (h_destIdx_le := by simp only; omega) (f := getFirstOracle 𝔽q β oStmtIn) - (r_challenges := getFoldingChallenges (𝓡 := 𝓡) (r := r) (Fin.last ℓ) stmtIn.challenges 0 (by simp only [zero_add, Fin.val_last]; omega)) - (y := extractSuffixFromChallenge 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (v := v) (destIdx := ⟨(ℓ / ϑ) * ϑ, by omega⟩) (h_destIdx_le := by simp only; omega)) + (r_challenges := getFoldingChallenges (𝓡 := 𝓡) (r := r) (Fin.last ℓ) stmtIn.challenges 0 + (by simp only [zero_add, Fin.val_last]; omega)) + (y := extractSuffixFromChallenge 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (v := v) + (destIdx := ⟨(ℓ / ϑ) * ϑ, by omega⟩) (h_destIdx_le := by simp only; omega)) (h_destIdx := by simp only [Fin.coe_ofNat_eq_mod, Nat.zero_mod, zero_add];) ) : c = stmtIn.final_constant := by + classical dsimp only [strictFinalSumcheckRelOut, strictFinalSumcheckRelOutProp, - strictFinalFoldingStateProp] at h_relIn + strictfinalSumcheckStepFoldingStateProp] at h_relIn simp only [Fin.val_last, exists_and_right, Subtype.exists] at h_relIn - -- 2. Extract the existential witnesses - -- We pull out the polynomial 'a' (let's call it 'poly') and the two proofs (consistency & fold). - rw [h_c_correct] - rcases h_relIn with ⟨exists_t_MLP, h_final_oracle_fold_to_constant⟩ simp only at h_final_oracle_fold_to_constant - -- h_final_oracle_fold_to_constant : (iterated_fold 𝔽q β ⟨↑(getLastOraclePositionIndex ℓ ϑ (Fin.last ℓ)) * ϑ, ⋯⟩ ϑ ⋯ ⋯ (getLastOracle 𝔽q β ⋯ oStmtIn) - -- fun cId ↦ stmtIn.challenges ⟨↑(getLastOraclePositionIndex ℓ ϑ (Fin.last ℓ)) * ϑ + ↑cId, ⋯⟩) = - -- fun x ↦ stmtIn.final_constant - have h_final_oracle_fold_to_const_at_0 := congr_fun h_final_oracle_fold_to_constant 0 simp only at h_final_oracle_fold_to_const_at_0 rw [h_final_oracle_fold_to_const_at_0.symm] - - -- ⊢ x = - -- iterated_fold 𝔽q β ⟨↑(getLastOraclePositionIndex ℓ ϑ (Fin.last ℓ)) * ϑ, ⋯⟩ ϑ ⋯ ⋯ (getLastOracle 𝔽q β ⋯ oStmtIn) - -- (fun cId ↦ stmtIn.challenges ⟨↑(getLastOraclePositionIndex ℓ ϑ (Fin.last ℓ)) * ϑ + ↑cId, ⋯⟩) 0 - rcases exists_t_MLP with ⟨t, h_t_mem_support, h_strictOracleFoldingConsistency⟩ dsimp only [strictOracleFoldingConsistencyProp] at h_strictOracleFoldingConsistency - let lastOraclePositionIndex := getLastOraclePositionIndex ℓ ϑ (Fin.last ℓ) have h_last_oracle_eq_t_evals_folded := h_strictOracleFoldingConsistency lastOraclePositionIndex have h_ϑ_pos : ϑ > 0 := Nat.pos_of_neZero ϑ have h_ϑ_le_ℓ : ϑ ≤ ℓ := by apply Nat.le_of_dvd (by exact Nat.pos_of_neZero ℓ) (hdiv.out) have h_ℓ_div_mul_eq_ℓ : (ℓ / ϑ) * ϑ = ℓ := Nat.div_mul_cancel hdiv.out - have h_lastOraclePosIdx_mul_add : (getLastOraclePositionIndex ℓ ϑ (Fin.last ℓ)).val * ϑ + ϑ = ℓ := by + have h_lastOraclePosIdx_mul_add : + (getLastOraclePositionIndex ℓ ϑ (Fin.last ℓ)).val * ϑ + ϑ = ℓ := by conv_rhs => rw [←h_ℓ_div_mul_eq_ℓ] rw [getLastOraclePositionIndex_last]; simp only rw [Nat.sub_mul, Nat.one_mul]; rw [Nat.sub_add_cancel (by rw [h_ℓ_div_mul_eq_ℓ]; omega)] - have h_first_oracle_eq_t_evals_folded := h_strictOracleFoldingConsistency ⟨0, by simp only [toOutCodewordsCount_last, Nat.div_pos_iff]; omega⟩ - dsimp only [getFirstOracle] - have h_getLastOracle_eq : oStmtIn lastOraclePositionIndex = getLastOracle (h_destIdx := by rfl) (oracleFrontierIdx := Fin.last ℓ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (oStmt := oStmtIn) := by rfl rw [←h_getLastOracle_eq] rw [h_last_oracle_eq_t_evals_folded, h_first_oracle_eq_t_evals_folded] simp only [Fin.mk_zero', Fin.coe_ofNat_eq_mod] - have h_zero_mod : 0 % toOutCodewordsCount ℓ ϑ (Fin.last ℓ) * ϑ = 0 := by rw [toOutCodewordsCount_last]; simp only [Nat.zero_mod, zero_mul] - rw [iterated_fold_transitivity 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (h_destIdx := by simp only [Fin.coe_ofNat_eq_mod, Nat.zero_mod, zero_add, Nat.add_right_cancel_iff, mul_eq_mul_right_iff]; @@ -1268,14 +862,17 @@ lemma query_phase_final_fold_eq_constant rw [getLastOraclePositionIndex_last]; simp only [true_or] )] - - set chalLeft := (getFoldingChallenges (i := Fin.last ℓ) (𝓡 := 𝓡) (r := r) (challenges := stmtIn.challenges) (k := 0) (ϑ := ℓ/ϑ * ϑ) (by + set chalLeft := (getFoldingChallenges (i := Fin.last ℓ) (𝓡 := 𝓡) (r := r) + (challenges := stmtIn.challenges) (k := 0) (ϑ := ℓ/ϑ * ϑ) (by simp only [zero_add, Fin.val_last]; omega)) with h_chalLeft -- have h_concat_challenges_eq : - set chalRight := Fin.append (getFoldingChallenges (i := Fin.last ℓ) (𝓡 := 𝓡) (r := r) (challenges := stmtIn.challenges) (k := 0) (ϑ := lastOraclePositionIndex.val * ϑ) (by simp only [zero_add, Fin.val_last, oracle_index_le_ℓ])) + set chalRight := Fin.append (getFoldingChallenges (i := Fin.last ℓ) (𝓡 := 𝓡) (r := r) + (challenges := stmtIn.challenges) (k := 0) (ϑ := lastOraclePositionIndex.val * ϑ) + (by simp only [zero_add, Fin.val_last, oracle_index_le_ℓ])) (fun (cId : Fin ϑ) ↦ - stmtIn.challenges ⟨(getLastOraclePositionIndex ℓ ϑ (Fin.last ℓ)) * ϑ + cId.val, by simp only [Fin.val_last, getLastOraclePositionIndex_last]; simp only [lastBlockIdx_mul_ϑ_add_fin_lt_ℓ]⟩) with h_chalLeft - + stmtIn.challenges ⟨(getLastOraclePositionIndex ℓ ϑ (Fin.last ℓ)) * ϑ + cId.val, by + simp only [Fin.val_last, getLastOraclePositionIndex_last]; + simp only [lastBlockIdx_mul_ϑ_add_fin_lt_ℓ]⟩) with h_chalLeft have h_chalLeft_eq_chalRight_cast : chalLeft = fun cIdx : Fin (ℓ / ϑ * ϑ) => chalRight ⟨cIdx, by dsimp only [lastOraclePositionIndex] simp only [getLastOraclePositionIndex_last]; @@ -1287,7 +884,7 @@ lemma query_phase_final_fold_eq_constant · -- Case 1: cId < k_steps, so it's from the first part simp only [Fin.val_last] dsimp only [Fin.append, Fin.addCases] - simp only [h, ↓reduceDIte, getFoldingChallenges, Fin.val_last, Fin.coe_castLT, zero_add] + simp only [h, ↓reduceDIte, getFoldingChallenges, Fin.val_last, Fin.val_castLT, zero_add] · -- Case 2: cId >= k_steps, so it's from the second part simp only [Fin.val_last] dsimp only [Fin.append, Fin.addCases] @@ -1302,53 +899,70 @@ lemma query_phase_final_fold_eq_constant simp only [Nat.sub_mul, one_mul, not_lt, tsub_le_iff_right] at ⊢ h exact h rw [h_chalLeft_eq_chalRight_cast] - conv_lhs => -- 1. Locate the specific sub-term corresponding to the folding function -- This finds the lambda "fun y ↦ ..." pattern (fun y ↦ iterated_fold _ _ _ _ _ _ _ _ _) enter [y] rw [iterated_fold_congr_steps_index 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := 0) - (steps := 0 % toOutCodewordsCount ℓ ϑ (Fin.last ℓ) * ϑ) (steps' := 0) (h_destIdx := by simp only [toOutCodewordsCount_last, - Nat.zero_mod, zero_mul, Fin.coe_ofNat_eq_mod, add_zero]) (h_destIdx_le := by simp only [toOutCodewordsCount_last, - Nat.zero_mod, zero_mul, zero_le]) (h_steps_eq_steps' := by omega)] - rw [iterated_fold_zero_steps 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := 0) (h_destIdx := by simp only [toOutCodewordsCount_last, + (steps := 0 % toOutCodewordsCount ℓ ϑ (Fin.last ℓ) * ϑ) (steps' := 0) (h_destIdx := by + simp only [toOutCodewordsCount_last, Nat.zero_mod, zero_mul, Fin.coe_ofNat_eq_mod, add_zero]) + (h_destIdx_le := by simp only [toOutCodewordsCount_last, Nat.zero_mod, zero_mul, zero_le]) + (h_steps_eq_steps' := by omega)] + rw [iterated_fold_zero_steps 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := 0) + (h_destIdx := by simp only [toOutCodewordsCount_last, Nat.zero_mod, zero_mul, Fin.coe_ofNat_eq_mod])] conv_lhs => simp only [cast_cast, cast_eq]; simp only [←fun_eta_expansion] conv_lhs => - rw [←iterated_fold_congr_steps_index 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := 0) (steps := ↑lastOraclePositionIndex * ϑ + ϑ) (steps' := (ℓ / ϑ * ϑ)) (h_destIdx := by dsimp only [lastOraclePositionIndex]; simp only [Fin.coe_ofNat_eq_mod, Nat.zero_mod, zero_add]; omega) (h_destIdx_le := by simp only; omega) (h_steps_eq_steps' := by dsimp only [lastOraclePositionIndex]; omega)] - - let P₀: L[X]_(2 ^ ℓ) := polynomialFromNovelCoeffsF₂ 𝔽q β ℓ (by omega) (fun ω => t.eval (bitsOfIndex ω)) + rw [←iterated_fold_congr_steps_index 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := 0) + (steps := ↑lastOraclePositionIndex * ϑ + ϑ) (steps' := (ℓ / ϑ * ϑ)) (h_destIdx := by + dsimp only [lastOraclePositionIndex]; + simp only [Fin.coe_ofNat_eq_mod, Nat.zero_mod, zero_add]; omega) + (h_destIdx_le := by simp only; omega) (h_steps_eq_steps' := by + dsimp only [lastOraclePositionIndex]; omega)] + let P₀: L[X]_(2 ^ ℓ) := polynomialFromNovelCoeffsF₂ 𝔽q β ℓ (by omega) + (fun ω => t.eval (bitsOfIndex ω)) let f₀ := polyToOracleFunc 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (domainIdx := 0) (P := P₀) - set destIdx' : Fin r := ⟨(getLastOracleDomainIndex ℓ ϑ (Fin.last ℓ)).val + ϑ, by rw [getLastOracleDomainIndex]; simp only; omega⟩ with h_destIdx' - - let point := extractSuffixFromChallenge 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (v := v) (destIdx := ⟨ℓ / ϑ * ϑ, by omega⟩) (h_destIdx_le := by simp only; omega) - + let point := extractSuffixFromChallenge 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (v := v) + (destIdx := ⟨ℓ / ϑ * ϑ, by omega⟩) (h_destIdx_le := by simp only; omega) conv_lhs => - rw [iterated_fold_congr_dest_index 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := 0) (steps := ↑lastOraclePositionIndex * ϑ + ϑ) (destIdx := ⟨ℓ / ϑ * ϑ, by omega⟩) (destIdx' := destIdx') (h_destIdx := by dsimp only [lastOraclePositionIndex]; simp only [Fin.coe_ofNat_eq_mod, Nat.zero_mod, zero_add]; omega) (h_destIdx_le := by simp only; omega) (h_destIdx_eq_destIdx' := by dsimp only [destIdx']; simp only [Fin.mk.injEq]; omega) (f := f₀) (r_challenges := chalRight) (y := point)] - - rw [iterated_fold_congr_steps_index 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := 0) (steps := ↑lastOraclePositionIndex * ϑ + ϑ) (steps' := ℓ) (h_destIdx := by - dsimp only [destIdx']; - simp only [Fin.coe_ofNat_eq_mod, Nat.zero_mod, zero_add, Nat.add_right_cancel_iff, mul_eq_mul_right_iff]; omega) - (h_destIdx_le := by dsimp only [destIdx']; simp only [oracle_index_add_steps_le_ℓ]) (h_steps_eq_steps' := by omega)] - - rw [iterated_fold_congr_steps_index 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := 0) (steps := ↑lastOraclePositionIndex * ϑ + ϑ) (steps' := ℓ) (h_destIdx := by + rw [iterated_fold_congr_dest_index 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := 0) + (steps := ↑lastOraclePositionIndex * ϑ + ϑ) (destIdx := ⟨ℓ / ϑ * ϑ, by omega⟩) + (destIdx' := destIdx') (h_destIdx := by + dsimp only [lastOraclePositionIndex]; + simp only [Fin.coe_ofNat_eq_mod, Nat.zero_mod, zero_add]; omega) + (h_destIdx_le := by simp only; omega) (h_destIdx_eq_destIdx' := by + dsimp only [destIdx']; simp only [Fin.mk.injEq]; omega) (f := f₀) + (r_challenges := chalRight) (y := point)] + rw [iterated_fold_congr_steps_index 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := 0) + (steps := ↑lastOraclePositionIndex * ϑ + ϑ) (steps' := ℓ) (h_destIdx := by + dsimp only [destIdx']; + simp only [Fin.coe_ofNat_eq_mod, Nat.zero_mod, zero_add, Nat.add_right_cancel_iff, + mul_eq_mul_right_iff]; omega) + (h_destIdx_le := by dsimp only [destIdx']; simp only [oracle_index_add_steps_le_ℓ]) + (h_steps_eq_steps' := by omega)] + rw [iterated_fold_congr_steps_index 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := 0) + (steps := ↑lastOraclePositionIndex * ϑ + ϑ) (steps' := ℓ) (h_destIdx := by dsimp only [destIdx']; - simp only [Fin.coe_ofNat_eq_mod, Nat.zero_mod, zero_add, Nat.add_right_cancel_iff, mul_eq_mul_right_iff]; omega) - (h_destIdx_le := by dsimp only [destIdx']; simp only [oracle_index_add_steps_le_ℓ]) (h_steps_eq_steps' := by omega)] - + simp only [Fin.coe_ofNat_eq_mod, Nat.zero_mod, zero_add, + Nat.add_right_cancel_iff, mul_eq_mul_right_iff]; omega) + (h_destIdx_le := by dsimp only [destIdx']; simp only [oracle_index_add_steps_le_ℓ]) + (h_steps_eq_steps' := by omega)] have h_sDomain_eq : (sDomain 𝔽q β h_ℓ_add_R_rate ⟨ℓ/ϑ * ϑ, by omega⟩) = (sDomain 𝔽q β h_ℓ_add_R_rate destIdx') := by apply sDomain_eq_of_eq; dsimp only [destIdx']; simp only [Fin.mk.injEq]; omega - - let res := iterated_fold_to_level_ℓ_is_constant 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (t := ⟨t, h_t_mem_support⟩) (destIdx := destIdx') (h_destIdx := by omega) (challenges := fun (cIdx : Fin ℓ) => chalRight ⟨cIdx, by dsimp only [lastOraclePositionIndex]; omega⟩) (x := cast (by rw [h_sDomain_eq]) point) (y := 0) + let res := iterated_fold_to_level_ℓ_is_constant 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (t := ⟨t, h_t_mem_support⟩) (destIdx := destIdx') (h_destIdx := by omega) + (challenges := fun (cIdx : Fin ℓ) => + chalRight ⟨cIdx, by dsimp only [lastOraclePositionIndex]; omega⟩) + (x := cast (by rw [h_sDomain_eq]) point) (y := 0) rw [res] /-- Relation used in the forIn loop of `checkSingleRepetition`: at index 0 the folded value is 0; - at index `oraclePositionIdx > 0` it equals `iterated_fold` up to that position with challenges from - `stmtIn` and suffix from `v`. -/ + at index `oraclePositionIdx > 0` it equals `iterated_fold` up to that position with challenges + from `stmtIn` and suffix from `v`. -/ @[reducible] def checkSingleRepetition_foldRel (stmtIn : FinalSumcheckStatementOut (L := L) (ℓ := ℓ)) @@ -1360,7 +974,8 @@ def checkSingleRepetition_foldRel if hk : oraclePositionIdx.val = 0 then val_folded_point = 0 -- Base case: initial value is 0 else - have h_toCodewordCount : toOutCodewordsCount ℓ ϑ (Fin.last ℓ) = ℓ / ϑ := toOutCodewordsCount_last ℓ ϑ + have h_toCodewordCount : toOutCodewordsCount ℓ ϑ (Fin.last ℓ) = ℓ / ϑ := + toOutCodewordsCount_last ℓ ϑ have h_le : oraclePositionIdx ≤ ℓ/ϑ := by have h := oraclePositionIdx.isLt simp only [List.length_finRange] at h @@ -1374,13 +989,15 @@ def checkSingleRepetition_foldRel extractSuffixFromChallenge 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (v:=v) (destIdx:=destIdx) (h_destIdx_le:=by omega) val_folded_point = iterated_fold - (i := 0) (steps := oraclePositionIdx * ϑ) (destIdx := destIdx) (h_destIdx := by simp only [Fin.coe_ofNat_eq_mod, Nat.zero_mod, zero_add]; rfl) + (i := 0) (steps := oraclePositionIdx * ϑ) (destIdx := destIdx) (h_destIdx := by + simp only [Fin.coe_ofNat_eq_mod, Nat.zero_mod, zero_add]; rfl) (h_destIdx_le := by rw [←h_mul] dsimp only [destIdx]; apply Nat.mul_le_mul_right; exact h_le ) (f := f₀) - (r_challenges := getFoldingChallenges (𝓡 := 𝓡) (r := r) (Fin.last ℓ) stmtIn.challenges 0 (by simp only [zero_add, Fin.val_last]; omega)) (y := suffix_point_from_v) + (r_challenges := getFoldingChallenges (𝓡 := 𝓡) (r := r) (Fin.last ℓ) stmtIn.challenges 0 + (by simp only [zero_add, Fin.val_last]; omega)) (y := suffix_point_from_v) /-- Safety of the simulated inner `forIn` loop used by `checkSingleRepetition_probFailure_eq_zero`. -/ @@ -1412,31 +1029,28 @@ lemma checkSingleRepetition_inner_forIn_probFailure_eq_zero Pr[⊥ | inner_forIn_block] = 0 := by intro step transcript so v f inner_forIn_block dsimp only [inner_forIn_block] - let Rel : Fin ((List.finRange (ℓ / ϑ)).length + 1) → L → Prop := - checkSingleRepetition_foldRel 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (stmtIn := stmtIn) (oStmtIn := oStmtIn) (v := v) - + checkSingleRepetition_foldRel 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (stmtIn := stmtIn) (oStmtIn := oStmtIn) (v := v) -- For this proof, we define a trivial relation since the real invariant -- is complex and involves the correctness of folding operations -- a. Push liftComp inside the forIn loop (twice, for the two layers) -- Goal: simulateQ so (liftComp (liftComp (forIn ...))) -- Becomes: simulateQ so (forIn ... (fun x s => liftComp ...)) - -- **Applying indutive relation inference** apply probFailure_forIn_of_relations_simplified (rel := Rel) (h_start := by rfl) (h_step := by -- Inductive step: any INNER repetition never fails intro (k : Fin (List.finRange (ℓ / ϑ)).length) (c_k : L) h_rel_k_c -- simp only [List.get_eq_getElem, List.getElem_finRange] at * - -- Simplify k.succ ≠ 0 (always true) have h_succ_ne_zero : k.succ ≠ 0 := Fin.succ_ne_zero k - constructor · -- Part 1: checkSingleFoldingStep is safe (never fails) - - -- where the forInStep.yield has spec `OracleComp [OracleStatement 𝔽q β ϑ (Fin.last ℓ)]ₒ (ForInStep L)` + -- where the forInStep.yield has spec + -- `OracleComp [OracleStatement 𝔽q β ϑ (Fin.last ℓ)]ₒ (ForInStep L)` -- [⊥|simulateQ so - -- ((ForInStep.yield <$> checkSingleFoldingStep 𝔽q β ((List.finRange (ℓ / ϑ)).get k) c_k v stmtIn).liftComp + -- ((ForInStep.yield <$> checkSingleFoldingStep 𝔽q β + -- ((List.finRange (ℓ / ϑ)).get k) c_k v stmtIn).liftComp -- ([]ₒ ++ₒ -- ([OracleStatement 𝔽q β ϑ (Fin.last ℓ)]ₒ ++ₒ -- [fun i ↦ ![Fin γ_repetitions → ↥(sDomain 𝔽q β h_ℓ_add_R_rate 0)] ↑i]ₒ)))] = @@ -1445,13 +1059,18 @@ lemma checkSingleRepetition_inner_forIn_probFailure_eq_zero -- rw [simulateQ_liftComp] rw [map_eq_bind_pure_comp] erw [probFailure_map] -- Pr[⊥ | f <$> mx] = Pr[⊥ | mx] **IMPORTANT** - -- ⊢ Pr[⊥ | simulateQ so (checkSingleFoldingStep 𝔽q β γ_repetitions ((List.finRange (ℓ / ϑ)).get k) c_k v stmtIn).run] = 0 + -- ⊢ Pr[⊥ | simulateQ so (checkSingleFoldingStep 𝔽q β γ_repetitions + -- ((List.finRange (ℓ / ϑ)).get k) c_k v stmtIn).run] = 0 dsimp only [checkSingleFoldingStep] erw [simulateQ_bind] erw [OptionT.probFailure_mk_do_bind_eq_zero_iff.{0, 0}] have h_probFailure_queryFiberPoints_eq_zero : Pr[⊥ | - OptionT.mk (simulateQ so (queryFiberPoints 𝔽q β γ_repetitions ((List.finRange (ℓ / ϑ)).get k) v))] = 0 := by - apply probFailure_simulateQ_queryFiberPoints_eq_zero (γ_repetitions := γ_repetitions) (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝔽q := 𝔽q) (β := β) + OptionT.mk + (simulateQ so + (queryFiberPoints 𝔽q β γ_repetitions ((List.finRange (ℓ / ϑ)).get k) v))] = 0 := by + apply probFailure_simulateQ_queryFiberPoints_eq_zero + (γ_repetitions := γ_repetitions) (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (𝔽q := 𝔽q) (β := β) (so := so) (k := k) (v := v) have h_probOutput_none_queryFiberPoints_eq_zero := OptionT.probOutput_none_run_eq_zero_of_probFailure_eq_zero @@ -1462,16 +1081,16 @@ lemma checkSingleRepetition_inner_forIn_probFailure_eq_zero HasEvalPMF.probFailure_eq_zero] · -- The guard and pure computation intro fiber_vec_opt h_fiber_vec_opt_mem_support - have h_fiber_vec_eq_some := exists_eq_some_of_mem_support_of_probOutput_none_eq_zero.{0, 0} (x := fiber_vec_opt) - (hx := h_fiber_vec_opt_mem_support) (hnone := h_probOutput_none_queryFiberPoints_eq_zero) + have h_fiber_vec_eq_some := + exists_eq_some_of_mem_support_of_probOutput_none_eq_zero.{0, 0} (x := fiber_vec_opt) + (hx := h_fiber_vec_opt_mem_support) + (hnone := h_probOutput_none_queryFiberPoints_eq_zero) rcases h_fiber_vec_eq_some with ⟨fiber_vec, rfl⟩ simp only [MessageIdx, List.get_eq_getElem, List.getElem_finRange, Fin.eta, Fin.val_cast, gt_iff_lt, CanonicallyOrderedAdd.mul_pos, Message, guard_eq, Fin.val_last, bind_pure_comp, - map_pure, dite_eq_ite] - + dite_eq_ite] have h_ϑ_pos : ϑ > 0 := by exact Nat.pos_of_neZero ϑ simp only [h_ϑ_pos, and_true] - by_cases h_i_pos : k.val > 0 · -- Case k > 0: guard (c_k = f_i_val) let k_idx : Fin (ℓ / ϑ) := ⟨k.val, by @@ -1483,10 +1102,12 @@ lemma checkSingleRepetition_inner_forIn_probFailure_eq_zero simp only [List.get_eq_getElem, List.getElem_finRange, Fin.eta] apply Fin.eq_of_val_eq simp only [Fin.val_cast]; rfl - -- 1. Simplify failure probability to just the guard condition simp only [h_i_pos, ↓reduceIte, OptionT.simulateQ_map] - have h_guard_pass : c_k = fiber_vec.get (extractMiddleFinMask 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (v := v) (i := ⟨k.val * ϑ, by omega⟩) (steps := ϑ)) := by + have h_guard_pass : + c_k = fiber_vec.get + (extractMiddleFinMask 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (v := v) (i := ⟨k.val * ϑ, by omega⟩) (steps := ϑ)) := by -- ⊢ c_k = f_i_on_fiber.get (extractMiddleFinMask ...) -- 1. Construct the correct index type for the lemma -- 3. Unfold Rel to get the equality @@ -1496,14 +1117,17 @@ lemma checkSingleRepetition_inner_forIn_probFailure_eq_zero rw [dif_neg h_k_castSucc_ne_0] at h_rel_k_c simp only [Fin.val_castSucc] at h_rel_k_c -- simp only [Fin.isValue, List.get_eq_getElem, List.getElem_finRange, Fin.eta, - -- Fin.coe_cast] - + -- Fin.val_cast] have h_mul_gt_0 : k.val * ϑ > 0 := by simp only [gt_iff_lt, CanonicallyOrderedAdd.mul_pos] omega - -- 4. Apply the lemma - have res := query_phase_consistency_guard_safe 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (k := k_idx) (v := v) (c_k := c_k) (f_i_on_fiber := fiber_vec) (stmtIn := stmtIn) (oStmtIn := oStmtIn) (witIn := witIn) (h_relIn := h_relIn) (h_c_k_correct := h_rel_k_c) (h_k_pos := h_mul_gt_0) (γ_repetitions := γ_repetitions) (challenges := challenges) (h_fiber_mem := by + have res := query_phase_consistency_guard_safe 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (k := k_idx) (v := v) (c_k := c_k) + (f_i_on_fiber := fiber_vec) (stmtIn := stmtIn) (oStmtIn := oStmtIn) + (witIn := witIn) (h_relIn := h_relIn) (h_c_k_correct := h_rel_k_c) + (h_k_pos := h_mul_gt_0) (γ_repetitions := γ_repetitions) + (challenges := challenges) (h_fiber_mem := by rw [h_k_idx_eq] exact h_fiber_vec_opt_mem_support ) @@ -1515,8 +1139,7 @@ lemma checkSingleRepetition_inner_forIn_probFailure_eq_zero erw [simulateQ_pure, probFailure_pure] · -- Part 2: Results in support satisfy the next relation intro s' h_s'_support - simp only [checkSingleRepetition_foldRel, dite_eq_ite, Fin.succ_ne_zero, ↓reduceIte, - Fin.val_succ, Rel] + simp only [checkSingleRepetition_foldRel, dite_eq_ite, Fin.val_succ, Rel] simp only [MessageIdx, List.get_eq_getElem, List.getElem_finRange, Fin.eta, support_map, Set.mem_image, OptionT.mem_support_iff, toPFunctor_emptySpec, OptionT.support_run, f] at h_s'_support @@ -1531,7 +1154,10 @@ lemma checkSingleRepetition_inner_forIn_probFailure_eq_zero exact h ⟩ -- Apply the preservation lemma - let res := query_phase_step_preserves_fold 𝔽q β (γ_repetitions := γ_repetitions) (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (k := k_idx) (v := v) (c_k := c_k) (s' := x) (stmtIn := stmtIn) (oStmtIn := oStmtIn) (h_relIn := h_relIn) (challenges := challenges) (h_s'_mem := by + let res := query_phase_step_preserves_fold 𝔽q β (γ_repetitions := γ_repetitions) + (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (k := k_idx) (v := v) (c_k := c_k) + (s' := x) (stmtIn := stmtIn) (oStmtIn := oStmtIn) (h_relIn := h_relIn) + (challenges := challenges) (h_s'_mem := by dsimp only [so] at h_x_support dsimp only [pSpecQuery] exact h_x_support @@ -1560,7 +1186,8 @@ lemma checkSingleRepetition_probFailure_eq_zero (stmtIn : FinalSumcheckStatementOut (L := L) (ℓ := ℓ)) (oStmtIn : ∀ j, OracleStatement 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (Fin.last ℓ) j) (witIn : Unit) - (h_relIn : strictFinalSumcheckRelOut 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ((stmtIn, oStmtIn), witIn)) + (h_relIn : strictFinalSumcheckRelOut 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + ((stmtIn, oStmtIn), witIn)) (rep : Fin γ_repetitions) (challenges : (pSpecQuery 𝔽q β γ_repetitions (h_ℓ_add_R_rate := h_ℓ_add_R_rate)).Challenges) : let step := queryPhaseLogicStep 𝔽q β (ϑ := ϑ) γ_repetitions (h_ℓ_add_R_rate := h_ℓ_add_R_rate) @@ -1568,21 +1195,19 @@ lemma checkSingleRepetition_probFailure_eq_zero let so := OracleInterface.simOracle2.{0, 0, 0, 0, 0} []ₒ oStmtIn transcript.messages let v := (FullTranscript.mk1 (challenges ⟨0, by rfl⟩)).challenges ⟨0, by rfl⟩ rep Pr[⊥ | OptionT.mk.{0, 0} (simulateQ.{0, 0, 0} so - (checkSingleRepetition 𝔽q β (γ_repetitions := γ_repetitions) (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) v stmtIn stmtIn.final_constant).run)] = 0 := by - + (checkSingleRepetition 𝔽q β (γ_repetitions := γ_repetitions) (ϑ:=ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) v stmtIn stmtIn.final_constant).run)] = 0 := by intro step transcript so v let f₀ := getFirstOracle 𝔽q β oStmtIn - let Rel : Fin ((List.finRange (ℓ / ϑ)).length + 1) → L → Prop := - checkSingleRepetition_foldRel 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (stmtIn := stmtIn) (oStmtIn := oStmtIn) (v := v) - + checkSingleRepetition_foldRel 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (stmtIn := stmtIn) (oStmtIn := oStmtIn) (v := v) -- 1. Expand definition to expose the `forIn` and `guard` dsimp only [checkSingleRepetition] - -- 2. Distribute simulateQ and liftM over the Bind (>>=) -- This splits `simulateQ (Loop >>= Guard)` into `simulateQ Loop >>= simulateQ Guard` simp only [bind_pure_comp] - simp only [Fin.eta, map_pure, bind_pure_comp] + simp only [Fin.eta] -- erw [liftComp_bind] erw [simulateQ_bind] dsimp only [Function.comp_def] @@ -1600,23 +1225,23 @@ lemma checkSingleRepetition_probFailure_eq_zero rw [true_and] intro c h_c_support_inner_loop -- **if the inner for loop is passed, then the guard must be passed (given relIn)** - - simp only [MessageIdx, Message, OptionT.simulateQ_map, id_map'] at h_c_support_inner_loop - - set f : Fin (ℓ / ϑ) → L → OracleComp []ₒ (Option (ForInStep L)) := fun (a : Fin (ℓ / ϑ)) (b : L) ↦ + simp only [MessageIdx, Message, LawfulApplicative.map_pure, bind_pure_comp, + OptionT.simulateQ_map] at h_c_support_inner_loop + set f : Fin (ℓ / ϑ) → L → OracleComp []ₒ (Option (ForInStep L)) := + fun (a : Fin (ℓ / ϑ)) (b : L) ↦ ((ForInStep.yield <$> (simulateQ.{0, 0, 0} so (checkSingleFoldingStep 𝔽q β (γ_repetitions := γ_repetitions) (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) a b v stmtIn ).run )) : OptionT (OracleComp []ₒ) (ForInStep L)) with h_f_def - - set inner_forIn_block := ((forIn (List.finRange (ℓ / ϑ)) (0 : L) f) : OptionT (OracleComp []ₒ) L) with h_inner_forIn_block - + set inner_forIn_block := ((forIn (List.finRange (ℓ / ϑ)) (0 : L) f) : + OptionT (OracleComp []ₒ) L) with h_inner_forIn_block have h_probFailure_loop_eq_zero : Pr[⊥ | inner_forIn_block] = 0 := by - exact checkSingleRepetition_inner_forIn_probFailure_eq_zero 𝔽q β (γ_repetitions := γ_repetitions) (ϑ := ϑ) - (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (stmtIn := stmtIn) (oStmtIn := oStmtIn) (witIn := witIn) (h_relIn := h_relIn) (rep := rep) (challenges := challenges) - + exact checkSingleRepetition_inner_forIn_probFailure_eq_zero 𝔽q β + (γ_repetitions := γ_repetitions) (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (stmtIn := stmtIn) (oStmtIn := oStmtIn) + (witIn := witIn) (h_relIn := h_relIn) (rep := rep) (challenges := challenges) have h_probOutput_inner_forIn_block_eq_none := OptionT.probOutput_none_run_eq_zero_of_probFailure_eq_zero (hfail := h_probFailure_loop_eq_zero) @@ -1624,7 +1249,6 @@ lemma checkSingleRepetition_probFailure_eq_zero (hx := h_c_support_inner_loop) (hnone := h_probOutput_inner_forIn_block_eq_none) rcases h_c_eq_some with ⟨c_val, rfl⟩ -- h_c_support_inner_loop : c ∈ forIn (List.finRange (ℓ / ϑ)) 0 f .support - -- ⊢ x = stmtIn.final_constant -- We reuse the SAME relation `Rel` and the SAME logic we used for safety! have h_c_eq_final_constant : c_val = stmtIn.final_constant := by @@ -1634,33 +1258,29 @@ lemma checkSingleRepetition_probFailure_eq_zero -- 1. Apply the helper lemma to transport the invariant to the end -- h_x_support : x ∈ -- (forIn (List.finRange (ℓ / ϑ)) 0 fun a b ↦ - -- simulateQ (QueryImpl.lift so) (checkSingleFoldingStep 𝔽q β a b v stmtIn) >>= pure ∘ ForInStep.yield).support - + -- simulateQ (QueryImpl.lift so) (checkSingleFoldingStep 𝔽q β a b v stmtIn) + -- >>= pure ∘ ForInStep.yield).support have h_rel_final : Rel ⟨ℓ/ϑ, by simp only [List.length_finRange, lt_add_iff_pos_right, zero_lt_one]⟩ c_val := by -- unfold OptionT at h_c_support_inner_loop - -- Apply the yield-only helper let relation_correct_of_mem_support := support_forIn_subset_rel_yield_only.{0} (m := OptionT (OracleComp []ₒ)) (l := List.finRange (ℓ/ϑ)) (rel := Rel) (f := f) (init := 0) (h_start := by rfl) (h_step := by -- simp only [←simulateQ_liftComp] - intro (k : Fin (List.finRange (ℓ / ϑ)).length) (c_k : L) h_rel_k_c iteration_output h_iteration_output_iteration + intro (k : Fin (List.finRange (ℓ / ϑ)).length) (c_k : L) h_rel_k_c iteration_output + h_iteration_output_iteration -- 1. Unpack support (extract c_next) - -- simp only [bind_pure_comp, map_pure, support_map, Set.mem_image] at h_iteration_output_iteration -- 1. Distribute simulateQ over >>= and pure -- This transforms: simulateQ (action >>= pure) -> (simulateQ action) >>= pure - simp only [MessageIdx, OptionT.simulateQ_map, List.get_eq_getElem, List.getElem_finRange, + simp only [MessageIdx, List.get_eq_getElem, List.getElem_finRange, Fin.eta, support_map, Set.mem_image, OptionT.mem_support_iff, toPFunctor_emptySpec, OptionT.support_run, f] at h_iteration_output_iteration - -- 2. Now the hypothesis is exactly: ∃ c_next, c_next ∈ support ∧ output = yield c_next -- Extract it just like before! rcases h_iteration_output_iteration with ⟨c_next, h_c_next_mem, h_iteration_output_eq⟩ rw [←h_iteration_output_eq] - dsimp only [OptionT.run] at h_c_next_mem - -- simp only [h_iteration_output_eq] constructor · rfl @@ -1671,7 +1291,10 @@ lemma checkSingleRepetition_probFailure_eq_zero simp only [List.length_finRange] at h_k_lt exact h_k_lt⟩ -- Apply preservation lemma (Exact same syntax as Part 2) - let res := query_phase_step_preserves_fold 𝔽q β (γ_repetitions := γ_repetitions) (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (k := k_idx) (v := v) (c_k := c_k) (s' := c_next) (stmtIn := stmtIn) (oStmtIn := oStmtIn) (h_relIn := h_relIn) (challenges := challenges) (h_s'_mem := h_c_next_mem) + let res := query_phase_step_preserves_fold 𝔽q β (γ_repetitions := γ_repetitions) + (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (k := k_idx) (v := v) (c_k := c_k) + (s' := c_next) (stmtIn := stmtIn) (oStmtIn := oStmtIn) (h_relIn := h_relIn) + (challenges := challenges) (h_s'_mem := h_c_next_mem) (h_c_k_correct_of_k_pos := by dsimp only [k_idx] dsimp only [Rel, checkSingleRepetition_foldRel] at h_rel_k_c @@ -1692,7 +1315,8 @@ lemma checkSingleRepetition_probFailure_eq_zero unfold Rel at h_rel_final -- Prove that the final index is not 0 have h_nonzero : (⟨ℓ/ϑ, by simp only [List.length_finRange, - lt_add_iff_pos_right, zero_lt_one]⟩ : Fin (List.length (List.finRange (ℓ / ϑ)) + 1)) ≠ 0 := by + lt_add_iff_pos_right, zero_lt_one]⟩ : + Fin (List.length (List.finRange (ℓ / ϑ)) + 1)) ≠ 0 := by simp only [ne_eq, Fin.mk_eq_zero, Nat.div_eq_zero_iff, not_or, not_lt] constructor · have h := Nat.pos_of_neZero (ϑ); omega @@ -1710,6 +1334,769 @@ lemma checkSingleRepetition_probFailure_eq_zero erw [simulateQ_pure.{0, 0, 0}] erw [probFailure_pure.{0, 0}] +/-- Pair-support projection wrapper of `support_simulateQ_run'_eq`. +`Prod.fst` of the stateful run support matches the spec support. -/ +lemma support_run_simulateQ_run_fst_eq {ι : Type} + {oSpec : OracleSpec ι} [oSpec.Fintype] [oSpec.Inhabited] {σ α : Type} + (impl : QueryImpl oSpec (StateT σ ProbComp)) + (oa : OracleComp oSpec (Option α)) (s : σ) + (hImplSupp : ∀ {β} (q : OracleQuery oSpec β) s, + Prod.fst <$> support ((QueryImpl.mapQuery impl q).run s) + = support (liftM q : OracleComp oSpec β)) : + Prod.fst <$> support (m := ProbComp) (α := Option α × σ) ((simulateQ impl oa) s) = + support (m := OracleComp oSpec) (α := Option α) oa := by + simpa [StateT.run'_eq, support_map] using + (support_simulateQ_run'_eq (impl := impl) (oa := oa) (s := s) + (hImplSupp := hImplSupp)) +/-! **Per-repetition support → logical** (extracted for reuse from completeness-style reasoning). +**Counterpart** of `checkSingleRepetition_probFailure_eq_zero` for the `OracleComp.support` case. +If `(ForInStep.yield PUnit.unit, state_post)` lies in the support of one iteration of the + verifier's forIn body (for a given `rep`), then the logical proximity check holds for that + repetition: `logical_checkSingleRepetition 𝔽q β oStmtIn (tr.challenges ⟨0, rfl⟩ rep) stmtIn + stmtIn.final_constant`. +-/ +omit [CharP L 2] [SampleableType L] in +lemma logical_checkSingleRepetition_of_mem_support_forIn_body {σ : Type} + (impl : QueryImpl []ₒ (StateT σ ProbComp)) + (oStmtIn : ∀ j, OracleStatement 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (Fin.last ℓ) j) + (tr : FullTranscript (pSpecQuery 𝔽q β γ_repetitions (h_ℓ_add_R_rate := h_ℓ_add_R_rate))) + (stmtIn : FinalSumcheckStatementOut) + (rep : Fin γ_repetitions) + (state_pre : σ) + (forIn_body : Fin γ_repetitions → PUnit → StateT σ ProbComp (Option (ForInStep PUnit))) + (h_forIn_body_eq : forIn_body = + fun (a : Fin γ_repetitions) (_ : PUnit.{1}) => + OptionT.mk (simulateQ impl ((((fun (_ : Unit) ↦ ForInStep.yield PUnit.unit) <$> + ((simulateQ.{0, 0, 0} (impl := OracleInterface.simOracle2 []ₒ oStmtIn tr.messages) + ((checkSingleRepetition 𝔽q β (γ_repetitions := γ_repetitions) (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + ((FullTranscript.mk1 (tr.challenges ⟨0, rfl⟩)).challenges ⟨0, rfl⟩ a) + stmtIn stmtIn.final_constant) : + OptionT (OracleComp + ([]ₒ + ([OracleStatement 𝔽q β ϑ (Fin.last ℓ)]ₒ + + [(pSpecQuery 𝔽q β γ_repetitions).Message]ₒ))) Unit).run) : + OracleComp []ₒ (Option Unit))) : + OptionT (OracleComp []ₒ) (ForInStep PUnit.{1}))))) + (h_mem : ∃ (res : ForInStep PUnit.{1} × σ), (some res.1, res.2) ∈ + ((forIn_body rep PUnit.unit).run state_pre).support) : + logical_checkSingleRepetition 𝔽q β oStmtIn (tr.challenges ⟨0, rfl⟩ rep) stmtIn + stmtIn.final_constant := by + -- 1. Extract the witness res = (control_flow, state_post) + rcases h_mem with ⟨⟨res_flow, state_post_single_outer_repetition⟩, h_support⟩ + -- 2. Unfold the body definition + rw [h_forIn_body_eq] at h_support + set v := tr.challenges ⟨0, rfl⟩ rep with h_v + let Rel := checkSingleRepetition_foldRel 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (stmtIn := stmtIn) (oStmtIn := oStmtIn) (v := v) + dsimp only [logical_checkSingleRepetition] + conv at h_support => + -- 1. Expand definition to expose the `forIn` and `guard` + dsimp only [checkSingleRepetition] + -- 2. Distribute simulateQ and liftM over the Bind (>>=) + -- This splits `simulateQ (Loop >>= Guard)` into `simulateQ Loop >>= simulateQ Guard` + dsimp only [liftM, monadLift, MonadLift.monadLift] + erw [simulateQ_bind, simulateQ_bind, simulateQ_bind] + erw [support_bind] + dsimp only [Function.comp_def] + simp only [Fin.isValue, id_map', + guard_eq, map_bind, simulateQ_bind, simulateQ_liftComp, StateT.run_bind, Function.comp_apply, + simulateQ_map, simulateQ_ite, simulateQ_pure, OptionT.simulateQ_failure, + StateT.run_map, support_bind, + support_map, Set.mem_iUnion, Set.mem_image, Prod.mk.injEq, Prod.exists, exists_eq_right_right, + exists_and_right, exists_and_left, exists_prop] + erw [support_bind] + simp only [Fin.isValue, id_map', + guard_eq, map_bind, simulateQ_bind, simulateQ_liftComp, StateT.run_bind, Function.comp_apply, + simulateQ_map, simulateQ_ite, simulateQ_pure, OptionT.simulateQ_failure, + StateT.run_map, support_bind, + support_map, Set.mem_iUnion, Set.mem_image, Prod.mk.injEq, Prod.exists, exists_eq_right_right, + exists_and_right, exists_and_left, exists_prop] + obtain ⟨output_final_guard, output_state_final_guard, exists_c_last, + h_final_yield_support_mem⟩ := h_support + -- c_last is the yielded folded value from the last inner iteration (i.e. γ_repetitions-1) + rcases exists_c_last with ⟨c_last, output_state_inner_forIn, ⟨h_mem_forIn_support, + h_mem_final_guard_support⟩⟩ + conv at h_mem_forIn_support => + simp only [Function.comp_def, simulateQ_pure, pure_bind] + rw [OptionT.simulateQ_forIn] + rw [OptionT.simulateQ_forIn_stateful_comp] + -- Bridge to the `OptionT` path lemma: extract a successful `c_last` from support. + obtain ⟨c_last_val, h_c_last_eq_some⟩ : ∃ c_last_val : L, c_last = some c_last_val := by + cases h_c : c_last with + | none => + exfalso + simp only [MessageIdx, h_c, Message, simulateQ_pure] at h_mem_final_guard_support + erw [support_pure] at h_mem_final_guard_support + simp only [Set.mem_singleton_iff, Prod.mk.injEq] at h_mem_final_guard_support + obtain ⟨h_guard_none, _⟩ := h_mem_final_guard_support + have h_final_mem := h_final_yield_support_mem + simp only [h_guard_none, simulateQ_pure] at h_final_mem + erw [support_pure] at h_final_mem + simp only [Set.mem_singleton_iff, Prod.mk.injEq, reduceCtorEq, false_and] at h_final_mem + | some a => + exact ⟨a, rfl⟩ + have h_mem_forIn_support_some := by + simpa [h_c_last_eq_some] using h_mem_forIn_support + have h_ϑ_pos : ϑ > 0 := by exact Nat.pos_of_neZero ϑ + have h_ϑ_le_ℓ : ϑ ≤ ℓ := by apply Nat.le_of_dvd (by exact Nat.pos_of_neZero ℓ) (hdiv.out) + have h_ℓ_div_ϑ_ge_1 : ℓ/ϑ ≥ 1 := by exact (Nat.one_le_div_iff h_ϑ_pos).mpr h_ϑ_le_ℓ + have h_0_lt : 0 < (ℓ / ϑ) := by omega + have h_ℓ_div_mul_eq_ℓ : (ℓ / ϑ) * ϑ = ℓ := Nat.div_mul_cancel hdiv.out + have h_lastOraclePosIdx_mul_add : + (getLastOraclePositionIndex ℓ ϑ (Fin.last ℓ)).val * ϑ + ϑ = ℓ := by + conv_rhs => rw [←h_ℓ_div_mul_eq_ℓ] + rw [getLastOraclePositionIndex_last]; simp only + rw [Nat.sub_mul, Nat.one_mul]; rw [Nat.sub_add_cancel (by rw [h_ℓ_div_mul_eq_ℓ]; omega)] + -- **Applying indutive relation inference** for the inner `forIn` only + let Rel' := fun (i : Fin ((List.finRange (ℓ / ϑ)).length + 1)) (c_next : Option L) (_s : σ) => + -- state i => at the end of the inner repetition `i-1` + -- which means at `i = 0`, value = True since nothing meaningful to check + logical_stepCondition 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (oStmt := oStmtIn) + (k := ⟨i - 1, by + have hi := i.isLt; + simp only [List.length_finRange] at hi; omega + ⟩) (v := v) (stmt := stmtIn) (final_constant := stmtIn.final_constant) + ∧ ( + if hi : i > 0 then + have hi_lt := i.isLt; + have hi_lt₂ : i - 1 < ℓ / ϑ := by + simp only [List.length_finRange] at hi_lt; omega + let k : Fin (ℓ / ϑ) := ⟨i - 1, by omega⟩ + -- **NOTE**: At the end of repetition `k = i-1`, the value c_next which is + -- the evaluation on `S^{(k+1)*ϑ}` of the folded oracle function must be computed + -- let point := getChallengeSuffix 𝔽q β (List.finRange (ℓ / ϑ))[↑k] v; fiber_vec.get + let point := getChallengeSuffix 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (v := v) (k := k) + let fiber_vec : Fin (2 ^ ϑ) → L := logical_queryFiberPoints 𝔽q β oStmtIn k v + let output_of_iteration_k : L := + (single_point_localized_fold_matrix_form 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := ⟨k.val * ϑ, by + exact lt_r_of_lt_ℓ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (h := k_mul_ϑ_lt_ℓ (k := k)) + ⟩) (steps := ϑ) (destIdx := ⟨k.val * ϑ + ϑ, by + apply lt_r_of_le_ℓ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + exact k_succ_mul_ϑ_le_ℓ_₂ (k := k) + ⟩) (h_destIdx := by + simp only [Nat.add_right_cancel_iff]) + (h_destIdx_le := k_succ_mul_ϑ_le_ℓ_₂ (k := k)) + (r_challenges := fun j ↦ stmtIn.challenges ⟨↑k * ϑ + ↑j, by + simp only [Fin.val_last] + have h_le : k.val * ϑ + ϑ ≤ ℓ := k_succ_mul_ϑ_le_ℓ_₂ (k := k) + omega + ⟩) + (y := point) (fiber_eval_mapping := fiber_vec)) + some output_of_iteration_k = c_next + else True) + have h_ϑ_pos : ϑ > 0 := Nat.pos_of_neZero ϑ + -- inductive relation inference for the intermediate folding steps + have h_inductive_relations := _root_.OptionT.exists_rel_path_of_mem_support_forIn_stateful.{0} + (spec := []ₒ) (l := List.finRange (ℓ / ϑ)) (init := 0) (σ := σ) + (s := state_pre) (res := (c_last_val, output_state_inner_forIn)) + (h_mem := h_mem_forIn_support_some) (rel := Rel') (h_start := by + simp only [logical_stepCondition, logical_checkSingleFoldingStep, gt_iff_lt, + CanonicallyOrderedAdd.mul_pos, tsub_pos_iff_lt, dite_else_true, Fin.val_last, + Fin.coe_ofNat_eq_mod, List.length_finRange, Nat.zero_mod, zero_tsub, h_0_lt, ↓reduceDIte, + not_lt_zero', false_and, zero_mul, Fin.mk_zero', IsEmpty.forall_iff, lt_self_iff_false, + zero_add, and_self, Rel']; + ) + (h_step := by + intro k (c_cur : L) (s_curr : σ) h_rel_k res_step h_res_step_mem + -- c_cur is the yielded folded value from the previous inner iteration (i.e. k-1) + have h_k := k.isLt + simp only [List.length_finRange] at h_k + have h_k_succ_sub_1_lt : k.succ.val - 1 < ℓ / ϑ := by + simp only [Fin.val_succ, add_tsub_cancel_right]; omega + have h_k_sub_1_lt : k.val - 1 < ℓ / ϑ := by + omega + have h_k_succ_gt_0 : k.succ > 0 := by simp only [gt_iff_lt, Fin.succ_pos] + dsimp only [Rel', logical_stepCondition] at h_rel_k + simp only [Fin.val_castSucc, h_k_sub_1_lt, ↓reduceDIte] at h_rel_k + -- **Nested simulateQ structure** (do not simp the outer impl): + -- • Outer: `simulateQ impl (...)` comes from RoundByRound's toFun_full: the reduction runs + -- the verifier with a stateful oracle impl (black box). We do NOT unfold impl; we only + -- use that its support equals the spec (support_simulateQ_run'_eq). + -- • Inner: `simulateQ (simOracle2 []ₒ oStmtIn tr.messages) (...)` comes + -- from OracleVerifier.toVerifier (Basic.lean): verifier checks are run with + -- simOracle2 so oStmtIn and transcript answer the oracle queries. This inner layer + -- can be simplified further (unfold checkSingleFoldingStep, use simOracle2 lemmas). + set inner_base : OracleComp []ₒ (Option L) := + simulateQ (OracleInterface.simOracle2 []ₒ oStmtIn tr.messages) + (checkSingleFoldingStep 𝔽q β γ_repetitions ((List.finRange (ℓ / ϑ)).get k) + c_cur v stmtIn).run + set inner_oa : OptionT (OracleComp []ₒ) (ForInStep L) := + ForInStep.yield <$> (OptionT.mk inner_base) + have h_run'_supp_eq := OptionT.support_run_simulateQ_run'_eq (impl := impl) + (oa := inner_oa) + (s := s_curr) + (hImplSupp := by simp only [Set.fmap_eq_image, IsEmpty.forall_iff, implies_true]) + -- res_step ∈ (run s).support → res_step.1 ∈ (run' s).support = inner_oa.support + have h_fst_mem : + some res_step.1 ∈ ((simulateQ impl + inner_oa).run' s_curr).support := by + have h_run_mem : + (some res_step.1, res_step.2) ∈ + ((simulateQ impl inner_oa).run s_curr).support := by + simpa [inner_oa, inner_base, h_v] using h_res_step_mem + simp only [StateT.run', support_map, Set.mem_image] + exact ⟨(some res_step.1, res_step.2), h_run_mem, rfl⟩ + rw [h_run'_supp_eq] at h_fst_mem + have h_fst_mem_opt : res_step.1 ∈ support (inner_oa) := by + exact (OptionT.mem_support_iff (mx := inner_oa) (x := res_step.1)).2 h_fst_mem + have h_inner_step_mem : + ∃ c_next, + (some c_next) ∈ support inner_base ∧ ForInStep.yield c_next = res_step.1 := by + rcases (OptionT.mem_support_OptionT_map_some + (ma := OptionT.mk inner_base) (f := ForInStep.yield) (y := res_step.1)).1 + h_fst_mem_opt with + ⟨c_next, h_c_next_mem_mk, h_yield_eq⟩ + exact ⟨c_next, (OptionT.mem_support_mk (mx := inner_base) (x := c_next)).1 + h_c_next_mem_mk, h_yield_eq⟩ + rcases h_inner_step_mem with ⟨c_next, h_fst_mem, h_res_step1_eq⟩ + dsimp only [Rel', logical_stepCondition] + dsimp only [inner_base] at h_fst_mem + unfold checkSingleFoldingStep at h_fst_mem + erw [simulateQ_bind] at h_fst_mem + erw [simulateQ_bind, support_bind] at h_fst_mem + dsimp only [OptionT.run] at h_fst_mem + simp only [Set.mem_iUnion, exists_prop] at h_fst_mem + rcases h_fst_mem with ⟨fiber_vec_opt, h_fiber_vec_opt_mem_support, h_c_k_mem_output⟩ + have h_probFailure_queryFiberPoints_eq_zero := probFailure_simulateQ_queryFiberPoints_eq_zero + (𝔽q := 𝔽q) (β := β) (γ_repetitions := γ_repetitions) (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (so := OracleInterface.simOracle2 []ₒ oStmtIn tr.messages) (k := k) (v := v) + have h_probOutput_none_queryFiberPoints_eq_zero := + OptionT.probOutput_none_run_eq_zero_of_probFailure_eq_zero + (hfail := h_probFailure_queryFiberPoints_eq_zero) + have h_fiber_vec_opt_mem_support_run : + fiber_vec_opt ∈ + (simulateQ (OracleInterface.simOracle2 []ₒ oStmtIn tr.messages) + (queryFiberPoints 𝔽q β (γ_repetitions := γ_repetitions) (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ((List.finRange (ℓ / ϑ)).get k) + v)).support := by + have h_fiber_vec_opt_mem_support' := h_fiber_vec_opt_mem_support + simp only [queryFiberPoints, support_bind, + Set.mem_iUnion, exists_prop] at h_fiber_vec_opt_mem_support' ⊢ + rcases h_fiber_vec_opt_mem_support' with ⟨i, h_i_mem, h_i_out⟩ + have h_eq : fiber_vec_opt = i := by + cases i <;> + simpa [simulateQ_pure, support_pure, Set.mem_singleton_iff] using h_i_out + subst h_eq + simpa [bind_pure_comp] using h_i_mem + have h_fiber_vec_opt_eq_some := exists_eq_some_of_mem_support_of_probOutput_none_eq_zero + (x := fiber_vec_opt) (hx := h_fiber_vec_opt_mem_support_run) + (hnone := h_probOutput_none_queryFiberPoints_eq_zero) + rcases h_fiber_vec_opt_eq_some with ⟨fiber_vec, h_fiber_vec_opt_eq_some⟩ + rw [h_fiber_vec_opt_eq_some] at h_fiber_vec_opt_mem_support_run h_c_k_mem_output + have h_fiber_val := mem_support_queryFiberPoints 𝔽q β γ_repetitions + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (oraclePositionIdx := ⟨k, h_k⟩) (v := v) + (f_i_on_fiber := fiber_vec) (stmtIn := stmtIn) (oStmtIn := oStmtIn) + (witIn := ()) (challenges := tr.challenges) + (h_fiber_mem := by + dsimp only [queryPhaseLogicStep] + have h_transcript : (FullTranscript.mk1 (pSpec := pSpecQuery 𝔽q β γ_repetitions + (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) (tr.challenges ⟨0, rfl⟩)).messages + = tr.messages := by + -- funext j + simp only [MessageIdx, Fin.isValue, FullTranscript.mk1_eq_snoc] + unfold FullTranscript.messages Transcript.concat + funext x + obtain ⟨i, hi⟩ := x; fin_cases i; simp at hi + rw [h_transcript] + have h_k_fin_eq : (List.finRange (ℓ / ϑ)).get k = ⟨k, h_k⟩ := by + apply Fin.eq_of_val_eq + simp only [List.get_eq_getElem, List.getElem_finRange, Fin.eta, Fin.val_cast] + simpa only [MessageIdx, List.get_eq_getElem, List.getElem_finRange, Fin.eta] using + h_fiber_vec_opt_mem_support_run + ) + simp only at h_fiber_val + have h_fiber_val_eq : fiber_vec.get = fun (fiberIndex : Fin (2 ^ ϑ)) => oStmtIn ⟨k.val, by + simp only [toOutCodewordsCount_last]; omega⟩ + (getFiberPoint 𝔽q β ⟨↑k, h_k⟩ v fiberIndex) := by + funext fiberIndex + exact h_fiber_val fiberIndex + simp only [h_fiber_val] at h_c_k_mem_output + simp only [h_k_succ_sub_1_lt, h_k_succ_gt_0, ↓reduceDIte] + -- ⊢ logical_checkSingleFoldingStep 𝔽q β oStmtIn ⟨↑k.succ - 1, ⋯⟩ v stmtIn + dsimp only [logical_checkSingleFoldingStep] + by_cases h_k_gt_0 : k.val > 0 + · have h_gt : (k.succ.val - 1) * ϑ > 0 := by + have hk' : k.succ.val - 1 > 0 := by + simpa [Fin.val_succ, add_tsub_cancel_right] using h_k_gt_0 + exact Nat.mul_pos hk' h_ϑ_pos + simp only [MessageIdx, List.get_eq_getElem, List.getElem_finRange, Fin.eta, Fin.val_cast, + gt_iff_lt, h_k_gt_0, mul_pos_iff_of_pos_left, h_ϑ_pos, ↓reduceDIte, Message, guard_eq, + Fin.val_last, bind_pure_comp, OptionT.simulateQ_map] at h_c_k_mem_output + erw [simulateQ_ite] at h_c_k_mem_output + set V_check := (c_cur = oStmtIn ⟨k, by + simp only [toOutCodewordsCount_last]; omega⟩ ( + (getFiberPoint 𝔽q β ⟨↑k, h_k⟩ v (extractMiddleFinMask 𝔽q β v ⟨k.val * ϑ, by + have h := oracle_index_le_ℓ (i := Fin.last ℓ) + (j := ⟨k, by simpa only [toOutCodewordsCount_last] using h_k⟩) + simp only at h; omega⟩ ϑ)) + )) with h_V_check_def + have h_V_check_passed : V_check := by + by_contra h_V_check_false + rw [h_V_check_def] at h_V_check_false + simp only [h_V_check_false, ↓reduceIte, OptionT.simulateQ_failure, OptionT.map_failure, + OptionT.support_failure_run, Set.mem_singleton_iff, reduceCtorEq] at h_c_k_mem_output + rw [h_V_check_def] at h_V_check_passed + simp only [h_V_check_passed, ↓reduceIte] at h_c_k_mem_output + erw [simulateQ_pure, support_bind] at h_c_k_mem_output + simp only [support_pure, Set.mem_singleton_iff, Function.comp_apply, + Set.iUnion_iUnion_eq_left, OptionT.support_OptionT_pure_run, + Option.some.injEq] at h_c_k_mem_output + -- dsimp only [Functor.map] at h_c_k_mem_output + have h_k_cast_gt_0 : 0 < k.castSucc := by + simpa [gt_iff_lt, Fin.val_castSucc] using h_k_gt_0 + simp only [gt_iff_lt, h_k_cast_gt_0, ↓reduceDIte, Fin.val_last, + Option.some.injEq] at h_rel_k + simp only [h_gt, ↓reduceDIte] + simp only [Fin.val_succ, add_tsub_cancel_right] + -- Goal: LHS = RHS. We have h_c_k_mem_output.1 : b = (RHS as oStmtIn ... getFiberPoint ...). + conv_rhs => dsimp only [logical_queryFiberPoints]; + dsimp only [logical_queryFiberPoints] + -- ⊢ logical_computeFoldedValue 𝔽q β ⟨↑k - 1, ⋯⟩ v stmtIn (logical_queryFiberPoints 𝔽q β + -- oStmtIn ⟨↑k - 1, ⋯⟩ v) = oStmtIn ⟨↑k, ⋯⟩ (getFiberPoint 𝔽q β ⟨↑k, ⋯⟩ v + -- (extractMiddleFinMask 𝔽q β v ⟨↑k * ϑ, ⋯⟩ ϑ)) + dsimp only [logical_computeFoldedValue, logical_queryFiberPoints] + constructor + · -- V check in the current iteration passes + rw [←h_V_check_passed] + -- rw previous computation of c_cur (in previous iteration) + simp only [Fin.val_last, h_rel_k.2.symm] + rfl + · -- prove equality relation for the output of the current iteration (i.e. c_next) + simp only [ForInStep.state] + rw [h_c_k_mem_output] at h_res_step1_eq + rw [h_res_step1_eq.symm] + dsimp only [ForInStep.state] + rw [h_fiber_val_eq] + simp only [Nat.add_one_sub_one, Fin.val_last] + have h_k_fin_eq : (List.finRange (ℓ / ϑ)).get k = ⟨k, by omega⟩ := by + apply Fin.eq_of_val_eq; + simp only [List.get_eq_getElem, List.getElem_finRange, Fin.eta, Fin.val_cast] + let destIdx : Fin r := ⟨k.val * ϑ + ϑ, by + apply lt_r_of_le_ℓ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + have h_le : k.val * ϑ + ϑ ≤ ℓ := by + exact oracle_index_add_steps_le_ℓ (ℓ := ℓ) (ϑ := ϑ) (i := Fin.last ℓ) + (j := ⟨k.val, by simpa [toOutCodewordsCount_last] using h_k⟩) + exact h_le + ⟩ + conv_lhs => rw [single_point_localized_fold_matrix_form_congr_dest_index 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (destIdx' := destIdx) (h_destIdx_eq_destIdx' := by + simp only [add_tsub_cancel_right]; dsimp only [destIdx])] + conv_rhs => rw [single_point_localized_fold_matrix_form_congr_dest_index 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (destIdx' := destIdx) (h_destIdx_eq_destIdx' := by + simp only [List.getElem_finRange, Fin.eta, Fin.val_cast]; dsimp only [destIdx])] + congr 1; congr 1; + -- only challenges equality left + simp only [Nat.add_one_sub_one, cast_eq] + dsimp only [getChallengeSuffix] + apply extractSuffixFromChallenge_congr_destIdx 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (h_idx_eq := by + simp only [List.getElem_finRange, Fin.eta, Fin.val_cast]) (h_le := by + have h_main : k.val * ϑ + ϑ ≤ ℓ := by + exact oracle_index_add_steps_le_ℓ (ℓ := ℓ) (ϑ := ϑ) (i := Fin.last ℓ) + (j := ⟨k.val, by simpa [toOutCodewordsCount_last] using h_k⟩) + simpa [Fin.val_succ, add_tsub_cancel_right] using h_main + ) (h_le' := by + have h_main : k.val * ϑ + ϑ ≤ ℓ := by + exact oracle_index_add_steps_le_ℓ (ℓ := ℓ) (ϑ := ϑ) (i := Fin.last ℓ) + (j := ⟨k.val, by simpa [toOutCodewordsCount_last] using h_k⟩) + simpa [List.getElem_finRange, Fin.eta, Fin.val_cast] using h_main + ) + · have h_ne_gt : ¬ ((k.succ.val - 1) * ϑ > 0) := by + intro h_gt + have h_mul_pos : k.val * ϑ > 0 := by + simpa [Fin.val_succ, add_tsub_cancel_right] using h_gt + have hk_pos : k.val > 0 := by + exact Nat.pos_of_mul_pos_right h_mul_pos + exact h_k_gt_0 hk_pos + simp only [h_ne_gt, ↓reduceDIte, true_and] + simp only [Fin.val_succ, add_tsub_cancel_right, Nat.add_one_sub_one, Fin.val_last, + Option.some.injEq] + -- ⊢ single_point_localized_fold_matrix_form 𝔽q β ⟨(↑k.succ - 1) * ϑ, ⋯⟩ ϑ ⋯ ⋯ + -- (fun j ↦ stmtIn.challenges ⟨(↑k.succ - 1) * ϑ + ↑j, ⋯⟩) + -- (getChallengeSuffix 𝔽q β ⟨↑k.succ - 1, ⋯⟩ v) + -- (logical_queryFiberPoints 𝔽q β oStmtIn ⟨↑k.succ - 1, ⋯⟩ v) = + -- res_step.1.state + simp only [MessageIdx, List.get_eq_getElem, List.getElem_finRange, Fin.eta, Fin.val_cast, + gt_iff_lt, CanonicallyOrderedAdd.mul_pos, h_k_gt_0, false_and, ↓reduceDIte, Message, + Fin.val_last, bind_pure_comp, LawfulApplicative.map_pure] at h_c_k_mem_output + erw [simulateQ_pure, support_pure] at h_c_k_mem_output + simp only [Set.mem_singleton_iff, Option.some.injEq] at h_c_k_mem_output + rw [h_c_k_mem_output] at h_res_step1_eq + rw [h_res_step1_eq.symm] + dsimp only [ForInStep.state] + dsimp only [logical_queryFiberPoints] + rw [h_fiber_val_eq] + have h_k_fin_eq : (List.finRange (ℓ / ϑ)).get k = ⟨k, by omega⟩ := by + apply Fin.eq_of_val_eq; + simp only [List.get_eq_getElem, List.getElem_finRange, Fin.eta, Fin.val_cast] + let destIdx : Fin r := ⟨k.val * ϑ + ϑ, by + apply lt_r_of_le_ℓ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + have h_le : k.val * ϑ + ϑ ≤ ℓ := by + exact oracle_index_add_steps_le_ℓ (ℓ := ℓ) (ϑ := ϑ) (i := Fin.last ℓ) + (j := ⟨k.val, by simpa [toOutCodewordsCount_last] using h_k⟩) + exact h_le + ⟩ + conv_lhs => rw [single_point_localized_fold_matrix_form_congr_dest_index 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (destIdx' := destIdx) (h_destIdx_eq_destIdx' := by + apply Fin.eq_of_val_eq; + simp only; dsimp only [destIdx])] + conv_rhs => rw [single_point_localized_fold_matrix_form_congr_dest_index 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (destIdx' := destIdx) (h_destIdx_eq_destIdx' := by + simp only [List.getElem_finRange, Fin.eta, Fin.val_cast]; dsimp only [destIdx])] + congr 1; + -- only challenges equality left + simp only [cast_eq] + dsimp only [getChallengeSuffix] + apply extractSuffixFromChallenge_congr_destIdx 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (h_idx_eq := by + simp only [List.getElem_finRange, Fin.eta, Fin.val_cast]) (h_le := by + have h_main : k.val * ϑ + ϑ ≤ ℓ := by + exact oracle_index_add_steps_le_ℓ (ℓ := ℓ) (ϑ := ϑ) (i := Fin.last ℓ) + (j := ⟨k.val, by simpa [toOutCodewordsCount_last] using h_k⟩) + simpa [Fin.val_succ, add_tsub_cancel_right] using h_main + ) (h_le' := by + have h_main : k.val * ϑ + ϑ ≤ ℓ := by + exact oracle_index_add_steps_le_ℓ (ℓ := ℓ) (ϑ := ϑ) (i := Fin.last ℓ) + (j := ⟨k.val, by simpa [toOutCodewordsCount_last] using h_k⟩) + simpa [List.getElem_finRange, Fin.eta, Fin.val_cast] using h_main + ) + ) + (h_yield := by + intro k c_cur s_curr res_step h_res_step_mem + -- erw [OptionT.support_run] at h_res_step_mem + erw [simulateQ_bind] at h_res_step_mem + erw [simulateQ_bind, support_bind] at h_res_step_mem + dsimp only [OptionT.run] at h_res_step_mem + simp only [MessageIdx, Fin.isValue, Message, + Set.mem_iUnion, exists_prop, Prod.exists] at h_res_step_mem + rcases h_res_step_mem with + ⟨c_next_opt, output_state_next, _h_mem_support_cur_folding_step, h_res_step_mem_yield⟩ + cases h_c : c_next_opt with + | none => + have h_res_step1_mem : + some res_step.1 ∈ support (m := OracleComp []ₒ) (α := Option (ForInStep L)) + (simulateQ (OracleInterface.simOracle2 []ₒ oStmtIn tr.messages) + (pure (none : Option (ForInStep L)))) := by + have h_proj_mem : + some res_step.1 ∈ Prod.fst <$> support (m := ProbComp) + (α := Option (ForInStep L) × σ) + ((simulateQ impl + (simulateQ (OracleInterface.simOracle2 []ₒ oStmtIn tr.messages) + (pure (none : Option (ForInStep L))))) output_state_next) := by + refine ⟨(some res_step.1, res_step.2), ?_, rfl⟩ + simpa only [MessageIdx, simulateQ_pure, h_c] using h_res_step_mem_yield + have h_proj_eq := support_run_simulateQ_run_fst_eq (impl := impl) + (oa := simulateQ (OracleInterface.simOracle2 []ₒ oStmtIn tr.messages) + (pure (none : Option (ForInStep L)))) + (s := output_state_next) + (hImplSupp := by simp only [Set.fmap_eq_image, IsEmpty.forall_iff, implies_true]) + rw [h_proj_eq] at h_proj_mem + exact h_proj_mem + simp only [simulateQ_pure, support_pure, Set.mem_singleton_iff] at h_res_step1_mem + cases h_res_step1_mem + | some next => + have h_res_step1_mem : + some res_step.1 ∈ support (m := OracleComp []ₒ) (α := Option (ForInStep L)) + (simulateQ (OracleInterface.simOracle2 []ₒ oStmtIn tr.messages) + (pure (some (ForInStep.yield next)))) := by + have h_proj_mem : + some res_step.1 ∈ Prod.fst <$> support (m := ProbComp) + (α := Option (ForInStep L) × σ) + ((simulateQ impl + (simulateQ (OracleInterface.simOracle2 []ₒ oStmtIn tr.messages) + (pure (some (ForInStep.yield next))))) output_state_next) := by + refine ⟨(some res_step.1, res_step.2), ?_, rfl⟩ + simpa [h_c] using h_res_step_mem_yield + have h_proj_eq := support_run_simulateQ_run_fst_eq (impl := impl) + (oa := simulateQ (OracleInterface.simOracle2 []ₒ oStmtIn tr.messages) + (pure (some (ForInStep.yield next)))) + (s := output_state_next) + (hImplSupp := by simp only [Set.fmap_eq_image, IsEmpty.forall_iff, implies_true]) + rw [h_proj_eq] at h_proj_mem + exact h_proj_mem + simp only [simulateQ_pure, support_pure, Set.mem_singleton_iff] at h_res_step1_mem + exact ⟨next, by simpa using h_res_step1_mem⟩ + ) + -- extract the final guard relation from h_c_last_mem + set v_challenge := (FullTranscript.mk1 (pSpec := pSpecQuery 𝔽q β γ_repetitions + (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) (tr.challenges ⟨0, rfl⟩)).challenges ⟨0, rfl⟩ + with h_v_challenge + intro (k : Fin (ℓ / ϑ + 1)) + dsimp only [logical_stepCondition] + by_cases h_k_lt : ↑k < ℓ / ϑ + · simp only [h_k_lt, ↓reduceDIte] + have h_pred_lt : k.val + 1 - 1 < ℓ / ϑ := by omega + have res := h_inductive_relations.2 + -- 1. Unpack the existence proof + rcases res with ⟨bs, ss, h_init, h_s_init, h_final_b, h_final_s, h_steps, h_rel_all⟩ + -- 2. Specialize the relation for the 'input' to the k-th iteration + -- Since k : Fin (ℓ / ϑ), it can be cast into Fin (ℓ / ϑ + 1) + have h_rel_for_k_th_level_guard := h_rel_all ⟨k + 1, by simp only [List.length_finRange]; omega⟩ + dsimp only [Rel', checkSingleRepetition_foldRel] at h_rel_for_k_th_level_guard + have h_res := h_rel_for_k_th_level_guard + simp only [logical_stepCondition, h_pred_lt, ↓reduceDIte, gt_iff_lt, Fin.val_last, + dite_else_true] at h_res + -- rw [h_v] at h_res + exact h_res.1 + · simp only [h_k_lt, ↓reduceDIte] + -- ⊢ logical_computeFoldedValue 𝔽q β ⟨ℓ / ϑ - 1, ⋯⟩ v stmtIn + -- (logical_queryFiberPoints 𝔽q β oStmtIn ⟨ℓ / ϑ - 1, ⋯⟩ v) = stmtIn.final_constant + have h_last_guard_relation := h_inductive_relations.1.2 + dsimp only [Rel', Rel, checkSingleRepetition_foldRel] at h_last_guard_relation + simp only [List.length_finRange, gt_iff_lt, Fin.val_last, + dite_else_true] at h_last_guard_relation + have h_lt : 0 < (⟨ℓ/ϑ, by simp only [List.length_finRange, lt_add_iff_pos_right, + zero_lt_one]⟩ : Fin ((List.finRange (ℓ / ϑ)).length + 1)) := by + change (0 : ℕ) < (ℓ / ϑ) + exact h_0_lt + dsimp only [logical_computeFoldedValue] + simp only [h_lt, forall_true_left] at h_last_guard_relation + obtain ⟨rfl⟩ := h_c_last_eq_some + simp only [Option.some.injEq] at h_last_guard_relation + simp only [MessageIdx, h_last_guard_relation.symm, Message] at h_mem_final_guard_support + erw [simulateQ_ite, simulateQ_ite, simulateQ_pure, simulateQ_pure] at h_mem_final_guard_support + have h_dest_le_final : (ℓ / ϑ - 1) * ϑ + ϑ ≤ ℓ := by + have h_dest_eq_final : (ℓ / ϑ - 1) * ϑ + ϑ = ℓ := by + calc + (ℓ / ϑ - 1) * ϑ + ϑ = ((ℓ / ϑ - 1) + 1) * ϑ := by + rw [Nat.add_mul, Nat.one_mul] + _ = (ℓ / ϑ) * ϑ := by + rw [Nat.sub_add_cancel (Nat.succ_le_of_lt h_0_lt)] + _ = ℓ := h_ℓ_div_mul_eq_ℓ + exact le_of_eq h_dest_eq_final + let destIdx : Fin r := ⟨(ℓ / ϑ - 1) * ϑ + ϑ, by + apply lt_r_of_le_ℓ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + exact h_dest_le_final + ⟩ + set fiber_vec := logical_queryFiberPoints 𝔽q β oStmtIn ⟨ℓ / ϑ - 1, by omega⟩ v + with h_fiber_vec_def + set single_point_localized_fold_matrix_form_val := + single_point_localized_fold_matrix_form 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + _ _ _ _ _ _ _ with h_single_point_localized_fold_matrix_form_val_def + conv at h_mem_final_guard_support => + rw [support_StateT_ite_apply] + erw [support_pure, support_pure] + enter [1] + -- dsimp only [getChallengeSuffix] + rw [h_last_guard_relation] + -- h_final_yield_support_mem + -- erw [extractSuffixFromChallenge_congr_destIdx 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + -- (h_idx_eq := by simp only [List.length_finRange]) (h_le := by exact h_dest_le_final) + -- (h_le' := by simpa [List.length_finRange] using h_dest_le_final)] + -- change single_point_localized_fold_matrix_form_val + have h_final_check_passed : c_last_val = stmtIn.final_constant := by + by_contra h_neq + simp only [h_neq, ↓reduceIte, Set.mem_singleton_iff, + Prod.mk.injEq] at h_mem_final_guard_support + -- h_mem_final_guard_support : + -- output_final_guard = none ∧ output_state_final_guard = output_state_inner_forIn + simp only [h_mem_final_guard_support, simulateQ_pure] at h_final_yield_support_mem + erw [support_pure] at h_final_yield_support_mem + simp only [Set.mem_singleton_iff, Prod.mk.injEq, reduceCtorEq, + false_and] at h_final_yield_support_mem + simp only [h_final_check_passed, ↓reduceIte, Set.mem_singleton_iff, + Prod.mk.injEq] at h_mem_final_guard_support -- pure equalities now + -- h_mem_final_guard_support : + -- output_final_guard = some () ∧ output_state_final_guard = output_state_inner_forIn + rw [←h_final_check_passed] + rw [←h_last_guard_relation] + dsimp only [single_point_localized_fold_matrix_form_val] + conv_lhs => + rw [single_point_localized_fold_matrix_form_congr_dest_index 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (destIdx' := destIdx) + (h_destIdx_eq_destIdx' := by dsimp only [destIdx]) (fiber_eval_mapping := fiber_vec)] + conv_rhs => rw [single_point_localized_fold_matrix_form_congr_dest_index 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (destIdx' := destIdx) (h_destIdx_eq_destIdx' := by + simp only [List.length_finRange]; dsimp only [destIdx]) (fiber_eval_mapping := fiber_vec)] + congr 1 + -- only challenges equality left + simp only [cast_eq] + dsimp only [getChallengeSuffix] + apply extractSuffixFromChallenge_congr_destIdx 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (h_idx_eq := by + simp only [List.length_finRange]) (h_le := by + exact h_dest_le_final) (h_le' := by + simpa [List.length_finRange] using h_dest_le_final) + +/-! Main lemma connecting verifier support to logical proximity checks. + This is the key lemma used in toFun_full of queryKnowledgeStateFunction. + The left side matches the hypothesis from StateT.run characterization: + (stmtOut, oStmtOut) ∈ support ((fun x ↦ x.1) <$> simulateQ impl (Verifier.run ...) s) + The right side gives us: + 1. stmtOut = true + 2. oStmtOut = mkVerifierOStmtOut ... + 3. ∀ rep, logical_checkSingleRepetition ... (the proximity checks spec) +-/ +omit [CharP L 2] [SampleableType L] in +lemma logical_consistency_checks_passed_of_mem_support_V_run {σ : Type} + (impl : QueryImpl []ₒ (StateT σ ProbComp)) + (stmtIn : FinalSumcheckStatementOut) + (oStmtIn : ∀ j, OracleStatement 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (Fin.last ℓ) j) + (tr : FullTranscript (pSpecQuery 𝔽q β γ_repetitions (h_ℓ_add_R_rate := h_ℓ_add_R_rate))) + (s : σ) (stmtOut : Bool) (oStmtOut : Empty → Unit) + (h_mem_V_run_support : + (stmtOut, oStmtOut) ∈ + OptionT.mk (Prod.fst <$> ((simulateQ.{0, 0, 0} impl + (Verifier.run (stmtIn, oStmtIn) tr + (queryOracleVerifier 𝔽q β (ϑ := ϑ) γ_repetitions + (h_ℓ_add_R_rate := h_ℓ_add_R_rate)).toVerifier)) : + StateT σ ProbComp (Option (Bool × (Empty → Unit)))).run s).support) : + (stmtOut = true ∧ + oStmtOut = OracleVerifier.mkVerifierOStmtOut + (embed := (queryOracleVerifier 𝔽q β (ϑ := ϑ) γ_repetitions + (h_ℓ_add_R_rate := h_ℓ_add_R_rate)).embed) + (hEq := (queryOracleVerifier 𝔽q β (ϑ := ϑ) γ_repetitions + (h_ℓ_add_R_rate := h_ℓ_add_R_rate)).hEq) oStmtIn tr ∧ + ∀ (rep : Fin γ_repetitions), + logical_checkSingleRepetition 𝔽q β oStmtIn + (tr.challenges ⟨0, rfl⟩ rep) stmtIn stmtIn.final_constant) := by + -- dsimp only [OptionT.mk] at h_mem_V_run_support + conv at h_mem_V_run_support => + dsimp only [Verifier.run, OracleVerifier.toVerifier, queryOracleVerifier] + dsimp only [queryPhaseLogicStep] + -- Simplify the `(fun x ↦ x.1) <$> ...` part + -- Group the last two `bind` + rw [pure_bind]; rw [bind_assoc]; rw [pure_bind] + -- Distribute `simulateQ` over the `bind` + erw [simulateQ_bind, simulateQ_bind, simulateQ_bind] + -- Resolve the constant mappings + simp only [Function.comp_def, simulateQ_pure, pure_bind] + rw [OptionT.simulateQ_forIn] + rw [OptionT.simulateQ_forIn_stateful_comp] + conv at h_mem_V_run_support => + -- rw [simulateQ_forIn_stateful_comp (impl := impl) + -- (l := List.finRange γ_repetitions) (init := PUnit.unit)] + erw [OptionT.support_mk] + erw [support_map] + erw [Set.mem_image] + erw [support_bind] + enter [1, x] + simp only [MessageIdx, Message, Fin.isValue, FullTranscript.mk1_eq_snoc, bind_pure_comp, + OptionT.simulateQ_map, id_map', Set.mem_iUnion, + exists_prop, Prod.exists] + obtain ⟨x, hx_mem, hx_1_eq_stmtOut_oStmtOut⟩ := h_mem_V_run_support + -- Note: hx_mem now refers to the exact simulateQ (forIn ...) block + -- after the conv with OptionT.simulateQ_forIn + -- The structure is: hx_mem : ∃ a b, (a, b) ∈ (simulateQ impl (forIn ...)).support + -- where the forIn is exactly: forIn (List.finRange γ_repetitions) PUnit.unit (fun a b => ...) + let forIn_body : Fin γ_repetitions → PUnit.{1} → + StateT σ ProbComp (Option (ForInStep PUnit.{1})) := fun (a : Fin γ_repetitions) + (b : PUnit.{1}) => + simulateQ impl ( + (((fun (_ : Unit) ↦ ForInStep.yield PUnit.unit) <$> + ((simulateQ.{0, 0, 0} (impl := OracleInterface.simOracle2 []ₒ oStmtIn tr.messages) + ((checkSingleRepetition 𝔽q β (γ_repetitions := γ_repetitions) (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + ((FullTranscript.mk1 (tr.challenges ⟨0, rfl⟩)).challenges ⟨0, rfl⟩ a) + stmtIn stmtIn.final_constant) : + OptionT (OracleComp + ([]ₒ + ([OracleStatement 𝔽q β ϑ (Fin.last ℓ)]ₒ + + [(pSpecQuery 𝔽q β γ_repetitions).Message]ₒ))) Unit).run) : + OracleComp []ₒ (Option Unit))) : + OptionT (OracleComp []ₒ) (ForInStep PUnit.{1})) + ) + let forIn_block : OptionT (StateT σ ProbComp) PUnit.{1} := + forIn (xs := List.finRange γ_repetitions) (b := PUnit.unit.{1}) (f := forIn_body) + -- let simulateQ_forIn_block := simulateQ impl forIn_block + -- Verify that hx_mem is about the exact simulateQ (forIn ...) block + conv at hx_mem => + enter [1, x, 1, b, 1, 1, 1, 1] + -- Unfold the set definitions to expose the structure + change (forIn_block) + conv at hx_mem => + enter [1, x_1, 1, b, 1] + change ((x_1, b) ∈ (((forIn_block >>= + (fun (u : Option PUnit.{1}) => (_ : StateT σ ProbComp (Option Bool)))) + : StateT σ ProbComp (Option Bool)).run s).support) + rw [OptionT.mem_support_StateT_bind_run (ma := forIn_block) (x := (x_1, b))] + rcases hx_mem with ⟨y, s', h_y_s'_mem_support_forIn_block, h_x_eq⟩ + -- simp only [StateT.run_pure, support_pure, Set.mem_singleton_iff] at h_x_eq -- TODO + have h_y_ne_none : y ≠ none := by + intro h_y_eq_none + simp only [h_y_eq_none, simulateQ_pure] at h_x_eq + erw [support_pure] at h_x_eq + simp only [Set.mem_singleton_iff] at h_x_eq + rw [Prod.mk_inj] at h_x_eq + rw [hx_1_eq_stmtOut_oStmtOut] at h_x_eq + simp only [reduceCtorEq, false_and] at h_x_eq + obtain ⟨y_val, h_y_eq⟩ := Option.ne_none_iff_exists.mp h_y_ne_none + obtain ⟨rfl⟩ := h_y_eq + simp only at h_x_eq + erw [simulateQ_pure, support_pure] at h_x_eq + rw [Set.mem_singleton_iff, Prod.mk_inj] at h_x_eq + -- **Now we have pure equalities of x.1 and x.2** + rcases h_y_s'_mem_support_forIn_block with ⟨z, s'', h_forIn_run_mem, h_pure⟩ + have h_z_ne_none : z ≠ none := by + intro h_z_eq_none + simp only [h_z_eq_none, simulateQ_pure, StateT.run_pure, support_pure, Set.mem_singleton_iff, + Prod.mk.injEq, reduceCtorEq, false_and] at h_pure + obtain ⟨z_val, h_z_eq⟩ := Option.ne_none_iff_exists.mp h_z_ne_none + obtain ⟨rfl⟩ := h_z_eq + erw [simulateQ_pure, support_pure] at h_pure + simp only [Set.mem_singleton_iff, Prod.mk.injEq, Option.some.injEq] at h_pure + -- **h_pure : y_val = true ∧ s' = s''** + dsimp only [forIn_block] at h_forIn_run_mem + -- 1. Apply the extraction lemma + have h_independent_support_mem_exists := OptionT.exists_path_of_mem_support_forIn_unit.{0} + (spec := []ₒ) (l := List.finRange γ_repetitions) (f := forIn_body) (s_init := s) + (s_final := s'') (u := z_val) + (h_yield := by + intro rep s_pre res_step h_res_step_mem + dsimp only [forIn_body] at h_res_step_mem + set oa : OracleComp []ₒ (Option Unit) := + ((simulateQ.{0, 0, 0} (impl := OracleInterface.simOracle2 []ₒ oStmtIn tr.messages) + ((checkSingleRepetition 𝔽q β (γ_repetitions := γ_repetitions) (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + ((FullTranscript.mk1 (tr.challenges ⟨0, rfl⟩)).challenges ⟨0, rfl⟩ rep) + stmtIn stmtIn.final_constant) : + OptionT (OracleComp + ([]ₒ + ([OracleStatement 𝔽q β ϑ (Fin.last ℓ)]ₒ + + [(pSpecQuery 𝔽q β γ_repetitions).Message]ₒ))) Unit).run) : + OracleComp []ₒ (Option Unit)) + have h_fst_mem : some res_step.1 ∈ ((simulateQ impl + ((((fun (_ : Unit) ↦ ForInStep.yield PUnit.unit) <$> oa) : + OptionT (OracleComp []ₒ) (ForInStep PUnit)))).run' s_pre).support := by + rw [StateT.run', support_map] + exact Set.mem_image_of_mem Prod.fst h_res_step_mem + have h_run'_supp_eq := support_simulateQ_run'_eq (impl := impl) + (oa := ((((fun (_ : Unit) ↦ ForInStep.yield PUnit.unit) <$> oa) : + OptionT (OracleComp []ₒ) (ForInStep PUnit)))) + (s := s_pre) + (hImplSupp := by simp only [Set.fmap_eq_image, IsEmpty.forall_iff, implies_true]) + rw [h_run'_supp_eq] at h_fst_mem + erw [OptionT.mem_support_OptionT_run_map_some] at h_fst_mem + obtain ⟨u, _h_u_mem, h_eq⟩ := h_fst_mem + exact h_eq.symm + ) + (h_mem := h_forIn_run_mem) + set γ_challenges : Fin γ_repetitions → + sDomain 𝔽q β h_ℓ_add_R_rate ⟨0, by omega⟩ := tr.challenges ⟨0, rfl⟩ with h_γ_challenges_def + rw [h_pure.1] at h_x_eq + rw [h_x_eq.1] at hx_1_eq_stmtOut_oStmtOut + simp only [Option.some.injEq, Prod.mk.injEq, Bool.true_eq] at hx_1_eq_stmtOut_oStmtOut + constructor + · exact hx_1_eq_stmtOut_oStmtOut.1 + · constructor + · exact hx_1_eq_stmtOut_oStmtOut.2.symm + · -- 2. Quantify over an arbitrary repetition + intro rep + -- ⊢ logical_checkSingleRepetition 𝔽q β oStmtIn (γ_challenges rep) + -- stmtIn stmtIn.final_constant + have h_rep_th_support_mem := h_independent_support_mem_exists rep + (by simp only [List.mem_finRange]) + rcases h_rep_th_support_mem with ⟨state_pre_repetition, state_post_repetition, + h_support_rep_ith_iteration⟩ + exact logical_checkSingleRepetition_of_mem_support_forIn_body 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (γ_repetitions := γ_repetitions) (σ := σ) (impl := impl) + (oStmtIn := oStmtIn) (tr := tr) (stmtIn := stmtIn) (rep := rep) + (state_pre := state_pre_repetition) (forIn_body := forIn_body) (h_forIn_body_eq := rfl) + (h_mem := by + use (ForInStep.yield PUnit.unit, state_post_repetition) + exact h_support_rep_ith_iteration + ) + /-- Strong completeness for the query phase logic step. This proves that for any valid input satisfying `strictFinalSumcheckRelOut`, @@ -1719,15 +2106,12 @@ theorem queryPhaseLogicStep_isStronglyComplete : (queryPhaseLogicStep 𝔽q β (ϑ:=ϑ) γ_repetitions (h_ℓ_add_R_rate := h_ℓ_add_R_rate)).IsStronglyCompleteUnderSimulation := by intro stmtIn witIn oStmtIn challenges h_relIn - let f₀ := getFirstOracle 𝔽q β oStmtIn have h_ϑ_pos : ϑ > 0 := by exact Nat.pos_of_neZero ϑ have h_ϑ_le_ℓ : ϑ ≤ ℓ := by apply Nat.le_of_dvd (by exact Nat.pos_of_neZero ℓ); exact hdiv.out let step := queryPhaseLogicStep 𝔽q β (ϑ:=ϑ) γ_repetitions (h_ℓ_add_R_rate := h_ℓ_add_R_rate) - -- 1. Generate the Honest Transcript (Deterministic given challenges) let transcript := step.honestProverTranscript stmtIn witIn oStmtIn challenges - -- 2. Define the honest oracle simulator -- simOracle2 oSpec t₁ t₂ : SimOracle.Stateless (oSpec + ([T₁]ₒ + [T₂]ₒ)) oSpec -- This answers queries to OracleIn using oStmtIn and queries to Messages using transcript @@ -1737,7 +2121,6 @@ theorem queryPhaseLogicStep_isStronglyComplete : -- 2. [fun b => b = true | verifierCheck ...] = 1 (always returns true) -- 3. completeness_relOut holds -- 4-5. Prover and verifier agree - -- Prove safety: verifier check never fails have h_guards_pass : Pr[⊥ | OptionT.mk (simulateQ so (step.verifierCheck stmtIn transcript))] = 0 := by @@ -1756,7 +2139,6 @@ theorem queryPhaseLogicStep_isStronglyComplete : -- rw [OptionT.support_run_eq] -- simp only [←probOutput_eq_zero_iff] -- erw [probOutput_none_OptionT_pure_eq_zero] - apply OptionT.probOutput_none_run_eq_zero_of_probFailure_eq_zero -- rw [probFailure_bind_eq_zero_iff] erw [OptionT.probFailure_mk_bind_eq_zero_iff] @@ -1766,7 +2148,6 @@ theorem queryPhaseLogicStep_isStronglyComplete : -- simp only [liftM_OptionT_eq, bind_pure_comp] set simulateQ_forIn_block : OracleComp []ₒ (Option PUnit.{1}) := simulateQ so _ with h_simulateQ_forIn_block - have h_probFailure_simulateQ_forIn_eq_0 : Pr[⊥ | OptionT.mk simulateQ_forIn_block] = 0 := by dsimp only [simulateQ_forIn_block] rw [OptionT.simulateQ_forIn] @@ -1780,7 +2161,7 @@ theorem queryPhaseLogicStep_isStronglyComplete : -- -- ⊢ Pr[=none | simulateQ_forIn_block] = 0 -- change (Pr[=none | simulateQ_forIn_block] = 0) -- 3. Now we are at the outer loop (forIn γ_repetitions). - -- Push simulateQ inside the loop using the lemma that `simulateQ distributes over the loop structure` + -- Push simulateQ inside the loop using the lemma that `simulateQ distributes over the loop` -- NOW apply the safety lemma -- The goal is: [⊥ | forIn ... (fun ... ↦ simulateQ so ...)] = 0 apply _root_.probFailure_forIn_eq_zero_of_body_safe @@ -1790,7 +2171,8 @@ theorem queryPhaseLogicStep_isStronglyComplete : -- rw [probFailure_bind_eq_zero_iff] conv => enter [2] - simp only [bind_pure_comp, map_pure, Function.comp_apply, simulateQ_pure, probFailure_pure, implies_true] + simp only [bind_pure_comp, map_pure, Function.comp_apply, simulateQ_pure, probFailure_pure, + implies_true] erw [OptionT.probFailure_mk] conv_lhs => enter [1]; @@ -1817,8 +2199,7 @@ theorem queryPhaseLogicStep_isStronglyComplete : rcases h_x_eq with ⟨val, h_x_eq⟩ rw [h_x_eq] rw [OptionT.probFailure_mk] - simp only [MessageIdx, Message, bind_pure_comp, map_pure, HasEvalPMF.probFailure_eq_zero, - zero_add] + simp only [MessageIdx, Message, bind_pure_comp, HasEvalPMF.probFailure_eq_zero, zero_add] erw [simulateQ_pure] simp only [probOutput_eq_zero_iff, support_pure, Set.mem_singleton_iff, reduceCtorEq, not_false_eq_true] @@ -1862,15 +2243,16 @@ theorem queryOracleProof_perfectCompleteness {σ : Type} -- dsimp only [queryOracleProof, queryOracleProver, queryOracleVerifier, dsimp only [OracleVerifier.toVerifier, FullTranscript.mk1] let step := (queryPhaseLogicStep 𝔽q β (ϑ:=ϑ) γ_repetitions (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) - let strongly_complete : step.IsStronglyCompleteUnderSimulation := queryPhaseLogicStep_isStronglyComplete (L := L) - 𝔽q β (ϑ := ϑ) γ_repetitions (h_ℓ_add_R_rate := h_ℓ_add_R_rate) - + let strongly_complete : step.IsStronglyCompleteUnderSimulation := + queryPhaseLogicStep_isStronglyComplete (L := L) + 𝔽q β (ϑ := ϑ) γ_repetitions (h_ℓ_add_R_rate := h_ℓ_add_R_rate) constructor -- GOAL 1: SAFETY - Prove the verifier never crashes ([⊥|...] = 0) · -- Peel off monadic layers to reach the core verifier logic -- ⊢ [⊥| do -- let challenge ← getChallenge -- (A) V samples v ← B_{ℓ+R} - -- let receiveChallengeFn ← pure (...) -- (B) P receives challenge (pure, never fails) + -- let receiveChallengeFn ← pure (...) -- (B) P receives challenge + -- (pure, never fails) -- let __discr ← proverOut ... -- (C) P computes output (pure, never fails) -- let verifierStmtOut ← simulateQ ... -- (D) V runs verifierCheck ← THIS IS THE KEY -- do @@ -1878,7 +2260,6 @@ theorem queryOracleProof_perfectCompleteness {σ : Type} -- pure verifierOut -- pure (...) -- ] = 0 - -- Step 1: Peel off the safe layers -- For each layer: -- A: neverFails_getChallenge or neverFails_query @@ -1901,9 +2282,9 @@ theorem queryOracleProof_perfectCompleteness {σ : Type} rw [OptionT.probFailure_lift] simp only [Fin.isValue, liftComp_eq_liftM, liftComp_id, HasEvalPMF.probFailure_eq_zero] rw [true_and] - intro h_receiveChallengeFn h_receiveChallengeFn_support - -- 1.B Handle the `(queryOracleReduction 𝔽q β γ_repetitions).prover.output (h_receiveChallengeFn chal)) ...` + -- 1.B Handle the `(queryOracleReduction 𝔽q β γ_repetitions).prover.output + -- (h_receiveChallengeFn chal)) ...` conv => enter [1]; simp only [ChallengeIdx, Challenge, Fin.isValue, Matrix.cons_val_zero, @@ -1920,13 +2301,13 @@ theorem queryOracleProof_perfectCompleteness {σ : Type} rw [liftComp_id] simp only [Fin.reduceLast, Fin.isValue] dsimp only [OptionT.lift]; - erw [support_bind]; dsimp only [liftM, monadLift, MonadLift.monadLift]; rw [support_liftComp]; erw [support_pure] + erw [support_bind]; dsimp only [liftM, monadLift, MonadLift.monadLift]; + rw [support_liftComp]; erw [support_pure] simp only [Fin.isValue, Challenge, Matrix.cons_val_zero, Set.mem_singleton_iff, support_pure, Set.iUnion_iUnion_eq_left, Option.some.injEq] -- pure equalities now - -- 1.C Handle the `let __discr ← proverOut ...` - -- Note: Use simp instead of rw to avoid typeclass diamond issues with FiniteRange instances + -- Note: Use simp instead of rw to avoid typeclass diamond issues with Fintype instances -- erw [probFailure_liftComp] -- split; simp only [ChallengeIdx, Challenge, MessageIdx, bind_pure_comp, liftComp_eq_liftM, @@ -1947,9 +2328,7 @@ theorem queryOracleProof_perfectCompleteness {σ : Type} conv => enter [x] rw [OptionT.support_run] - intro vStmtOut h_vStmtOut_mem_support - -- Apply the simulateQ safety lemma -- Can't apply probFailure_simulateQ_simOracle2_eq_zero here obtain ⟨h_V_check, h_rel, h_agree⟩ := strongly_complete @@ -1958,26 +2337,23 @@ theorem queryOracleProof_perfectCompleteness {σ : Type} match j with | 0 => exact chal ) - have h_transcript_eq : FullTranscript.mk1 ((FullTranscript.mk1 chal).challenges ⟨0, by rfl⟩) = FullTranscript.mk1 (pSpec := pSpecQuery 𝔽q β γ_repetitions) chal := by rfl rw [h_transcript_eq] - - have h_probOutput_none_V_check_eq_0 := OptionT.probOutput_none_run_eq_zero_of_probFailure_eq_zero (hfail := h_V_check) - + have h_probOutput_none_V_check_eq_0 := + OptionT.probOutput_none_run_eq_zero_of_probFailure_eq_zero (hfail := h_V_check) have h_vStmtOut_eq : ∃ val, vStmtOut = some (val) := by have h_exists_some := exists_eq_some_of_mem_support_of_probOutput_none_eq_zero (x := vStmtOut) (hx := h_vStmtOut_mem_support) (hnone := by dsimp only [step] at h_probOutput_none_V_check_eq_0 - dsimp only [queryOracleProof, queryOracleReduction, queryPhaseLogicStep, queryOracleVerifier, - OracleVerifier.toVerifier] at h_probOutput_none_V_check_eq_0 ⊢ + dsimp only [queryOracleProof, queryOracleReduction, queryPhaseLogicStep, + queryOracleVerifier, OracleVerifier.toVerifier] at h_probOutput_none_V_check_eq_0 ⊢ rw [h_transcript_eq] at h_probOutput_none_V_check_eq_0 ⊢ - simp only [MessageIdx, Message, Fin.isValue, liftM_OptionT_eq, bind_pure_comp, map_pure, - List.forIn_yield_eq_foldlM, id_map', List.foldlM_range, Functor.map_map, + simp only [MessageIdx, Message, Fin.isValue, bind_pure_comp, Functor.map_map, OptionT.simulateQ_map] - simp only [MessageIdx, Message, Fin.isValue, liftM_OptionT_eq, bind_pure_comp, map_pure, - List.forIn_yield_eq_foldlM, id_map', List.foldlM_range, OptionT.simulateQ_map] at h_probOutput_none_V_check_eq_0 + simp only [MessageIdx, Message, Fin.isValue, bind_pure_comp, + OptionT.simulateQ_map] at h_probOutput_none_V_check_eq_0 exact h_probOutput_none_V_check_eq_0 ) exact h_exists_some @@ -2005,18 +2381,14 @@ theorem queryOracleProof_perfectCompleteness {σ : Type} toPFunctor_add, toPFunctor_emptySpec, OptionT.support_run, ↓existsAndEq, and_true, true_and, exists_eq_right_right', liftM_pure, support_pure, exists_eq_left] dsimp only [monadLift, MonadLift.monadLift] - simp only [Fin.isValue, Challenge, Matrix.cons_val_one, Matrix.cons_val_zero, ChallengeIdx, - liftComp_eq_liftM, liftM_pure, liftComp_pure, support_pure, Set.mem_singleton_iff, - Fin.reduceLast, MessageIdx, Message, exists_eq_left] at hx_mem_support + simp only [Fin.isValue, Challenge, Matrix.cons_val_zero, ChallengeIdx, + liftComp_eq_liftM, Fin.reduceLast, MessageIdx] at hx_mem_support -- Step 2b: Extract the challenge r1 and the trace equations obtain ⟨r1, ⟨_h_r1_mem_challenge_support, h_trace_support⟩⟩ := hx_mem_support - rcases h_trace_support with ⟨prvWitOut, h_prvOut_mem_support, h_verOut_mem_support⟩ - conv at h_prvOut_mem_support => -- similar simplification as in commit step - dsimp only [queryOracleProof, queryOracleReduction, queryPhaseLogicStep, queryOracleProver, queryOracleVerifier, - OracleVerifier.toVerifier, - FullTranscript.mk1] + dsimp only [queryOracleProof, queryOracleReduction, queryPhaseLogicStep, queryOracleProver, + queryOracleVerifier, OracleVerifier.toVerifier, FullTranscript.mk1] dsimp only [liftM, monadLift, MonadLift.monadLift] rw [liftComp_id] rw [support_liftComp] @@ -2027,8 +2399,8 @@ theorem queryOracleProof_perfectCompleteness {σ : Type} -- rw [OptionT.simulateQ_simOracle2_liftM_query_T2] -- erw [_root_.bind_pure_simulateQ_comp] simp only - -- simp only [show OptionT.pure (m := (OracleComp ([]ₒ + ([OracleStatement 𝔽q β ϑ (Fin.last ℓ)]ₒ + - -- [pSpecFold.Message]ₒ)))) = pure by rfl] + -- simp only [show OptionT.pure (m := (OracleComp ([]ₒ + -- + ([OracleStatement 𝔽q β ϑ (Fin.last ℓ)]ₒ + [pSpecFold.Message]ₒ)))) = pure by rfl] change some (verStmtOut, verOStmtOut) ∈ (liftComp _ _).support rw [support_liftComp] dsimp only [Functor.map] @@ -2037,12 +2409,10 @@ theorem queryOracleProof_perfectCompleteness {σ : Type} Function.comp_apply, Set.iUnion_exists, Set.biUnion_and'] -- erw [support_pure] -- simp only [Set.mem_singleton_iff, Option.some.injEq, Prod.mk.injEq] - - rcases h_verOut_mem_support with ⟨VCheck_boolean, h_VCheck_boolean_mem_support, VOut_boolean, h_VOut_boolean_mem_support, h_VOut_mem_support⟩ - + rcases h_verOut_mem_support with ⟨VCheck_boolean, h_VCheck_boolean_mem_support, + VOut_boolean, h_VOut_boolean_mem_support, h_VOut_mem_support⟩ set V_check := step.verifierCheck stmtIn (FullTranscript.mk1 (msg0 := _)) with h_V_check_def - -- Apply the simulateQ safety lemma -- Can't apply probFailure_simulateQ_simOracle2_eq_zero here obtain ⟨h_V_check_not_fail, h_rel, h_agree⟩ := strongly_complete @@ -2051,7 +2421,6 @@ theorem queryOracleProof_perfectCompleteness {σ : Type} match j with | 0 => exact r1 ) - have h_VOut_boolean_eq_true : VOut_boolean = true := by match VCheck_boolean with -- VOut_boolean depends on VCheck_boolean | some a => @@ -2061,16 +2430,14 @@ theorem queryOracleProof_perfectCompleteness {σ : Type} dsimp only [queryPhaseLogicStep] at h_VOut_boolean_mem_support exact h_VOut_boolean_mem_support | none => - simp only [simulateQ_pure, support_pure, Set.mem_singleton_iff] at h_VOut_boolean_mem_support + simp only [simulateQ_pure, support_pure, Set.mem_singleton_iff] + at h_VOut_boolean_mem_support simp only [h_VOut_boolean_mem_support, support_pure, Set.mem_singleton_iff, reduceCtorEq] at h_VOut_mem_support ⊢ - simp only [h_VOut_boolean_eq_true, OptionT.support_OptionT_pure_run, Set.mem_singleton_iff, Option.some.injEq, Prod.mk.injEq] at h_VOut_mem_support -- pure equalities now - have prvStmtOut_eq := h_prvOut_mem_support obtain ⟨verStmtOut_eq, verOStmtOut_eq⟩ := h_VOut_mem_support - constructor · rw [verStmtOut_eq, verOStmtOut_eq]; exact h_rel @@ -2095,35 +2462,24 @@ noncomputable def queryRbrExtractor : extractMid := fun _ _ _ witMidSucc => witMidSucc extractOut := fun _ _ _ => () -def queryKStateProp {m : Fin (1 + 1)} +def queryKStateProp (m : Fin (1 + 1)) (tr : ProtocolSpec.Transcript m (pSpecQuery 𝔽q β γ_repetitions (h_ℓ_add_R_rate := h_ℓ_add_R_rate))) - (stmt : FinalSumcheckStatementOut (L := L) (ℓ := ℓ)) + (stmtIn : FinalSumcheckStatementOut (L := L) (ℓ := ℓ)) (witMid : Unit) - (oStmt : ∀ j, OracleStatement 𝔽q β (ϑ := ϑ) + (oStmtIn : ∀ j, OracleStatement 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (Fin.last ℓ) j) : Prop := -if h0 : m.val = 0 then - -- Same as last Kstate of finalSumcheck reduction - Binius.BinaryBasefold.finalSumcheckRelOutProp 𝔽q β - (input:=⟨⟨stmt, oStmt⟩, witMid⟩) -else - let r := stmt.ctx.t_eval_point - let s := stmt.ctx.original_claim - let challenges : Fin ℓ → L := stmt.challenges - let tr_so_far := (pSpecQuery 𝔽q β γ_repetitions - (h_ℓ_add_R_rate := h_ℓ_add_R_rate)).take m m.is_le - let chalIdx : tr_so_far.ChallengeIdx := ⟨⟨0, - Nat.lt_of_succ_le (by omega)⟩, by simp only [Nat.reduceAdd]; rfl⟩ - let γ_challenges : Fin γ_repetitions → sDomain 𝔽q - β h_ℓ_add_R_rate ⟨0, by omega⟩ := ((ProtocolSpec.Transcript.equivMessagesChallenges (k:=m) - (pSpec:=pSpecQuery 𝔽q β γ_repetitions (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) - tr).2 chalIdx) - let fold_challenges := stmt.challenges - -- Checks available after message 1 (V -> P: γ challenges) - let proximityTestsCheck : Prop := - proximityChecksSpec 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) - (ϑ:=ϑ) γ_repetitions γ_challenges oStmt fold_challenges stmt.final_constant - proximityTestsCheck + match m with + | ⟨0, _⟩ => -- Same as last KState of finalSumcheck reduction (= relIn) + Binius.BinaryBasefold.finalSumcheckRelOutProp 𝔽q β + (input := ⟨⟨stmtIn, oStmtIn⟩, witMid⟩) + | ⟨1, _⟩ => -- After V sends γ challenges: proximity tests must pass + let γ_challenges : Fin γ_repetitions → sDomain 𝔽q β h_ℓ_add_R_rate ⟨0, by omega⟩ := + tr.challenges ⟨0, rfl⟩ + let fold_challenges := stmtIn.challenges + logical_proximityChecksSpec 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (ϑ := ϑ) (γ_repetitions := γ_repetitions) (γ_challenges := γ_challenges) + (final_constant := stmtIn.final_constant) (oStmt := oStmtIn) (stmt := stmtIn) /-- The knowledge state function for the query phase -/ noncomputable def queryKnowledgeStateFunction {σ : Type} (init : ProbComp σ) @@ -2131,30 +2487,310 @@ noncomputable def queryKnowledgeStateFunction {σ : Type} (init : ProbComp σ) (queryOracleVerifier 𝔽q β (ϑ:=ϑ) γ_repetitions).KnowledgeStateFunction init impl (relIn := finalSumcheckRelOut 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ) (relOut := acceptRejectOracleRel) - (extractor := queryRbrExtractor 𝔽q β (ϑ:=ϑ) γ_repetitions (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) where - toFun := fun m ⟨stmt, oStmt⟩ tr witMid => + (extractor := queryRbrExtractor 𝔽q β (ϑ:=ϑ) + γ_repetitions (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) where + toFun := fun m ⟨stmtIn, oStmtIn⟩ tr witMid => queryKStateProp 𝔽q β (ϑ:=ϑ) (γ_repetitions:=γ_repetitions) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) - (m:=m) (tr:=tr) (stmt:=stmt) (witMid:=witMid) (oStmt:=oStmt) - toFun_empty := fun stmt witMid => by simp only; rfl - toFun_next := fun m hDir stmt tr msg witMid h => by - fin_cases m; simp [pSpecQuery] at hDir - toFun_full := fun stmt tr witOut h => by - sorry + (m:=m) (tr:=tr) (stmtIn:=stmtIn) (witMid:=witMid) (oStmtIn:=oStmtIn) + toFun_empty := fun ⟨stmtIn, oStmtIn⟩ witMid => by rfl + toFun_next := fun m hDir ⟨stmtMid, oStmtMid⟩ tr msg witMid => by + simp only [ne_eq, reduceCtorEq, not_false_eq_true, Matrix.cons_val_fin_one, + Direction.not_V_to_P_eq_P_to_V] at hDir + toFun_full := fun ⟨stmtIn, oStmtIn⟩ tr witOut probEvent_relOut_gt_0 => by + -- h_relOut: ∃ stmtOut oStmtOut, verifier outputs (stmtOut, oStmtOut) with prob > 0 + -- and ((stmtOut, oStmtOut), witOut) ∈ foldStepRelOut + simp only [StateT.run'_eq, gt_iff_lt, probEvent_pos_iff, Prod.exists] at probEvent_relOut_gt_0 + rcases probEvent_relOut_gt_0 with ⟨stmtOut, oStmtOut, h_output_mem_V_run_support, h_relOut⟩ + have h_output_mem_V_run_support' : + some (stmtOut, oStmtOut) ∈ + (do + let s ← init + Prod.fst <$> + (simulateQ impl + (Verifier.run (stmtIn, oStmtIn) tr + (queryOracleVerifier 𝔽q β (ϑ := ϑ) γ_repetitions + (h_ℓ_add_R_rate := h_ℓ_add_R_rate)).toVerifier)).run s).support := by + exact (OptionT.mem_support_iff + (mx := OptionT.mk (do + let s ← init + Prod.fst <$> + (simulateQ impl + (Verifier.run (stmtIn, oStmtIn) tr + (queryOracleVerifier 𝔽q β (ϑ := ϑ) γ_repetitions + (h_ℓ_add_R_rate := h_ℓ_add_R_rate)).toVerifier)).run s)) + (x := (stmtOut, oStmtOut))).1 h_output_mem_V_run_support + simp only [support_bind, Set.mem_iUnion, exists_prop] at h_output_mem_V_run_support' + rcases h_output_mem_V_run_support' with ⟨s, hs_init, h_output_mem_V_run_support_with_s⟩ + -- Apply the main lemma connecting verifier support to logical proximity checks + have h_res := logical_consistency_checks_passed_of_mem_support_V_run + (impl := impl) (stmtIn := stmtIn) (oStmtIn := oStmtIn) (tr := tr) + (s := s) (stmtOut := stmtOut) (oStmtOut := oStmtOut) + (h_mem_V_run_support := by + rw [OptionT.mem_support_iff] + dsimp only [OptionT.mk, OptionT.run] + exact h_output_mem_V_run_support_with_s + ) + -- The lemma gives us: + exact h_res.2.2 -/-- Round-by-round knowledge soundness for the oracle verifier (query phase) -/ -theorem queryOracleVerifier_rbrKnowledgeSoundness [Fintype L] {σ : Type} (init : ProbComp σ) +/-- **Single Repetition Proximity Check Bound (Proposition 4.23)** + +For a single repetition of the proximity check, the probability that a non-compliant +oracle (not close to RS codeword) passes the fold consistency check is bounded by: + `(1/2) + 1/(2 * 2^𝓡)` + +**Preconditions (from Proposition 4.23 in DG25):** +- `h_not_oracleFoldingConsistent`: At least one oracle is non-compliant +- `h_no_bad_event`: No bad folding events occurred (Definition 4.19) + +This is the fundamental proximity testing bound used in the soundness proof. -/ +theorem prop_4_23_singleRepetition_proximityCheck_bound + (stmtIn : FinalSumcheckStatementOut (L := L) (ℓ := ℓ)) + (oStmtIn : ∀ j, OracleStatement 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (Fin.last ℓ) j) + (h_not_oracleFoldingConsistent : ¬ finalSumcheckStepOracleConsistencyProp 𝔽q β + (h_le := by apply Nat.le_of_dvd (by exact Nat.pos_of_neZero ℓ) (hdiv.out)) + (stmtOut := stmtIn) (oStmtOut := oStmtIn)) + (h_no_bad_event : ¬ blockBadEventExistsProp 𝔽q β (stmtIdx := Fin.last ℓ) + (oracleIdx := OracleFrontierIndex.mkFromStmtIdx (Fin.last ℓ)) + (oStmt := oStmtIn) (challenges := stmtIn.challenges)) : + Pr_{ let v ← $ᵖ ↥(sDomain 𝔽q β h_ℓ_add_R_rate 0) }[ + logical_checkSingleRepetition 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + oStmtIn v stmtIn stmtIn.final_constant ] ≤ + queryRbrKnowledgeError_singleRepetition (𝓡 := 𝓡) := by + -- Delegates to Soundness Prop 4.23 (Lemma 4.25 supplies the query-rejection property). + have h_res := + (Binius.BinaryBasefold.prop_4_23_singleRepetition_proximityCheck_bound + (stmtIn := stmtIn) (oStmtIn := oStmtIn) + (h_not_consistent := h_not_oracleFoldingConsistent) + (h_no_bad := h_no_bad_event) + (h_le := by + apply Nat.le_of_dvd (by exact Nat.pos_of_neZero ℓ) (hdiv.out))) + dsimp only [queryRbrKnowledgeError_singleRepetition] + simp only [one_div, mul_inv_rev, ENNReal.coe_add, ne_eq, OfNat.ofNat_ne_zero, + not_false_eq_true, ENNReal.coe_inv, ENNReal.coe_ofNat, ENNReal.coe_mul, pow_eq_zero_iff', + false_and, ENNReal.coe_pow, ge_iff_le] + simp only [one_div, ne_eq, OfNat.ofNat_ne_zero, not_false_eq_true, ENNReal.coe_inv, + ENNReal.coe_ofNat, ENNReal.coe_one] at h_res + rw [ENNReal.mul_inv (ha := by + left; simp only [ne_eq, OfNat.ofNat_ne_zero, not_false_eq_true]) + (hb := by + left; simp only [ne_eq, ENNReal.ofNat_ne_top, not_false_eq_true]) , mul_comm] at h_res + exact h_res + +theorem singleRepetition_proximityCheck_bound + (stmtIn : FinalSumcheckStatementOut (L := L) (ℓ := ℓ)) + (oStmtIn : ∀ j, OracleStatement 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (Fin.last ℓ) j) + (h_not_oracleFoldingConsistent : ¬ finalSumcheckStepOracleConsistencyProp 𝔽q β + (h_le := by apply Nat.le_of_dvd (by exact Nat.pos_of_neZero ℓ) (hdiv.out)) + (stmtOut := stmtIn) (oStmtOut := oStmtIn)) + (h_no_bad_event : ¬ blockBadEventExistsProp 𝔽q β (stmtIdx := Fin.last ℓ) + (oracleIdx := OracleFrontierIndex.mkFromStmtIdx (Fin.last ℓ)) + (oStmt := oStmtIn) (challenges := stmtIn.challenges)) : + Pr_{ let v ← $ᵖ ↥(sDomain 𝔽q β h_ℓ_add_R_rate 0) }[ + logical_checkSingleRepetition 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + oStmtIn v stmtIn stmtIn.final_constant ] ≤ + queryRbrKnowledgeError_singleRepetition (𝓡 := 𝓡) := by + -- This is Proposition 4.23 (DG25) specialized to a single repetition. + simpa using + (prop_4_23_singleRepetition_proximityCheck_bound (𝔽q := 𝔽q) (β := β) + (stmtIn := stmtIn) (oStmtIn := oStmtIn) + (h_not_oracleFoldingConsistent := h_not_oracleFoldingConsistent) + (h_no_bad_event := h_no_bad_event)) + +open Classical in +/-! Round-by-round knowledge soundness for the oracle verifier (query phase). + +**Proof Strategy (RBR Extraction Failure Event):** + +The RBR extraction failure event is: `¬ KState(0) ∧ KState(1)`, i.e., + - `¬ finalSumcheckRelOutProp` (KState 0 = FALSE), AND + - `proximityChecksSpec` (KState 1 = TRUE) + +By De Morgan's law: + `¬ finalSumcheckRelOutProp = ¬ (oracleFoldingConsistency ∨ badEvent)` + `= ¬ oracleFoldingConsistency ∧ ¬ badEvent` + +This means: + - `¬ oracleFoldingConsistency`: Some oracle is NOT compliant (not close to correct folding) + - `¬ badEvent`: No bad events detected + +**Proposition 4.23 (DG25 - Assuming no bad events):** +If any of the adversary's oracles is not compliant (not close to RS codeword), +then the verifier accepts with at most negligible probability: + `Pr[V accepts] ≤ ((1/2) + 1/(2 * 2^𝓡))^γ_repetitions` + +This is exactly `queryRbrKnowledgeError`. -/ +theorem queryOracleVerifier_rbrKnowledgeSoundness {σ : Type} (init : ProbComp σ) (impl : QueryImpl []ₒ (StateT σ ProbComp)) : (queryOracleVerifier 𝔽q β (ϑ:=ϑ) γ_repetitions).rbrKnowledgeSoundness init impl (relIn := finalSumcheckRelOut 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ) (relOut := acceptRejectOracleRel) (rbrKnowledgeError := queryRbrKnowledgeError 𝔽q β γ_repetitions (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) := by - use fun _ => Unit - use queryRbrExtractor 𝔽q β (ϑ:=ϑ) γ_repetitions (h_ℓ_add_R_rate := h_ℓ_add_R_rate) - use queryKnowledgeStateFunction 𝔽q β (ϑ:=ϑ) γ_repetitions init impl - intro stmtIn witIn prover j - sorry + classical + apply OracleReduction.unroll_rbrKnowledgeSoundness + (kSF := queryKnowledgeStateFunction 𝔽q β (ϑ:=ϑ) γ_repetitions init impl) + intro stmtIn_oStmtIn witIn prover j initState + let P := rbrExtractionFailureEvent + (kSF := queryKnowledgeStateFunction 𝔽q β (ϑ:=ϑ) γ_repetitions init impl) + (extractor := queryRbrExtractor 𝔽q β (ϑ:=ϑ) γ_repetitions (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) + (i := j) (stmtIn := stmtIn_oStmtIn) + rw [OracleReduction.probEvent_soundness_goal_unroll_log' (pSpec := pSpecQuery 𝔽q β γ_repetitions + (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) (P := P) (impl := impl) (prover := prover) (i := j) + (stmt := stmtIn_oStmtIn) (wit := witIn) (s := initState)] + have h_j_eq_1 : j = ⟨0, rfl⟩ := + match j with + | ⟨0, h0⟩ => rfl + subst h_j_eq_1 + conv_lhs => simp only [Fin.isValue, Fin.castSucc_zero]; + rw [OracleReduction.soundness_unroll_runToRound_0_pSpec_1_V_to_P + (prover := prover) (stmtIn := stmtIn_oStmtIn) (witIn := witIn)] + simp only [Fin.isValue, Challenge, Matrix.cons_val_zero, ChallengeIdx, + QueryImpl.addLift_def, QueryImpl.liftTarget_self, bind_pure_comp, + liftComp_eq_liftM, simulateQ_bind, simulateQ_map, StateT.run'_eq, + StateT.run_bind, StateT.run_map, map_bind, Functor.map_map] + rw [probEvent_bind_eq_tsum] + -- erw [simulateQ_simOracle2_lift_liftComp_query_T1] + -- conv => + -- enter [1] + -- erw [probEvent_map] + -- rw [OracleQuery.cont_apply] + -- erw [probEvent_bind_eq_tsum] + apply OracleReduction.ENNReal.tsum_mul_le_of_le_of_sum_le_one + · -- Bound the conditional probability for each transcript + intro x + -- rw [OracleComp.probEvent_map] + simp only [Fin.isValue, probEvent_map] + let q : OracleQuery + [(pSpecQuery 𝔽q β γ_repetitions (h_ℓ_add_R_rate := h_ℓ_add_R_rate)).Challenge]ₒ + _ := query ⟨⟨0, by rfl⟩, ()⟩ + erw [OracleReduction.probEvent_StateT_run_ignore_state + (comp := simulateQ (impl.addLift challengeQueryImpl) (liftM (query q.input))) + (s := x.2) + (P := fun a => P (x.1.1) (q.cont a))] + rw [probEvent_eq_tsum_ite] + erw [simulateQ_query] + simp only [ChallengeIdx, Challenge, Fin.isValue, monadLift_self, + QueryImpl.addLift_def, QueryImpl.liftTarget_self, StateT.run'_eq, StateT.run_map, + Functor.map_map, ge_iff_le] + have h_L_inhabited : Inhabited L := ⟨0⟩ + conv_lhs => + enter [1, x_1, 2, 1, 2] + rw [addLift_challengeQueryImpl_input_run_eq_liftM_run (impl := impl) (q := q) (s := x.2)] + erw [StateT.run_monadLift, monadLift_self, liftComp_id] + rw [bind_pure_comp] + conv => + enter [1, 1, x_1, 2] + rw [Functor.map_map] + rw [← probEvent_eq_eq_probOutput] + rw [probEvent_map] + rw [OracleQuery.cont_apply] + dsimp only [MonadLift.monadLift] + rw [OracleQuery.cont_apply] + dsimp only [q] + simp_rw [OracleQuery.input_query, OracleQuery.snd_query] + conv_lhs => change (∑' (x_1 : (Fin γ_repetitions → ↥(sDomain 𝔽q β h_ℓ_add_R_rate 0))), _) + simp only [Function.comp_id] + conv => + enter [1, 1, x_1, 2] + rw [probEvent_eq_eq_probOutput] + change Pr[=x_1 | $ᵗ (Fin γ_repetitions → ↥(sDomain 𝔽q β h_ℓ_add_R_rate 0))] + rw [OracleReduction.probOutput_uniformOfFintype_eq_Pr (L := _) (x := x_1)] + rw [OracleReduction.tsum_uniform_Pr_eq_Pr + (L := (Fin γ_repetitions → ↥(sDomain 𝔽q β h_ℓ_add_R_rate 0))) + (P := fun x_1 => P x.1.1 (q.2 x_1))] + -- Now the goal is in do-notation form, which is exactly what Pr_ notation expands to + -- Make this explicit using change + conv_lhs => change (∑' (x_1 : (Fin γ_repetitions → ↥(sDomain 𝔽q β h_ℓ_add_R_rate 0))), _) + -- Now the goal is in do-notation form, which is exactly what Pr_ notation expands to + -- Make this explicit using change + change Pr_{ let y ← $ᵖ (Fin γ_repetitions → ↥(sDomain 𝔽q β h_ℓ_add_R_rate 0)) }[(P x.1.1) y] ≤ + queryRbrKnowledgeError 𝔽q β γ_repetitions + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ⟨0, rfl⟩ + -- Factor over independent repetitions using the structure of rbrExtractionFailureEvent + -- + -- Key observations: + -- 1. P = rbrExtractionFailureEvent = ∃ witMid : Unit, ¬kSF 0 ... ∧ kSF 1 ... + -- 2. Since witMid : Unit, the existential is trivial (there's only ()) + -- 3. kSF 1 = logical_proximityChecksSpec = ∀ rep, single_check (challenges rep) + -- 4. The bound follows from: P y → ∀ rep, single_check (y rep) + -- So Pr[P y] ≤ Pr[∀ rep, single_check (y rep)] = Pr[single_check c]^γ + -- + -- Strategy: Use monotonicity of probability, then factor the forall + obtain ⟨stmtIn, oStmtIn⟩ := stmtIn_oStmtIn + -- Step 1: Define the single-repetition predicate + let single_P : ↥(sDomain 𝔽q β h_ℓ_add_R_rate 0) → Prop := fun v => + logical_checkSingleRepetition 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) oStmtIn v stmtIn + stmtIn.final_constant + -- Case split FIRST: if P is empty, handle directly; otherwise extract preconditions + by_cases h_P_nonempty : ∃ y, P x.1.1 y + case neg => + -- If no y satisfies P x.1.1 y, then Pr[P x.1.1 _] = 0 ≤ bound trivially + push_neg at h_P_nonempty + -- Show Pr[P x.1.1 _] = 0 using that P is never true + calc Pr_{ let y ← $ᵖ (Fin γ_repetitions → + ↥(sDomain 𝔽q β h_ℓ_add_R_rate 0)) }[ P x.1.1 y ] + _ = Pr_{ let y ← $ᵖ (Fin γ_repetitions → + ↥(sDomain 𝔽q β h_ℓ_add_R_rate 0)) }[ False ] := by + congr 1; ext y; + simp only [Fin.isValue, h_P_nonempty, PMF.monad_pure_eq_pure, PMF.monad_bind_eq_bind, + PMF.bind_const, PMF.pure_apply, eq_iff_iff, iff_false, ite_not] + _ = 0 := by + simp only [PMF.monad_pure_eq_pure, PMF.monad_bind_eq_bind, PMF.bind_const, PMF.pure_apply, + eq_iff_iff, iff_false, not_true_eq_false, ↓reduceIte] + _ ≤ _ := zero_le _ + case pos => + -- P is non-empty: extract preconditions from a witness + obtain ⟨y₀, h_P_y₀⟩ := h_P_nonempty + -- Step 2: Show P implies the forall form + have h_P_implies_forall : ∀ y, P x.1.1 y → (∀ rep : Fin γ_repetitions, single_P (y rep)) := by + intro y h_P + unfold rbrExtractionFailureEvent at h_P + rcases h_P with ⟨witMid, h_kSF_false_before, h_kSF_true_after⟩ + unfold queryKnowledgeStateFunction queryKStateProp logical_proximityChecksSpec + at h_kSF_true_after + exact h_kSF_true_after + -- Step 2b: Extract the preconditions from h_kSF_false_before via De Morgan + have h_preconditions : + (¬ finalSumcheckStepOracleConsistencyProp 𝔽q β + (h_le := by apply Nat.le_of_dvd (by exact Nat.pos_of_neZero ℓ) (hdiv.out)) + (stmtOut := stmtIn) (oStmtOut := oStmtIn)) ∧ + (¬ blockBadEventExistsProp 𝔽q β (stmtIdx := Fin.last ℓ) + (oracleIdx := OracleFrontierIndex.mkFromStmtIdx (Fin.last ℓ)) + (oStmt := oStmtIn) (challenges := stmtIn.challenges)) := by + -- Use h_P_y₀ to extract preconditions + -- First substitute P with its definition + simp only [P] at h_P_y₀ + unfold rbrExtractionFailureEvent at h_P_y₀ + rcases h_P_y₀ with ⟨witMid, h_kSF_false_before, h_kSF_true_after⟩ + unfold queryKnowledgeStateFunction at h_kSF_false_before + simp only [Fin.castSucc_zero, queryRbrExtractor] at h_kSF_false_before + unfold queryKStateProp at h_kSF_false_before + simp only [Prod.mk.injEq] at h_kSF_false_before + unfold finalSumcheckRelOutProp finalSumcheckStepFoldingStateProp at h_kSF_false_before + simp only [Prod.fst, Prod.snd] at h_kSF_false_before + push_neg at h_kSF_false_before + exact h_kSF_false_before + obtain ⟨h_not_consistent, h_no_bad⟩ := h_preconditions + -- Step 3: Apply monotonicity + apply le_trans (prob_mono (D := $ᵖ (Fin γ_repetitions → + ↥(sDomain 𝔽q β h_ℓ_add_R_rate 0))) (P x.1.1) + (fun y => ∀ rep : Fin γ_repetitions, single_P (y rep)) h_P_implies_forall) + -- Step 4: Factor independent repetitions + rw [prob_pow_of_forall_finFun (n := γ_repetitions) (P := single_P)] + -- Step 5: Bound single repetition using singleRepetition_proximityCheck_bound + have h_single_repetition_bound : + Pr_{ let v ← $ᵖ ↥(sDomain 𝔽q β h_ℓ_add_R_rate 0) }[ single_P v ] ≤ + queryRbrKnowledgeError_singleRepetition (𝓡 := 𝓡) := + singleRepetition_proximityCheck_bound 𝔽q β stmtIn oStmtIn h_not_consistent h_no_bad + -- Step 6: Finalize exponential bound + unfold queryRbrKnowledgeError + exact ENNReal.pow_le_pow_left h_single_repetition_bound + · -- Prove: ∑' x, [=x|transcript computation] ≤ 1 + apply tsum_probOutput_le_one end FinalQueryRoundIOR end diff --git a/ArkLib/ProofSystem/Binius/BinaryBasefold/ReductionLogic.lean b/ArkLib/ProofSystem/Binius/BinaryBasefold/ReductionLogic.lean index b0c796148..7c8aae1e7 100644 --- a/ArkLib/ProofSystem/Binius/BinaryBasefold/ReductionLogic.lean +++ b/ArkLib/ProofSystem/Binius/BinaryBasefold/ReductionLogic.lean @@ -11,7 +11,8 @@ import ArkLib.Data.Misc.Basic set_option maxHeartbeats 400000 -- Increase if needed set_option profiler true -namespace Binius.BinaryBasefold.CoreInteraction +-- set_option profiler.threshold 50 -- Show anything taking over 10ms +namespace Binius.BinaryBasefold /-! ## Binary Basefold single steps - **Fold step** : @@ -58,6 +59,25 @@ def hEq {ιₒᵢ ιₒₒ : Type} {OracleIn : ιₒᵢ → Type} | Sum.inl j => OracleIn j | Sum.inr j => pSpec.Message j +/-- **RBR Extraction Failure Event**: Generic predicate for round-by-round knowledge soundness. + +This captures when the RBR extractor fails to produce a valid witness at round `i.1.castSucc`, +but a valid witness exists at round `i.1.succ`. This is the fundamental "bad event" that must +be bounded in all RBR knowledge soundness proofs. + +**Usage:** Instantiate with protocol-specific `kSF`, `extractor`, and transcript to get the -/ +@[reducible] +def rbrExtractionFailureEvent {ι : Type} {oSpec : OracleSpec ι} {StmtIn WitIn WitOut : Type} {n : ℕ} + {pSpec : ProtocolSpec n} {WitMid : Fin (n + 1) → Type} + (kSF : (m : Fin (n + 1)) → StmtIn → Transcript m pSpec → WitMid m → Prop) + (extractor : Extractor.RoundByRound oSpec StmtIn WitIn WitOut pSpec WitMid) + (i : pSpec.ChallengeIdx) (stmtIn : StmtIn) + (transcript : Transcript i.1.castSucc pSpec) (challenge : pSpec.Challenge i) : Prop := + ∃ witMid : WitMid i.1.succ, + ¬ kSF i.1.castSucc stmtIn transcript + (extractor.extractMid i.1 stmtIn (transcript.concat challenge) witMid) ∧ + kSF i.1.succ stmtIn (transcript.concat challenge) witMid + /-- The Pure Logic of an interactive reduction step. Parametrized by a 'Challenges' type that aggregates all verifier randomness. -/ structure ReductionLogicStep @@ -67,8 +87,8 @@ structure ReductionLogicStep (StmtOut WitOut : Type) {n : ℕ} (pSpec : ProtocolSpec n) where -- 1. The Specification (Relations) - now with indexed oracles - completeness_relIn : (StmtIn × (∀ i, OracleIn i)) × WitIn → Prop - completeness_relOut : (StmtOut × (∀ i, OracleOut i)) × WitOut → Prop + completeness_relIn : (StmtIn × (∀ i, OracleIn i)) × WitIn → Prop + completeness_relOut : (StmtOut × (∀ i, OracleOut i)) × WitOut → Prop -- 2. The Verifier (Pure Logic) verifierCheck : StmtIn → FullTranscript pSpec → Prop verifierOut : StmtIn → FullTranscript pSpec → StmtOut @@ -512,7 +532,8 @@ lemma snoc_oracle_eq_mkVerifierOStmtOut_commitStep (newOracle : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (domainIdx := ⟨i.val + 1, by omega⟩)) (transcript : FullTranscript (pSpecCommit 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i)) - (h_transcript_eq : transcript.messages ⟨0, rfl⟩ = newOracle) : + (h_transcript_eq : transcript.messages ⟨0, rfl⟩ = newOracle) + : snoc_oracle 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (destIdx := ⟨i.val + 1, by omega⟩) (h_destIdx := by rfl) oStmtIn newOracle = OracleVerifier.mkVerifierOStmtOut (commitStepLogic (mp := mp) 𝔽q β (ϑ := ϑ) @@ -847,11 +868,6 @@ The step consists of : - And `s_ℓ = ∑_{w ∈ B_0} t(r'_0, ..., r'_{ℓ-1}) = t(r'_0, ..., r'_{ℓ-1})` -/ -/-- Oracle interface instance for the final sumcheck step message -/ -instance : ∀ j, OracleInterface ((pSpecFinalSumcheckStep (L := L)).Message j) := fun j => - match j with - | ⟨0, _⟩ => OracleInterface.instDefault - /-- The Logic Instance for the final sumcheck step. This is a 1-message protocol where the prover sends the final constant c. -/ def finalSumcheckStepLogic : @@ -974,14 +990,14 @@ lemma iterated_fold_to_const_strict let P₀: L[X]_(2 ^ ℓ) := polynomialFromNovelCoeffsF₂ 𝔽q β ℓ (by omega) (fun ω => witIn.t.val.eval (bitsOfIndex ω)) let f₀ := polyToOracleFunc 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (domainIdx := 0) (P := P₀) - -- From strictOracleWitnessConsistency, we can construct strictFinalFoldingStateProp + -- From strictOracleWitnessConsistency, we can construct strictfinalSumcheckStepFoldingStateProp -- which contains strictFinalConstantConsistency, giving us the desired equality -- Extract components from h_strictOracleWitConsistency_In have h_wit_struct := h_strictOracleWitConsistency_In.1 have h_strict_oracle_folding := h_strictOracleWitConsistency_In.2 dsimp only [Fin.val_last, OracleFrontierIndex.val_mkFromStmtIdx, strictOracleFoldingConsistencyProp] at h_strict_oracle_folding - -- Construct the input for strictFinalFoldingStateProp + -- Construct the input for strictfinalSumcheckStepFoldingStateProp let stmtOut : FinalSumcheckStatementOut (L := L) (ℓ := ℓ) := { ctx := stmtIn.ctx, sumcheck_target := stmtIn.sumcheck_target, @@ -1168,7 +1184,7 @@ lemma iterated_fold_to_const_strict - Using `projectToMidSumcheckPoly_at_last`: - `witIn.H.val.eval (fun _ => 0) = eqTilde(...) * witIn.f ⟨0, ...⟩` 3. Combining these gives the verifier check equation. -/ -omit [CharP L 2] [SampleableType L] in +omit [CharP L 2] [SampleableType L] [DecidableEq 𝔽q] in lemma finalSumcheckStep_verifierCheck_passed (stmtIn : Statement (SumcheckBaseContext L ℓ) (Fin.last ℓ)) (witIn : Witness 𝔽q β (Fin.last ℓ)) @@ -1183,6 +1199,7 @@ lemma finalSumcheckStep_verifierCheck_passed let step := finalSumcheckStepLogic 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) let transcript := step.honestProverTranscript stmtIn witIn oStmtIn challenges step.verifierCheck stmtIn transcript := by + classical let step := finalSumcheckStepLogic 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) let transcript := step.honestProverTranscript stmtIn witIn oStmtIn challenges -- Simplify the verifier check to the equality we need to prove @@ -1210,70 +1227,12 @@ lemma finalSumcheckStep_verifierCheck_passed -- NOTE: this is important let h_c_eq : (transcript.messages ⟨0, rfl⟩) = witIn.t.val.eval stmtIn.challenges := by change witIn.f ⟨0, by simp only [zero_mem]⟩ = witIn.t.val.eval stmtIn.challenges - -- Since `f (f_ℓ)` is `getMidCodewords` of `t`, `f = fold(f₀, r') where f₀ = fun x => t.eval x` dsimp only [getMidCodewords, Fin.coe_ofNat_eq_mod] at h_f_eq_getMidCodewords_t rw [congr_fun h_f_eq_getMidCodewords_t ⟨0, by simp only [zero_mem]⟩] - -- ⊢ iterated_fold 𝔽q β 0 ℓ ⋯ - -- (fun x ↦ Polynomial.eval ↑x ↑(polynomialFromNovelCoeffsF₂ 𝔽q β ℓ ⋯ fun ω ↦ - -- (MvPolynomial.eval ↑↑ω) ↑witIn.t)) - -- stmtIn.challenges ⟨↑⟨0, ⋯⟩, ⋯⟩ = - -- (MvPolynomial.eval stmtIn.challenges) ↑witIn.t - -- have h_eq : @Fin.mk r (0 % ℓ) (isLt := by exact Nat.pos_of_ne_zero (by omega)) = 0 := by - -- simp only [Nat.zero_mod, Fin.mk_zero'] - let coeffs := fun (ω : Fin (2 ^ (ℓ - 0))) => witIn.t.val.eval (bitsOfIndex ω) - let res := iterated_fold_advances_evaluation_poly 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) - (i := 0) (steps := Fin.last ℓ) (destIdx := ⟨↑(Fin.last ℓ), by omega⟩) (h_destIdx := by - simp only [Fin.val_last, Fin.coe_ofNat_eq_mod, Nat.zero_mod, zero_add]) - (h_destIdx_le := by simp only; omega) (coeffs := coeffs) (r_challenges := stmtIn.challenges) - unfold polyToOracleFunc at res - simp only at res - rw [intermediate_poly_P_base 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) - (h_ℓ := by omega) (coeffs := coeffs)] at res - dsimp only [polynomialFromNovelCoeffsF₂] - change iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) 0 ↑(Fin.last ℓ) - (destIdx := ⟨↑(Fin.last ℓ), by omega⟩) (by simp only [Fin.val_last, Fin.coe_ofNat_eq_mod, - Nat.zero_mod, zero_add]) (by simp only; omega) - (fun x ↦ - Polynomial.eval (↑x) (polynomialFromNovelCoeffs 𝔽q β ℓ (h_ℓ := by omega) coeffs)) - stmtIn.challenges ⟨0, by simp only [Fin.val_last, zero_mem]⟩ = - (MvPolynomial.eval stmtIn.challenges) (witIn.t.val) - rw [res] - -- (intermediateEvaluationPoly 𝔽q β h_ℓ_add_R_rate ⟨0 % ℓ + ℓ, ⋯⟩ fun j ↦ - -- ∑ x, multilinearWeight stmtIn.challenges x * coeffs ⟨↑j * 2 ^ ℓ + ↑x, ⋯⟩) = - -- (MvPolynomial.eval stmtIn.challenges) ↑witIn.t - dsimp only [intermediateEvaluationPoly] - -- have h_empty_univ : Fin (ℓ - (Fin.last ℓ)) = Fin 0 := by - -- simp only [Fin.val_last, tsub_self] - haveI : IsEmpty (Fin (ℓ - (Fin.last ℓ).val)) := by - simp only [Fin.val_last, Nat.sub_self] - infer_instance - conv_lhs => -- Eliminate the intermediateNovelBasisX terms - dsimp only [intermediateNovelBasisX] - simp only [Finset.univ_eq_empty, Finset.prod_empty] -- eliminate the finsum over (Fin 0) - simp only [map_mul, mul_one] - rw [←map_sum] -- bring the `C` out of the sum - have h_Fin_eq : Fin (2 ^ (ℓ - ↑(Fin.last ℓ))) = Fin 1 := by - simp only [Fin.val_last, tsub_self, pow_zero] - haveI : Unique (Fin (2 ^ (ℓ - (Fin.last ℓ).val))) := by - simp only [Fin.val_last, Nat.sub_self, pow_zero] - exact Fin.instUnique - have h_default : (@default (Fin (2 ^ (ℓ - ↑(Fin.last ℓ)))) Unique.instInhabited).val = 0 := by - have hlt := (@default (Fin (2 ^ (ℓ - ↑(Fin.last ℓ)))) Unique.instInhabited).isLt - simp only [Fin.val_last, Nat.sub_self, pow_zero] at hlt - exact Nat.lt_one_iff.mp hlt - simp only [Fintype.sum_unique, Fin.val_zero, h_default] - simp only [Fin.val_last, Nat.sub_zero, zero_mul, zero_add, Fin.eta, map_sum, map_mul] - dsimp only [Nat.sub_zero, Fin.isValue, coeffs] - simp only [←map_mul, ←map_sum] - letI : NeZero (Fin.last ℓ).val := { - out := by - have h_ℓ_pos : ℓ > 0 := by exact Nat.pos_of_neZero ℓ - rw [Fin.val_last]; omega - } - let res := multilinear_eval_eq_sum_bool_hypercube (challenges := stmtIn.challenges) - (t := witIn.t) - simp only [Fin.val_last] at res - rw [res, Polynomial.eval_C]; + let h_eval := iterated_fold_to_level_ℓ_eval 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (t := witIn.t) (destIdx := ⟨Fin.last ℓ, by omega⟩) + (h_destIdx := by simp only [Fin.val_last]) (challenges := stmtIn.challenges) + exact congr_fun (h := h_eval) ⟨0, by simp only [Fin.val_last, zero_mem]⟩ -- Apply `projectToMidSumcheckPoly_at_last` to connect H.eval with eqTilde * f(0) have h_H_eval_at_zero_eq_mul : witIn.H.val.eval (fun _ => (0 : L)) = eqTilde stmtIn.ctx.t_eval_point stmtIn.challenges * @@ -1301,8 +1260,9 @@ lemma finalSumcheckStep_verifierCheck_passed - Need to connect these properties to show the verifier check passes 2. **Relation Out**: Show that the output satisfies `finalSumcheckRelOut` - - This involves showing `finalFoldingStateProp` holds for the output + - This involves showing `finalSumcheckStepFoldingStateProp` holds for the output -/ +omit [DecidableEq 𝔽q] in lemma finalSumcheckStep_is_logic_complete : (finalSumcheckStepLogic 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑)).IsStronglyComplete := by @@ -1343,8 +1303,8 @@ lemma finalSumcheckStep_is_logic_complete : -- Fact 2: Output relation holds (foldStepRelOut) simp only [finalSumcheckStepLogic, strictRoundRelation, strictRoundRelationProp, Fin.val_last, Prod.mk.eta, Set.mem_setOf_eq, strictFinalSumcheckRelOut, strictFinalSumcheckRelOutProp, - strictFinalFoldingStateProp, exists_and_right, Subtype.exists, Fin.isValue, MessageIdx, - Fin.eta, step] + strictfinalSumcheckStepFoldingStateProp, exists_and_right, Subtype.exists, + Fin.isValue, MessageIdx, Fin.eta, step] -- let r_i' := challenges ⟨1, rfl⟩ -- rw [h_verifierOStmtOut_eq]; dsimp only [strictOracleWitnessConsistency, Fin.val_last, OracleFrontierIndex.mkFromStmtIdx, @@ -1360,6 +1320,7 @@ lemma finalSumcheckStep_is_logic_complete : exact h_oracle_folding_In · -- Component 2: finalOracleFoldingConsistency funext y + classical let res := iterated_fold_to_const_strict 𝔽q β (𝓑 := 𝓑) (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (stmtIn := stmtIn) (witIn := witIn) (oStmtIn := oStmtIn) (challenges := challenges) @@ -1375,4 +1336,4 @@ lemma finalSumcheckStep_is_logic_complete : end FinalSumcheckStep end SingleIteratedSteps end -end Binius.BinaryBasefold.CoreInteraction +end Binius.BinaryBasefold diff --git a/ArkLib/ProofSystem/Binius/BinaryBasefold/Soundness.lean b/ArkLib/ProofSystem/Binius/BinaryBasefold/Soundness.lean new file mode 100644 index 000000000..57079a1f7 --- /dev/null +++ b/ArkLib/ProofSystem/Binius/BinaryBasefold/Soundness.lean @@ -0,0 +1,4515 @@ +import ArkLib.ProofSystem.Binius.BinaryBasefold.Spec +import ArkLib.ProofSystem.Binius.BinaryBasefold.Code +import ArkLib.Data.Misc.Basic +import ArkLib.Data.Probability.Instances +import Mathlib.Data.Matrix.Mul +import Mathlib.LinearAlgebra.Matrix.Nondegenerate +import CompPoly.Fields.Binary.Tower.Prelude +import ArkLib.Data.CodingTheory.InterleavedCode +import ArkLib.Data.CodingTheory.ProximityGap.DG25 +import Mathlib.Algebra.Group.Action.Defs + +namespace Binius.BinaryBasefold + +/-! **Soundness foundational Lemmas (Lemmas 4.20 - 4.25)** +- Probability reasoning: using lemmas in `DG25.lean` +- Foundational definitions and lemmas: `ArkLib.Data.FieldTheory.AdditiveNTT.AdditiveNTT.lean` + and `ArkLib.ProofSystem.Binius.BinaryBasefold.Prelude` +- Raw proof specs: `ArkLib/ProofSystem/Binius/BinaryBasefold/SoundnessFoundationsSpec.md` +-/ + +set_option maxHeartbeats 400000 -- Large file with heavy tactics; avoid heartbeat timeouts + +open OracleSpec OracleComp ProtocolSpec Finset AdditiveNTT Polynomial MvPolynomial + Binius.BinaryBasefold +open scoped NNReal +open ReedSolomon Code BerlekampWelch Function +open Finset AdditiveNTT Polynomial MvPolynomial Nat Matrix +open ProbabilityTheory + +variable {r : ℕ} [NeZero r] +variable {L : Type} [Field L] [Fintype L] [DecidableEq L] [CharP L 2] +variable (𝔽q : Type) [Field 𝔽q] [Fintype 𝔽q] [DecidableEq 𝔽q] + [h_Fq_char_prime : Fact (Nat.Prime (ringChar 𝔽q))] [hF₂ : Fact (Fintype.card 𝔽q = 2)] +variable [Algebra 𝔽q L] +variable (β : Fin r → L) [hβ_lin_indep : Fact (LinearIndependent 𝔽q β)] + [h_β₀_eq_1 : Fact (β 0 = 1)] +variable {ℓ 𝓡 ϑ : ℕ} (γ_repetitions : ℕ) [NeZero ℓ] [NeZero 𝓡] [NeZero ϑ] -- Should we allow ℓ = 0? +variable {h_ℓ_add_R_rate : ℓ + 𝓡 < r} -- ℓ ∈ {1, ..., r-1} +variable {𝓑 : Fin 2 ↪ L} +noncomputable section +variable [SampleableType L] +variable [hdiv : Fact (ϑ ∣ ℓ)] + +section IndexBounds + +omit [NeZero ℓ] in +@[simp] +lemma k_mul_ϑ_lt_ℓ {k : Fin (ℓ / ϑ)} : + ↑k * ϑ < ℓ := by + have h_mul_eq : (ℓ/ϑ) * ϑ = ℓ := Nat.div_mul_cancel hdiv.out + calc ↑k * ϑ < (ℓ / ϑ) * ϑ := Nat.mul_lt_mul_of_pos_right k.isLt (NeZero.pos ϑ) + _ = ℓ := h_mul_eq + +omit [NeZero ℓ] [NeZero ϑ] in +@[simp] +lemma k_succ_mul_ϑ_le_ℓ {k : Fin (ℓ / ϑ)} : (k.val + 1) * ϑ ≤ ℓ := by + have h_mul_eq : (ℓ/ϑ) * ϑ = ℓ := Nat.div_mul_cancel hdiv.out + calc (k.val + 1) * ϑ ≤ (ℓ / ϑ) * ϑ := Nat.mul_le_mul_right (k := ϑ) (h := by omega) + _ = ℓ := h_mul_eq + +omit [NeZero ℓ] [NeZero ϑ] in +@[simp] +lemma k_succ_mul_ϑ_le_ℓ_₂ {k : Fin (ℓ / ϑ)} : (k.val * ϑ + ϑ) ≤ ℓ := by + conv_lhs => enter [2]; rw [←Nat.one_mul ϑ] + rw [←Nat.add_mul]; + exact k_succ_mul_ϑ_le_ℓ; + +omit [NeZero r] [NeZero ℓ] [NeZero 𝓡] in +@[simp] +lemma lt_r_of_le_ℓ {h_ℓ_add_R_rate : ℓ + 𝓡 < r} {x : ℕ} (h : x ≤ ℓ) + : x < r := by omega + +omit [NeZero r] [NeZero ℓ] [NeZero 𝓡] in +@[simp] +lemma lt_r_of_lt_ℓ {h_ℓ_add_R_rate : ℓ + 𝓡 < r} {x : ℕ} (h : x < ℓ) + : x < r := by omega + +end IndexBounds + +open scoped NNReal ProbabilityTheory + +omit [CharP L 2] [SampleableType L] in +/-- **Probability bound for the bad sumcheck event** (Schwartz-Zippel). +When the verifier challenge `r_i'` is uniform over `L`, the probability that two distinct +degree-≤2 round polynomials agree at `r_i'` is at most `2 / |L|`. -/ +lemma probability_bound_badSumcheckEventProp (h_i h_star : L⦃≤ 2⦄[X]) : + Pr_{ let r_i' ← $ᵖ L }[ badSumcheckEventProp r_i' h_i h_star ] ≤ + (2 : ℝ≥0) / Fintype.card L := by + unfold badSumcheckEventProp + by_cases h_ne : h_i ≠ h_star + · simp only [ne_eq, h_ne, not_false_eq_true, true_and, ENNReal.coe_ofNat] + exact prob_poly_agreement_degree_two (p := h_i) (q := h_star) (h_ne := h_ne) + · simp only [h_ne, false_and, ENNReal.coe_ofNat] + -- lhs is `Pr_{ let r_i' ← $ᵖ ... }[ False ]` + simp only [PMF.monad_pure_eq_pure, PMF.monad_bind_eq_bind, PMF.bind_const, PMF.pure_apply, + eq_iff_iff, iff_false, not_true_eq_false, ↓reduceIte, _root_.zero_le] + +namespace QueryPhase + +/-! +## Common Proximity Check Helpers + +These functions extract the shared logic between `queryOracleVerifier` +and `queryKnowledgeStateFunction` for proximity testing, allowing code reuse +and ensuring both implementations follow the same logic. +-/ + +/-- Extract suffix starting at position `destIdx` from a full challenge. -/ +def extractSuffixFromChallenge (v : sDomain 𝔽q β h_ℓ_add_R_rate ⟨0, by omega⟩) + (destIdx : Fin r) (h_destIdx_le : destIdx ≤ ℓ) : + sDomain 𝔽q β h_ℓ_add_R_rate destIdx := + iteratedQuotientMap 𝔽q β h_ℓ_add_R_rate (i := ⟨0, by omega⟩) (k := destIdx.val) + (h_destIdx := by simp only [zero_add]) (h_destIdx_le := h_destIdx_le) (x := v) + +omit [CharP L 2] [SampleableType L] [DecidableEq 𝔽q] hF₂ [NeZero 𝓡] in +/-- **Congruence Lemma for Challenge Suffixes**: +Allows proving equality between two suffix extractions when the destination indices +are proven equal (`destIdx = destIdx'`), handling the necessary type casting. -/ +lemma extractSuffixFromChallenge_congr_destIdx + (v : sDomain 𝔽q β h_ℓ_add_R_rate ⟨0, by omega⟩) + {destIdx destIdx' : Fin r} + (h_idx_eq : destIdx = destIdx') + (h_le : destIdx ≤ ℓ) + (h_le' : destIdx' ≤ ℓ) : + extractSuffixFromChallenge 𝔽q β v destIdx h_le = + cast (by rw [h_idx_eq]) (extractSuffixFromChallenge 𝔽q β v destIdx' h_le') := by + subst h_idx_eq; rfl + +omit [CharP L 2] [SampleableType L] [DecidableEq 𝔽q] h_β₀_eq_1 in +/-- **First Oracle Equals Polynomial Oracle Function**: +When `strictOracleFoldingConsistencyProp` holds, the first oracle (`getFirstOracle`) equals +the polynomial oracle function `f₀` derived from the multilinear polynomial `t`. +This follows from the consistency property for `j = 0`, where `iterated_fold` with 0 steps +is the identity function. -/ +lemma polyToOracleFunc_eq_getFirstOracle + (t : MultilinearPoly L ℓ) + (i : Fin (ℓ + 1)) + (challenges : Fin i → L) + (oStmt : ∀ j, OracleStatement 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i j) + (h_consistency : strictOracleFoldingConsistencyProp 𝔽q β (t := t) (i := i) + (challenges := challenges) (oStmt := oStmt)) : + let P₀ : L[X]_(2 ^ ℓ) := + polynomialFromNovelCoeffsF₂ 𝔽q β ℓ (by omega) (fun ω => t.val.eval (bitsOfIndex ω)) + let f₀ := polyToOracleFunc 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (domainIdx := 0) (P := P₀) + f₀ = getFirstOracle 𝔽q β oStmt := by + intro P₀ f₀ + -- Use strictOracleFoldingConsistencyProp for j = 0 + have h_pos : 0 < toOutCodewordsCount ℓ ϑ i := by + exact (instNeZeroNatToOutCodewordsCount ℓ ϑ i).pos + have h_first_oracle := h_consistency ⟨0, by omega⟩ + dsimp only [strictOracleFoldingConsistencyProp] at h_first_oracle + dsimp only [f₀, P₀, getFirstOracle] at h_first_oracle ⊢ + rw [h_first_oracle] + funext y + conv_rhs => + rw [iterated_fold_congr_steps_index 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (steps' := 0) + (h_destIdx := by simp only [Nat.zero_mod, zero_mul, Fin.coe_ofNat_eq_mod, add_zero]) + (h_destIdx_le := by simp only [zero_mul, zero_le]) + (h_steps_eq_steps' := by simp only [zero_mul])] + rw [iterated_fold_zero_steps 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := 0) + (h_destIdx := by simp only [Nat.zero_mod, zero_mul, Fin.coe_ofNat_eq_mod])] + conv_rhs => simp only [cast_cast, cast_eq]; simp only [←fun_eta_expansion] + +/-- Decompose challenge v at position i into (fiberIndex, suffix). + This is the inverse of `Nat.joinBits` in some sense. + Uses loose indexing with `Fin r`. -/ +def decomposeChallenge (v : sDomain 𝔽q β h_ℓ_add_R_rate ⟨0, by omega⟩) + (i : Fin r) {destIdx : Fin r} (steps : ℕ) + (h_destIdx_le : destIdx ≤ ℓ) : + Fin (2^steps) × sDomain 𝔽q β h_ℓ_add_R_rate destIdx := + (extractMiddleFinMask 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (v:=v) (i:=i) (steps:=steps), + extractSuffixFromChallenge 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (v:=v) + (destIdx:=destIdx) (h_destIdx_le:=h_destIdx_le)) + +-- TODO: KEY LEMMA for connecting fiber queries to challenge decomposition +-- TODO: Lemma connecting queryFiberPoints to extractMiddleFinMask + +def queryRbrKnowledgeError_singleRepetition := ((1/2 : ℝ≥0) + (1 : ℝ≥0) / (2 * 2^𝓡)) + +/-- RBR knowledge error for the query phase. +Proximity testing error rate: `(1/2 + 1/(2 * 2^𝓡))^γ` -/ +def queryRbrKnowledgeError := fun _ : (pSpecQuery 𝔽q β γ_repetitions + (h_ℓ_add_R_rate := h_ℓ_add_R_rate)).ChallengeIdx => + (queryRbrKnowledgeError_singleRepetition (𝓡 := 𝓡))^γ_repetitions + +/-- Oracle query helper: query a committed codeword at a given domain point. + Restricted to codeword indices where the oracle range is L. -/ +def queryCodeword (j : Fin (toOutCodewordsCount ℓ ϑ (Fin.last ℓ))) + (point : (sDomain 𝔽q β h_ℓ_add_R_rate) ⟨oraclePositionToDomainIndex ℓ ϑ j, by omega⟩) : + OptionT (OracleComp ([]ₒ + + ([OracleStatement 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ( Fin.last ℓ)]ₒ + + [(pSpecQuery 𝔽q β γ_repetitions (h_ℓ_add_R_rate := h_ℓ_add_R_rate)).Message]ₒ))) L := + query (spec := [OracleStatement 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (Fin.last ℓ)]ₒ) + ⟨⟨j, by omega⟩, point⟩ + +section FinalQueryRoundIOR + +/-! +### IOR Implementation for the Final Query Round +-/ +def getChallengeSuffix (k : Fin (ℓ / ϑ)) (v : sDomain 𝔽q β h_ℓ_add_R_rate ⟨0, by omega⟩) : + let i := k.val * ϑ + have h_i_add_ϑ_le_ℓ : i + ϑ ≤ ℓ := k_succ_mul_ϑ_le_ℓ_₂ (k := k) + let destIdx : Fin r := ⟨i + ϑ, by omega⟩ + sDomain 𝔽q β h_ℓ_add_R_rate destIdx := + have h_i_add_ϑ_le_ℓ := k_succ_mul_ϑ_le_ℓ_₂ (k := k) + extractSuffixFromChallenge 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (v:=v) (destIdx := ⟨k.val * ϑ + ϑ, by omega⟩) (h_destIdx_le:=by omega) + +def challengeSuffixToFin (k : Fin (ℓ / ϑ)) + (suffix : sDomain 𝔽q β h_ℓ_add_R_rate ⟨k.val * ϑ + ϑ, by + have := k_succ_mul_ϑ_le_ℓ_₂ (k := k); omega⟩) : Fin (2 ^ (ℓ + 𝓡 - (k.val * ϑ + ϑ))) := + let i := k.val * ϑ + have h_i_add_ϑ_le_ℓ : i + ϑ ≤ ℓ := k_succ_mul_ϑ_le_ℓ_₂ (k := k) + let destIdx : Fin r := ⟨i + ϑ, by omega⟩ + sDomainToFin 𝔽q β h_ℓ_add_R_rate (i := ⟨k.val * ϑ + ϑ, by omega⟩) (h_i := by + simp only [k_succ_mul_ϑ_le_ℓ_₂, Nat.lt_add_of_pos_right_of_le]) suffix + +/-- Return the point `f^(i)(u_0, ..., u_{ϑ-1}, v_{i+ϑ}, ..., v_{ℓ+R-1})` +for a fiber index `u ∈ B_ϑ`. -/ +noncomputable def getFiberPoint + (k : Fin (ℓ / ϑ)) (v : sDomain 𝔽q β h_ℓ_add_R_rate ⟨0, by omega⟩) (u : Fin (2 ^ ϑ)) : + (sDomain 𝔽q β h_ℓ_add_R_rate) (i := ⟨oraclePositionToDomainIndex ℓ ϑ (i := Fin.last ℓ) + (positionIdx := ⟨k, by simp only [toOutCodewordsCount_last, Fin.is_lt]⟩), + lt_r_of_lt_ℓ (x := k.val * ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (h := k_mul_ϑ_lt_ℓ)⟩) := + by + simpa [oraclePositionToDomainIndex] using + (qMap_total_fiber 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := ⟨k.val * ϑ, + lt_r_of_lt_ℓ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (x := k.val * ϑ) + (h := k_mul_ϑ_lt_ℓ (k := k))⟩) + (steps := ϑ) + (h_destIdx := by rfl) + (h_destIdx_le := by + exact k_succ_mul_ϑ_le_ℓ_₂ (k := k)) + (y := getChallengeSuffix 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (k := k) (v := v)) + u) + +section MonadicOracleVerification +/-! +### Helper Functions for Verifier Logic + +These functions break down the verifier's proximity checking logic into composable blocks, +making it easier to prove properties about each component separately. +-/ + +/-- Query all fiber points for a given folding step. + Returns a list of evaluations `f^(i)(u_0, ..., u_{ϑ-1}, v_{i+ϑ}, ..., v_{ℓ+R-1})` + for all `u ∈ B_ϑ`. + Note: `oStmtIn` is accessed via oracle queries in the OracleComp context. -/ +noncomputable def queryFiberPoints + (k : Fin (ℓ / ϑ)) + (v : sDomain 𝔽q β h_ℓ_add_R_rate ⟨0, by omega⟩) : + OptionT + (OracleComp + ([]ₒ + ([OracleStatement 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (Fin.last ℓ)]ₒ + + [(pSpecQuery 𝔽q β γ_repetitions (h_ℓ_add_R_rate := h_ℓ_add_R_rate)).Message]ₒ))) + (Vector L (2^ϑ)) := do + let k_th_oracleIdx : Fin (toOutCodewordsCount ℓ ϑ (Fin.last ℓ)) := + ⟨k, by simp only [toOutCodewordsCount, Fin.val_last, lt_self_iff_false, ↓reduceIte, add_zero, + Fin.is_lt]⟩ + -- 2. Map over the Vector monadically + let results : Vector L (2^ϑ) ← (⟨Array.finRange (2^ϑ), by simp only [Array.size_finRange]⟩ + : Vector (Fin (2^ϑ)) (2^ϑ)).mapM (fun (u : Fin (2^ϑ)) => do + queryCodeword 𝔽q β (γ_repetitions := γ_repetitions) (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (j := k_th_oracleIdx) (point := + getFiberPoint 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (k := k) (v := v) (u := u)) + ) + pure results + +/-- Check a single folding step: query fiber points, verify consistency, and compute next value. + Returns `(c_next, all_checks_passed)` where `c_next` is the computed folded value + and `all_checks_passed` indicates if all consistency checks passed. + Note: `oStmtIn` is accessed via oracle queries in the OracleComp context. -/ +noncomputable def checkSingleFoldingStep + (k_val : Fin (ℓ / ϑ)) (c_cur : L) (v : sDomain 𝔽q β h_ℓ_add_R_rate ⟨0, by omega⟩) + (stmt : FinalSumcheckStatementOut (L := L) (ℓ := ℓ)) : + OptionT (OracleComp ([]ₒ + ([OracleStatement 𝔽q β (ϑ:=ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (Fin.last ℓ)]ₒ + [(pSpecQuery 𝔽q β + γ_repetitions (h_ℓ_add_R_rate := h_ℓ_add_R_rate)).Message]ₒ))) L := do + let i := k_val.val * ϑ + have h_k: k_val ≤ (ℓ/ϑ - 1) := by omega + have h_i_add_ϑ_le_ℓ : i + ϑ ≤ ℓ := by + calc i + ϑ = k_val * ϑ + ϑ := by omega + _ ≤ (ℓ/ϑ - 1) * ϑ + ϑ := by + apply Nat.add_le_add_right; apply Nat.mul_le_mul_right; omega + _ = ℓ/ϑ * ϑ := by + rw [Nat.sub_mul, one_mul, Nat.sub_add_cancel]; + conv_lhs => rw [←one_mul ϑ] + apply Nat.mul_le_mul_right; omega + _ ≤ ℓ := by apply Nat.div_mul_le_self; + have h_i_lt_ℓ : i < ℓ := by + calc i ≤ ℓ - ϑ := by omega + _ < ℓ := by + apply Nat.sub_lt (by exact Nat.pos_of_neZero ℓ) (by exact Nat.pos_of_neZero ϑ) + let f_i_on_fiber ← queryFiberPoints 𝔽q β (γ_repetitions := γ_repetitions) (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) k_val v + -- Check consistency if i > 0 + if h_i_pos : i > 0 then + let oracle_point_idx := extractMiddleFinMask 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (v:=v) (i:=⟨i, by omega⟩) (steps:=ϑ) + let f_i_val := f_i_on_fiber.get oracle_point_idx + guard (c_cur = f_i_val) + -- Compute next folded value + let destIdx : Fin r := ⟨i + ϑ, by omega⟩ + let next_suffix_of_v : sDomain 𝔽q β h_ℓ_add_R_rate destIdx := + getChallengeSuffix (k := k_val) (v := v) + let cur_challenge_batch : Fin ϑ → L := fun j => + stmt.challenges ⟨i + j.val, by simp only [Fin.val_last]; omega⟩ + -- c_next = folded value at step k (logical counterpart: `logical_computeFoldedValue`) + let c_next : L := single_point_localized_fold_matrix_form 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i:=⟨i, by omega⟩) (steps:=ϑ) (destIdx:=destIdx) (h_destIdx:=by dsimp only [destIdx]) + (h_destIdx_le:=by omega) (r_challenges:=cur_challenge_batch) (y:=next_suffix_of_v) + (fiber_eval_mapping := f_i_on_fiber.get) + return c_next + +/-- Check a single repetition: iterate through all folding steps and verify final consistency. + Returns `true` if all checks pass, `false` otherwise. + Note: `oStmtIn` is accessed via oracle queries in the OracleComp context. + Uses `mut` + `for` loop for true early termination (stops immediately on first failure). + For proofs, we'll need to reason about the loop invariant that `c_cur` maintains the + correct accumulated value through iterations. -/ +noncomputable def checkSingleRepetition + (v : sDomain 𝔽q β h_ℓ_add_R_rate ⟨0, by omega⟩) + (stmt : FinalSumcheckStatementOut (L := L) (ℓ := ℓ)) (final_constant : L) : + OptionT (OracleComp ([]ₒ + ([OracleStatement 𝔽q β (ϑ:=ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (Fin.last ℓ)]ₒ + [(pSpecQuery 𝔽q β + γ_repetitions (h_ℓ_add_R_rate := h_ℓ_add_R_rate)).Message]ₒ))) Unit := do + let mut c_cur : L := 0 -- Will be initialized in first iteration + -- Iterate through the `ℓ/ϑ` adjacent pairs of oracles & validate local folding consistency + -- Early termination: stops immediately on first failure via `return false` + for k_val in List.finRange (ℓ / ϑ) do + let c_next ← checkSingleFoldingStep 𝔽q β (ϑ:=ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (γ_repetitions := γ_repetitions) + ⟨k_val, by omega⟩ c_cur v stmt + c_cur := c_next + -- Final check: c_ℓ ?= final_constant + guard (c_cur = final_constant) + +end MonadicOracleVerification + +section LogicalOracleVerification + +/-! +### Proximity check spec: logical defs (mirror monadic verifier exactly) + +Logical (non-monadic) versions that capture 100% of the monadic definitions. + +Key property from docstring: + if `i > 0` then `V` requires `c_i ?= f^(i)(v_i, ..., v_{ℓ+R-1})`. + `V` defines `c_{i+ϑ} := fold(f^(i), r'_i, ..., r'_{i+ϑ-1})(v_{i+ϑ}, ..., v_{ℓ+R-1})`. + `V` requires `c_ℓ ?= c`. + +The logical definitions mirror this exactly: +- `logical_queryFiberPoints` → Queries all `u` for a given step `k` (where `i = k·ϑ`) +- `logical_computeFoldedValue` → Computes `c_{i+ϑ}` via folding +- `logical_checkSingleFoldingStep` → Performs the guard check when `i > 0` +- `logical_checkSingleRepetition` → Enforces all guard checks and the final equality +- `logical_proximityChecksSpec` → Lifts to all `γ` repetitions + +### Correspondence with Monadic Implementation + +Each monadic function has a logical counterpart: +- `queryFiberPoints` ↔ `logical_queryFiberPoints` +- `checkSingleFoldingStep` ↔ `logical_checkSingleFoldingStep` + `logical_computeFoldedValue` +- `checkSingleRepetition` ↔ `logical_checkSingleRepetition` +-/ + +/-- Fiber evals for all u (logical; same as monadic `queryFiberPoints`). -/ +def logical_queryFiberPoints + (oStmt : ∀ j, OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ (Fin.last ℓ) j) + (k : Fin (ℓ / ϑ)) (v : sDomain 𝔽q β h_ℓ_add_R_rate ⟨0, by omega⟩) : Fin (2 ^ ϑ) → L := + let k_th_oracleIdx : Fin (toOutCodewordsCount ℓ ϑ (Fin.last ℓ)) := + ⟨k.val, by simp only [toOutCodewordsCount, Fin.val_last, lt_self_iff_false, ↓reduceIte, + add_zero, Fin.is_lt]⟩ + fun u => oStmt k_th_oracleIdx (getFiberPoint 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) k v u) + +/-- Compute folded value at step `k` (same as `c_next` in monadic `checkSingleFoldingStep`). +This takes `f_i_on_fiber` - the list of `2^ϑ` fiber evaluations on oracle domain +`k*ϑ`, folds them into a single oracle evaluation on oracle domain `(k+1)*ϑ`, i.e. `c_{i+ϑ}`. -/ +def logical_computeFoldedValue + (k : Fin (ℓ / ϑ)) (v : sDomain 𝔽q β h_ℓ_add_R_rate ⟨0, by omega⟩) + (stmt : FinalSumcheckStatementOut (L := L) (ℓ := ℓ)) + (f_i_on_fiber : Fin (2 ^ ϑ) → L) : L := + let i := k.val * ϑ + have h_i_add_ϑ_le_ℓ : i + ϑ ≤ ℓ := k_succ_mul_ϑ_le_ℓ_₂ (k := k) + let destIdx : Fin r := ⟨i + ϑ, by omega⟩ + let next_suffix_of_v : sDomain 𝔽q β h_ℓ_add_R_rate destIdx := + getChallengeSuffix (k := k) (v := v) + let cur_challenge_batch : Fin ϑ → L := fun j => + stmt.challenges ⟨i + j.val, by simp only [Fin.val_last]; omega⟩ + single_point_localized_fold_matrix_form 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := ⟨i, by omega⟩) (steps := ϑ) (destIdx := destIdx) (h_destIdx := by dsimp only [destIdx]) + (h_destIdx_le := by omega) (r_challenges := cur_challenge_batch) (y := next_suffix_of_v) + (fiber_eval_mapping := f_i_on_fiber) + +/-- Check a single folding step at k (logical; mirrors monadic `checkSingleFoldingStep`). + + Captures the guard check from docstring: + if `i > 0` then `V` requires `c_i ?= f^(i)(v_i, ..., v_{ℓ+R-1})` + Where c_i is the fold value from step k-1, and f^(i)(v_i,...) is the oracle + at position k evaluated at the "overlap" point. + Note: h_i_pos implies k > 0, so k-1 is valid. -/ +def logical_checkSingleFoldingStep + (oStmt : ∀ j, OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ (Fin.last ℓ) j) + (k : Fin (ℓ / ϑ)) (v : sDomain 𝔽q β h_ℓ_add_R_rate ⟨0, by omega⟩) + (stmt : FinalSumcheckStatementOut (L := L) (ℓ := ℓ)) : Prop := + -- Index k represents + let i := k.val * ϑ + -- `k ∈ {0, 1, ..., ℓ/ϑ-1}`, `i ∈ {0, ϑ, 2ϑ, ..., ℓ-ϑ}` + -- **NOTE**: this definition is the + -- `c_i ?= f^(i)(v_i, ..., v_{ℓ+R-1})` check at inner repetition `k` + have h_i_add_ϑ_le_ℓ : i + ϑ ≤ ℓ := k_succ_mul_ϑ_le_ℓ_₂ (k := k) + let f_i_on_fiber := logical_queryFiberPoints 𝔽q β oStmt k v + -- Actually we only need value of one point of `f_i_on_fiber` for this check + -- This matches monadic: `guard (c_cur = f_i_val)` + if h_i_pos : i > 0 then + -- h_i_pos implies k > 0 (since i = k * ϑ and ϑ > 0) + have h_k_pos : k.val > 0 := Nat.pos_of_mul_pos_right h_i_pos + let k_prev : Fin (ℓ / ϑ) := ⟨k.val - 1, by omega⟩ + -- c_cur = fold value from step k-1 + let f_prev_on_fiber := logical_queryFiberPoints 𝔽q β oStmt k_prev v + -- In logical specification, we look backwards at oracle domain `(k-1)*ϑ` to query + -- the fiber evaluations `f_prev_on_fiber`, fold them to create `c_cur`. + -- In the monadic `checkSingleFoldingStep`, `c_cur` is automatically available. + let c_cur := logical_computeFoldedValue 𝔽q β k_prev v stmt f_prev_on_fiber + -- f_i_val = oracle value at overlap point + let oracle_point_idx := extractMiddleFinMask 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (v := v) (i := ⟨i, by omega⟩) (steps := ϑ) + let f_i_val := f_i_on_fiber oracle_point_idx + c_cur = f_i_val + else True + +/-- Logical check specific to step k. + If k is an intermediate index, it is the consistency of the folding step. + If k is the terminal index, it is the constant check. -/ +def logical_stepCondition (oStmt : ∀ j, OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ (Fin.last ℓ) j) + (k : Fin (ℓ / ϑ + 1)) (v : sDomain 𝔽q β h_ℓ_add_R_rate ⟨0, by omega⟩) + (stmt : FinalSumcheckStatementOut (L := L) (ℓ := ℓ)) (final_constant : L) : Prop := + if h_k_lt : k.val < (ℓ / ϑ) then + -- Condition for `k ∈ {0, 1, ..., ℓ/ϑ-1}` + logical_checkSingleFoldingStep 𝔽q β oStmt ⟨k.val, h_k_lt⟩ v stmt + else + -- Condition for the final state k = `ℓ/ϑ` + have h_div_pos : ℓ / ϑ > 0 := + Nat.div_pos (Nat.le_of_dvd (Nat.pos_of_neZero ℓ) hdiv.out) (Nat.pos_of_neZero ϑ) + let k_last : Fin (ℓ / ϑ) := ⟨ℓ / ϑ - 1, by omega⟩ + let f_last_on_fiber := logical_queryFiberPoints 𝔽q β oStmt k_last v + logical_computeFoldedValue 𝔽q β k_last v stmt f_last_on_fiber = final_constant + +/-- Check a single repetition (logical; mirrors monadic `checkSingleRepetition`). + Captures: + 1. All guard checks pass: ∀ k, logical_checkSingleFoldingStep + 2. Final check: c_ℓ = final_constant (fold at last step equals final constant) -/ +def logical_checkSingleRepetition + (oStmt : ∀ j, OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ (Fin.last ℓ) j) + (v : sDomain 𝔽q β h_ℓ_add_R_rate ⟨0, by omega⟩) + (stmt : FinalSumcheckStatementOut (L := L) (ℓ := ℓ)) (final_constant : L) : Prop := + ∀ k : Fin (ℓ / ϑ + 1), + logical_stepCondition 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (oStmt := oStmt) (k := k) (v := v) (stmt := stmt) (final_constant := final_constant) + +/-- Proximity checks spec: for all γ repetitions, `logical_checkSingleRepetition` holds. -/ +def logical_proximityChecksSpec + (γ_challenges : Fin γ_repetitions → sDomain 𝔽q β h_ℓ_add_R_rate ⟨0, by omega⟩) + (oStmt : ∀ j, OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ (Fin.last ℓ) j) + (stmt : FinalSumcheckStatementOut (L := L) (ℓ := ℓ)) (final_constant : L) : Prop := + ∀ rep : Fin γ_repetitions, + logical_checkSingleRepetition 𝔽q β oStmt (γ_challenges rep) stmt final_constant + +lemma getFiberPoint_eq_qMap_total_fiber + (k : Fin (ℓ / ϑ)) (v : sDomain 𝔽q β h_ℓ_add_R_rate ⟨0, by omega⟩) + (u : Fin (2 ^ ϑ)) : + getFiberPoint 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) k v u = + qMap_total_fiber 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := ⟨k.val * ϑ, + lt_r_of_lt_ℓ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (x := k.val * ϑ) + (h := k_mul_ϑ_lt_ℓ (k := k))⟩) + (steps := ϑ) (h_destIdx := by rfl) + (h_destIdx_le := by exact k_succ_mul_ϑ_le_ℓ_₂ (k := k)) + (y := getChallengeSuffix 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (k := k) (v := v)) u := by + unfold getFiberPoint + simp only [oraclePositionToDomainIndex, id_eq] + +lemma logical_queryFiberPoints_eq_fiberEvaluations + (oStmt : ∀ j, OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ (Fin.last ℓ) j) + (k : Fin (ℓ / ϑ)) (v : sDomain 𝔽q β h_ℓ_add_R_rate ⟨0, by omega⟩) : + logical_queryFiberPoints 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) oStmt k v = + fiberEvaluations 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := ⟨k.val * ϑ, + lt_r_of_lt_ℓ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (x := k.val * ϑ) + (h := k_mul_ϑ_lt_ℓ (k := k))⟩) (steps := ϑ) + (h_destIdx := by rfl) (h_destIdx_le := by + exact k_succ_mul_ϑ_le_ℓ_₂ (k := k)) + (f := oStmt ⟨k.val, by + simp only [toOutCodewordsCount, Fin.val_last, lt_self_iff_false, ↓reduceIte, add_zero, + Fin.is_lt]⟩) + (y := getChallengeSuffix 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (k := k) (v := v)) := by + funext u + simp only [logical_queryFiberPoints, fiberEvaluations] + rw [getFiberPoint_eq_qMap_total_fiber 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) k v u] + +lemma logical_computeFoldedValue_eq_iterated_fold + (oStmt : ∀ j, OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ (Fin.last ℓ) j) + (k : Fin (ℓ / ϑ)) (v : sDomain 𝔽q β h_ℓ_add_R_rate ⟨0, by omega⟩) + (stmt : FinalSumcheckStatementOut (L := L) (ℓ := ℓ)) : + logical_computeFoldedValue 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) k v stmt + (logical_queryFiberPoints 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) oStmt k v) + = + iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := ⟨k.val * ϑ, + lt_r_of_lt_ℓ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (x := k.val * ϑ) + (h := k_mul_ϑ_lt_ℓ (k := k))⟩) (steps := ϑ) + (h_destIdx := by rfl) (h_destIdx_le := by + exact k_succ_mul_ϑ_le_ℓ_₂ (k := k)) + (f := oStmt ⟨k.val, by + simp only [toOutCodewordsCount, Fin.val_last, lt_self_iff_false, ↓reduceIte, add_zero, + Fin.is_lt]⟩) + (r_challenges := fun j => + stmt.challenges ⟨k.val * ϑ + j.val, by + have h_le : k.val * ϑ + ϑ ≤ ℓ := k_succ_mul_ϑ_le_ℓ_₂ (k := k) + have h_lt : k.val * ϑ + j.val < k.val * ϑ + ϑ := by + exact Nat.add_lt_add_left j.isLt (k.val * ϑ) + exact lt_of_lt_of_le h_lt h_le⟩) + (y := getChallengeSuffix 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (k := k) (v := v)) := by + simp only [logical_computeFoldedValue] + rw [iterated_fold_eq_matrix_form 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := ⟨k.val * ϑ, + lt_r_of_lt_ℓ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (x := k.val * ϑ) + (h := k_mul_ϑ_lt_ℓ (k := k))⟩) (steps := ϑ) + (h_destIdx := by rfl) (h_destIdx_le := by exact k_succ_mul_ϑ_le_ℓ_₂ (k := k)) + (f := oStmt ⟨k.val, by + simp only [toOutCodewordsCount, Fin.val_last, lt_self_iff_false, ↓reduceIte, add_zero, + Fin.is_lt]⟩) + (r_challenges := fun j => + stmt.challenges ⟨k.val * ϑ + j.val, by + have h_le : k.val * ϑ + ϑ ≤ ℓ := k_succ_mul_ϑ_le_ℓ_₂ (k := k) + have h_lt : k.val * ϑ + j.val < k.val * ϑ + ϑ := by + exact Nat.add_lt_add_left j.isLt (k.val * ϑ) + exact lt_of_lt_of_le h_lt h_le⟩)] + simpa [localized_fold_matrix_form, single_point_localized_fold_matrix_form, + logical_queryFiberPoints_eq_fiberEvaluations 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + oStmt k v] + +end LogicalOracleVerification + +end FinalQueryRoundIOR + +end QueryPhase + + +omit [Fintype L] [DecidableEq L] [CharP L 2] in +lemma multilinearWeight_bitsOfIndex_eq_indicator {n : ℕ} (j k : Fin (2 ^ n)) : + multilinearWeight (F := L) (r := bitsOfIndex k) (i := j) = if j = k then 1 else 0 := by + set r_k := bitsOfIndex (L := L) k with h_r_k + unfold multilinearWeight + -- NOTE: maybe we can generalize this into a lemma? + -- ⊢ (∏ j_1, if (↑j).testBit ↑j_1 = true then r_k j_1 else 1 - r_k j_1) = if j = k then 1 else 0 + dsimp only [bitsOfIndex, r_k] + simp_rw [Nat.testBit_eq_getBit] + by_cases h_eq : j = k + · simp only [h_eq, ↓reduceIte] + have h_eq: ∀ (x : Fin n), + ((if (x.val).getBit ↑k = 1 then if (x.val).getBit ↑k = 1 then (1 : L) else (0 : L) else 1 - if (x.val).getBit ↑k = 1 then (1 : L) else (0 : L))) = (1 : L) := by + intro x + by_cases h_eq : (x.val).getBit ↑k = 1 + · simp only [h_eq, ↓reduceIte] + · simp only [h_eq, ↓reduceIte, sub_zero] + simp_rw [h_eq] + simp only [prod_const_one] + · simp only [h_eq, ↓reduceIte] + -- ⊢ (∏ x, if (↑x).getBit ↑j = 1 then if (↑x).getBit ↑k = 1 then 1 else 0 else 1 - if (↑x).getBit ↑k = 1 then 1 else 0) = 0 + rw [Finset.prod_eq_zero_iff] + -- ⊢ ∃ a ∈ univ, + -- (if (↑a).getBit ↑j = 1 then if (↑a).getBit ↑k = 1 then 1 else 0 else 1 - if (↑a).getBit ↑k = 1 then 1 else 0) = 0 + let exists_bit_diff_idx := Nat.exist_bit_diff_if_diff (a := j) (b := k) (h_a_ne_b := h_eq) + rcases exists_bit_diff_idx with ⟨bit_diff_idx, h_bit_diff_idx⟩ + have h_getBit_of_j_lt_2 : Nat.getBit (k := bit_diff_idx.val) (n := j) < 2 := by + exact Nat.getBit_lt_2 (k := bit_diff_idx.val) (n := j) + have h_getBit_of_k_lt_2 : Nat.getBit (k := bit_diff_idx.val) (n := k) < 2 := by + exact Nat.getBit_lt_2 (k := bit_diff_idx.val) (n := k) + use bit_diff_idx + constructor + · simp only [mem_univ] + · by_cases h_bit_diff_of_j_eq_0 : Nat.getBit (k := bit_diff_idx.val) (n := j) = 0 + · simp only [h_bit_diff_of_j_eq_0, zero_ne_one, ↓reduceIte] + -- ⊢ (1 - if (↑bit_diff_idx).getBit ↑k = 1 then 1 else 0) = 0 + have h_bit_diff_of_k_eq_1 : Nat.getBit (k := bit_diff_idx.val) (n := k) = 1 := by + omega + simp only [h_bit_diff_of_k_eq_1, ↓reduceIte, sub_self] + · have h_bit_diff_of_j_eq_1 : + Nat.getBit (k := bit_diff_idx.val) (n := j) = 1 := by + omega + have h_bit_diff_of_k_eq_0 : Nat.getBit (k := bit_diff_idx.val) (n := k) = 0 := by + omega + simp only [h_bit_diff_of_j_eq_1, ↓reduceIte, h_bit_diff_of_k_eq_0, zero_ne_one] + +omit [Fintype L] [DecidableEq L] [CharP L 2] in +/-- **Key Property of Tensor Expansion with Binary Challenges**: +When `r = bitsOfIndex k`, the tensor expansion `challengeTensorExpansion n r` +is the indicator vector for index `k` (i.e., 1 at position `k`, 0 elsewhere). +This is a fundamental property used in both Proposition 4.20 and Lemma 4.21. -/ +lemma challengeTensorExpansion_bitsOfIndex_is_eq_indicator {n : ℕ} (k : Fin (2 ^ n)) : + -- Key Property: Tensor(r_k) is the indicator vector for k. + -- Tensor(r_k)[j] = 1 if j=k, 0 if j≠k. + challengeTensorExpansion n (r := bitsOfIndex (L := L) k) = fun j => if j = k then 1 else 0 := by + -- Let r_k be the bit-vector corresponding to index k + funext j + unfold challengeTensorExpansion + -- ⊢ multilinearWeight r_k j = if j = k then 1 else 0 + apply multilinearWeight_bitsOfIndex_eq_indicator + +section Lift_PreTensorCombine + +/-! **Interleaved Word Construction (Supporting definition for Lemma 4.21)** +Constructs the rows `f_j^{(i+steps)}` of the interleaved word. +For a fixed row index `j` and a domain point `y ∈ S^{i+steps}`, +the value is the `j`-th entry of the vector `M_y * fiber_vals`. +-- NOTE: the way we define `ι` as `sDomain 𝔽q β h_ℓ_add_R_rate ⟨i + steps, by omega⟩` instead of +`Fin` requires using the generic versions of code/proximity gap lemmas. +We don't have a unified mat-mul formula for this, because the `M_y` matrix varies over `y` -/ +def preTensorCombine_WordStack (i : Fin ℓ) (steps : ℕ) {destIdx : Fin r} + (h_destIdx : destIdx.val = i.val + steps) (h_destIdx_le : destIdx ≤ ℓ) + (f_i : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ⟨i, by omega⟩) : + WordStack (A := L) (κ := Fin (2 ^ steps)) + (ι := sDomain 𝔽q β h_ℓ_add_R_rate destIdx) := fun j y => + -- 1. Calculate the folding matrix M_y + let M_y : Matrix (Fin (2 ^ steps)) (Fin (2 ^ steps)) L := + foldMatrix 𝔽q β (i := ⟨i, by omega⟩) (steps := steps) + (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) (y := y) + -- 2. Get the evaluation of f on the fiber of y + let fiber_vals : Fin (2 ^ steps) → L := + fiberEvaluations 𝔽q β (i := ⟨i, by omega⟩) (steps := steps) + (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) (f := f_i) (y := y) + -- 3. The result is the j-th component of the matrix-vector product + (M_y *ᵥ fiber_vals) j + +omit [CharP L 2] [DecidableEq 𝔽q] hF₂ h_β₀_eq_1 [NeZero ℓ] in +/-- **Folding with Binary Challenges selects a Matrix Row** +This lemma establishes the geometric link: +The `j`-th row of the `preTensorCombine` matrix product is exactly equal to +folding the function `f` using the bits of `j` as challenges. +This holds for ANY function `f`, not just codewords. +-/ +lemma preTensorCombine_row_eq_fold_with_binary_row_challenges + (i : Fin ℓ) (steps : ℕ) {destIdx : Fin r} + (h_destIdx : destIdx.val = i.val + steps) (h_destIdx_le : destIdx ≤ ℓ) + (f : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ⟨i, by omega⟩) + (rowIdx : Fin (2 ^ steps)) : + ∀ y : sDomain 𝔽q β h_ℓ_add_R_rate destIdx, + (preTensorCombine_WordStack 𝔽q β i steps h_destIdx h_destIdx_le f) rowIdx y = + iterated_fold 𝔽q β ⟨i, by omega⟩ steps + (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) (f := f) + (r_challenges := bitsOfIndex (L := L) (n := steps) rowIdx) (y := y) := by + intro y + -- 1. Expand the definition of preTensorCombine (The LHS) + -- LHS = (M_y * f_vals)[rowIdx] + dsimp [preTensorCombine_WordStack] + -- 2. Expand the matrix form of iterated_fold (The RHS) + -- RHS = Tensor(r) • (M_y * f_vals) + rw [iterated_fold_eq_matrix_form 𝔽q β (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le)] + unfold localized_fold_matrix_form single_point_localized_fold_matrix_form + -- 3. Use the Tensor Property + -- Tensor(bits(rowIdx)) is the indicator vector for rowIdx + let tensor := challengeTensorExpansion (L := L) steps (bitsOfIndex rowIdx) + have h_indicator : tensor = fun k => if k = rowIdx then 1 else 0 := + challengeTensorExpansion_bitsOfIndex_is_eq_indicator (L := L) rowIdx + simp only + -- 4. Simplify the Dot Product + -- (Indicator • Vector) is exactly Vector[rowIdx] + dsimp only [tensor] at h_indicator + rw [h_indicator] + rw [dotProduct] + simp only [boole_mul] + rw [Finset.sum_eq_single rowIdx] + · -- The term at rowIdx is (1 * val) + simp only [if_true] + · -- All other terms are 0 + intro b _ hb_ne + simp [hb_ne] + · -- rowIdx is in the domain + intro h_notin + exact (h_notin (Finset.mem_univ rowIdx)).elim + +omit [CharP L 2] in +lemma preTensorCombine_is_interleavedCodeword_of_codeword (i : Fin ℓ) (steps : ℕ) {destIdx : Fin r} + (h_destIdx : destIdx.val = i.val + steps) (h_destIdx_le : destIdx ≤ ℓ) + (f : BBF_Code 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ⟨i, by omega⟩) : + (⋈|(preTensorCombine_WordStack 𝔽q β i steps h_destIdx h_destIdx_le f)) ∈ + (BBF_Code 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) destIdx ^⋈ (Fin (2 ^ steps))) := by + -- 1. Interleaved Code Definition: "A word is in the interleaved code iff every row is in the base code" + set S_next := sDomain 𝔽q β h_ℓ_add_R_rate destIdx with h_S_next + set u := (⋈|(preTensorCombine_WordStack 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i steps h_destIdx h_destIdx_le f)) with h_u + set C_next := BBF_Code 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := destIdx) + simp only [InterleavedWord, InterleavedSymbol, ModuleCode, + instCodeInterleavableModuleCodeInterleavedSymbol, ModuleCode.moduleInterleavedCode, + interleavedCodeSet, SetLike.mem_coe, Submodule.mem_mk, AddSubmonoid.mem_mk, + AddSubsemigroup.mem_mk, Set.mem_setOf_eq] + -- ⊢ ∀ (k : Fin (2 ^ steps)), uᵀ k ∈ C_next + intro rowIdx + -- 2. Setup: Define the specific challenge 'r' corresponding to row index 'rowIdx' + let r_binary : Fin steps → L := bitsOfIndex rowIdx + -- 3. Geometric Equivalence: + -- Show that the `rowIdx`-th row of preTensorCombine is exactly `iterated_fold` of u with challenge r + -- We rely on Lemma 4.9 (Matrix Form) which states: M_y * vals = iterated_fold(u, r, y) + let preTensorCombine_Row: S_next → L := preTensorCombine_WordStack 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i steps + h_destIdx h_destIdx_le (f_i := f) rowIdx + let rowIdx_binary_folded_Row: S_next → L := iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ⟨i, by omega⟩ steps (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) (f := f) (r_challenges := r_binary) + have h_row_eq_fold : preTensorCombine_Row = rowIdx_binary_folded_Row := by + funext y + exact preTensorCombine_row_eq_fold_with_binary_row_challenges 𝔽q β i + steps h_destIdx h_destIdx_le f rowIdx y + have h_row_of_u_eq: (uᵀ rowIdx) = preTensorCombine_Row := by rfl + rw [←h_row_of_u_eq] at h_row_eq_fold + rw [h_row_eq_fold] + -- ⊢ rowIdx_binary_folded_Row ∈ C_next (i.e. lhs is of `fold(f, binary_rowIdx_challenges)` form) + unfold rowIdx_binary_folded_Row + exact iterated_fold_preserves_BBF_Code_membership 𝔽q β (i := ⟨i, by omega⟩) (steps := steps) (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) (f := f) (r_challenges := r_binary) + +/-! +-------------------------------------------------------------------------------- + SECTION: THE LIFT INFRASTRUCTURE + Constructing the inverse map from Interleaved Codewords back to Domain CodeWords +-------------------------------------------------------------------------------- +-/ + + +open Code.InterleavedCode in +def getRowPoly (i : Fin ℓ) (steps : ℕ) {destIdx : Fin r} + (h_destIdx : destIdx.val = i.val + steps) (h_destIdx_le : destIdx ≤ ℓ) + (V_codeword : ((BBF_Code 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) destIdx) + ^⋈(Fin (2 ^ steps)))) : Fin (2 ^ steps) → L⦃<2^(ℓ-destIdx.val)⦄[X] := fun j => by + -- 1. Extract polynomials P_j from V_codeword components + set S_next := sDomain 𝔽q β h_ℓ_add_R_rate destIdx with h_S_next + set C_next := BBF_Code 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) destIdx with h_C_next + let curRow := getRow (show InterleavedCodeword (A := L) (κ := Fin (2 ^ steps)) (ι := S_next) (C := C_next) from V_codeword) j + have h_V_in_C_next : curRow.val ∈ (C_next) := by + have h_V_mem := V_codeword.property + let res := Code.InterleavedCode.getRowOfInterleavedCodeword_mem_code (C := (C_next : Set (S_next → L))) + (κ := Fin (2 ^ steps)) (ι := S_next) (u := V_codeword) (rowIdx := j) + exact res + -- For each j, there exists a polynomial P_j of degree < 2^(ℓ - (i+steps)) + exact getBBF_Codeword_poly 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := destIdx) curRow + +def getLiftCoeffs (i : Fin ℓ) (steps : ℕ) {destIdx : Fin r} + (h_destIdx : destIdx.val = i.val + steps) (h_destIdx_le : destIdx ≤ ℓ) + (V_codeword : ((BBF_Code 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) destIdx) + ^⋈(Fin (2 ^ steps)))) : Fin (2^(ℓ - i)) → L := fun coeff_idx => + -- intertwining novel coeffs of the rows of V_codeword + -- decompose `coeff_idx = colIdx * 2 ^ steps + rowIdx` as in paper, + -- i.e. traverse column by column + let colIdx : Fin (2 ^ (ℓ - destIdx.val)) := ⟨coeff_idx.val / (2 ^ steps), by + apply Nat.div_lt_of_lt_mul; + rw [← Nat.pow_add]; + convert coeff_idx.isLt using 2; omega + ⟩ + let rowIdx : Fin (2 ^ steps) := ⟨coeff_idx.val % (2 ^ steps), by + have h_coeff_idx_lt_two_pow_ℓ_i : coeff_idx.val < 2 ^ (ℓ - i) := by + exact coeff_idx.isLt + have h_coeff_idx_mod_two_pow_steps : coeff_idx.val % (2 ^ steps) < 2 ^ steps := by + apply Nat.mod_lt; simp only [gt_iff_lt, ofNat_pos, pow_pos] + exact h_coeff_idx_mod_two_pow_steps + ⟩ + let coeff := getINovelCoeffs 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := destIdx) (h_i := h_destIdx_le) (P := (getRowPoly 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + i steps h_destIdx h_destIdx_le V_codeword) rowIdx) colIdx + coeff + +/-- Given an interleaved codeword `V ∈ C ⋈^ (2^steps)`, this method converts `2^steps` row polys +of `V` into a poly `P ∈ L[X]_(2^(ℓ-i))` that generates the fiber evaluator `g : S⁽ⁱ⁾ → L` +(this `g` produces the RHS vector in equality of **Lemma 4.9**). If we fold this function `g` using +**binary challenges** corresponding to each of the `2^steps` rows of `V`, let's say `j`, +we also folds `P` into the corresponding row polynomial `P_j` of the `j`-th row of `V` +(via **Lemma 4.13, aka iterated_fold_advances_evaluation_poly**). This works as a core engine for +proof of **Lemma 4.21**. -/ +def getLiftPoly (i : Fin ℓ) (steps : ℕ) {destIdx : Fin r} + (h_destIdx : destIdx.val = i.val + steps) (h_destIdx_le : destIdx ≤ ℓ) + (V_codeword : ((BBF_Code 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) destIdx) + ^⋈(Fin (2 ^ steps)))) : L⦃<2^(ℓ-i)⦄[X] := by + have h_ℓ_lt_r : ℓ < r := by + have h_pos : 0 < 𝓡 := Nat.pos_of_neZero (n := 𝓡) + exact lt_trans (Nat.lt_add_of_pos_right (n := ℓ) (k := 𝓡) h_pos) h_ℓ_add_R_rate + have h_i_lt_r : (i : Nat) < r := lt_trans i.isLt h_ℓ_lt_r + let iR : Fin r := ⟨i, h_i_lt_r⟩ + refine ⟨intermediateEvaluationPoly 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := iR) (h_i := by + exact Nat.le_of_lt i.isLt) + (coeffs := getLiftCoeffs 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + i steps h_destIdx h_destIdx_le V_codeword), ?_⟩ + apply Polynomial.mem_degreeLT.mpr + exact degree_intermediateEvaluationPoly_lt (𝔽q := 𝔽q) (β := β) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := iR) (h_i := by + exact Nat.le_of_lt i.isLt) + (coeffs := getLiftCoeffs 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + i steps h_destIdx h_destIdx_le V_codeword) + +/-- **Lift Function (Inverse Folding)** +Constructs a function `f` on the domain `S^{(i)}` from an interleaved word `W` on `S^{(i+steps)}`. +For any point `x` in the larger domain, we identify its quotient `y` and its index in the fiber. +We recover the fiber values by applying `M_y⁻¹` to the column `W(y)`. +-/ +noncomputable def lift_interleavedCodeword (i : Fin ℓ) (steps : ℕ) {destIdx : Fin r} + (h_destIdx : destIdx.val = i.val + steps) (h_destIdx_le : destIdx ≤ ℓ) + (V_codeword : ((BBF_Code 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) destIdx) + ^⋈(Fin (2 ^ steps)))) : + BBF_Code 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ⟨i, by omega⟩ := by + let P : L[X]_(2 ^ (ℓ - ↑i)) := getLiftPoly 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i steps + h_destIdx h_destIdx_le V_codeword + -- 3. Define g as evaluation of P + let g := getBBF_Codeword_of_poly 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := ⟨i, by omega⟩) (h_i := by + exact Nat.le_of_lt i.isLt) P + exact g + +omit [CharP L 2] in +/-- **Lemma 4.21 Helper**: Folding the "Lifted" polynomial `g` with binary challenges corresponding +to row index `j ∈ Fin(2^steps)`, results exactly in the `j`-th row polynomial `P_j`. +**Key insight**: **Binary folding** is a **(Row) Selector** +Proof strategy: applying `iterated_fold_advances_evaluation_poly` and +`intermediateEvaluationPoly_from_inovel_coeffs_eq_self`, then arithemetic equality for novel coeffs +computations in both sides. -/ +lemma folded_lifted_IC_eq_IC_row_polyToOracleFunc (i : Fin ℓ) (steps : ℕ) {destIdx : Fin r} + (h_destIdx : destIdx.val = i.val + steps) (h_destIdx_le : destIdx ≤ ℓ) + (V_codeword : ((BBF_Code 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) destIdx)^⋈(Fin (2 ^ steps)))) + (j : Fin (2 ^ steps)) : + let g := lift_interleavedCodeword 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + i steps h_destIdx h_destIdx_le V_codeword + let P_j := (getRowPoly 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i steps h_destIdx h_destIdx_le + V_codeword) j + iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ⟨i, by omega⟩ steps + (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) g (bitsOfIndex j) = + polyToOracleFunc 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (domainIdx := destIdx) P_j := by + -- 1. Unfold definitions to expose the underlying polynomials + -- dsimp only [lift_interleavedCodeword, getLiftPoly] + simp only + set g := lift_interleavedCodeword 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i steps + h_destIdx h_destIdx_le V_codeword with h_g + set P_j := (getRowPoly 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i steps h_destIdx h_destIdx_le V_codeword) j + set P_G := getLiftPoly 𝔽q β i steps h_destIdx h_destIdx_le V_codeword with h_P_G -- due to def of `g` + have h_g : g = polyToOracleFunc 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (domainIdx := ⟨i, by omega⟩) P_G := by rfl + -- unfold getLiftPoly at h_P_G + let novelCoeffs := getLiftCoeffs 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i steps + h_destIdx h_destIdx_le V_codeword + -- have h_P_G_eq: P_G = intermediateEvaluationPoly 𝔽q β h_ℓ_add_R_rate + -- (i := ⟨i, by omega⟩) novelCoeffs := by rfl + let h_fold_g_advances_P_G := iterated_fold_advances_evaluation_poly 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := ⟨i, by omega⟩) (steps := steps) + (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) (r_challenges := bitsOfIndex j) (coeffs := novelCoeffs) + simp only at h_fold_g_advances_P_G + conv_lhs at h_fold_g_advances_P_G => -- make it matches the lhs goal + change iterated_fold 𝔽q β (i := ⟨i, by omega⟩) (steps := steps) (h_destIdx := h_destIdx) + (h_destIdx_le := h_destIdx_le) (f := g) (bitsOfIndex j) + conv_lhs => rw [h_fold_g_advances_P_G] + -- ⊢ polyToOracleFunc 𝔽q β ⟨↑i + steps, ⋯⟩ + -- (intermediateEvaluationPoly 𝔽q β h_ℓ_add_R_rate ⟨↑i + steps, ⋯⟩ fun j_1 ↦ + -- ∑ x, multilinearWeight (bitsOfIndex j) x * novelCoeffs ⟨↑j_1 * 2 ^ steps + ↑x, ⋯⟩) = + -- polyToOracleFunc 𝔽q β ⟨↑i + steps, ⋯⟩ ↑P_j + have h_P_j_novel_form := intermediateEvaluationPoly_from_inovel_coeffs_eq_self 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := destIdx) (h_i := h_destIdx_le) (P := P_j) (hP_deg := by + have h_mem := P_j.property + rw [Polynomial.mem_degreeLT] at h_mem + exact h_mem ) + conv_rhs => rw [←h_P_j_novel_form] + -- polyToOracleFunc(intermediateEvaluationPoly(FOLDED novelCoeffs of P))) (via Lemma 4.13) + -- = polyToOracleFunc(intermediateEvaluationPoly(inovelCoeffs of P_j)) + unfold polyToOracleFunc intermediateEvaluationPoly novelCoeffs + simp only [map_sum, map_mul] + funext y + congr 1 + apply Finset.sum_congr rfl + intro x hx_mem_univ + rw [mul_eq_mul_right_iff]; left + -- **Arithemetic reasoning**: + -- ⊢ ∑ x_1, Polynomial.C (multilinearWeight (bitsOfIndex j) x_1) * + -- Polynomial.C (getLiftCoeffs 𝔽q β i steps ⟨↑x * 2 ^ steps + ↑x_1, ⋯⟩) = + -- Polynomial.C (getINovelCoeffs 𝔽q β h_ℓ_add_R_rate ⟨↑i + steps, ⋯⟩ (↑P_j) x) + -- 1. Combine the Ring Homomorphisms to pull C outside the sum + -- ∑ C(w) * C(v) -> C(∑ w * v) + simp_rw [←Polynomial.C_mul] + unfold getINovelCoeffs getLiftCoeffs + simp only [mul_add_mod_self_right, map_mul] + -- , ←Polynomial.C_sum] + -- 2. Use the Indicator Property of multilinearWeight with binary challenges + -- This logic should ideally be its own lemma: `weight_bits_eq_indicator` + have h_indicator : ∀ m : Fin (2^steps), multilinearWeight (F := L) (r := bitsOfIndex j) + (i := m) = if m = j then 1 else 0 := fun m => by + apply multilinearWeight_bitsOfIndex_eq_indicator (j := m) (k := j) + simp_rw [h_indicator] + -- 3. Collapse the Sum using Finset.sum_eq_single + rw [Finset.sum_eq_single j] + · -- Case: The Match (x_1 = j) + simp only [↓reduceIte, map_one, one_mul, Polynomial.C_inj] + unfold getINovelCoeffs + have h_idx_decomp : (x.val * 2 ^ steps + j.val) / 2^steps = x.val := by + have h_j_div_2_pow_steps : j.val / 2^steps = 0 := by + apply Nat.div_eq_of_lt; omega + rw [mul_comm] + have h_res := Nat.mul_add_div (m := 2 ^ steps) (x := x.val) (y := j.val) (m_pos := by + simp only [gt_iff_lt, ofNat_pos, pow_pos]) + simp only [h_j_div_2_pow_steps, add_zero] at h_res + exact h_res + congr 1 + · funext k + congr + · apply Nat.mod_eq_of_lt; omega + · simp_rw [h_idx_decomp] + · -- Case: The Mismatch (x_1 ≠ j) + intro m _ h_neq + simp only [h_neq, ↓reduceIte, map_zero, zero_mul] + · -- Case: Domain (Empty implies false, but we are in Fin (2^steps)) + intro h_absurd + exfalso; exact h_absurd (Finset.mem_univ j) + +omit [CharP L 2] in +open Code.InterleavedCode in +lemma preTensorCombine_of_lift_interleavedCodeword_eq_self (i : Fin ℓ) (steps : ℕ) {destIdx : Fin r} + (h_destIdx : destIdx.val = i.val + steps) (h_destIdx_le : destIdx ≤ ℓ) + (V_codeword : ((BBF_Code 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) destIdx) + ^⋈(Fin (2 ^ steps)))) : + let g := lift_interleavedCodeword 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + i steps h_destIdx h_destIdx_le V_codeword + (⋈|(preTensorCombine_WordStack 𝔽q β i steps h_destIdx h_destIdx_le g)) = V_codeword.val := by + let S_next := sDomain 𝔽q β h_ℓ_add_R_rate destIdx + let C_next := BBF_Code 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) destIdx + set g := lift_interleavedCodeword 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := i) + (steps := steps) (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) (V_codeword := V_codeword) + -- **FIRST**, + -- `∀ j : Fin (2^ϑ), (V_codeword j)` and `fold(g, bitsOfIndex j)` agree identically + -- over `S^{(i+ϑ)}` + -- the dotproduct between `M_y's j'th ROW` and `G = g's restriction to the fiber of y` + -- is actually the result of `fold(G, bitsOfIndex j)` + have h_agree_with_fold := preTensorCombine_row_eq_fold_with_binary_row_challenges 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i steps h_destIdx h_destIdx_le g + let eq_iff_all_rows_eq := (instInterleavedStructureInterleavedWord (A := L) (κ := Fin (2 ^ steps)) + (ι := S_next)).eq_iff_all_rows_eq (u := ⋈|preTensorCombine_WordStack 𝔽q β (i := i) + (steps := steps) (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) (↑g)) (v := V_codeword.val) + simp only + rw [eq_iff_all_rows_eq] + intro j + funext (y : S_next) -- compare the cells at (j, y) + set G := fiberEvaluations 𝔽q β (i := ⟨i, by omega⟩) (steps := steps) + (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) (f := g) (y := y) + simp only [InterleavedWord, Word, InterleavedSymbol, instInterleavedStructureInterleavedWord, + InterleavedWord.getRowWord, InterleavedWord.getSymbol, transpose_apply, WordStack, + instInterleavableWordStackInterleavedWord, interleave_wordStack_eq, ModuleCode, + instCodeInterleavableModuleCodeInterleavedSymbol.eq_1, ModuleCode.moduleInterleavedCode.eq_1, + interleavedCodeSet.eq_1] + -- ⊢ preTensorCombine_WordStack 𝔽q β i steps h_destIdx h_destIdx_le (↑g) j = (↑V_codeword)ᵀ j + unfold preTensorCombine_WordStack + simp only + set M_y := foldMatrix 𝔽q β (i := ⟨i, by omega⟩) (steps := steps) + (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) y + -- ⊢ (foldMatrix 𝔽q β ⟨↑i, ⋯⟩ steps ⋯ y *ᵥ fiberEvaluations 𝔽q β ⟨↑i, ⋯⟩ steps ⋯ (↑g) y) j + -- = ↑V_codeword y j + change (M_y *ᵥ G) j = V_codeword.val y j + let lhs_eq_fold := h_agree_with_fold j y + unfold preTensorCombine_WordStack at lhs_eq_fold + simp at lhs_eq_fold + rw [lhs_eq_fold] + -- ⊢ iterated_fold 𝔽q β ⟨↑i, ⋯⟩ steps ⋯ (↑g) (bitsOfIndex j) y = ↑V_codeword y j + -- **SECOND**, we prove that **the same row polynomial `P_j(X)` is used to generates** bot + -- `fold(g, bitsOfIndex j)` and `j'th row of V_codeword` + let curRow := getRow (show InterleavedCodeword (A := L) (κ := Fin (2 ^ steps)) + (ι := S_next) (C := C_next) from V_codeword) j + let P_j := getRowPoly 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i steps h_destIdx h_destIdx_le V_codeword j + let lhs := iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ⟨i, by omega⟩ (steps := steps) + (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) (f := g) + (r_challenges := bitsOfIndex j) + let rhs := curRow.val + let generatedRow : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) destIdx := + polyToOracleFunc 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (domainIdx := destIdx) (P := P_j) + have h_left_eq_P_j_gen: lhs = generatedRow := by + unfold lhs generatedRow -- ⊢ iterated_fold 𝔽q β ⟨↑i, ⋯⟩ steps ⋯ (↑g) (bitsOfIndex j) + -- = polyToOracleFunc 𝔽q β ⟨↑i + steps, ⋯⟩ ↑P_j + apply folded_lifted_IC_eq_IC_row_polyToOracleFunc + have h_right_eq_P_j_eval: rhs = generatedRow := by + unfold rhs generatedRow + rw [getBBF_Codeword_poly_spec 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := destIdx) (u := curRow)]; rfl + conv_lhs => change lhs y + conv_rhs => change rhs y + rw [h_left_eq_P_j_gen, h_right_eq_P_j_eval] + +/-- TODO: **Lifting Equivalence Lemma**: `lift(preTensorCombine(f)) = f`. -/ + +def fiberDiff (i : Fin ℓ) (steps : ℕ) {destIdx : Fin r} + (h_destIdx : destIdx.val = i.val + steps) (h_destIdx_le : destIdx ≤ ℓ) + (f g : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ⟨i, by omega⟩) + (y : sDomain 𝔽q β h_ℓ_add_R_rate destIdx) : Prop := + ∃ x, + iteratedQuotientMap 𝔽q β h_ℓ_add_R_rate (i := ⟨i, by omega⟩) (destIdx := destIdx) + (k := steps) (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) x = y ∧ + f x ≠ g x + +/-- **Distance Isomorphism Lemma** +The crucial logic for Lemma 4.21: +Two functions `f, g` differ on a specific fiber `y` IF AND ONLY IF +their tensor-combinations `U, V` differ at the column `y`. +This holds because `M_y` is a bijection. -/ +lemma fiberwise_disagreement_isomorphism (i : Fin ℓ) (steps : ℕ) {destIdx : Fin r} + (h_destIdx : destIdx.val = i.val + steps) (h_destIdx_le : destIdx ≤ ℓ) + (f g : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ⟨i, by omega⟩) + (y : sDomain 𝔽q β h_ℓ_add_R_rate destIdx) : + fiberDiff 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i steps h_destIdx h_destIdx_le f g y ↔ + WordStack.getSymbol (preTensorCombine_WordStack 𝔽q β i steps h_destIdx h_destIdx_le f) y ≠ + WordStack.getSymbol (preTensorCombine_WordStack 𝔽q β i steps h_destIdx h_destIdx_le g) y := by + -- U_y = M_y * f_vals, V_y = M_y * g_vals + let M_y := foldMatrix 𝔽q β (i := ⟨i, by omega⟩) (steps := steps) + (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) y + let f_vals := fiberEvaluations 𝔽q β (i := ⟨i, by omega⟩) (steps := steps) + (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) f y + let g_vals := fiberEvaluations 𝔽q β (i := ⟨i, by omega⟩) (steps := steps) + (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) g y + have h_det : M_y.det ≠ 0 := foldMatrix_det_ne_zero 𝔽q β (i := ⟨i, by omega⟩) (steps := steps) + (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) (y := y) + constructor + · -- Forward: Fiber different => Columns different + intro h_diff + -- If fiber is different, then vectors f_vals ≠ g_vals + have h_vec_diff : f_vals ≠ g_vals := by + rcases h_diff with ⟨x, h_gen_y, h_val_ne⟩ -- h_val_ne : f x ≠ g x + intro h_eq + let x_is_fiber_of_y := is_fiber_iff_generates_quotient_point 𝔽q β + (i := ⟨i, by omega⟩) (steps := steps) (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) + (x := x) (y := y).mp (by exact id (Eq.symm h_gen_y)) + let x_fiberIdx : Fin (2 ^ steps) := + pointToIterateQuotientIndex 𝔽q β (i := ⟨i, by omega⟩) (steps := steps) + (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) (x := x) + have h_left_eval : f_vals x_fiberIdx = f x := by + unfold f_vals fiberEvaluations + rw [x_is_fiber_of_y] + have h_right_eval : g_vals x_fiberIdx = g x := by + unfold g_vals fiberEvaluations + rw [x_is_fiber_of_y] + let h_eval_eq := congrFun h_eq x_fiberIdx + rw [h_left_eval, h_right_eval] at h_eval_eq -- f x = g x + exact h_val_ne h_eval_eq + -- M_y is invertible, so M_y * u = M_y * v => u = v. Contrapositive: u ≠ v => M_y * u ≠ M_y * v + intro h_col_eq + apply h_vec_diff + -- ⊢ f_vals = g_vals + -- h_col_eq: WordStack.getSymbol (preTensorCombine_WordStack ... f) y = WordStack.getSymbol (preTensorCombine_WordStack ... g) y + -- This means: M_y *ᵥ f_vals = M_y *ᵥ g_vals + -- Rewrite as: M_y *ᵥ (f_vals - g_vals) = 0 + have h_mulVec_sub_eq_zero : M_y *ᵥ (f_vals - g_vals) = 0 := by + -- From h_col_eq and the definition of preTensorCombine_WordStack: + -- WordStack.getSymbol (preTensorCombine_WordStack ... f) y = M_y *ᵥ f_vals + -- WordStack.getSymbol (preTensorCombine_WordStack ... g) y = M_y *ᵥ g_vals + have h_f_col : WordStack.getSymbol (preTensorCombine_WordStack 𝔽q β i steps h_destIdx h_destIdx_le f) y = M_y *ᵥ f_vals := by + ext j + simp only [WordStack.getSymbol, Matrix.transpose_apply] + rfl + have h_g_col : WordStack.getSymbol (preTensorCombine_WordStack 𝔽q β i steps h_destIdx h_destIdx_le g) y = M_y *ᵥ g_vals := by + ext j + simp only [WordStack.getSymbol, Matrix.transpose_apply] + rfl + -- ⊢ M_y *ᵥ (f_vals - g_vals) = 0 + rw [h_f_col, h_g_col] at h_col_eq + -- Now h_col_eq: M_y *ᵥ f_vals = M_y *ᵥ g_vals + rw [Matrix.mulVec_sub] + -- Goal: M_y *ᵥ f_vals - M_y *ᵥ g_vals = 0 + rw [← h_col_eq] + -- Goal: M_y *ᵥ f_vals - M_y *ᵥ f_vals = 0 + rw [sub_self] + -- Apply eq_zero_of_mulVec_eq_zero to get f_vals - g_vals = 0 + have h_sub_eq_zero : f_vals - g_vals = 0 := + Matrix.eq_zero_of_mulVec_eq_zero h_det h_mulVec_sub_eq_zero -- `usage of M_y's nonsingularity` + -- Convert to f_vals = g_vals + exact sub_eq_zero.mp h_sub_eq_zero + · -- Backward: Columns different => Fiber different + intro h_col_diff + by_contra h_fiber_eq + -- h_fiber_eq: ¬fiberDiff, i.e., ∀ x, iteratedQuotientMap ... x = y → f x = g x + -- If f and g agree on all points in the fiber of y, then f_vals = g_vals + have h_vals_eq : f_vals = g_vals := by + ext idx + -- f_vals idx = f evaluated at the idx-th point in the fiber of y + -- g_vals idx = g evaluated at the idx-th point in the fiber of y + -- We need to show they're equal + unfold f_vals g_vals fiberEvaluations + -- fiberEvaluations f y idx = f (qMap_total_fiber ... y idx) + -- fiberEvaluations g y idx = g (qMap_total_fiber ... y idx) + let x := qMap_total_fiber 𝔽q β (i := ⟨i, by omega⟩) (steps := steps) + (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) (y := y) idx + -- x is in the fiber of y, so iteratedQuotientMap ... x = y + have h_x_in_fiber : + iteratedQuotientMap 𝔽q β h_ℓ_add_R_rate (i := ⟨i, by omega⟩) (destIdx := destIdx) + (k := steps) (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) x = y := by + -- This follows from generates_quotient_point_if_is_fiber_of_y + have h := generates_quotient_point_if_is_fiber_of_y 𝔽q β (i := ⟨i, by omega⟩) (steps := steps) + (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) (x := x) (y := y) (hx_is_fiber := by use idx) + exact h.symm + -- Since h_fiber_eq says no point in the fiber has f x ≠ g x, + -- we have f x = g x for all x in the fiber + have h_fx_eq_gx : f x = g x := by + -- h_fiber_eq: ¬fiberDiff, which is ¬(∃ x, iteratedQuotientMap ... x = y ∧ f x ≠ g x) + -- By De Morgan: ∀ x, ¬(iteratedQuotientMap ... x = y ∧ f x ≠ g x) + -- Which means: ∀ x, iteratedQuotientMap ... x = y → f x = g x + -- h_fiber_eq is now: ∀ x, iteratedQuotientMap ... x = y → f x = g x + unfold fiberDiff at h_fiber_eq + simp only [ne_eq, Subtype.exists, not_exists, not_and, Decidable.not_not] at h_fiber_eq + let res := h_fiber_eq x (by simp only [SetLike.coe_mem]) h_x_in_fiber + exact res + -- Now f_vals idx = f x = g x = g_vals idx + exact h_fx_eq_gx + -- If f_vals = g_vals, then M_y *ᵥ f_vals = M_y *ᵥ g_vals + have h_col_eq : WordStack.getSymbol (preTensorCombine_WordStack 𝔽q β i steps h_destIdx h_destIdx_le f) y = + WordStack.getSymbol (preTensorCombine_WordStack 𝔽q β i steps h_destIdx h_destIdx_le g) y := by + -- From the forward direction, we know: + -- WordStack.getSymbol (preTensorCombine_WordStack ... f) y = M_y *ᵥ f_vals + -- WordStack.getSymbol (preTensorCombine_WordStack ... g) y = M_y *ᵥ g_vals + have h_f_col : WordStack.getSymbol (preTensorCombine_WordStack 𝔽q β i steps h_destIdx h_destIdx_le f) y = M_y *ᵥ f_vals := by + ext j + simp only [WordStack.getSymbol, Matrix.transpose_apply] + rfl + have h_g_col : WordStack.getSymbol (preTensorCombine_WordStack 𝔽q β i steps h_destIdx h_destIdx_le g) y = M_y *ᵥ g_vals := by + ext j + simp only [WordStack.getSymbol, Matrix.transpose_apply] + rfl + rw [h_f_col, h_g_col] + -- Goal: M_y *ᵥ f_vals = M_y *ᵥ g_vals + rw [h_vals_eq] + -- This contradicts h_col_diff + exact h_col_diff h_col_eq + +end Lift_PreTensorCombine + +open Classical in +/-- **Proposition 4.20 (Case 1)**: +If f⁽ⁱ⁾ is fiber-wise close to the code, the probability of the bad event is bounded. +The bad event here is: `Δ⁽ⁱ⁾(f⁽ⁱ⁾, f̄⁽ⁱ⁾) ⊄ Δ(fold(f⁽ⁱ⁾), fold(f̄⁽ⁱ⁾))`. +-/ +lemma prop_4_20_case_1_fiberwise_close (i : Fin ℓ) (steps : ℕ) [NeZero steps] + {destIdx : Fin r} (h_destIdx : destIdx.val = i.val + steps) (h_destIdx_le : destIdx ≤ ℓ) + (f_i : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ⟨i, by omega⟩) + (h_close : fiberwiseClose 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := ⟨i, by omega⟩) (steps := steps) (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) (f := f_i)) : + let S_next := sDomain 𝔽q β h_ℓ_add_R_rate destIdx + let domain_size := Fintype.card S_next + Pr_{ let r_challenges ←$ᵖ (Fin steps → L) }[ + -- The definition of foldingBadEvent under the "then" branch of h_close + let f_bar_i := UDRCodeword 𝔽q β (i := ⟨i, by omega⟩) (h_i := by + exact Nat.le_of_lt i.isLt) f_i + (UDRClose_of_fiberwiseClose 𝔽q β ⟨i, by omega⟩ steps h_destIdx h_destIdx_le f_i h_close) + let folded_f_i := iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ⟨i, by omega⟩ + steps h_destIdx h_destIdx_le f_i r_challenges + let folded_f_bar_i := iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ⟨i, by omega⟩ + steps h_destIdx h_destIdx_le f_bar_i r_challenges + ¬ (fiberwiseDisagreementSet 𝔽q β (i := ⟨i, by omega⟩) steps h_destIdx h_destIdx_le f_i f_bar_i ⊆ + disagreementSet 𝔽q β (i := destIdx) (destIdx := destIdx) (h_destIdx := rfl) (f := folded_f_i) (g := folded_f_bar_i)) + ] ≤ ((steps * domain_size) / Fintype.card L) := by + let S_next := sDomain 𝔽q β h_ℓ_add_R_rate destIdx + let L_card := Fintype.card L + -- 1. Setup Definitions + let f_bar_i : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ⟨i, by omega⟩ := + UDRCodeword 𝔽q β (i := ⟨i, by omega⟩) (h_i := by + exact Nat.le_of_lt i.isLt) + (f := f_i) (h_within_radius := UDRClose_of_fiberwiseClose 𝔽q β ⟨i, by omega⟩ steps h_destIdx h_destIdx_le f_i h_close) + let Δ_fiber : Set (sDomain 𝔽q β h_ℓ_add_R_rate destIdx) := + fiberwiseDisagreementSet 𝔽q β (i := ⟨i, by omega⟩) steps h_destIdx h_destIdx_le f_i f_bar_i + -- We apply the Union Bound over `y ∈ Δ_fiber` + -- `Pr[ ∃ y ∈ Δ_fiber, y ∉ Disagreement(folded) ] ≤ ∑ Pr[ y ∉ Disagreement(folded) ]` + have h_union_bound : + Pr_{ let r ←$ᵖ (Fin steps → L) }[ + ¬(Δ_fiber ⊆ disagreementSet 𝔽q β (i := destIdx) (destIdx := destIdx) (h_destIdx := rfl) + (f := iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ⟨i, by omega⟩ steps h_destIdx h_destIdx_le f_i r) + (g := iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ⟨i, by omega⟩ steps h_destIdx h_destIdx_le f_bar_i r)) + ] ≤ ∑ y ∈ Δ_fiber.toFinset, + Pr_{ let r ←$ᵖ (Fin steps → L) }[ + -- The condition y ∉ Disagreement(folded) implies folded values are equal at y + (iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ⟨i, by omega⟩ steps h_destIdx h_destIdx_le f_i r) y = + (iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ⟨i, by omega⟩ steps h_destIdx h_destIdx_le f_bar_i r) y + ] := by + -- Standard probability union bound logic + -- Convert probability to cardinality ratio for the Union Bound + rw [prob_uniform_eq_card_filter_div_card] + simp_rw [prob_uniform_eq_card_filter_div_card] + simp only [ENNReal.coe_natCast, Fintype.card_pi, prod_const, Finset.card_univ, + Fintype.card_fin, cast_pow, ENNReal.coe_pow] + set left_set : Finset (Fin steps → L) := + Finset.univ.filter fun r => + ¬(Δ_fiber ⊆ + disagreementSet 𝔽q β (i := destIdx) (destIdx := destIdx) (h_destIdx := rfl) (f := iterated_fold 𝔽q β ⟨i, by omega⟩ steps + h_destIdx h_destIdx_le f_i r) + (g := iterated_fold 𝔽q β ⟨↑i, by omega⟩ steps + h_destIdx h_destIdx_le f_bar_i r)) + set right_set : + (x : sDomain 𝔽q β h_ℓ_add_R_rate destIdx) → + Finset (Fin steps → L) := + fun x => + (Finset.univ.filter fun r => + iterated_fold 𝔽q β ⟨↑i, by omega⟩ steps + h_destIdx h_destIdx_le + f_i r x = + iterated_fold 𝔽q β ⟨↑i, by omega⟩ steps + h_destIdx h_destIdx_le + f_bar_i r x) + conv_lhs => + change _ * ((Fintype.card L : ENNReal) ^ steps)⁻¹ + rw [mul_comm] + conv_rhs => + change + ∑ y ∈ Δ_fiber.toFinset, + ((#(right_set y) : ENNReal) * ((Fintype.card L : ENNReal) ^ steps)⁻¹) + conv_rhs => + simp only [mul_comm] + rw [←Finset.mul_sum] + -- ⊢ (↑(Fintype.card L) ^ steps)⁻¹ * ↑(#left_set) ≤ (↑(Fintype.card L) ^ steps)⁻¹ * ∑ i ∈ Δ_fiber.toFinset, ↑(#(right_set i)) + let left_le_right_if := (ENNReal.mul_le_mul_left (a := ((Fintype.card L : ENNReal) ^ steps)⁻¹) (b := (#left_set)) (c := ∑ i ∈ Δ_fiber.toFinset, (#(right_set i))) (h0 := by simp only [ne_eq, + ENNReal.inv_eq_zero, ENNReal.pow_eq_top_iff, ENNReal.natCast_ne_top, false_and, + not_false_eq_true]) (hinf := by simp only [ne_eq, ENNReal.inv_eq_top, pow_eq_zero_iff', + cast_eq_zero, Fintype.card_ne_zero, false_and, not_false_eq_true])).mpr + apply left_le_right_if + -- ⊢ ↑(#left_set) ≤ ∑ i ∈ Δ_fiber.toFinset, ↑(#(right_set i)) + -- 1. Prove the subset relation: left_set ⊆ ⋃_{y ∈ Δ} right_set y + -- This formally connects the failure condition (∃ y, agree) to the union of agreement sets. + have h_subset : left_set ⊆ Δ_fiber.toFinset.biUnion right_set := by + intro r hr + -- Unpack membership in left_set: r is bad if Δ_fiber ⊈ disagreementSet + simp only [Finset.mem_filter, Finset.mem_univ, true_and, left_set] at hr + rw [Set.not_subset] at hr + rcases hr with ⟨y, hy_mem, hy_not_dis⟩ + -- We found a y ∈ Δ_fiber where they do NOT disagree (i.e., they agree) + rw [Finset.mem_biUnion] + use y + constructor + · exact Set.mem_toFinset.mpr hy_mem + · -- Show r ∈ right_set y (which is defined as the set of r where they agree at y) + simp only [Finset.mem_filter, Finset.mem_univ, true_and, right_set] + -- hy_not_dis is ¬(folded_f_i y ≠ folded_f_bar_i y) ↔ folded_f_i y = folded_f_bar_i y + simp only [disagreementSet, ne_eq, coe_filter, mem_univ, true_and, Set.mem_setOf_eq, + Decidable.not_not] at hy_not_dis + exact hy_not_dis + -- 2. Apply cardinality bounds (Union Bound) + calc + (left_set.card : ENNReal) + _ ≤ (Δ_fiber.toFinset.biUnion right_set).card := by + -- Monotonicity of measure/cardinality: A ⊆ B → |A| ≤ |B| + gcongr + _ ≤ ∑ i ∈ Δ_fiber.toFinset, (right_set i).card := by + -- Union Bound: |⋃ S_i| ≤ ∑ |S_i| + -- push_cast moves the ENNReal coercion inside the sum + push_cast + let h_le_in_Nat := Finset.card_biUnion_le (s := Δ_fiber.toFinset) (t := right_set) + norm_cast + _ = _ := by push_cast; rfl + apply le_trans h_union_bound + -- Now bound the individual probabilities using Schwartz-Zippel + have h_prob_y : ∀ y ∈ Δ_fiber, + Pr_{ let r ←$ᵖ (Fin steps → L) }[ + (iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ⟨i, by omega⟩ steps h_destIdx h_destIdx_le f_i r) y = + (iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ⟨i, by omega⟩ steps h_destIdx h_destIdx_le f_bar_i r) y + ] ≤ (steps) / L_card := by + intro y hy + -- 1. Apply Lemma 4.9 (iterated_fold_eq_matrix_form) to express the equality as a matrix eq. + -- Equality holds iff Tensor(r) * M_y * (f - f_bar)|_fiber = 0. + -- 2. Define the polynomial P(r) = Tensor(r) * w, where w = M_y * (vals_f - vals_f_bar). + -- 3. Show w ≠ 0: + -- a. vals_f - vals_f_bar ≠ 0 because y ∈ Δ_fiber (definitions). + -- b. M_y is nonsingular (Lemma 4.9 / Butterfly structure). + -- 4. Apply prob_schwartz_zippel_mv_polynomial to P(r). + -- degree(P) = steps. + -- 1. Apply Lemma 4.9 to express folding as Matrix Form + -- Equality holds iff [Tensor(r)] * [M_y] * [f - f_bar] = 0 + let vals_f : Fin (2 ^ steps) → L := fiberEvaluations 𝔽q β (i := ⟨i, by omega⟩) (steps := steps) + (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) f_i y + let vals_f_bar : Fin (2 ^ steps) → L := fiberEvaluations 𝔽q β (i := ⟨i, by omega⟩) (steps := steps) + (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) f_bar_i y + let v_diff : Fin (2 ^ steps) → L := vals_f - vals_f_bar + -- 2. Show `v_diff ≠ 0` because `y ∈ Δ_fiber`, this is actually by definition of `Δ_fiber`. + have hv_ne_zero : v_diff ≠ 0 := by + unfold v_diff + have h_exists_diff_point: ∃ x: Fin (2 ^ steps), vals_f x ≠ vals_f_bar x := by + dsimp only [fiberwiseDisagreementSet, ne_eq, Δ_fiber] at hy + -- ∃ x, iteratedQuotientMap 𝔽q β h_ℓ_add_R_rate ⟨i, by omega⟩ (k := steps) h_destIdx h_destIdx_le x = y ∧ f_i x ≠ f_bar_i x + simp only [Subtype.exists, coe_filter, mem_univ, true_and, Set.mem_setOf_eq] at hy + -- rcases hy with ⟨xL, h_quot, h_ne⟩ + rcases hy with ⟨xL, h_prop_xL⟩ + rcases h_prop_xL with ⟨xL_mem_sDomain, h_quot, h_ne⟩ + set xSDomain : sDomain 𝔽q β h_ℓ_add_R_rate (i := ⟨i, by omega⟩) := ⟨xL, xL_mem_sDomain⟩ + let x_is_fiber_of_y := + is_fiber_iff_generates_quotient_point 𝔽q β (i := ⟨i, by omega⟩) (steps := steps) + (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) + (x := xSDomain) (y := y).mp (by exact id (Eq.symm h_quot)) + let x_fiberIdx : Fin (2 ^ steps) := pointToIterateQuotientIndex 𝔽q β (i := ⟨i, by omega⟩) (steps := steps) + (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) (x := xSDomain) + use x_fiberIdx + have h_left_eval : vals_f x_fiberIdx = f_i xSDomain := by + unfold vals_f fiberEvaluations + rw [x_is_fiber_of_y] + have h_right_eval : vals_f_bar x_fiberIdx = f_bar_i xSDomain := by + unfold vals_f_bar fiberEvaluations + rw [x_is_fiber_of_y] + rw [h_left_eval, h_right_eval] + exact h_ne + by_contra h_eq_zero + rw [funext_iff] at h_eq_zero + rcases h_exists_diff_point with ⟨x, h_ne⟩ + have h_eq: vals_f x = vals_f_bar x := by + have res := h_eq_zero x + simp only [Pi.sub_apply, Pi.zero_apply] at res + rw [sub_eq_zero] at res + exact res + exact h_ne h_eq + -- 3. M_y is nonsingular (from Lemma 4.9 context/properties of AdditiveNTT) + let M_y := foldMatrix 𝔽q β (i := ⟨i, by omega⟩) (steps := steps) + (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) y + have hMy_det_ne_zero : M_y.det ≠ 0 := by + apply foldMatrix_det_ne_zero 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := ⟨i, by omega⟩) (steps := steps) + (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) (y := y) + -- 4. w = M_y * v_diff is non-zero + let w := M_y *ᵥ v_diff + have hw_ne_zero : w ≠ 0 := by + intro h + exact hv_ne_zero (by exact Matrix.eq_zero_of_mulVec_eq_zero hMy_det_ne_zero h) + -- 5. Construct the polynomial P(r) = Tensor(r) ⬝ w + -- This is a multilinear polynomial of degree `steps` + -- Tensor(r)_k corresponds to the Lagrange basis polynomial evaluated at r + let P : MvPolynomial (Fin steps) L := + ∑ k : Fin (2^steps), (MvPolynomial.C (w k)) * (MvPolynomial.eqPolynomial (r := bitsOfIndex k)) + have hP_eval : ∀ r, P.eval r = (challengeTensorExpansion steps r) ⬝ᵥ w := by + intro r + simp only [P, MvPolynomial.eval_sum, MvPolynomial.eval_mul, MvPolynomial.eval_C] + rw [dotProduct] + apply Finset.sum_congr rfl + intro k hk_univ + conv_lhs => rw [mul_comm] + congr 1 + -- evaluation of Lagrange basis matches tensor expansion + -- ⊢ (MvPolynomial.eval r) (eqPolynomial (bitsOfIndex k)) = challengeTensorExpansion steps r k + -- Unfold definitions to expose the product structure + unfold eqPolynomial singleEqPolynomial bitsOfIndex challengeTensorExpansion multilinearWeight + rw [MvPolynomial.eval_prod] -- prod structure of `eqPolynomial` + -- Now both sides have form `∏ (j : Fin steps), ...` + apply Finset.prod_congr rfl + intro j _ + -- Simplify polynomial evaluation + simp only [MonoidWithZeroHom.map_ite_one_zero, ite_mul, one_mul, zero_mul, + MvPolynomial.eval_add, MvPolynomial.eval_mul, MvPolynomial.eval_sub, map_one, + MvPolynomial.eval_X] + split_ifs with h_bit + · -- Case: Bit is 1 + simp only [sub_self, zero_mul, MvPolynomial.eval_X, zero_add] + · -- Case: Bit is 0 + simp only [sub_zero, one_mul, map_zero, add_zero] + have hP_nonzero : P ≠ 0 := by + -- Assume P = 0 for contradiction + intro h_P_zero + -- Since w ≠ 0, there exists some index k such that w k ≠ 0 + rcases Function.ne_iff.mp hw_ne_zero with ⟨k, hk_ne_zero⟩ + -- Let r_k be the bit-vector corresponding to index k + let r_k := bitsOfIndex (L := L) k + -- If P = 0, then P(r_k) must be 0 + have h_eval_zero : MvPolynomial.eval r_k P = 0 := by + rw [h_P_zero]; simp only [map_zero] + -- On the other hand, we proved P(r) = Tensor(r) ⬝ w + rw [hP_eval r_k] at h_eval_zero + -- Key Property: Tensor(r_k) is the indicator vector for k. + -- Tensor(r_k)[j] = 1 if j=k, 0 if j≠k. + have h_tensor_k : ∀ j, (challengeTensorExpansion steps r_k) j = if j = k then 1 else 0 := by + intro j + rw [challengeTensorExpansion_bitsOfIndex_is_eq_indicator (L := L) (n := steps) (k := k)] + -- Thus the dot product is exactly w[k] + rw [dotProduct, Finset.sum_eq_single k] at h_eval_zero + · simp only [h_tensor_k, if_true, one_mul] at h_eval_zero + exact hk_ne_zero h_eval_zero + · -- Other terms are zero + intro j _ h_ne + simp [h_tensor_k, h_ne] + · simp only [mem_univ, not_true_eq_false, _root_.mul_eq_zero, IsEmpty.forall_iff] -- Case where index k is not in univ (impossible for Fin n) + have hP_deg : P.totalDegree ≤ steps := by + -- Use the correct lemma from the list: sum degree ≤ d if all terms degree ≤ d + apply MvPolynomial.totalDegree_finsetSum_le + intro k _ + -- Bound degree of each term: deg(C * eqPoly) ≤ deg(C) + deg(eqPoly) = 0 + deg(eqPoly) + apply le_trans (MvPolynomial.totalDegree_mul _ _) + simp only [MvPolynomial.totalDegree_C, zero_add] + -- Bound degree of eqPolynomial (product of linear terms) + unfold eqPolynomial + -- deg(∏ f) ≤ ∑ deg(f) + apply le_trans (MvPolynomial.totalDegree_finset_prod _ _) + -- The sum of `steps` terms, each of degree ≤ 1 + trans ∑ (i : Fin steps), 1 + · apply Finset.sum_le_sum + intro i _ + -- Check degree of singleEqPolynomial: r*X + (1-r)*(1-X) + unfold singleEqPolynomial + -- deg(A + B) ≤ max(deg A, deg B) + apply (MvPolynomial.totalDegree_add _ _).trans + rw [max_le_iff] + constructor + · -- deg(C * X) ≤ 1 + apply (MvPolynomial.totalDegree_mul _ _).trans + -- simp [MvPolynomial.totalDegree_C, MvPolynomial.totalDegree_X] + -- ⊢ (1 - MvPolynomial.C (bitsOfIndex k i)).totalDegree + (1 - MvPolynomial.X i).totalDegree ≤ 1 + calc + _ ≤ ((1 : L[X Fin steps]) - MvPolynomial.X i).totalDegree := by + have h_left_le := MvPolynomial.totalDegree_sub_C_le (p := (1 : L[X Fin steps])) (r := bitsOfIndex k i) + simp only [totalDegree_one] at h_left_le -- (1 - C (bitsOfIndex k i)).totalDegree ≤ 0 + omega + _ ≤ max ((1 : L[X Fin steps]).totalDegree) ((MvPolynomial.X (R := L) i).totalDegree) := by + apply MvPolynomial.totalDegree_sub + _ = _ := by + simp only [totalDegree_one, totalDegree_X, _root_.zero_le, sup_of_le_right] + · -- deg(C * (X)) ≤ 1 + apply (MvPolynomial.totalDegree_mul _ _).trans + simp only [MvPolynomial.totalDegree_C, zero_add] + -- ⊢ (MvPolynomial.X i).totalDegree ≤ 1 + simp only [totalDegree_X, le_refl] + · simp only [sum_const, Finset.card_univ, Fintype.card_fin, smul_eq_mul, mul_one, le_refl] + -- 6. Apply Schwartz-Zippel using Pr_congr to switch the event + rw [Pr_congr (Q := fun r => MvPolynomial.eval r P = 0)] + · apply prob_schwartz_zippel_mv_polynomial P hP_nonzero hP_deg + · intro r + -- Show that (Folding Eq) ↔ (P(r) = 0) + rw [iterated_fold_eq_matrix_form 𝔽q β (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le), iterated_fold_eq_matrix_form 𝔽q β (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le)] + -- Expand the dot product logic: + unfold localized_fold_matrix_form single_point_localized_fold_matrix_form + rw [hP_eval] + rw [Matrix.dotProduct_mulVec] + simp only + -- ⊢ challengeTensorExpansion steps r ᵥ* foldMatrix 𝔽q β ⟨↑i, ⋯⟩ steps ⋯ y ⬝ᵥ fiberEvaluations 𝔽q β ⟨↑i, ⋯⟩ steps ⋯ f_i y = + -- challengeTensorExpansion steps r ⬝ᵥ + -- foldMatrix 𝔽q β ⟨↑i, ⋯⟩ steps ⋯ y *ᵥ fiberEvaluations 𝔽q β ⟨↑i, ⋯⟩ steps ⋯ f_bar_i y ↔ + -- challengeTensorExpansion steps r ⬝ᵥ w = 0 + rw [←sub_eq_zero] + -- Transform LHS: u ⬝ (M * a) - u ⬝ (M * b) = u ⬝ (M * a - M * b) + rw [←Matrix.dotProduct_mulVec] + rw [←dotProduct_sub] + -- Transform inner vector: M * a - M * b = M * (a - b) + rw [←Matrix.mulVec_sub] + -- Substitute definition of w: w = M * (vals_f - vals_f_bar) + -- Note: v_diff was defined as vals_f - vals_f_bar + -- And w was defined as M_y *ᵥ v_diff + -- Sum the bounds: |Δ_fiber| * (steps / |L|) + -- Since |Δ_fiber| ≤ |S_next|, this is bounded by |S_next| * steps / |L| + have h_card_fiber : Δ_fiber.toFinset.card ≤ Fintype.card S_next := + Finset.card_le_univ Δ_fiber.toFinset + calc + _ ≤ ∑ y ∈ Δ_fiber.toFinset, (steps : ENNReal) / L_card := by + apply Finset.sum_le_sum + intro y hy -- hy : y ∈ Δ_fiber.toFinset + let res := h_prob_y y (by exact Set.mem_toFinset.mp hy) + exact res + _ = (Δ_fiber.toFinset.card) * (steps / L_card) := by + simp only [Finset.sum_const, nsmul_eq_mul] + _ ≤ (Fintype.card S_next) * (steps / L_card) := by + gcongr + _ = (steps * Fintype.card S_next) / L_card := by + ring_nf + conv_rhs => rw [mul_div_assoc] + +open Code.InterleavedCode in +/-- **Lemma 4.21** (Interleaved Distance Preservation): +If `d⁽ⁱ⁾(f⁽ⁱ⁾, C⁽ⁱ⁾) ≥ d_{i+ϑ} / 2` (`f` is fiber-wise far wrt UDR), +then `d^{2^ϑ}( (f_j⁽ⁱ⁺ϑ⁾)_{j=0}^{2^ϑ - 1}, C^{(i+ϑ)^{2^ϑ}} ) ≥ d_{i+ϑ} / 2` + (i.e. interleaved distance ≥ UDR distance). +* **Main Idea of Proof:** For an ARBITRARY interleaved codeword `(g_j⁽ⁱ⁺ϑ⁾)`, +a "lift" `g⁽ⁱ⁾ ∈ C⁽ⁱ⁾` is constructed. It's shown that `g⁽ⁱ⁾` relates to `(g_j⁽ⁱ⁺ϑ⁾)` (via +folding with basis vectors as challenges) similarly to how `f⁽ⁱ⁾` relates to `(f_j⁽ⁱ⁺ϑ⁾)` (via +Lemma 4.9 and matrix `M_y`). Since `f⁽ⁱ⁾` is far from `g⁽ⁱ⁾` on many fibers (by hypothesis), and +`M_y` is invertible, the columns `(f_j⁽ⁱ⁺ϑ⁾(y))` and `(g_j⁽ⁱ⁺ϑ⁾(y))` must differ for these `y`, +establishing the distance for the interleaved words. -/ +lemma lemma_4_21_interleaved_word_UDR_far (i : Fin ℓ) (steps : ℕ) [NeZero steps] + {destIdx : Fin r} (h_destIdx : destIdx.val = i.val + steps) (h_destIdx_le : destIdx ≤ ℓ) + (f_i : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ⟨i, by omega⟩) + (h_far : ¬fiberwiseClose 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := ⟨i, by omega⟩) (steps := steps) + (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) (f := f_i)) : + let U := preTensorCombine_WordStack 𝔽q β i steps h_destIdx h_destIdx_le f_i + let C_next : Set (sDomain 𝔽q β h_ℓ_add_R_rate destIdx → L) := + BBF_Code 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) destIdx + ¬(jointProximityNat (C := C_next) (u := U) (e := Code.uniqueDecodingRadius (C := C_next))) := by + -- 1. Setup variables and definitions + let m := 2^steps + let S_next := sDomain 𝔽q β h_ℓ_add_R_rate destIdx + let C : Set (sDomain 𝔽q β h_ℓ_add_R_rate ⟨i, by omega⟩ → L) := + (BBF_Code 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ⟨i, by omega⟩) + let C_next := BBF_Code 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) destIdx + let C_int := C_next ^⋈ (Fin m) + let U_wordStack := preTensorCombine_WordStack 𝔽q β i steps h_destIdx h_destIdx_le f_i + let U_interleaved : InterleavedWord L (Fin m) S_next := ⋈|U_wordStack + let d_next := BBF_CodeDistance 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := destIdx) + let e_udr := Code.uniqueDecodingRadius (C := (C_next : Set (S_next → L))) + -- 2. Analyze the "Far" hypothesis + -- h_far : ¬(2 * fiberwiseDistance < d_next) ↔ 2 * fiberwiseDistance ≥ d_next + -- This means for ANY g ∈ C^(i), the number of fiber disagreements is ≥ d_next/2. + have h_fiber_dist_ge : ∀ g : BBF_Code 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ⟨i, by omega⟩, + 2 * (fiberwiseDisagreementSet 𝔽q β ⟨i, by omega⟩ steps h_destIdx h_destIdx_le f_i g).card ≥ d_next := by + -- Expand negation of fiberwiseClose definition + intro g + -- 1. Unwrap the "Far" hypothesis + -- "Not Close" means 2 * min_dist ≥ d_next + unfold fiberwiseClose at h_far + rw [not_lt] at h_far + -- 2. Set up the set of all distances + let dist_set := (fun (g' : C) => + (fiberwiseDisagreementSet 𝔽q β (i := ⟨i, by omega⟩) steps h_destIdx h_destIdx_le f_i g').card) '' Set.univ + -- 3. Show that the specific g's distance is >= the minimum distance + -- We use `csInf_le` which says inf(S) ≤ x for any x ∈ S (provided S is bounded below) + have h_min_le_g : fiberwiseDistance 𝔽q β (i := ⟨i, by omega⟩) steps h_destIdx h_destIdx_le f_i ≤ + (fiberwiseDisagreementSet 𝔽q β (i := ⟨i, by omega⟩) steps h_destIdx h_destIdx_le f_i g).card := by + apply csInf_le + -- S must be bounded below (0 is a lower bound for Nat) + · use 0 + rintro _ ⟨_, _, rfl⟩ + simp only [_root_.zero_le] + -- S must be nonempty (g exists) + · use g + simp only [Set.mem_univ, true_and] + rfl + -- 4. Transitivity: d_next ≤ 2 * min ≤ 2 * specific + calc + d_next ≤ 2 * fiberwiseDistance 𝔽q β (i := ⟨i, by omega⟩) steps h_destIdx h_destIdx_le f_i := by + norm_cast at h_far + _ ≤ 2 * (fiberwiseDisagreementSet 𝔽q β (i := ⟨i, by omega⟩) steps h_destIdx h_destIdx_le f_i g).card := by + let res := Nat.mul_le_mul_left (k := (2 : ℕ)) (h := (h_min_le_g)) + exact res + -- 3. Proof by Contradiction + -- Assume U is close to C_int (distance ≤ e_udr). + simp only + intro h_U_close -- Proof by Contradiction: Assume U is UDR-close to C_int. + -- By definition of jointProximityNat, this means U is e_udr-close to C_int. + -- Since C_int is nonempty, there exists a closest codeword V ∈ C_int. + obtain ⟨V_codeword, h_dist_U_V⟩ := jointProximityNat_iff_closeToInterleavedCodeword + (u := U_wordStack) (e := e_udr) (C := C_next) |>.mp h_U_close + -- 4. Construct the "Lifted" Codeword g + -- We claim there exists a g ∈ C^(i) such that applying `preTensorCombine_WordStack` to g yields V. + -- This essentially inverses the folding operation. M_y is invertible, so we can recover + -- the fiber evaluations of g from the columns of V. + -- The algebraic properties of Binius ensure this reconstructed g is a valid codeword in C^(i). + let g := lift_interleavedCodeword 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i steps h_destIdx h_destIdx_le V_codeword + have h_g_is_lift_of_V : (⋈|preTensorCombine_WordStack 𝔽q β i steps h_destIdx h_destIdx_le ↑g) + = V_codeword.val := by + apply preTensorCombine_of_lift_interleavedCodeword_eq_self 𝔽q β + -- 5. Equivalence of Disagreements via Non-Singular M_y + -- We show that U and V differ at column y iff f_i and g differ on the CORRESPONDING fiber of y. + -- This relies on U_y = M_y * f_fiber and V_y = M_y * g_fiber. + have h_disagreement_equiv : ∀ y : S_next, + (∃ x, + iteratedQuotientMap 𝔽q β h_ℓ_add_R_rate (i := ⟨i, by omega⟩) (destIdx := destIdx) + (k := steps) (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) x = y ∧ + f_i x ≠ g.val x) ↔ + getSymbol U_interleaved y ≠ getSymbol V_codeword y := by + intro y + let res := fiberwise_disagreement_isomorphism 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := i) (steps := steps) + (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) (f := f_i) (g := g.val) (y := y) + unfold fiberDiff at res + rw [res] + have h_col_U_y_eq : (preTensorCombine_WordStack 𝔽q β i steps h_destIdx h_destIdx_le f_i).getSymbol y + = getSymbol U_interleaved y := by rfl + have h_col_V_y_eq : (preTensorCombine_WordStack 𝔽q β i steps h_destIdx h_destIdx_le g.val).getSymbol y + = getSymbol V_codeword y := by + have h_get_symbol_eq : (preTensorCombine_WordStack 𝔽q β i steps h_destIdx h_destIdx_le + (g.val)).getSymbol y = getSymbol (⋈|preTensorCombine_WordStack 𝔽q β i steps h_destIdx h_destIdx_le ↑g) y := by rfl + rw [h_get_symbol_eq] + rw [h_g_is_lift_of_V] + -- ⊢ getSymbol (↑V_codeword) y = getSymbol V_codeword y (lhs is I(nterleaved) word, rhs is IC) + rfl + rw [h_col_U_y_eq, h_col_V_y_eq] + -- 6. Connect Distances + -- The Hamming distance Δ₀(U, V) is exactly the number of columns where they differ. + -- By the equivalence above, this is exactly the size of `fiberwiseDisagreementSet f_i g`. + have h_dist_eq : Δ₀(U_interleaved, V_codeword.val) ≥ + (fiberwiseDisagreementSet 𝔽q β (i := ⟨i, by omega⟩) steps h_destIdx h_destIdx_le f_i g).card := by + -- Use h_disagreement_equiv and definition of Hamming distance / fiberwiseDisagreementSet + -- We prove equality, which implies ≥ + apply le_of_eq + -- Definition of Hamming distance is count of {y | U y ≠ V y} + unfold hammingDist + -- Definition of fiberwiseDisagreementSet is {y | ∃ x ∈ Fiber(y), f x ≠ g x} + unfold fiberwiseDisagreementSet + -- We want to show card {y | U y ≠ V y} = card {y | fiber_diff } + -- It suffices to show the sets are equal. + congr 1 + ext y + simp only [Finset.mem_filter, Finset.mem_univ, true_and] + -- Apply the equivalence we proved in step 5 + rw [h_disagreement_equiv] + -- The LHS of h_disagreement_equiv matches the RHS of our goal here. + -- The RHS of h_disagreement_equiv matches the LHS of our goal here. + -- Just need to handle the `InterleavedWord` wrapper + rfl + -- 7. Contradiction Algebra + -- We have: + -- (1) 2 * dist(U, V) ≤ 2 * e_udr (by assumption h_U_close) + -- (2) 2 * e_udr < d_next (by definition of UDR) + -- (3) 2 * card(disagreement f g) ≥ d_next (by h_far hypothesis applied to g) + -- (4) dist(U, V) = card(disagreement f g) (by h_dist_eq) + -- Combining (3) and (4): 2 * dist(U, V) ≥ d_next + -- Combining (1) and (2): 2 * dist(U, V) < d_next + -- Contradiction. + have h_ineq_1 : ¬(2 * (fiberwiseDisagreementSet 𝔽q β (i := ⟨i, by omega⟩) steps + h_destIdx h_destIdx_le f_i g).card < d_next) := by + simp only [not_lt, h_fiber_dist_ge (g := ⟨g, by simp only [SetLike.coe_mem]⟩)] + have h_ineq_2 : + 2 * (fiberwiseDisagreementSet 𝔽q β (i := ⟨i, by omega⟩) steps + h_destIdx h_destIdx_le f_i g).card < d_next := by + calc + 2 * (fiberwiseDisagreementSet 𝔽q β (i := ⟨i, by omega⟩) steps + h_destIdx h_destIdx_le f_i g).card + _ ≤ 2 * Δ₀(U_interleaved, V_codeword.val) := by + omega + _ ≤ 2 * e_udr := by + -- {n m : Nat} (k : Nat) (h : n ≤ m) : n * k ≤ m * k := + let res := Nat.mul_le_mul_left (k := 2) (h := h_dist_U_V) + exact res + _ < d_next := by + -- ⊢ 2 * e_udr < d_next + letI : NeZero (‖(C_next : Set (S_next → L))‖₀) := NeZero.of_pos (by + have h_pos : 0 < + BBF_CodeDistance 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) destIdx := by + -- ⊢ 0 < 2 ^ (ℓ + 𝓡 - destIdx.val) - 2 ^ (ℓ - destIdx.val) + 1 + simp [BBF_CodeDistance_eq (L := L) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := destIdx) (h_i := h_destIdx_le)] + simpa [C_next, BBF_CodeDistance] using h_pos + ) + let res := Code.UDRClose_iff_two_mul_proximity_lt_d_UDR + (C := (C_next : Set (S_next → L))) (e := e_udr).mp (by omega) + exact res + exact h_ineq_1 h_ineq_2 + +lemma prop_4_20_case_2_fiberwise_far (i : Fin ℓ) (steps : ℕ) [NeZero steps] + {destIdx : Fin r} (h_destIdx : destIdx.val = i.val + steps) (h_destIdx_le : destIdx ≤ ℓ) + (f_i : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ⟨i, by omega⟩) + (h_far : ¬fiberwiseClose 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := ⟨i, by omega⟩) (steps := steps) + (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) (f := f_i)) : + let next_domain_size := Fintype.card (sDomain 𝔽q β h_ℓ_add_R_rate destIdx) + Pr_{ let r ←$ᵖ (Fin steps → L) }[ + let f_next := iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ⟨i, by omega⟩ steps + h_destIdx h_destIdx_le f_i r + UDRClose 𝔽q β destIdx h_destIdx_le f_next + ] ≤ ((steps * next_domain_size) / Fintype.card L) := by + -- This requires mapping the fiberwise distance to the interleaved code distance + -- and applying the tensor product proximity gap results from DG25.lean. + let S_next := sDomain 𝔽q β h_ℓ_add_R_rate destIdx + let L_card := Fintype.card L + -- 1. Construct the interleaved word U from f_i + -- U is a matrix of size m x |S_next|, where row j corresponds to the j-th fiber index. + let U := + preTensorCombine_WordStack 𝔽q β i steps (destIdx := destIdx) h_destIdx h_destIdx_le f_i + -- 2. Translate Fiber-wise Distance to Interleaved Distance + -- The fiberwise distance is exactly the minimum Hamming distance between + -- the columns of U (viewed as vectors in L^m) and the code C^m (interleaved). + -- Actually, based on Def 4.15/4.16, fiberwiseDistance is the distance of f_i to C_i + -- but viewed through the fibers. This corresponds to the distance of U to C_next^m. + let C_next := BBF_Code 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) destIdx + let C_interleaved := C_next ^⋈ (Fin (2^steps)) + let d_next := BBF_CodeDistance 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := destIdx) + -- 3. Apply Tensor Gap Theorem (Contrapositive) + -- Theorem 3.6 / Corollary 3.7 states: + -- If Pr[ multilinearCombine(U, r) is close ] > ε/|L|, then U is close to C_int. + -- Contrapositive: If U is FAR from C_int, then Pr[ multilinearCombine(U, r) is close ] ≤ ε/|L|. + -- We identify "close" as distance ≤ e, where e = floor((d_next - 1) / 2). + let e_prox := (d_next - 1) / 2 + -- Check that "far" hypothesis implies "not close" + -- h_U_far says 2*dist ≥ d_next. + -- "Close" means dist ≤ e_prox = (d_next - 1)/2 < d_next/2. + -- So U is strictly greater than e_prox distance away. + have h_U_not_UDR_close : ¬ (jointProximityNat (u := U) (e := e_prox) (C := (C_next : Set _))) := by + apply lemma_4_21_interleaved_word_UDR_far 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := i) + (steps := steps) (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) (f_i := f_i) (h_far := h_far) + -- The epsilon for RS codes / Tensor Gaps is typically |S_next| * steps (or similar). + -- In DG25 Cor 3.7, ε = |S_next|. The bound is ϑ * ε / |L|. + let ε_gap := Fintype.card S_next + -- Apply the Tensor Gap Theorem (Corollary 3.7 for RS codes or Theorem 3.6 generic) + have h_prob_bound : + Pr_{ let r ←$ᵖ (Fin steps → L) }[ Δ₀(multilinearCombine U r, C_next) ≤ e_prox ] + ≤ (steps * ε_gap) / L_card := by + -- Apply contrapositive of h_tensor_gap applied to U + by_contra h_prob_gt_bound + let α := Embedding.subtype fun (x : L) ↦ x ∈ S_next + let C_i_plus_steps := BBF_Code 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) destIdx + let RS_i_plus_steps := ReedSolomon.code α (2^(ℓ - destIdx.val)) + letI : Nontrivial (RS_i_plus_steps) := by infer_instance + let h_tensor_gap := reedSolomon_multilinearCorrelatedAgreement_Nat (A := L) (ι := sDomain 𝔽q β h_ℓ_add_R_rate destIdx) + (α := α) + (k := 2^(ℓ - destIdx.val)) + (hk := by + rw [sDomain_card 𝔽q β h_ℓ_add_R_rate (i := destIdx) (h_i := Sdomain_bound (by + exact h_destIdx_le)), hF₂.out] + have h_exp : ℓ - destIdx.val ≤ ℓ + 𝓡 - destIdx.val := by + omega + exact Nat.pow_le_pow_right (hx := by omega) h_exp + ) + (e := e_prox) (he := by exact Nat.le_refl _) + (ϑ := steps) (hϑ_gt_0 := by exact Nat.pos_of_neZero steps) + -- 3. Apply the theorem to our specific word U + -- This concludes "U is close" (jointProximityNat) + let h_U_UDR_close : jointProximityNat (C := C_i_plus_steps) U e_prox := + h_tensor_gap U (by + rw [ENNReal.coe_natCast] + rw [not_le] at h_prob_gt_bound + exact h_prob_gt_bound + ) + exact h_U_not_UDR_close h_U_UDR_close + -- 4. Connect Folding to Multilinear Combination + -- Show that `iterated_fold` is exactly `multilinearCombine` of `U` + -- Lemma 4.9 (iterated_fold_eq_matrix_form) essentially establishes this connection + -- multilinearCombine U r = Tensor(r) ⬝ U = iterated_fold f r + have h_fold_eq_combine : ∀ r, + (iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ⟨i, by omega⟩ (steps := steps) + (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) f_i r) = + multilinearCombine U r := by + intro r + ext y + rw [iterated_fold_eq_matrix_form 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := ⟨i, by omega⟩) (steps := steps) + (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) (f := f_i) (r_challenges := r)] + unfold localized_fold_matrix_form single_point_localized_fold_matrix_form multilinearCombine + simp only [dotProduct, smul_eq_mul] + apply Finset.sum_congr rfl + intro (rowIdx : Fin (2^steps)) h_rowIdx_univ + rfl + -- 5. Conclusion + -- The event inside the probability is: 2 * dist(folded, C_next) < d_next + -- This is equivalent to dist(folded, C_next) ≤ (d_next - 1) / 2 = e_prox + rw [Pr_congr (Q := fun r => Δ₀(multilinearCombine U r, C_next) ≤ e_prox)] + · exact h_prob_bound + · intro r + rw [←h_fold_eq_combine] + rw [UDRClose_iff_within_UDR_radius] + have h_e_prox_def : e_prox = Code.uniqueDecodingRadius + (C := (C_next : Set (sDomain 𝔽q β h_ℓ_add_R_rate destIdx → L))) := by rfl + rw [h_e_prox_def] + +/-! +### Soundness Lemmas (4.20 - 4.25) +-/ + +open Classical in +/-- **Proposition 4.20** (Bound on Bad Folding Event): +The probability (over random challenges `r`) of the bad folding event is bounded. +Bound: `μ(Eᵢ) ≤ ϑ ⋅ |S⁽ⁱ⁺ϑ⁾| / |L|` (where `μ(R) = Pr_{ let r ←$ᵖ (Fin steps → L) }[ R ]`) +**Case 1: Fiber-wise close** => + `μ(Δ⁽ⁱ⁾(f⁽ⁱ⁾, f̄⁽ⁱ⁾) ⊄ Δ_folded_disagreement) ≤ steps · |S⁽ⁱ⁺steps⁾| / |L|` +Proof strategy: +- Show that `∀ y ∈ Δ_fiber, μ(y ∉ Δ_folded_disagreement) ≤ steps / |L|` +- Apply the Union Bound over `y ∈ Δ_fiber` +**Case 2: Fiber-wise far** => + μ(`d(fold(f⁽ⁱ⁾, rᵢ', ..., rᵢ₊steps₋₁'), C⁽ⁱ⁺steps⁾) < dᵢ₊steps / 2`) ≤ steps · |S⁽ⁱ⁺steps⁾| / |L| +-/ +lemma prop_4_20_bad_event_probability (i : Fin ℓ) (steps : ℕ) [NeZero steps] + {destIdx : Fin r} (h_destIdx : destIdx.val = i.val + steps) (h_destIdx_le : destIdx ≤ ℓ) + (f_i : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ⟨i, by omega⟩) : + let domain_size := Fintype.card (sDomain 𝔽q β h_ℓ_add_R_rate destIdx) + Pr_{ let r_challenges ←$ᵖ (Fin steps → L) }[ + foldingBadEvent 𝔽q β (i := ⟨i, by omega⟩) steps h_destIdx h_destIdx_le f_i r_challenges ] ≤ + ((steps * domain_size) / Fintype.card L) := by + let S_next := sDomain 𝔽q β h_ℓ_add_R_rate destIdx + let L_card := Fintype.card L + -- Unfold the event definition to split into the two cases + unfold foldingBadEvent + by_cases h_close : fiberwiseClose 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := ⟨i, by omega⟩) (steps := steps) (h_destIdx := h_destIdx) + (h_destIdx_le := h_destIdx_le) (f := f_i) + · -- CASE 1: Fiber-wise Close (The main focus of the provided text) + simp only [h_close, ↓reduceDIte] + let res := prop_4_20_case_1_fiberwise_close 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := i) + (steps := steps) (h_destIdx := h_destIdx) + (h_destIdx_le := h_destIdx_le) (f_i := f_i) (h_close := h_close) + exact res + · -- CASE 2: Fiber-wise Far + -- The bad event is that the folded function becomes UDRClose. + simp only [h_close, ↓reduceDIte] + -- If fiberwise distance is "far" (≥ d_next / 2), + -- then the probability of becoming "close" (< d_next / 2) is bounded. + apply prop_4_20_case_2_fiberwise_far 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (steps := steps) + (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) (h_far := h_close) + +omit [CharP L 2] [DecidableEq 𝔽q] hF₂ [NeZero 𝓡] [SampleableType L] in +lemma iteratedQuotientMap_succ_comp + (i : Fin r) {midIdx destIdx : Fin r} (steps : ℕ) + (h_midIdx : midIdx.val = i.val + 1) + (h_destIdx : destIdx.val = i.val + (steps + 1)) + (h_destIdx_le : destIdx ≤ ℓ) + (x : sDomain 𝔽q β h_ℓ_add_R_rate i) : + iteratedQuotientMap 𝔽q β h_ℓ_add_R_rate + (i := i) (k := steps + 1) (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) x + = + iteratedQuotientMap 𝔽q β h_ℓ_add_R_rate + (i := midIdx) (k := steps) + (h_destIdx := by omega) + (h_destIdx_le := h_destIdx_le) + (iteratedQuotientMap 𝔽q β h_ℓ_add_R_rate + (i := i) (k := 1) (h_destIdx := h_midIdx) (h_destIdx_le := by omega) x) := by + apply Subtype.ext + simp only [iteratedQuotientMap] + have h_poly_comp := intermediateNormVpoly_comp 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := i) (destIdx := midIdx) (k := 1) (l := steps) + (h_destIdx := by simpa using h_midIdx) (h_k := by omega) (h_l := by omega) + have h_poly_comp' : + intermediateNormVpoly 𝔽q β h_ℓ_add_R_rate (i := i) (k := steps + 1) (h_k := by omega) = + (intermediateNormVpoly 𝔽q β h_ℓ_add_R_rate (i := midIdx) (k := steps) + (h_k := by omega)).comp + (intermediateNormVpoly 𝔽q β h_ℓ_add_R_rate (i := i) (k := 1) (h_k := by omega)) := by + simpa [Nat.add_comm] using h_poly_comp + rw [h_poly_comp'] + simp only [Polynomial.eval_comp] + +open Classical in +/-- **Proposition 4.20.2 (Case 1: FiberwiseClose)**. +Incremental bad-event bound for a fixed block start and fixed consumed prefix, under the +block-level close branch. + +The fresh event at step `k` is +`ℰ_{i+k} = ¬ E(i, k) ∧ E(i, k+1)` where `E := incrementalFoldingBadEvent`. + +#### **Case 1: FiberwiseClose** + +**Hypothesis:** `d^{(i)}(f^{(i)}, C^{(i)}) < d_{i+ϑ} / 2`. +**Condition:** We assume the bad event has *not* happened up to step `k` (i.e., `¬ E(i, k)` holds). This implies: +`Δ^{(i)}(f^{(i)}, f_bar^{(i)}) ⊆ Δ^{(i+k)}(fold_k(f^{(i)}), fold_k(f_bar^{(i)}))` +where `Δ^{(i+k)}` is the disagreement set projected to the destination domain `S^{i+ϑ}`. + +We must bound the probability that a quotient point `y ∈ Δ^{(i+k)}` "vanishes" from the disagreement set in the next step `k+1`, i.e. `y ∉ Δ^{(i+k+1)}(fold(fold_k(f^{(i)}), r), fold(fold_k(f_bar^{(i)}), r))`. Let `f_k := fold_k(f^{(i)})` and `f_bar_k := fold_k(f_bar^{(i)})`. + +Fix any `y ∈ Δ^{(i+k)}`. + +* By definition, there exists at least one point `z` in the fiber of `y` (within the current domain `S^{i+k}`) such that `f_k(z) ≠ f_bar_k(z)` (by definition of `Δ^{(i+k)}`). + +Consider the folding step `S^{i+k} → S^{i+k+1}`. The map `q` pairs points in `S^{i+k}` (say `x₀, x₁`) to a single point `w` in `S^{i+k+1}`. +The folded value at `w` is defined as (Definition 4.6): +`fold(f_k, r)(w) = [1-r, r] · M · [f_k(x₀), f_k(x₁)]ᵀ` +where `M = [[x₁, -x₀], [-1, 1]]` is an invertible matrix. + +Let `E_y(r)(w)` (where `y ∈ Δ^{(i+k)}(fold_k(f^{(i)}), fold_k(f_bar^{(i)}))`) be the difference between the folded values of `f_k` and `f_bar_k` in `S^{i+k+1}` at `w`: +`E_y(r)(w) := fold(f_k, r)(w) - fold(f_bar_k, r)(w)` + +Linearity allows us to rewrite this as: +`E_y(r)(w) = [1-r, r] · M · [f_k(x₀) - f_bar_k(x₀), f_k(x₁) - f_bar_k(x₁)]ᵀ` + +Since `y ∈ Δ^{(i+k)} ⊂ S^{i+ϑ}`, the difference vector `v_vec = [f_k(x₀) - f_bar_k(x₀), f_k(x₁) - f_bar_k(x₁)]ᵀ` is non-zero for at least one pair `(x₀, x₁)` in the fiber of `y` (otherwise `f_k` is equal to `f_bar_k` at all points in `S^{i+k}`, contradicting the definition of `Δ^{(i+k)}`). + +Because `M` is invertible, the vector `v_vec' = M · v_vec` is also **non-zero**. Let `v_vec' = [a, b]ᵀ`. Then: +`E_y(r)(w) = a(1-r) + br = a + (b-a)r` + +This is a polynomial in `r` of degree at most 1. Since `v_vec' ≠ 0`, the **coefficients `a` and `b` cannot both be zero**. + +* If `b ≠ a`, `E_y(r)(w)` has exactly one root. +* If `b = a ≠ 0`, `E_y(r)(w) = a ≠ 0`, so it has no roots. + +Thus, `E_y(r)(w) = 0` (i.e. **the case where the point `y` disappears from `Δ^{i+k+1}`, though it was assumed to be in `Δ^{i+k}**`) with probability at most `1 / |L|` (**Schwartz-Zippel Lemma**). + +If `E_y(r)(w) ≠ 0`, then `w ∈ Δ^{(i+k+1)}`, meaning `y` is preserved in the projected disagreement set, so it's not the case we care. + +Applying the Union Bound over all `y ∈ Δ^{(i)} ⊆ S^{i+ϑ}` (noting that `|Δ^{(i)}| ≤ |S^{i+ϑ}|`): +`Pr[∃ y ∈ Δ^{(i)}, y ∉ Δ^{(i+k+1)}] ≤ ∑_{y ∈ Δ^{(i)}} 1 / |L| ≤ |S^{i+ϑ}| / |L|` + +This completes the proof for Case 1. +-/ +lemma prop_4_20_2_case_1_fiberwise_close_incremental + (block_start_idx : Fin r) {midIdx_i midIdx_i_succ destIdx : Fin r} (k : ℕ) (h_k_lt : k < ϑ) + (h_midIdx_i : midIdx_i = block_start_idx + k) (h_midIdx_i_succ : midIdx_i_succ = block_start_idx + k + 1) + (h_destIdx : destIdx = block_start_idx + ϑ) (h_destIdx_le : destIdx ≤ ℓ) + (f_block_start : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) block_start_idx) + (r_prefix : Fin k → L) + (h_block_close : fiberwiseClose 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := block_start_idx) (steps := ϑ) (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) + (f := f_block_start)) : + let domain_size := Fintype.card (sDomain 𝔽q β h_ℓ_add_R_rate destIdx) + Pr_{ let r_new ← $ᵖ L }[ + ¬ incrementalFoldingBadEvent 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (block_start_idx := block_start_idx) (midIdx := midIdx_i) (destIdx := destIdx) (k := k) + (h_k_le := Nat.le_of_lt h_k_lt) (h_midIdx := h_midIdx_i) (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) + (f_block_start := f_block_start) (r_challenges := r_prefix) + ∧ + incrementalFoldingBadEvent 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (block_start_idx := block_start_idx) (midIdx := midIdx_i_succ) (destIdx := destIdx) (k := k + 1) + (h_k_le := Nat.succ_le_of_lt h_k_lt) (h_midIdx := h_midIdx_i_succ) (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) + (f_block_start := f_block_start) + (r_challenges := Fin.snoc r_prefix r_new) + ] ≤ + (domain_size / Fintype.card L) := by + -- ──────────────────────────────────────────────────────── + -- Step 0: Simplify incrementalFoldingBadEvent using h_block_close + -- ──────────────────────────────────────────────────────── + dsimp only [incrementalFoldingBadEvent] + have h_k_succ_ne_0 : ¬(k + 1 = 0) := by omega + simp only [h_block_close, ↓reduceDIte] + -- ──────────────────────────────────────────────────────── + -- Step 1: Name the key objects + -- ──────────────────────────────────────────────────────── + let f_i := f_block_start + let f_bar_i : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) block_start_idx := + UDRCodeword 𝔽q β (i := block_start_idx) (h_i := by omega) + (f := f_i) (h_within_radius := UDRClose_of_fiberwiseClose 𝔽q β block_start_idx ϑ h_destIdx h_destIdx_le f_i h_block_close) + let Δ_fiber : Finset (sDomain 𝔽q β h_ℓ_add_R_rate destIdx) := + fiberwiseDisagreementSet 𝔽q β (i := block_start_idx) ϑ h_destIdx h_destIdx_le f_i f_bar_i + -- The k-step folds (fixed, no r_new dependency) + let fold_k_f := iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := block_start_idx) (steps := k) (h_destIdx := h_midIdx_i) (h_destIdx_le := by omega) + (f := f_i) (r_challenges := r_prefix) + let fold_k_f_bar := iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := block_start_idx) (steps := k) (h_destIdx := h_midIdx_i) (h_destIdx_le := by omega) + (f := f_bar_i) (r_challenges := r_prefix) + -- ──────────────────────────────────────────────────────── + -- Step 2: Factor out the deterministic ¬E(k) conjunct. + -- ¬E(k) = (Δ_fiber ⊆ disagr_set_at_k) does NOT depend on r_new, + -- so we case-split: if false, Pr = 0; if true, use it as hypothesis. + -- ──────────────────────────────────────────────────────── + -- The ¬E(k) predicate (subset condition at step k) + let not_Ek := Δ_fiber ⊆ fiberwiseDisagreementSet 𝔽q β + midIdx_i (ϑ - k) (by omega) h_destIdx_le fold_k_f fold_k_f_bar + by_cases h_not_Ek : not_Ek + swap + · -- Case: ¬not_Ek, i.e. ¬(Δ_fiber ⊆ D_k). Then ¬¬(Δ ⊆ D_k) = False, so conjunction always False. + -- Pr[always False] = 0 ≤ bound. + apply le_trans (Pr_le_Pr_of_implies ($ᵖ L) _ (fun _ => False) (fun r_new h => absurd (not_not.mp h.1) h_not_Ek)) + simp only [PMF.monad_pure_eq_pure, PMF.monad_bind_eq_bind, PMF.bind_const, PMF.pure_apply, + eq_iff_iff, iff_false, not_true_eq_false, ↓reduceIte, _root_.zero_le]; + · -- pos case + -- From here: h_not_Ek : Δ_fiber ⊆ fiberwiseDisagreementSet(midIdx_i, ϑ-k, fold_k_f, fold_k_f_bar) + -- Use prob_mono to drop the ¬E(k) conjunct (it's deterministically true). + apply le_trans (Pr_le_Pr_of_implies ($ᵖ L) _ _ (fun r_new h => h.2)) + -- ──────────────────────────────────────────────────────── + -- Step 3: Bound Pr_{r_new}[E(k+1)] ≤ |S^{destIdx}| / |L| + -- ──────────────────────────────────────────────────────── + -- E(k+1) = ¬(Δ_fiber ⊆ fiberwiseDisagreementSet(midIdx_i_succ, ϑ-(k+1), + -- fold_{k+1}(f, snoc r_prefix r_new), fold_{k+1}(f̄, snoc r_prefix r_new))) + -- + -- Strategy: Union Bound + single-step Schwartz-Zippel (degree ≤ 1 in r_new). + -- + -- (3a) E(k+1) = ∃ y ∈ Δ_fiber, y ∉ disagreement set at step k+1. + -- (3b) By union bound: Pr[∃ y dropped] ≤ ∑_{y ∈ Δ_fiber} Pr[y dropped]. + -- (3c) Per-point bound: Pr[y dropped] ≤ 1/|L|. + -- fold_{k+1} = fold(fold_k, r_new) by iterated_fold_last. + -- The fold difference at any fiber point w is a + (b-a)·r_new (degree ≤ 1). + -- By non-degeneracy (butterfly matrix invertible), the polynomial is non-zero + -- for any y with disagreeing fiber values. By Schwartz-Zippel, ≤ 1/|L|. + -- (3d) Sum: |Δ_fiber| · (1/|L|) ≤ |S^{destIdx}| / |L|. + let L_card := Fintype.card L + -- Convert probability to cardinality ratio + rw [prob_uniform_eq_card_filter_div_card] + -- ── 3d: Per-point Schwartz-Zippel + union bound ── + -- Per-point Schwartz-Zippel: |{r_new : y dropped}| ≤ 1 for each y, + -- because fold difference is degree-1 in r_new with at most 1 root. + have h_per_point_card : ∀ y ∈ Δ_fiber, -- y must be in Δ_fiber to ensure non-trivial fiber disagreement + (Finset.filter (fun r_new => + y ∉ fiberwiseDisagreementSet 𝔽q β + midIdx_i_succ (ϑ - (k + 1)) (by omega) h_destIdx_le + (iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := block_start_idx) (steps := k + 1) + (h_destIdx := h_midIdx_i_succ) (h_destIdx_le := by omega) + (f := f_i) (r_challenges := Fin.snoc r_prefix r_new)) + (iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := block_start_idx) (steps := k + 1) + (h_destIdx := h_midIdx_i_succ) (h_destIdx_le := by omega) + (f := f_bar_i) (r_challenges := Fin.snoc r_prefix r_new))) + Finset.univ).card ≤ 1 := by + intro y hy_in_Δ + -- ════════════════════════════════════════════════════════ + -- A. Decompose iterated_fold(k+1, Fin.snoc r_prefix r_new) + -- = fold(fold_k, r_new) via iterated_fold_last + -- ════════════════════════════════════════════════════════ + -- A1. iterated_fold(k+1, snoc r_prefix r_new) pointwise equals + -- fold(iterated_fold(k, Fin.init (snoc r_prefix r_new)), snoc r_prefix r_new (Fin.last k)) + have h_decomp_f : ∀ r_new : L, + iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := block_start_idx) (steps := k + 1) + (h_destIdx := h_midIdx_i_succ) (h_destIdx_le := by omega) + (f := f_i) (r_challenges := Fin.snoc r_prefix r_new) + = fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := midIdx_i) + (destIdx := midIdx_i_succ) (h_destIdx := by omega) (h_destIdx_le := by omega) + (f := fold_k_f) (r_chal := r_new) := by + intro r_new + have := iterated_fold_last 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := block_start_idx) (steps := k) (midIdx := midIdx_i) (destIdx := midIdx_i_succ) + (h_midIdx := h_midIdx_i) (h_destIdx := h_midIdx_i_succ) (h_destIdx_le := by omega) + (f := f_i) (r_challenges := Fin.snoc r_prefix r_new) + simp only [Fin.init_snoc, Fin.snoc_last] at this + exact this + have h_decomp_f_bar : ∀ r_new : L, + iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := block_start_idx) (steps := k + 1) + (h_destIdx := h_midIdx_i_succ) (h_destIdx_le := by omega) + (f := f_bar_i) (r_challenges := Fin.snoc r_prefix r_new) + = fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := midIdx_i) + (destIdx := midIdx_i_succ) (h_destIdx := by omega) (h_destIdx_le := by omega) + (f := fold_k_f_bar) (r_chal := r_new) := by + intro r_new + have := iterated_fold_last 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := block_start_idx) (steps := k) (midIdx := midIdx_i) (destIdx := midIdx_i_succ) + (h_midIdx := h_midIdx_i) (h_destIdx := h_midIdx_i_succ) (h_destIdx_le := by omega) + (f := f_bar_i) (r_challenges := Fin.snoc r_prefix r_new) + simp only [Fin.init_snoc, Fin.snoc_last] at this + exact this + -- ════════════════════════════════════════════════════════ + -- B. Identify a witness fiber point w ∈ S^{i+k+1} where + -- the fold_k values disagree in the fiber of y + -- ════════════════════════════════════════════════════════ + -- B1. y ∈ Δ_fiber means ∃ x in fiber of y at level block_start_idx + -- where f_i(x) ≠ f̄_i(x). We need to lift this to level i+k+1. + -- B2. Construct w ∈ S^{i+k+1} such that: + -- (a) w is in the fiber of y (from midIdx_i_succ to destIdx), and + -- (b) in the fiber of w at level i+k, fold_k values disagree. + have h_exists_disagreeing_w : + ∃ w : sDomain 𝔽q β h_ℓ_add_R_rate midIdx_i_succ, + (iteratedQuotientMap 𝔽q β h_ℓ_add_R_rate + (i := midIdx_i_succ) (k := ϑ - (k + 1)) + (h_destIdx := by omega) (h_destIdx_le := h_destIdx_le) w = y) ∧ + (let fiberMap := qMap_total_fiber 𝔽q β (i := midIdx_i) (steps := 1) + (h_destIdx := by omega) (h_destIdx_le := by omega) (y := w) + let x₀ := fiberMap 0 + let x₁ := fiberMap 1 + (fold_k_f x₀ ≠ fold_k_f_bar x₀ ∨ fold_k_f x₁ ≠ fold_k_f_bar x₁)) := by + -- From h_not_Ek and hy_in_Δ, extract z in the fiber at level midIdx_i + have hy_in_disagr := h_not_Ek hy_in_Δ + simp only [fiberwiseDisagreementSet, Finset.mem_filter, Finset.mem_univ, + true_and] at hy_in_disagr + obtain ⟨z, hz_quotient, hz_ne⟩ := hy_in_disagr + -- Set w := iteratedQuotientMap(z, midIdx_i → midIdx_i_succ) + let w : sDomain 𝔽q β h_ℓ_add_R_rate midIdx_i_succ := + iteratedQuotientMap 𝔽q β h_ℓ_add_R_rate + (i := midIdx_i) (k := 1) (h_destIdx := by omega) + (h_destIdx_le := by omega) z + refine ⟨w, ?_, ?_⟩ + · -- iteratedQuotientMap(w, midIdx_i_succ → destIdx) = y + have h_factor := iteratedQuotientMap_succ_comp 𝔽q β + (i := midIdx_i) (midIdx := midIdx_i_succ) (destIdx := destIdx) + (steps := ϑ - k - 1) (h_midIdx := by omega) + (h_destIdx := by omega) (h_destIdx_le := h_destIdx_le) z + rw [←hz_quotient] + have h_factor_congr := iteratedQuotientMap_congr_k 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := midIdx_i) (k₁ := (ϑ - k - 1) + 1) (k₂ := ϑ - k) + (hk := by omega) (h_destIdx₁ := by omega) (h_destIdx₂ := by omega) + (h_destIdx_le := h_destIdx_le) z + rw [← h_factor_congr, h_factor] + · -- z is one of x₀ or x₁ in the fiber of w, hence fold_k disagreement + intro fiberMap x₀ x₁ + have h_midIdx_i_succ_le : midIdx_i_succ.val ≤ ℓ := by omega + have hw_eq : w = iteratedQuotientMap 𝔽q β h_ℓ_add_R_rate + (i := midIdx_i) (k := 1) (h_destIdx := by omega) + (h_destIdx_le := h_midIdx_i_succ_le) z := rfl + have hz_fiber := (is_fiber_iff_generates_quotient_point 𝔽q β + (i := midIdx_i) (steps := 1) (h_destIdx := by omega) + (h_destIdx_le := h_midIdx_i_succ_le) + z w).mp hw_eq + set idx := pointToIterateQuotientIndex 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := midIdx_i) (steps := 1) (h_destIdx := by omega) + (h_destIdx_le := h_midIdx_i_succ_le) z with h_idx_def + have hz_eq : fiberMap idx = z := hz_fiber + by_cases h0 : idx = 0 + · left; rw [h0] at hz_eq + change fold_k_f (fiberMap 0) ≠ fold_k_f_bar (fiberMap 0) + rw [hz_eq]; exact hz_ne + · right; have h1 : idx = 1 := Fin.eq_one_of_ne_zero idx h0 + rw [h1] at hz_eq + change fold_k_f (fiberMap 1) ≠ fold_k_f_bar (fiberMap 1) + rw [hz_eq]; exact hz_ne + obtain ⟨w, hw_in_fiber, hw_disagree⟩ := h_exists_disagreeing_w + -- ════════════════════════════════════════════════════════ + -- C. The fold difference at w is a degree-≤1 polynomial in r_new. + -- fold(fold_k_f, r)(w) - fold(fold_k_f̄, r)(w) + -- = Δ₀ · ((1-r)·x₁ - r) + Δ₁ · (r - (1-r)·x₀) + -- where Δ_j = fold_k_f(x_j) - fold_k_f̄(x_j). + -- ════════════════════════════════════════════════════════ + let fiberMap_w := qMap_total_fiber 𝔽q β (i := midIdx_i) (steps := 1) + (h_destIdx := by omega) (h_destIdx_le := by omega) (y := w) + let x₀ := fiberMap_w 0 + let x₁ := fiberMap_w 1 + let Δ₀ := fold_k_f x₀ - fold_k_f_bar x₀ + let Δ₁ := fold_k_f x₁ - fold_k_f_bar x₁ + -- C1. The fold difference equals the affine polynomial + have h_fold_diff : ∀ r_new : L, + fold 𝔽q β (i := midIdx_i) (h_destIdx := by omega) (h_destIdx_le := by omega) + (f := fold_k_f) (r_chal := r_new) w + - fold 𝔽q β (i := midIdx_i) (h_destIdx := by omega) (h_destIdx_le := by omega) + (f := fold_k_f_bar) (r_chal := r_new) w + = Δ₀ * ((1 - r_new) * x₁.val - r_new) + + Δ₁ * (r_new - (1 - r_new) * x₀.val) := by + intro r_new + simp only [fold, Δ₀, Δ₁, x₀, x₁, fiberMap_w] + ring + -- C2. (Δ₀, Δ₁) ≠ (0, 0) from hw_disagree + have h_Δ_ne_zero : Δ₀ ≠ 0 ∨ Δ₁ ≠ 0 := by + rcases hw_disagree with h0 | h1 + · left; exact sub_ne_zero.mpr h0 + · right; exact sub_ne_zero.mpr h1 + -- ════════════════════════════════════════════════════════ + -- D. The polynomial a + (b-a)·r has at most 1 root. + -- Here a = Δ₀·x₁ - Δ₁·x₀ and (b-a) involves the + -- butterfly matrix coefficients. Since the butterfly + -- matrix [[x₁, -x₀],[-1,1]] is invertible (det = x₁-x₀ ≠ 0) + -- and (Δ₀,Δ₁) ≠ 0, we get (a,b) ≠ (0,0), so the + -- polynomial is non-trivial → ≤ 1 root. + -- ════════════════════════════════════════════════════════ + -- The polynomial P(r) = Δ₀·((1-r)·x₁-r) + Δ₁·(r-(1-r)·x₀) can be rewritten as: + -- P(r) = (Δ₀·x₁ - Δ₁·x₀) + r·(Δ₁·(1+x₀) - Δ₀·(1+x₁)) + -- This corresponds to [1-r, r] · M · [Δ₀, Δ₁]ᵀ where M = [[x₁,-x₀],[-1,1]]. + -- det(M) = x₁ - x₀ ≠ 0 (distinct NTT points in the fiber). + -- Since (Δ₀,Δ₁) ≠ 0 and M invertible, M·[Δ₀,Δ₁]ᵀ ≠ 0. + -- P has at most 1 root → P(r₁) = P(r₂) = 0 ⟹ r₁ = r₂. + have h_x₀_ne_x₁ : (x₀ : L) ≠ (x₁ : L) := by + have h_inj := qMap_total_fiber_injective 𝔽q β midIdx_i 1 + (by omega) (by omega : midIdx_i_succ.val ≤ ℓ) w + have h_ne : (0 : Fin (2 ^ 1)) ≠ 1 := by decide + exact Subtype.val_injective.ne (h_inj.ne h_ne) + -- In char 2: sub = add, neg = id. So P(r) simplifies to: + -- P(r) = Δ₀·((1+r)·x₁ + r) + Δ₁·(r + (1+r)·x₀) + -- = (Δ₀·x₁ + Δ₁·x₀) + r·(Δ₀·(x₁+1) + Δ₁·(x₀+1)) + -- Let a := Δ₀·x₁ + Δ₁·x₀, c := Δ₀·(x₁+1) + Δ₁·(x₀+1). + -- Then P(r) = a + c·r. If c ≠ 0, exactly 1 root. If c = 0, then a ≠ 0 + -- (by butterfly invertibility + (Δ₀,Δ₁) ≠ 0), so no roots. + -- Either way, P(r₁)=P(r₂)=0 ⟹ r₁=r₂. + -- Char-2 rewrite of the polynomial + have h_poly_char2 : ∀ r_val : L, + Δ₀ * ((1 - r_val) * x₁.val - r_val) + Δ₁ * (r_val - (1 - r_val) * x₀.val) = + (Δ₀ * x₁.val + Δ₁ * x₀.val) + + r_val * (Δ₀ * (x₁.val + 1) + Δ₁ * (x₀.val + 1)) := by + intro r_val + simp only [CharTwo.sub_eq_add] + ring + -- Helper: in char 2, u + v = 0 ↔ u = v + have char2_add_zero : ∀ (u v : L), u + v = 0 ↔ u = v := + sum_zero_iff_eq_of_self_sum_zero (F := L) (h_self_sum_eq_zero := by + intro x; exact CharTwo.add_self_eq_zero x) + have h_at_most_one_root : ∀ r₁ r₂ : L, + (Δ₀ * ((1 - r₁) * x₁.val - r₁) + Δ₁ * (r₁ - (1 - r₁) * x₀.val) = 0) → + (Δ₀ * ((1 - r₂) * x₁.val - r₂) + Δ₁ * (r₂ - (1 - r₂) * x₀.val) = 0) → + r₁ = r₂ := by + intro r₁ r₂ h1 h2 + rw [h_poly_char2] at h1 h2 + -- h1 : A + r₁*C = 0, h2 : A + r₂*C = 0 where A,C are the constant/linear coeffs + -- From h1,h2: A = r₁*C and A = r₂*C, so r₁*C = r₂*C, so (r₁+r₂)*C = 0 + have h_sub : (r₁ + r₂) * (Δ₀ * (↑x₁ + 1) + Δ₁ * (↑x₀ + 1)) = 0 := by + have h1' := (char2_add_zero _ _).mp h1 + have h2' := (char2_add_zero _ _).mp h2 + rw [add_mul, ← h1', ← h2', CharTwo.add_self_eq_zero] + rcases mul_eq_zero.mp h_sub with h_diff | h_coeff + · exact (char2_add_zero r₁ r₂).mp h_diff + · exfalso + have h_a_eq_0 : Δ₀ * ↑x₁ + Δ₁ * ↑x₀ = 0 := by + rw [h_coeff, mul_zero, add_zero] at h1; exact h1 + have h_Δ_eq : Δ₀ = Δ₁ := by + have hc : Δ₀ * (↑x₁ + 1) + Δ₁ * (↑x₀ + 1) = + (Δ₀ * ↑x₁ + Δ₁ * ↑x₀) + (Δ₀ + Δ₁) := by ring + rw [h_a_eq_0, zero_add] at hc + rw [hc] at h_coeff + exact (char2_add_zero Δ₀ Δ₁).mp h_coeff + have h_Δ₀_mul : Δ₀ * (↑x₁ + ↑x₀) = 0 := by + have : Δ₀ * ↑x₁ + Δ₀ * ↑x₀ = 0 := h_Δ_eq ▸ h_a_eq_0 + rwa [← mul_add] at this + have h_sum_ne : (↑x₁ : L) + ↑x₀ ≠ 0 := by + rwa [Ne, ← CharTwo.sub_eq_add, sub_eq_zero, eq_comm] + have h_Δ₀_zero := (mul_eq_zero.mp h_Δ₀_mul).resolve_right h_sum_ne + exact h_Δ_ne_zero.elim (absurd h_Δ₀_zero) (absurd (h_Δ_eq ▸ h_Δ₀_zero)) + -- ════════════════════════════════════════════════════════ + -- E. Conclude |{r_new : y dropped}| ≤ 1 + -- ════════════════════════════════════════════════════════ + -- E1. If y is NOT in the (k+1)-step disagreement set, then in particular + -- fold_{k+1}(f) and fold_{k+1}(f̄) agree at w, hence the fold + -- difference polynomial evaluated at r_new is 0. + -- E2. By h_at_most_one_root, this can happen for ≤ 1 value of r_new. + rw [Finset.card_le_one] + intro a ha b hb + simp only [Finset.mem_filter, Finset.mem_univ, true_and] at ha hb + -- ha : y ∉ fiberwiseDisagreementSet(…, fold_{k+1}(f, snoc … a), …) + -- hb : y ∉ fiberwiseDisagreementSet(…, fold_{k+1}(f, snoc … b), …) + -- Need: a = b + -- Extract that fold difference = 0 at w for both a and b, + -- then apply h_at_most_one_root. + -- E3. Connect "y ∉ fiberwiseDisagreementSet(k+1)" to fold agreement at w + -- Helper: extract pointwise agreement from non-membership in disagreement set + have h_agree_at_w : ∀ (r_val : L), + y ∉ fiberwiseDisagreementSet 𝔽q β + midIdx_i_succ (ϑ - (k + 1)) (by omega) h_destIdx_le + (iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := block_start_idx) (steps := k + 1) + (h_destIdx := h_midIdx_i_succ) (h_destIdx_le := by omega) + (f := f_i) (r_challenges := Fin.snoc r_prefix r_val)) + (iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := block_start_idx) (steps := k + 1) + (h_destIdx := h_midIdx_i_succ) (h_destIdx_le := by omega) + (f := f_bar_i) (r_challenges := Fin.snoc r_prefix r_val)) → + fold 𝔽q β (i := midIdx_i) (h_destIdx := by omega) (h_destIdx_le := by omega) + (f := fold_k_f) (r_chal := r_val) w + = fold 𝔽q β (i := midIdx_i) (h_destIdx := by omega) (h_destIdx_le := by omega) + (f := fold_k_f_bar) (r_chal := r_val) w := by + intro r_val h_not_in + -- y ∉ fiberwiseDisagreementSet means: no z in fiber of y has disagreeing values. + -- In particular, w is in y's fiber (by hw_in_fiber), so values agree at w. + -- Rewrite iterated_fold(k+1) as fold(fold_k, r_val) + rw [h_decomp_f r_val, h_decomp_f_bar r_val] at h_not_in + -- h_not_in : y ∉ fiberwiseDisagreementSet(midIdx_i_succ, ϑ-(k+1), ..., fold(fold_k_f, r_val), fold(fold_k_f̄, r_val)) + -- Unfold fiberwiseDisagreementSet + simp only [fiberwiseDisagreementSet, Finset.mem_filter, Finset.mem_univ, + true_and, not_exists, not_and] at h_not_in + -- h_not_in : ∀ z, iteratedQuotientMap z = y → fold(fold_k_f, r_val)(z) = fold(fold_k_f̄, r_val)(z) + exact not_not.mp (h_not_in w hw_in_fiber) + -- E4. From fold agreement → polynomial = 0 → apply injectivity + have h_agree_a := h_agree_at_w a ha + have h_agree_b := h_agree_at_w b hb + have h_poly_zero_a : Δ₀ * ((1 - a) * x₁.val - a) + Δ₁ * (a - (1 - a) * x₀.val) = 0 := by + rw [← h_fold_diff a, sub_eq_zero]; exact h_agree_a + have h_poly_zero_b : Δ₀ * ((1 - b) * x₁.val - b) + Δ₁ * (b - (1 - b) * x₀.val) = 0 := by + rw [← h_fold_diff b, sub_eq_zero]; exact h_agree_b + exact h_at_most_one_root a b h_poly_zero_a h_poly_zero_b + -- The bad set {r_new : ¬(Δ ⊆ ...)} ⊆ ⋃_{y ∈ Δ_fiber} {r_new : y dropped} + have h_bad_subset : (Finset.filter (fun r_new => + ¬(↑Δ_fiber ⊆ ↑(fiberwiseDisagreementSet 𝔽q β + midIdx_i_succ (ϑ - (k + 1)) (by omega) h_destIdx_le + (iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := block_start_idx) (steps := k + 1) + (h_destIdx := h_midIdx_i_succ) (h_destIdx_le := by omega) + (f := f_i) (r_challenges := Fin.snoc r_prefix r_new)) + (iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := block_start_idx) (steps := k + 1) + (h_destIdx := h_midIdx_i_succ) (h_destIdx_le := by omega) + (f := f_bar_i) (r_challenges := Fin.snoc r_prefix r_new))))) + Finset.univ) ⊆ + Δ_fiber.biUnion (fun y => + Finset.filter (fun r_new => + y ∉ fiberwiseDisagreementSet 𝔽q β + midIdx_i_succ (ϑ - (k + 1)) (by omega) h_destIdx_le + (iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := block_start_idx) (steps := k + 1) + (h_destIdx := h_midIdx_i_succ) (h_destIdx_le := by omega) + (f := f_i) (r_challenges := Fin.snoc r_prefix r_new)) + (iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := block_start_idx) (steps := k + 1) + (h_destIdx := h_midIdx_i_succ) (h_destIdx_le := by omega) + (f := f_bar_i) (r_challenges := Fin.snoc r_prefix r_new))) + Finset.univ) := by + intro r_new hr + simp only [Finset.mem_filter, Finset.mem_univ, true_and] at hr + rw [Finset.not_subset] at hr + rcases hr with ⟨y, hy_mem, hy_not_in⟩ + simp only [Finset.mem_biUnion, Finset.mem_filter, Finset.mem_univ, true_and] + exact ⟨y, hy_mem, hy_not_in⟩ + -- |bad set| ≤ |⋃ per-y sets| ≤ ∑_{y ∈ Δ_fiber} |per-y set| ≤ |Δ_fiber| ≤ |S^{destIdx}| + calc ((Finset.filter _ Finset.univ).card : ENNReal) / (L_card : ENNReal) + _ ≤ (Fintype.card (sDomain 𝔽q β h_ℓ_add_R_rate destIdx) : ENNReal) / L_card := by + gcongr + calc (Finset.filter _ Finset.univ).card + _ ≤ (Δ_fiber.biUnion _).card := Finset.card_le_card h_bad_subset + _ ≤ ∑ y ∈ Δ_fiber, (Finset.filter _ Finset.univ).card := Finset.card_biUnion_le + _ ≤ ∑ _ ∈ Δ_fiber, 1 := Finset.sum_le_sum (fun y hy => h_per_point_card y hy) + _ = Δ_fiber.card := by simp + _ ≤ Fintype.card (sDomain 𝔽q β h_ℓ_add_R_rate destIdx) := Finset.card_le_univ _ + +-- ════════════════════════════════════════════════════════════════════════════ +-- Infrastructure lemmas for Case 2 (incremental fiberwise-far) +-- ════════════════════════════════════════════════════════════════════════════ + +/-- **Converse of Lemma 4.21**: fiberwiseClose implies jointProximityNat. +If the fiberwise distance `d^{(i)}(f, C^{(i)}) < d_{i+steps}/2`, then the +preTensorCombine word stack U is within unique-decoding radius of the +interleaved code `C_{dest}^{2^steps}`. + +Proof sketch: fiberwiseClose gives a codeword `g ∈ C^{(i)}` with +`|fiberwiseDisagreementSet(f, g)| < d/2`. Then `preTensorCombine(g)` is an +interleaved codeword (by `preTensorCombine_is_interleavedCodeword_of_codeword`), +and the column-wise Hamming distance between `preTensorCombine(f)` and +`preTensorCombine(g)` equals `|fiberwiseDisagreementSet(f, g)|` (M_y is +invertible, so columns agree iff fiber values agree). Hence +`Δ₀(⋈|U, C_dest^{2^steps}) ≤ e`. -/ +lemma fiberwiseClose_implies_jointProximityNat (i : Fin ℓ) (steps : ℕ) + {destIdx : Fin r} (h_destIdx : destIdx.val = i.val + steps) (h_destIdx_le : destIdx ≤ ℓ) + (f_i : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ⟨i, by omega⟩) + (h_close : fiberwiseClose 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := ⟨i, by omega⟩) + (steps := steps) (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) (f := f_i)) : + let U := preTensorCombine_WordStack 𝔽q β i steps h_destIdx h_destIdx_le f_i + let C_next : Set (sDomain 𝔽q β h_ℓ_add_R_rate destIdx → L) := + BBF_Code 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) destIdx + jointProximityNat (C := C_next) (u := U) (e := Code.uniqueDecodingRadius (C := C_next)) := by + intro U C_next + -- Step 1: Extract witness g ∈ C^(i) achieving the minimum fiberwise distance + let C_i := BBF_Code 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ⟨i, by omega⟩ + have h_close' := h_close + unfold fiberwiseClose fiberwiseDistance at h_close' + let dist_set := (fun (g : C_i) => + pair_fiberwiseDistance 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := ⟨i, by omega⟩) steps h_destIdx h_destIdx_le (f := f_i) (g := g)) '' Set.univ + have h_dist_set_nonempty : dist_set.Nonempty := + ⟨_, ⟨0, Set.mem_univ _, rfl⟩⟩ + have h_inf_mem : sInf dist_set ∈ dist_set := Nat.sInf_mem h_dist_set_nonempty + obtain ⟨g, _, h_g_dist⟩ := h_inf_mem + -- g is the closest codeword; its fiberwise distance = sInf + have h_g_close_nat : pair_fiberwiseDistance 𝔽q β (i := ⟨i, by omega⟩) steps h_destIdx + h_destIdx_le f_i g ≤ Code.uniqueDecodingRadius (C := C_next) := by + -- pfd g = sInf dist_set, and 2 * sInf < BBF_CodeDistance + have h_pfd_eq_sinf : pair_fiberwiseDistance 𝔽q β (i := ⟨i, by omega⟩) + steps h_destIdx h_destIdx_le f_i g = sInf dist_set := h_g_dist + have h_2pfd_lt_d : 2 * pair_fiberwiseDistance 𝔽q β (i := ⟨i, by omega⟩) + steps h_destIdx h_destIdx_le f_i g < + (BBF_CodeDistance 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) destIdx) := by + rw [h_pfd_eq_sinf]; exact_mod_cast h_close' + -- BBF_CodeDistance = ‖C_next‖₀ = Code.dist C_next + have h_dist_eq_norm : (BBF_CodeDistance 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + destIdx : ℕ) = ‖(C_next : Set _)‖₀ := by + simp only [C_next, BBF_CodeDistance] + have h_2pfd_lt_norm : 2 * pair_fiberwiseDistance 𝔽q β (i := ⟨i, by omega⟩) + steps h_destIdx h_destIdx_le f_i g < ‖(C_next : Set _)‖₀ := by + rw [← h_dist_eq_norm]; exact h_2pfd_lt_d + haveI : NeZero ‖(C_next : Set _)‖₀ := ⟨by omega⟩ + exact (Code.UDRClose_iff_two_mul_proximity_lt_d_UDR (C := C_next)).mpr h_2pfd_lt_norm + -- Step 2: preTensorCombine(g) is an interleaved codeword + let V := preTensorCombine_WordStack 𝔽q β i steps h_destIdx h_destIdx_le g.val + have h_V_codeword : (⋈|V) ∈ (C_next ^⋈ (Fin (2^steps))) := + preTensorCombine_is_interleavedCodeword_of_codeword 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i steps h_destIdx h_destIdx_le + ⟨g.val, g.property⟩ + -- Step 3: Column-wise Hamming distance = pair_fiberwiseDistance + have h_dist_eq : Δ₀(⋈|U, ⋈|V) = + pair_fiberwiseDistance 𝔽q β (i := ⟨i, by omega⟩) steps h_destIdx h_destIdx_le f_i g := by + unfold hammingDist pair_fiberwiseDistance fiberwiseDisagreementSet + congr 1; ext y + simp only [Finset.mem_filter, Finset.mem_univ, true_and] + have h_iso := fiberwise_disagreement_isomorphism 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := i) (steps := steps) (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) + (f := f_i) (g := g.val) (y := y) + unfold fiberDiff at h_iso + simp only [WordStack.getSymbol] at h_iso + exact h_iso.symm + -- Step 4: Conclude jointProximityNat + unfold jointProximityNat + calc Δ₀(⋈|U, (C_next ^⋈ (Fin (2^steps)))) + ≤ Δ₀(⋈|U, ⋈|V) := Code.distFromCode_le_dist_to_mem _ _ h_V_codeword + _ = ↑(pair_fiberwiseDistance 𝔽q β (i := ⟨i, by omega⟩) steps h_destIdx h_destIdx_le f_i g) := by + exact_mod_cast h_dist_eq + _ ≤ ↑(Code.uniqueDecodingRadius (C := C_next)) := by + exact_mod_cast h_g_close_nat + +omit [CharP L 2] [DecidableEq 𝔽q] hF₂ h_β₀_eq_1 [NeZero ℓ] [SampleableType L] in +/-- **Splitting a WordStack preserves non-closeness.** +If `U : WordStack L (Fin (2^{s+1})) ι` is NOT `e`-close to `C^{2^{s+1}}`, then +the interleaved pair `(⋈|U₀, ⋈|U₁)` is NOT `e`-close to `(C^{2^s})^⋈(Fin 2)`, +where `(U₀, U₁) := splitHalfRowWiseInterleavedWords(U)`. + +The key is that `mergeHalfRowWiseInterleavedWords(U₀, U₁) = U` and the +column-wise Hamming distance is preserved under the split/merge. -/ +lemma not_jointProximityNat_of_not_jointProximityNat_split + {ι : Type*} [Fintype ι] [Nonempty ι] [DecidableEq ι] + {s : ℕ} (C : Set (ι → L)) + (U : WordStack (A := L) (κ := Fin (2 ^ (s + 1))) (ι := ι)) + (e : ℕ) (h_far : ¬ jointProximityNat (C := C) (u := U) (e := e)) : + let U₀ := (splitHalfRowWiseInterleavedWords (ϑ := s) U).1 + let U₁ := (splitHalfRowWiseInterleavedWords (ϑ := s) U).2 + ¬ jointProximityNat₂ (A := InterleavedSymbol L (Fin (2^s))) + (C := (C ^⋈ (Fin (2^s)))) + (u₀ := ⋈|U₀) (u₁ := ⋈|U₁) (e := e) := by + exact fun h_close => h_far (CA_split_rowwise_implies_CA C U e h_close) + +open Classical in +omit [CharP L 2] [DecidableEq 𝔽q] h_β₀_eq_1 [NeZero ℓ] [SampleableType L] in +/-- **Affine proximity gap bound for RS interleaved codes (contrapositive form).** +If the pair `(u₀, u₁)` is NOT `e`-close to the interleaved code, then the +affine line `(1-r)·u₀ + r·u₁` is `e`-close to `C` for at most `|S|` values +of `r ∈ L`, giving `Pr_r[close] ≤ |S|/|L|`. + +This follows from the contrapositive of: +- DG25 Thm 2.2 (RS codes exhibit affine line proximity gaps with `ε = |S|`), and +- DG25 Thm 3.1 (affine line proximity gaps lift to interleaved codes). -/ +lemma affineProximityGap_RS_interleaved_contrapositive + {m : ℕ} (hm : m ≥ 1) {destIdx : Fin r} (h_destIdx_le : destIdx ≤ ℓ) + (u₀ u₁ : Word (InterleavedSymbol L (Fin m)) + (sDomain 𝔽q β h_ℓ_add_R_rate destIdx)) + (e : ℕ) (he : e ≤ Code.uniqueDecodingRadius + (ι := sDomain 𝔽q β h_ℓ_add_R_rate destIdx) (F := L) + (C := BBF_Code 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) destIdx)) + (h_far : ¬ jointProximityNat₂ (A := InterleavedSymbol L (Fin m)) + (C := ((BBF_Code 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) destIdx) ^⋈ (Fin m))) + (u₀ := u₀) (u₁ := u₁) (e := e)) : + Pr_{let r ← $ᵖ L}[ + Δ₀(affineLineEvaluation (F := L) u₀ u₁ r, + ((BBF_Code 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) destIdx) ^⋈ (Fin m))) ≤ e] + ≤ (Fintype.card (sDomain 𝔽q β h_ℓ_add_R_rate destIdx) : ℝ≥0) / (Fintype.card L) := by + by_contra h_prob_gt_bound + apply h_far + let S_dest := sDomain 𝔽q β h_ℓ_add_R_rate destIdx + let α := Embedding.subtype fun (x : L) ↦ x ∈ S_dest + let C_dest := BBF_Code 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) destIdx + let RS_dest := ReedSolomon.code α (2^(ℓ - destIdx.val)) + letI : Nontrivial RS_dest := by infer_instance + let h_RS_affine := ReedSolomon_ProximityGapAffineLines_UniqueDecoding + (A := L) (ι := S_dest) (α := α) (k := 2^(ℓ - destIdx.val)) + (hk := by + rw [sDomain_card 𝔽q β h_ℓ_add_R_rate (i := destIdx) + (h_i := Sdomain_bound (by exact h_destIdx_le))] + calc 2 ^ (ℓ - destIdx.val) ≤ 2 ^ (ℓ + 𝓡 - destIdx.val) := + Nat.pow_le_pow_right (by omega) (by omega) + _ = Fintype.card 𝔽q ^ (ℓ + 𝓡 - destIdx.val) := by rw [hF₂.out]) + e (by exact he) + let h_lifted := affine_gaps_lifted_to_interleaved_codes (A := L) + (F := L) (ι := S_dest) (MC := RS_dest) (m := m) + (e := e) (he := he) (ε := Fintype.card S_dest) + (hε := by + have h_dist_pos : 0 < ‖(C_dest : Set (S_dest → L))‖₀ := by + have h_pos : 0 < + BBF_CodeDistance 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) destIdx := by + simp [BBF_CodeDistance_eq (L := L) 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := destIdx) (h_i := h_destIdx_le)] + simpa [C_dest, BBF_CodeDistance] using h_pos + haveI : NeZero ‖(C_dest : Set (S_dest → L))‖₀ := NeZero.of_pos h_dist_pos + have h_2e_lt_d : 2 * e < ‖(C_dest : Set (S_dest → L))‖₀ := by + exact (Code.UDRClose_iff_two_mul_proximity_lt_d_UDR + (C := (C_dest : Set (S_dest → L))) (e := e)).1 (by simpa [C_dest, S_dest] using he) + have h_e_add_one_le_d : e + 1 ≤ ‖(C_dest : Set (S_dest → L))‖₀ := by + omega + have h_d_le_card : ‖(C_dest : Set (S_dest → L))‖₀ ≤ Fintype.card S_dest := by + exact Code.dist_le_card (C := (C_dest : Set (S_dest → L))) + exact le_trans h_e_add_one_le_d h_d_le_card) + h_RS_affine + exact h_lifted u₀ u₁ (by + rw [ENNReal.coe_natCast] + rw [not_le] at h_prob_gt_bound + exact h_prob_gt_bound) + +section EvenOddSplit +/-! **Even/odd split for Binius folding** + +The Binius protocol folds out the **least significant bit** (dimension `i`) first. +`splitHalfRowWiseInterleavedWords` splits by the **most significant bit**, which +corresponds to factoring the last challenge. For the fold-to-affineLineEvaluation +equivalence we need an **even/odd split** that factors the **first** challenge: +- `U_even[j] = U[2j]` (rows with LSB = 0) +- `U_odd[j] = U[2j+1]` (rows with LSB = 1) + +Then `affineLineEvaluation(U_even, U_odd, r_new)` correctly folds dimension `i` first. -/ + +variable {A : Type*} [AddCommMonoid A] [Module L A] {ι : Type*} + +/-- Even/odd split: separate rows by LSB. `U_even[j] = U[2j]`, `U_odd[j] = U[2j+1]`. -/ +def splitEvenOddRowWiseInterleavedWords {ϑ : ℕ} + (u : (Fin (2 ^ (ϑ + 1))) → ι → A) : + ((Fin (2 ^ ϑ)) → ι → A) × ((Fin (2 ^ ϑ)) → ι → A) := by + have h : ∀ j : Fin (2 ^ ϑ), 2 * j.val < 2 ^ (ϑ + 1) := fun j => by omega + let u_even : (Fin (2 ^ ϑ)) → ι → A := fun j => u ⟨2 * j.val, h j⟩ + let u_odd : (Fin (2 ^ ϑ)) → ι → A := fun j => + u ⟨2 * j.val + 1, by calc 2 * j.val + 1 < 2 * (2 ^ ϑ) := by omega + _ = 2 ^ (ϑ + 1) := by ring⟩ + exact ⟨u_even, u_odd⟩ + +/-- Factor the **first** challenge (LSB): `multilinearCombine u r` equals +`multilinearCombine (affineLineEval U_even U_odd (r 0)) (fun j => r (j+1))`. -/ +lemma multilinearCombine_recursive_form_first {ϑ : ℕ} + (u : (Fin (2 ^ (ϑ + 1))) → ι → A) (r_challenges : Fin (ϑ + 1) → L) : + let U_even := (splitEvenOddRowWiseInterleavedWords (r := r) (ℓ := ℓ) (𝓡 := 𝓡) + (ϑ := ϑ) u).1 + let U_odd := (splitEvenOddRowWiseInterleavedWords (r := r) (ℓ := ℓ) (𝓡 := 𝓡) + (ϑ := ϑ) u).2 + let r_tail : Fin ϑ → L := fun j => r_challenges (Fin.succ j) + multilinearCombine (F := L) u r_challenges = + multilinearCombine (F := L) (affineLineEvaluation (F := L) U_even U_odd (r_challenges 0)) r_tail := by + intro U_even U_odd r_tail + funext colIdx + -- Split the LHS sum into even and odd row indices. + unfold multilinearCombine + let f : ℕ → A := fun j => + if hj : j < 2 ^ (ϑ + 1) then + multilinearWeight r_challenges ⟨j, hj⟩ • u ⟨j, hj⟩ colIdx + else 0 + have h_lhs_as_f : + (∑ rowIdx : Fin (2 ^ (ϑ + 1)), + multilinearWeight r_challenges rowIdx • u rowIdx colIdx) + = ∑ rowIdx : Fin (2 ^ (ϑ + 1)), f rowIdx := by + apply Finset.sum_congr rfl + intro rowIdx _ + simp [f] + rw [h_lhs_as_f] + rw [← Fin.sum_univ_odd_even (n := ϑ) (f := f)] + simp [f] + simp only [U_even, U_odd, splitEvenOddRowWiseInterleavedWords] + -- Factor multilinear weights for even/odd indices by bit-0. + have h_tensor_even : ∀ i : Fin (2 ^ ϑ), + multilinearWeight r_challenges ⟨2 * i, by omega⟩ = + multilinearWeight r_tail i * (1 - r_challenges 0) := by + intro i + unfold multilinearWeight + rw [Fin.prod_univ_succ] + have h_bit0 : (2 * i.val).testBit 0 = false := by + rw [Nat.testBit_false_eq_getBit_eq_0] + simpa using (Nat.getBit_zero_of_two_mul (n := i.val)) + have h_bit0' : (2 * i.val).testBit (↑(0 : Fin (ϑ + 1))) = false := by + simpa using h_bit0 + have h_prod : + (∏ x : Fin ϑ, + if (2 * i.val).testBit x.succ = true then r_challenges x.succ else 1 - r_challenges x.succ) + = ∏ j : Fin ϑ, if i.val.testBit j.val = true then r_tail j else 1 - r_tail j := by + apply Finset.prod_congr rfl + intro j _ + have h_test : + ((2 * i.val).testBit (↑j.succ) = true) = (i.val.testBit j.val = true) := by + rw [Nat.testBit_true_eq_getBit_eq_1, Nat.testBit_true_eq_getBit_eq_1] + simpa [Fin.succ, Nat.add_comm, Nat.add_left_comm, Nat.add_assoc] using + congrArg (fun t : ℕ => t = 1) + (Nat.getBit_eq_succ_getBit_of_mul_two (n := i.val) (k := j.val)) + have h_succ : (↑j.succ : ℕ) = ↑j + 1 := by simp [Fin.succ] + have h_test' : + ((2 * i.val).testBit (↑j + 1) = true) = (i.val.testBit j.val = true) := by + simpa [h_succ] using h_test + simpa [h_test', r_tail] + rw [h_prod] + simp [h_bit0'] + ring + have h_tensor_odd : ∀ i : Fin (2 ^ ϑ), + multilinearWeight r_challenges ⟨2 * i + 1, by omega⟩ = + multilinearWeight r_tail i * (r_challenges 0) := by + intro i + unfold multilinearWeight + rw [Fin.prod_univ_succ] + have h_bit0 : (2 * i.val + 1).testBit 0 = true := by + rw [Nat.testBit_true_eq_getBit_eq_1] + unfold Nat.getBit + simp [Nat.and_one_is_mod] + have h_bit0' : (2 * i.val + 1).testBit (↑(0 : Fin (ϑ + 1))) = true := by + simpa using h_bit0 + have h_prod : + (∏ x : Fin ϑ, + if (2 * i.val + 1).testBit x.succ = true then r_challenges x.succ else 1 - r_challenges x.succ) + = ∏ j : Fin ϑ, if i.val.testBit j.val = true then r_tail j else 1 - r_tail j := by + apply Finset.prod_congr rfl + intro j _ + have h_test : + ((2 * i.val + 1).testBit (↑j.succ) = true) = (i.val.testBit j.val = true) := by + rw [Nat.testBit_true_eq_getBit_eq_1, Nat.testBit_true_eq_getBit_eq_1] + simpa [Fin.succ, Nat.add_comm, Nat.add_left_comm, Nat.add_assoc] using + congrArg (fun t : ℕ => t = 1) + (Nat.getBit_eq_succ_getBit_of_mul_two_add_one (n := i.val) (k := j.val)) + have h_succ : (↑j.succ : ℕ) = ↑j + 1 := by simp [Fin.succ] + have h_test' : + ((2 * i.val + 1).testBit (↑j + 1) = true) = (i.val.testBit j.val = true) := by + simpa [h_succ] using h_test + simp only [Fin.val_succ, h_test', r_tail] + rw [h_prod] + simp [h_bit0'] + ring + -- Apply tensor factorization to both even and odd sums. + simp_rw [h_tensor_even, h_tensor_odd] + have h_even_lt : ∀ x : Fin (2 ^ ϑ), 2 * x.val < 2 ^ (ϑ + 1) := by + intro x; omega + have h_odd_lt : ∀ x : Fin (2 ^ ϑ), 2 * x.val + 1 < 2 ^ (ϑ + 1) := by + intro x; omega + simp [h_even_lt, h_odd_lt] + rw [← Finset.sum_add_distrib] + -- Re-associate scalars to match affine-line form. + apply Finset.sum_congr rfl + intro x _ + rw [affineLineEvaluation, Pi.add_apply, Pi.smul_apply] + simp only [Word, Pi.smul_apply, Pi.add_apply, smul_add] + rw [←smul_assoc, ←smul_assoc] + rw [smul_eq_mul, smul_eq_mul] + +end EvenOddSplit + +/-- Even/odd split preserves non-closeness (bridge lemma for Binius first-step fold flow). +If `U` is not close to `C^⋈(Fin (2^(s+1)))`, then the even/odd split pair is not +jointly close to `C^⋈(Fin (2^s))`. -/ +lemma not_jointProximityNat_of_not_jointProximityNat_evenOdd_split + {ι : Type*} [Fintype ι] [Nonempty ι] [DecidableEq ι] + {s : ℕ} (C : Set (ι → L)) + (U : WordStack (A := L) (κ := Fin (2 ^ (s + 1))) (ι := ι)) + (e : ℕ) + (U_even : WordStack (A := L) (κ := Fin (2 ^ s)) (ι := ι)) + (U_odd : WordStack (A := L) (κ := Fin (2 ^ s)) (ι := ι)) + (hU_even : U_even = + (splitEvenOddRowWiseInterleavedWords (r := r) (ℓ := ℓ) (𝓡 := 𝓡) (ϑ := s) U).1 := by rfl) + (hU_odd : U_odd = + (splitEvenOddRowWiseInterleavedWords (r := r) (ℓ := ℓ) (𝓡 := 𝓡) (ϑ := s) U).2 := by rfl) + (h_far : ¬ jointProximityNat (C := C) (u := U) (e := e)) : + ¬ jointProximityNat₂ (A := InterleavedSymbol L (Fin (2^s))) + (C := (C ^⋈ (Fin (2^s)))) + (u₀ := interleaveWordStack U_even) (u₁ := interleaveWordStack U_odd) (e := e) := by + subst hU_even hU_odd + intro h_close + apply h_far + -- Unpack pairwise closeness witness on the even/odd split. + unfold jointProximityNat₂ jointProximityNat at h_close + simp only at h_close + rw [Code.closeToCode_iff_closeToCodeword_of_minDist] at h_close + rcases h_close with ⟨vSplit, hvSplit_mem, hvSplit_dist_le_e⟩ + rw [closeToWord_iff_exists_possibleDisagreeCols] at hvSplit_dist_le_e + rcases hvSplit_dist_le_e with ⟨D, hD_card_le_e, h_agree_outside_D⟩ + -- Build a single interleaved codeword for height `2^(s+1)` by re-merging rows by parity. + unfold jointProximityNat + rw [Code.closeToCode_iff_closeToCodeword_of_minDist + (u := ⋈|U) (e := e) (C := interleavedCodeSet (κ := Fin (2 ^ (s + 1))) C)] + simp_rw [closeToWord_iff_exists_possibleDisagreeCols] + let VSplit_rowwise := Matrix.transpose vSplit + let VSplit_even_rowwise := Matrix.transpose (VSplit_rowwise 0) + let VSplit_odd_rowwise := Matrix.transpose (VSplit_rowwise 1) + let v_rowwise_finmap : WordStack L (Fin (2 ^ (s + 1))) ι := fun rowIdx => + if h_even : rowIdx.val % 2 = 0 then + VSplit_even_rowwise ⟨rowIdx.val / 2, by omega⟩ + else + VSplit_odd_rowwise ⟨rowIdx.val / 2, by omega⟩ + let v_IC := ⋈|v_rowwise_finmap + use v_IC + constructor + · -- `v_IC` is a codeword in `C^⋈(Fin (2^(s+1)))`. + intro rowIdx + have h_vSplit_rows_mem : ∀ (i : Fin 2) (j : Fin (2 ^ s)), (fun col ↦ vSplit col i j) ∈ C := by + intro i j + exact hvSplit_mem i j + dsimp only [v_IC] + by_cases h_even : rowIdx.val % 2 = 0 + · let j : Fin (2 ^ s) := ⟨rowIdx.val / 2, by omega⟩ + have hRes := h_vSplit_rows_mem 0 j + simpa [v_IC, v_rowwise_finmap, h_even, VSplit_even_rowwise, VSplit_rowwise, j] using hRes + · let j : Fin (2 ^ s) := ⟨rowIdx.val / 2, by omega⟩ + have hRes := h_vSplit_rows_mem 1 j + simpa [v_IC, v_rowwise_finmap, h_even, VSplit_odd_rowwise, VSplit_rowwise, j] using hRes + · use D + constructor + · exact hD_card_le_e + · intro colIdx h_colIdx_notin_D + funext rowIdx + dsimp only [v_IC] + have hRes0 : + interleaveWordStack + ((splitEvenOddRowWiseInterleavedWords (r := r) (ℓ := ℓ) (𝓡 := 𝓡) (ϑ := s) U).1) + colIdx + = vSplit colIdx 0 := by + exact congrFun (h_agree_outside_D colIdx h_colIdx_notin_D) 0 + have hRes1 : + interleaveWordStack + ((splitEvenOddRowWiseInterleavedWords (r := r) (ℓ := ℓ) (𝓡 := 𝓡) (ϑ := s) U).2) + colIdx + = vSplit colIdx 1 := by + exact congrFun (h_agree_outside_D colIdx h_colIdx_notin_D) 1 + by_cases h_even : rowIdx.val % 2 = 0 + · -- even row: rowIdx = 2 * (rowIdx / 2) + have h_row_val : rowIdx.val = 2 * (rowIdx.val / 2) := by + have h_divmod := Nat.mod_add_div rowIdx.val 2 + omega + have h_row_eq : + (⟨2 * (rowIdx.val / 2), by omega⟩ : Fin (2 ^ (s + 1))) = rowIdx := by + apply Fin.eq_of_val_eq + exact h_row_val.symm + have hRes₀ := congrFun hRes0 ⟨rowIdx.val / 2, by omega⟩ + dsimp [splitEvenOddRowWiseInterleavedWords] at hRes₀ + simp [v_rowwise_finmap, h_even, VSplit_even_rowwise, VSplit_rowwise] + simpa [h_row_eq] using hRes₀ + · -- odd row: rowIdx = 2 * (rowIdx / 2) + 1 + have h_row_val : rowIdx.val = 2 * (rowIdx.val / 2) + 1 := by + have h_divmod := Nat.mod_add_div rowIdx.val 2 + omega + have h_row_eq : + (⟨2 * (rowIdx.val / 2) + 1, by omega⟩ : Fin (2 ^ (s + 1))) = rowIdx := by + apply Fin.eq_of_val_eq + exact h_row_val.symm + have hRes₁ := congrFun hRes1 ⟨rowIdx.val / 2, by omega⟩ + dsimp [splitEvenOddRowWiseInterleavedWords] at hRes₁ + simp [v_rowwise_finmap, h_even, VSplit_odd_rowwise, VSplit_rowwise] + simpa [h_row_eq] using hRes₁ + +/-- **One fold step on preTensorCombine = affine line evaluation on even/odd split.** +Given `f_i : S^i → L` and its preTensorCombine WordStack `U` of height `2^(steps+1)`, +using the **even/odd split** (LSB-first, see `splitEvenOddRowWiseInterleavedWords`): +`U_even[j] = U[2j]`, `U_odd[j] = U[2j+1]`. Folding dimension `i` first gives: +``` +⋈|preTensorCombine(i+1, steps, destIdx, fold(f_i, r_new)) + = affineLineEvaluation(⋈|U_even, ⋈|U_odd, r_new) +``` -/ +lemma fold_preTensorCombine_eq_affineLineEvaluation_split + (i : Fin ℓ) (steps : ℕ) [NeZero steps] {midIdx destIdx : Fin r} + (h_midIdx : midIdx.val = i.val + 1) + (h_destIdx : destIdx.val = i.val + (steps + 1)) + (h_destIdx_le : destIdx ≤ ℓ) + (f_i : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + ⟨i, by omega⟩) + (r_new : L) : + let h_midIdx_lt_ℓ : midIdx.val < ℓ := by + have := NeZero.pos steps; omega + let U := preTensorCombine_WordStack 𝔽q β i (steps + 1) + (destIdx := destIdx) (h_destIdx := h_destIdx) + (h_destIdx_le := h_destIdx_le) f_i + let U_even := (splitEvenOddRowWiseInterleavedWords (r := r) (ℓ := ℓ) (𝓡 := 𝓡) + (ϑ := steps) U).1 + let U_odd := (splitEvenOddRowWiseInterleavedWords (r := r) (ℓ := ℓ) (𝓡 := 𝓡) + (ϑ := steps) U).2 + let fold_1_f := fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + ⟨i, by omega⟩ (destIdx := midIdx) (h_destIdx := h_midIdx) + (h_destIdx_le := by omega) f_i r_new + let midIdx_fin_ℓ : Fin ℓ := ⟨midIdx.val, h_midIdx_lt_ℓ⟩ + let V := preTensorCombine_WordStack 𝔽q β midIdx_fin_ℓ steps + (destIdx := destIdx) + (h_destIdx := by simp [midIdx_fin_ℓ]; omega) + (h_destIdx_le := h_destIdx_le) (by exact fold_1_f) + interleaveWordStack V = + affineLineEvaluation (F := L) + (interleaveWordStack U_even) (interleaveWordStack U_odd) r_new := by + intro h_midIdx_lt_ℓ U U_even U_odd fold_1_f midIdx_fin_ℓ V + -- Connect V and U to iterated_fold via multilinearCombine + have h_fold_eq_U : ∀ r_chal : Fin (steps + 1) → L, + (iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ⟨i, by omega⟩ + (steps := steps + 1) (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) + f_i r_chal) = multilinearCombine U r_chal := by + intro r_chal; ext y' + rw [iterated_fold_eq_matrix_form] + unfold localized_fold_matrix_form single_point_localized_fold_matrix_form multilinearCombine + simp only [dotProduct, smul_eq_mul] + exact Finset.sum_congr rfl fun _ _ => rfl + have h_fold_eq_V : ∀ r_chal : Fin steps → L, + (iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) midIdx + (steps := steps) (h_destIdx := by omega) (h_destIdx_le := h_destIdx_le) + fold_1_f r_chal) = multilinearCombine V r_chal := by + intro r_chal; ext y' + rw [iterated_fold_eq_matrix_form] + unfold localized_fold_matrix_form single_point_localized_fold_matrix_form multilinearCombine + simp only [dotProduct, smul_eq_mul] + exact Finset.sum_congr rfl fun _ _ => rfl + -- Indicator: multilinearCombine W (bitsOfIndex j) y = W j y + have h_indicator : ∀ (W : WordStack L (Fin (2 ^ steps)) + (sDomain 𝔽q β h_ℓ_add_R_rate destIdx)) (j' : Fin (2 ^ steps)) + (y' : sDomain 𝔽q β h_ℓ_add_R_rate destIdx), + multilinearCombine (F := L) W (bitsOfIndex j') y' = W j' y' := by + intro W' j' y' + simp only [multilinearCombine, smul_eq_mul] + rw [show (∑ rowIdx, multilinearWeight (bitsOfIndex j') rowIdx * W' rowIdx y') = + ∑ rowIdx, (if rowIdx = j' then 1 else 0) * W' rowIdx y' from by + apply Finset.sum_congr rfl; intro k _ + congr 1 + have := congr_fun + (challengeTensorExpansion_bitsOfIndex_is_eq_indicator (L := L) j') k + simp only [challengeTensorExpansion, multilinearWeight] at this + exact this] + simp only [boole_mul, Finset.sum_ite_eq', Finset.mem_univ, ↓reduceIte] + -- multilinearCombine_recursive_form_first: factor LSB (r 0), connect U to U_even, U_odd + have h_recursive : ∀ r_chal : Fin (steps + 1) → L, + multilinearCombine U r_chal = + multilinearCombine (affineLineEvaluation (F := L) U_even U_odd (r_chal 0)) + (fun k => r_chal (Fin.succ k)) := by + intro r_chal + simpa [U_even, U_odd] using + (multilinearCombine_recursive_form_first (u := U) (r_challenges := r_chal)) + -- Main equality pointwise. Chain: V j y = ... = affineLineEval U_even U_odd r_new j y + ext y j + change V j y = affineLineEvaluation U_even U_odd r_new j y + -- Step 1: V j y = multilinearCombine V (bitsOfIndex j) y [indicator] + rw [←h_indicator V j y] + -- Step 2: multilinearCombine V (bits j) y = iterated_fold(midIdx, steps, fold_f, bits j) y + conv_lhs => rw [←h_fold_eq_V (bitsOfIndex j)] + -- Step 3: iterated_fold(midIdx, steps, fold_f, bits j) + -- = iterated_fold(i, steps+1, f, cons(r_new, bits j)) [iterated_fold_first] + have h_first : + iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := ⟨i, by omega⟩) (steps := steps + 1) (h_destIdx := h_destIdx) + (h_destIdx_le := h_destIdx_le) (f := f_i) + (r_challenges := Fin.cons r_new (bitsOfIndex j)) = + iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := midIdx) (steps := steps) (h_destIdx := by omega) + (h_destIdx_le := h_destIdx_le) (f := fold_1_f) + (r_challenges := bitsOfIndex j) := by + simpa [fold_1_f, Fin.cons_zero, Fin.cons_succ] using + (iterated_fold_first 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + ⟨i, by omega⟩ (steps := steps) h_midIdx h_destIdx h_destIdx_le f_i + (Fin.cons r_new (bitsOfIndex j))) + rw [←h_first] + -- Step 4: iterated_fold(i, steps+1, f, cons(r_new, bits j)) + -- = multilinearCombine U (cons r_new (bits j)) + rw [h_fold_eq_U (Fin.cons r_new (bitsOfIndex j))] + -- Step 5: multilinearCombine U (cons r_new (bits j)) + -- = multilinearCombine (affineLineEval U_even U_odd r_new) (bits j) + -- [multilinearCombine_recursive_form_first; cons 0 = r_new, succ = bits j] + rw [h_recursive (Fin.cons r_new (bitsOfIndex j))] + simp only [Fin.cons_zero, Fin.cons_succ] + -- Step 6: multilinearCombine (affineLineEval ...) (bits j) y = affineLineEval ... j y [indicator] + rw [h_indicator (affineLineEvaluation (F := L) U_even U_odd r_new) j y] + +section Fin1Interleaving +variable {A : Type*} [DecidableEq A] {ι : Type*} [Fintype ι] [DecidableEq ι] + +omit [CharP L 2] [DecidableEq 𝔽q] hF₂ h_β₀_eq_1 [NeZero ℓ] [NeZero 𝓡] [SampleableType L] + [Field L] [Fintype L] [DecidableEq L] [Field 𝔽q] [Fintype 𝔽q] h_Fq_char_prime [Algebra 𝔽q L] + hβ_lin_indep h_ℓ_add_R_rate in +/-- For `κ = Fin 1`, the Hamming distance between two interleaved words equals the +Hamming distance between their row-0 projections. -/ +lemma hammingDist_fin1_eq [DecidableEq (Fin 1 → A)] {u v : ι → Fin 1 → A} : + hammingDist u v = hammingDist (fun y => u y 0) (fun y => v y 0) := by + simp only [hammingDist] + congr 1; ext y; simp only [Finset.mem_filter, Finset.mem_univ, true_and] + constructor + · intro h heq; exact h (funext fun k => by rwa [show k = 0 from Subsingleton.elim k 0]) + · intro h heq; exact h (congr_fun heq 0) + +omit [CharP L 2] [DecidableEq 𝔽q] hF₂ h_β₀_eq_1 [NeZero ℓ] [NeZero 𝓡] [SampleableType L] + [Field L] [Fintype L] [DecidableEq L] [Field 𝔽q] [Fintype 𝔽q] h_Fq_char_prime [Algebra 𝔽q L] + hβ_lin_indep h_ℓ_add_R_rate in +/-- For `κ = Fin 1`, the distance from an interleaved word to an interleaved code equals +the distance from its row-0 projection to the base code. -/ +lemma distFromCode_fin1_eq [DecidableEq (Fin 1 → A)] (u : ι → Fin 1 → A) (C : Set (ι → A)) : + Δ₀(u, interleavedCodeSet (κ := Fin 1) C) = Δ₀((fun y => u y 0), C) := by + simp only [distFromCode] + congr 1; ext d; simp only [Set.mem_setOf_eq]; constructor + · rintro ⟨v, hv_mem, hv_dist⟩ + refine ⟨fun y => v y 0, hv_mem 0, ?_⟩ + rwa [←hammingDist_fin1_eq (u := u) (v := v)] + · rintro ⟨w, hw_mem, hw_dist⟩ + refine ⟨fun y _ => w y, + fun k => by rwa [show k = 0 from Subsingleton.elim k 0], ?_⟩ + rwa [hammingDist_fin1_eq (A := A) (u := u) (v := fun y _ => w y)] + +end Fin1Interleaving + +/-- Single-step fold equals multilinearCombine on the corresponding preTensorCombine stack. -/ +lemma fold_eq_multilinearCombine_preTensorCombine_step1 + (i : Fin ℓ) {destIdx : Fin r} + (h_destIdx : destIdx.val = i.val + 1) (h_destIdx_le : destIdx ≤ ℓ) + (f_i : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ⟨i, by omega⟩) + (r_new : L) : + let U := preTensorCombine_WordStack 𝔽q β i 1 + (destIdx := destIdx) (h_destIdx := h_destIdx) + (h_destIdx_le := h_destIdx_le) f_i + fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := ⟨i, by omega⟩) + (destIdx := destIdx) (h_destIdx := by omega) (h_destIdx_le := h_destIdx_le) f_i r_new + = multilinearCombine (F := L) U (fun (_ : Fin 1) => r_new) := by + intro U + ext y + rw [fold_eval_single_matrix_mul_form 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := ⟨i, by omega⟩) (destIdx := destIdx) (h_destIdx := by omega) + (h_destIdx_le := h_destIdx_le) (f := f_i) (r_challenge := r_new)] + unfold fold_single_matrix_mul_form multilinearCombine + dsimp [U] + have h_blk : + blockDiagMatrix (L := L) (r := r) (ℓ := ℓ) (𝓡 := 𝓡) (n := 0) + (Mz₀ := (1 : Matrix (Fin (2 ^ 0)) (Fin (2 ^ 0)) L)) + (Mz₁ := (1 : Matrix (Fin (2 ^ 0)) (Fin (2 ^ 0)) L)) + = (1 : Matrix (Fin (2 ^ 1)) (Fin (2 ^ 1)) L) := by + ext a b <;> fin_cases a <;> fin_cases b <;> + simp [blockDiagMatrix, reindexSquareMatrix, Matrix.from4Blocks] + simp [preTensorCombine_WordStack, foldMatrix, challengeTensorExpansion, h_blk] + have h_w0 : + vecHead (multilinearWeight (F := L) (r := fun _ : Fin 1 => r_new)) = + multilinearWeight (F := L) (r := fun _ : Fin 1 => r_new) 0 := by + rfl + have h_w1 : + vecHead (vecTail (multilinearWeight (F := L) (r := fun _ : Fin 1 => r_new))) = + multilinearWeight (F := L) (r := fun _ : Fin 1 => r_new) 1 := by + rfl + rw [h_w0, h_w1] + +/-- **Connecting fiberwiseClose of a folded function to affine line evaluation proximity.** +Given `f_i : S^i → L` with preTensorCombine `U := preTensorCombine(i, s+1, destIdx, f_i)` of +height `2^{s+1}`, and `r_new : L`, if +`fiberwiseClose(iterated_fold(i, s+1, destIdx, f_i, snoc r r_new), ...)` holds, then +`Δ₀(affineLineEval(⋈|U_even, ⋈|U_odd, r_new), C^⋈(2^s)) ≤ UDR(C)`. + +**Proof sketch:** +1. By `iterated_fold_last`: the folded function is `fold(f_i, r_new)`. +2. `fiberwiseClose(fold(f_i,r_new), s) → jointProximityNat(V)` where + `V = preTensorCombine(midIdx, s, destIdx, fold(f_i,r_new))` + (by `fiberwiseClose_implies_jointProximityNat`). +3. `⋈|V = affineLineEval(⋈|U_even, ⋈|U_odd, r_new)` + (by `fold_preTensorCombine_eq_affineLineEvaluation_split`). +4. Combine 2 and 3 to get the distance bound. -/ +lemma fiberwiseClose_fold_implies_affineLineEval_close + (i : Fin r) (h_i_lt_ℓ : i.val < ℓ) (s : ℕ) + {midIdx destIdx : Fin r} + (h_midIdx : midIdx.val = i.val + 1) + (h_destIdx : destIdx.val = i.val + (s + 1)) + (h_destIdx_le : destIdx ≤ ℓ) + (f_i : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i) + (r_new : L) + (h_fw_close : fiberwiseClose 𝔽q β midIdx s + (h_destIdx := by omega) (h_destIdx_le := h_destIdx_le) + (fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + i (destIdx := midIdx) (h_destIdx := h_midIdx) + (h_destIdx_le := by omega) f_i r_new)) : + let i_ℓ : Fin ℓ := ⟨i.val, h_i_lt_ℓ⟩ + let U := preTensorCombine_WordStack 𝔽q β i_ℓ (s + 1) + (destIdx := destIdx) (h_destIdx := by simp [i_ℓ]; omega) + (h_destIdx_le := h_destIdx_le) f_i + let U_even := (splitEvenOddRowWiseInterleavedWords (r := r) (ℓ := ℓ) (𝓡 := 𝓡) + (ϑ := s) U).1 + let U_odd := (splitEvenOddRowWiseInterleavedWords (r := r) (ℓ := ℓ) (𝓡 := 𝓡) + (ϑ := s) U).2 + let C_dest : Set (sDomain 𝔽q β h_ℓ_add_R_rate destIdx → L) := + BBF_Code 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) destIdx + Δ₀(affineLineEvaluation (F := L) + (interleaveWordStack U_even) (interleaveWordStack U_odd) r_new, + (C_dest ^⋈ (Fin (2^s)))) ≤ + Code.uniqueDecodingRadius (C := C_dest) := by + classical + intro i_ℓ U U_even U_odd C_dest + have h_midIdx_le_ℓ : midIdx.val ≤ ℓ := by omega + let fold_1_f := fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + i (destIdx := midIdx) (h_destIdx := h_midIdx) + (h_destIdx_le := by omega) f_i r_new + by_cases hs : s = 0 + · -- Case s = 0: midIdx = destIdx, fiberwiseClose with 0 steps = UDRClose, + -- and affineLineEvaluation of (U_even, U_odd) with height-1 stacks = fold. + -- The conclusion Δ₀(affineLineEval(⋈|U_even, ⋈|U_odd, r), C^⋈(Fin 1)) ≤ UDR + -- reduces to Δ₀(fold_f, C) ≤ UDR, which follows from UDRClose. + subst hs + -- After substitution: s = 0, destIdx = i + 1 = midIdx + have h_midIdx_eq_destIdx : midIdx = destIdx := Fin.eq_of_val_eq (by omega) + -- fiberwiseClose with 0 steps → UDRClose + have h_udr_close : UDRClose 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + midIdx (h_i := h_midIdx_le_ℓ) fold_1_f := by + rw [←fiberwiseClose_steps_zero_iff_UDRClose] + exact h_fw_close + -- UDRClose → Δ₀(fold_f, C_midIdx) ≤ UDR + rw [UDRClose_iff_within_UDR_radius] at h_udr_close + -- h_udr_close : Δ₀(fold_1_f, BBF_Code midIdx) ≤ UDR(BBF_Code midIdx) + -- Since midIdx = destIdx, C_dest = BBF_Code destIdx = BBF_Code midIdx + subst h_midIdx_eq_destIdx + -- Step 1: Convert interleaved distance to non-interleaved via distFromCode_fin1_eq + change Δ₀(affineLineEvaluation (F := L) + (interleaveWordStack U_even) (interleaveWordStack U_odd) r_new, + interleavedCodeSet (κ := Fin (2 ^ 0)) C_dest) ≤ + Code.uniqueDecodingRadius (C := C_dest) + rw [distFromCode_fin1_eq] + -- Goal: Δ₀(fun y => affineLineEval(⋈|U_even, ⋈|U_odd, r_new) y 0, C_dest) ≤ UDR(C_dest) + -- Step 2: Show fun y => affineLineEval(⋈|U_even, ⋈|U_odd, r_new) y 0 = fold_1_f + suffices h_eq : (fun y => affineLineEvaluation + (interleaveWordStack U_even) (interleaveWordStack U_odd) r_new y + (0 : Fin (2 ^ 0))) = + fold_1_f by + rw [h_eq]; exact h_udr_close + -- Part A: fold_1_f = multilinearCombine U [r_new] by iterated_fold_eq_matrix_form + have h_rhs : fold_1_f = multilinearCombine (F := L) U (fun (_ : Fin 1) => r_new) := by + simpa [fold_1_f, i_ℓ] using + fold_eq_multilinearCombine_preTensorCombine_step1 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := i_ℓ) + (destIdx := midIdx) (h_destIdx := by simp [i_ℓ]; omega) + (h_destIdx_le := h_midIdx_le_ℓ) (f_i := f_i) (r_new := r_new) + have h_affine_eq_mc : + (fun y => affineLineEvaluation + (interleaveWordStack U_even) (interleaveWordStack U_odd) r_new y + (0 : Fin (2 ^ 0))) = + multilinearCombine (F := L) U (fun (_ : Fin 1) => r_new) := by + ext y + simp [U_even, U_odd, splitEvenOddRowWiseInterleavedWords, affineLineEvaluation, + interleaveWordStack, multilinearCombine, multilinearWeight, smul_eq_mul] + -- Combine: fun y => affineLineEval(...)(y)(0) = fold_1_f + have h_fn_eq : (fun y => affineLineEvaluation + (interleaveWordStack U_even) (interleaveWordStack U_odd) r_new y + (0 : Fin (2 ^ 0))) = fold_1_f := by + rw [h_affine_eq_mc, h_rhs] + rw [h_fn_eq]; + · -- Case s ≥ 1: midIdx.val < ℓ follows from arithmetic + have h_midIdx_lt_ℓ : midIdx.val < ℓ := by omega + let midIdx_ℓ : Fin ℓ := ⟨midIdx.val, h_midIdx_lt_ℓ⟩ + haveI : NeZero s := ⟨hs⟩ + -- fiberwiseClose → jointProximityNat(V) + have h_joint := fiberwiseClose_implies_jointProximityNat 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := midIdx_ℓ) (steps := s) + (h_destIdx := by simp only [midIdx_ℓ]; omega) + (h_destIdx_le := h_destIdx_le) + (f_i := fold_1_f) + (h_close := h_fw_close) + -- ⋈|V = affineLineEvaluation (⋈|U_even) (⋈|U_odd) r_new + have h_eq := fold_preTensorCombine_eq_affineLineEvaluation_split 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := i_ℓ) (steps := s) + (midIdx := midIdx) (destIdx := destIdx) + (h_midIdx := by simp only [i_ℓ]; omega) + (h_destIdx := by simp only [i_ℓ]; omega) + (h_destIdx_le := h_destIdx_le) + (f_i := f_i) (r_new := r_new) + -- Combine: rewrite goal using h_eq, then use h_joint + have h_eq' : + interleaveWordStack + (preTensorCombine_WordStack 𝔽q β + (i := ⟨midIdx.val, h_midIdx_lt_ℓ⟩) (steps := s) + (destIdx := destIdx) + (h_destIdx := by simp [h_midIdx_lt_ℓ]; omega) + (h_destIdx_le := h_destIdx_le) fold_1_f) = + affineLineEvaluation (F := L) + (interleaveWordStack U_even) (interleaveWordStack U_odd) r_new := by + simpa [U_even, U_odd] using h_eq + unfold jointProximityNat at h_joint + rw [← h_eq'] + exact h_joint + +/-- +#### **Case 2: FiberwiseFar (Incremental)** + +**Proof outline (see infrastructure lemmas above for details):** +1. Build `U := preTensorCombine(midIdx_i, ϑ-k, destIdx, fold_k_f)` of height `2^{ϑ-k}`. +2. By Lemma 4.21: `¬fiberwiseClose(fold_k_f) → ¬jointProximityNat(U, e)`. +3. Split `U` into even/odd stacks `(U_even, U_odd) = splitEvenOdd(U)`, + each of height `2^{ϑ-k-1}`. + By `not_jointProximityNat_of_not_jointProximityNat_evenOdd_split`: + `¬jointProximityNat₂(U_even, U_odd, e)` for `C_dest^{2^{ϑ-k-1}}`. +4. Fold step gives affine combination: + `preTensorCombine(fold_{k+1}_f) = affineLineEval(U_even, U_odd, r_new)` + (by `fold_preTensorCombine_eq_affineLineEvaluation_split`). +5. `fiberwiseClose(fold_{k+1}_f) → jointProximityNat(preTensorCombine(fold_{k+1}_f), e)` + (by `fiberwiseClose_implies_jointProximityNat`). +6. Contrapositive of DG25 affine proximity gap + (by `affineProximityGap_RS_interleaved_contrapositive`): + `Pr_r[close] ≤ |S|/|L|`. +-/ +lemma prop_4_20_2_case_2_fiberwise_far_incremental + (block_start_idx : Fin r) {midIdx_i midIdx_i_succ destIdx : Fin r} (k : ℕ) (h_k_lt : k < ϑ) + (h_midIdx_i : midIdx_i = block_start_idx + k) (h_midIdx_i_succ : midIdx_i_succ = block_start_idx + k + 1) + (h_destIdx : destIdx = block_start_idx + ϑ) (h_destIdx_le : destIdx ≤ ℓ) + (f_block_start : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) block_start_idx) + (r_prefix : Fin k → L) + (h_block_far : ¬ fiberwiseClose 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := block_start_idx) (steps := ϑ) (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) + (f := f_block_start)) : + let domain_size := Fintype.card (sDomain 𝔽q β h_ℓ_add_R_rate destIdx) + Pr_{ let r_new ← $ᵖ L }[ + ¬ incrementalFoldingBadEvent 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (block_start_idx := block_start_idx) (midIdx := midIdx_i) (destIdx := destIdx) (k := k) + (h_k_le := Nat.le_of_lt h_k_lt) (h_midIdx := h_midIdx_i) (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) + (f_block_start := f_block_start) (r_challenges := r_prefix) + ∧ + incrementalFoldingBadEvent 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (block_start_idx := block_start_idx) (midIdx := midIdx_i_succ) (destIdx := destIdx) (k := k + 1) + (h_k_le := Nat.succ_le_of_lt h_k_lt) (h_midIdx := h_midIdx_i_succ) (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) + (f_block_start := f_block_start) + (r_challenges := Fin.snoc r_prefix r_new) + ] ≤ + (domain_size / Fintype.card L) := by + classical + -- ════════════════════════════════════════════════════════ + -- Step 0: Simplify incrementalFoldingBadEvent using h_block_far + -- ════════════════════════════════════════════════════════ + dsimp only [incrementalFoldingBadEvent] + simp only [h_block_far, not_false_eq_true, ↓reduceDIte, dite_false] + -- ════════════════════════════════════════════════════════ + -- Step 1: Name the key objects + -- ════════════════════════════════════════════════════════ + let fold_k_f := iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := block_start_idx) (steps := k) (h_destIdx := h_midIdx_i) (h_destIdx_le := by omega) + (f := f_block_start) (r_challenges := r_prefix) + -- ════════════════════════════════════════════════════════ + -- Step 2: Factor out the deterministic ¬E(k) conjunct. + -- ════════════════════════════════════════════════════════ + let Ek_close := fiberwiseClose 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := midIdx_i) (steps := ϑ - k) (h_destIdx := by omega) + (h_destIdx_le := h_destIdx_le) (f := fold_k_f) + by_cases h_Ek_close : Ek_close + · -- pos case: fold_k_f IS fiberwiseClose → ¬fiberwiseClose = False → Pr[False ∧ _] = 0 + apply le_trans (Pr_le_Pr_of_implies ($ᵖ L) _ _ (fun r_new h => h.1)) + have : Pr_{ let r_new ← $ᵖ L }[¬Ek_close] = 0 := by + rw [prob_uniform_eq_card_filter_div_card] + simp only [not_not.mpr h_Ek_close, filter_False, card_empty, CharP.cast_eq_zero, + ENNReal.coe_zero, ENNReal.coe_natCast, ENNReal.zero_div] + rw [this]; exact zero_le _ + · -- neg case: h_Ek_close : ¬fiberwiseClose (fold_k_f is FAR) + apply le_trans (Pr_le_Pr_of_implies ($ᵖ L) _ _ (fun r_new h => h.2)) + -- ════════════════════════════════════════════════════════ + -- Step 3: Decompose steps_remaining = s + 1 up front + -- ════════════════════════════════════════════════════════ + have h_midIdx_i_lt_ℓ : midIdx_i.val < ℓ := by omega + let s := ϑ - k - 1 + have h_steps_eq : ϑ - k = s + 1 := by omega + haveI : NeZero (s + 1) := ⟨by omega⟩ + -- ════════════════════════════════════════════════════════ + -- Step 4: Build U with height 2^(s+1) directly + -- ════════════════════════════════════════════════════════ + let S_dest := sDomain 𝔽q β h_ℓ_add_R_rate destIdx + let C_dest : Set (S_dest → L) := + BBF_Code 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) destIdx + let e_prox := Code.uniqueDecodingRadius (C := C_dest) + let i_ℓ : Fin ℓ := ⟨midIdx_i.val, h_midIdx_i_lt_ℓ⟩ + let U := preTensorCombine_WordStack 𝔽q β i_ℓ (s + 1) + (destIdx := destIdx) + (h_destIdx := by simp [i_ℓ]; omega) + (h_destIdx_le := h_destIdx_le) + fold_k_f + have h_U_far : ¬jointProximityNat (C := C_dest) (u := U) + (e := e_prox) := by + apply lemma_4_21_interleaved_word_UDR_far 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := i_ℓ) (steps := s + 1) + (h_destIdx := by simp [i_ℓ]; omega) + (h_destIdx_le := h_destIdx_le) + (f_i := fold_k_f) + (h_far := by convert h_Ek_close using 2; omega) + -- ════════════════════════════════════════════════════════ + -- Step 5: Split U into even/odd rows and establish pair-farness + -- ════════════════════════════════════════════════════════ + let U_even := (splitEvenOddRowWiseInterleavedWords (r := r) (ℓ := ℓ) (𝓡 := 𝓡) + (ϑ := s) U).1 + let U_odd := (splitEvenOddRowWiseInterleavedWords (r := r) (ℓ := ℓ) (𝓡 := 𝓡) + (ϑ := s) U).2 + let u_even := interleaveWordStack U_even + let u_odd := interleaveWordStack U_odd + have h_pair_far : ¬ jointProximityNat₂ + (A := InterleavedSymbol L (Fin (2^s))) + (C := (C_dest ^⋈ (Fin (2^s)))) + (u₀ := u_even) (u₁ := u_odd) (e := e_prox) := + not_jointProximityNat_of_not_jointProximityNat_evenOdd_split + (s := s) (C := C_dest) (U := U) (e := e_prox) + (U_even := U_even) (U_odd := U_odd) + (hU_even := by rfl) (hU_odd := by rfl) + (h_far := h_U_far) + -- ════════════════════════════════════════════════════════ + -- Step 6: Apply affine proximity gap (contrapositive) + -- ════════════════════════════════════════════════════════ + have h_affine_bound : + Pr_{let r ← $ᵖ L}[ + Δ₀(affineLineEvaluation (F := L) u_even u_odd r, + (C_dest ^⋈ (Fin (2^s)))) ≤ e_prox] + ≤ (Fintype.card S_dest : ℝ≥0) / (Fintype.card L) := + affineProximityGap_RS_interleaved_contrapositive + 𝔽q β (hm := Nat.one_le_two_pow) (h_destIdx_le := h_destIdx_le) + (e := e_prox) (he := le_refl _) (h_far := h_pair_far) + -- ════════════════════════════════════════════════════════ + -- Step 7: Connect the events and conclude + -- ════════════════════════════════════════════════════════ + apply le_trans _ h_affine_bound + apply Pr_le_Pr_of_implies ($ᵖ L) _ _ + intro r_new h_fw_close + -- fiberwiseClose(fold_{k+1}_f) → Δ₀(affineLineEval ...) ≤ e_prox + -- via fiberwiseClose_fold_implies_affineLineEval_close + exact fiberwiseClose_fold_implies_affineLineEval_close 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := midIdx_i) (h_i_lt_ℓ := h_midIdx_i_lt_ℓ) (s := s) + (midIdx := midIdx_i_succ) (destIdx := destIdx) + (h_midIdx := by omega) + (h_destIdx := by omega) + (h_destIdx_le := h_destIdx_le) + (f_i := fold_k_f) (r_new := r_new) + (h_fw_close := by + rw [iterated_fold_last 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + block_start_idx (steps := k) + (h_midIdx := h_midIdx_i) (h_destIdx := h_midIdx_i_succ) + (h_destIdx_le := by omega) + f_block_start (Fin.snoc r_prefix r_new)] at h_fw_close + simp only [Fin.init_snoc, Fin.snoc_last] at h_fw_close + convert h_fw_close using 1) + +lemma prop_4_20_2_incremental_bad_event_probability + (block_start_idx : Fin r) {midIdx_i midIdx_i_succ destIdx : Fin r} (k : ℕ) (h_k_lt : k < ϑ) + (h_midIdx_i : midIdx_i = block_start_idx + k) (h_midIdx_i_succ : midIdx_i_succ = block_start_idx + k + 1) + (h_destIdx : destIdx = block_start_idx + ϑ) (h_destIdx_le : destIdx ≤ ℓ) + (f_block_start : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) block_start_idx) + (r_prefix : Fin k → L) : + let domain_size := Fintype.card (sDomain 𝔽q β h_ℓ_add_R_rate destIdx) + Pr_{ let r_new ← $ᵖ L }[ + ¬ incrementalFoldingBadEvent 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (block_start_idx := block_start_idx) (midIdx := midIdx_i) (destIdx := destIdx) (k := k) + (h_k_le := Nat.le_of_lt h_k_lt) (h_midIdx := h_midIdx_i) (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) + (f_block_start := f_block_start) (r_challenges := r_prefix) + ∧ + incrementalFoldingBadEvent 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (block_start_idx := block_start_idx) (midIdx := midIdx_i_succ) (destIdx := destIdx) (k := k + 1) + (h_k_le := Nat.succ_le_of_lt h_k_lt) (h_midIdx := h_midIdx_i_succ) (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) + (f_block_start := f_block_start) + (r_challenges := Fin.snoc r_prefix r_new) + ] ≤ + (domain_size / Fintype.card L) := by + by_cases h_block_close : fiberwiseClose 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := block_start_idx) (steps := ϑ) (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) + (f := f_block_start) + · exact prop_4_20_2_case_1_fiberwise_close_incremental 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (block_start_idx := block_start_idx) + (midIdx_i := midIdx_i) (midIdx_i_succ := midIdx_i_succ) (destIdx := destIdx) (k := k) (h_k_lt := h_k_lt) (h_midIdx_i := h_midIdx_i) (h_midIdx_i_succ := h_midIdx_i_succ) (h_destIdx := h_destIdx) + (h_destIdx_le := h_destIdx_le) (f_block_start := f_block_start) + (r_prefix := r_prefix) (h_block_close := h_block_close) + · exact prop_4_20_2_case_2_fiberwise_far_incremental 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (block_start_idx := block_start_idx) + (midIdx_i := midIdx_i) (midIdx_i_succ := midIdx_i_succ) (destIdx := destIdx) (k := k) (h_k_lt := h_k_lt) (h_midIdx_i := h_midIdx_i) (h_midIdx_i_succ := h_midIdx_i_succ) (h_destIdx := h_destIdx) + (h_destIdx_le := h_destIdx_le) (f_block_start := f_block_start) + (r_prefix := r_prefix) (h_block_far := h_block_close) + +open Classical in +omit [CharP L 2] [DecidableEq 𝔽q] hF₂ in +/-- Helper: If `f` and `g` agree on the fiber of `y`, their folds agree at `y`. +NOTE: this might not be needed -/ +lemma fold_agreement_of_fiber_agreement (i : Fin ℓ) (steps : ℕ) + {destIdx : Fin r} (h_destIdx : destIdx.val = i.val + steps) (h_destIdx_le : destIdx ≤ ℓ) + (f g : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ⟨i, by omega⟩) + (r_challenges : Fin steps → L) (y : sDomain 𝔽q β h_ℓ_add_R_rate destIdx) : + (∀ x, + iteratedQuotientMap 𝔽q β h_ℓ_add_R_rate (i := ⟨i, by omega⟩) (destIdx := destIdx) + (k := steps) (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) x = y → + f x = g x) → + (iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ⟨i, by omega⟩ steps h_destIdx h_destIdx_le f (r_challenges := r_challenges) y = + (iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ⟨i, by omega⟩ steps h_destIdx h_destIdx_le g (r_challenges := r_challenges) (y := y))) := by + intro h_fiber_agree + -- Expand to matrix form: fold(y) = Tensor(r) * M_y * fiber_vals + rw [iterated_fold_eq_matrix_form 𝔽q β (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le)] + rw [iterated_fold_eq_matrix_form 𝔽q β (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le)] + -- ⊢ localized_fold_matrix_form 𝔽q β i steps h_destIdx h_destIdx_le f r y = + -- localized_fold_matrix_form 𝔽q β i steps h_destIdx h_destIdx_le g r y + unfold localized_fold_matrix_form single_point_localized_fold_matrix_form + simp only + congr 2 + let left := fiberEvaluations 𝔽q β (i := ⟨i, by omega⟩) (steps := steps) (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) f y + let right := fiberEvaluations 𝔽q β (i := ⟨i, by omega⟩) (steps := steps) (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) g y + have h_fiber_eval_eq : left = right := by + unfold left right fiberEvaluations + ext idx + let x := qMap_total_fiber 𝔽q β (i := ⟨i, by omega⟩) (steps := steps) (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) y idx + have h_x_folds_to_y := generates_quotient_point_if_is_fiber_of_y 𝔽q β (i := ⟨i, by omega⟩) (steps := steps) + (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) (x := x) (y := y) (hx_is_fiber := by use idx) + exact h_fiber_agree x h_x_folds_to_y.symm + unfold left right at h_fiber_eval_eq + rw [h_fiber_eval_eq] + +omit [CharP L 2] [DecidableEq 𝔽q] hF₂ in +/-- Helper: The disagreement set of the folded functions is a subset of the fiberwise disagreement set. -/ +lemma disagreement_fold_subset_fiberwiseDisagreement (i : Fin ℓ) (steps : ℕ) + {destIdx : Fin r} (h_destIdx : destIdx.val = i.val + steps) (h_destIdx_le : destIdx ≤ ℓ) + (f g : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ⟨i, by omega⟩) + (r_challenges : Fin steps → L) : + let folded_f := iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ⟨i, by omega⟩ steps h_destIdx h_destIdx_le f (r_challenges := r_challenges) + let folded_g := iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ⟨i, by omega⟩ steps h_destIdx h_destIdx_le g (r_challenges := r_challenges) + disagreementSet 𝔽q β (i := destIdx) (destIdx := destIdx) (h_destIdx := rfl) (f := folded_f) (g := folded_g) ⊆ + fiberwiseDisagreementSet 𝔽q β (i := ⟨i, by omega⟩) steps h_destIdx h_destIdx_le f g := by + simp only + intro y hy_mem + simp only [disagreementSet, ne_eq, mem_filter, mem_univ, true_and] at hy_mem + simp only [fiberwiseDisagreementSet, ne_eq, Subtype.exists, mem_filter, mem_univ, true_and] + -- Contrapositive: If y is NOT in fiberwise disagreement, then f and g agree on fiber. + -- Then folds must agree (lemma above). Then y is NOT in disagreement set. + by_contra h_not_in_fiber_diff + have h_agree_on_fiber : ∀ x, + iteratedQuotientMap 𝔽q β h_ℓ_add_R_rate (i := ⟨i, by omega⟩) (destIdx := destIdx) + (k := steps) (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) x = y → + f x = g x := by + intro x hx + by_contra h_neq + exact h_not_in_fiber_diff ⟨x, (by simp only [SetLike.coe_mem]), (by simp only [Subtype.coe_eta]; constructor; exact hx; exact h_neq)⟩ + have h_fold_eq := fold_agreement_of_fiber_agreement 𝔽q β i steps h_destIdx h_destIdx_le f g (r_challenges := r_challenges) (y := y) h_agree_on_fiber + exact hy_mem h_fold_eq + +/-- **Lemma 4.24** +For `i*` where `f^(i)` is non-compliant, `f^(i+ϑ)` is UDR-close, and the bad event `E_{i*}` +doesn't occur, the folded function of `f^(i)` is not UDR-close to the UDR-decoded codeword +of `f^(i+ϑ)`. -/ +lemma lemma_4_24_dist_folded_ge_of_last_noncompliant (i_star : Fin ℓ) (steps : ℕ) [NeZero steps] + {destIdx : Fin r} (h_destIdx : destIdx.val = i_star.val + steps) (h_destIdx_le : destIdx ≤ ℓ) + (f_star : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ⟨i_star, by omega⟩) + (f_next : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) destIdx) + (r_challenges : Fin steps → L) + -- 1. f_next is the actual folded function + -- 2. i* is non-compliant + (h_not_compliant : ¬ isCompliant 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ⟨i_star, by omega⟩ steps h_destIdx h_destIdx_le + f_star f_next (challenges := r_challenges)) + -- 3. No bad event occurred at i* + (h_no_bad_event : ¬ foldingBadEvent 𝔽q β (i := ⟨i_star, by omega⟩) steps h_destIdx h_destIdx_le f_star r_challenges) + -- 4. The next function `f_next` IS close enough to have a unique codeword `f_bar_next` + (h_next_close : UDRClose 𝔽q β destIdx h_destIdx_le f_next) : + let f_i_star_folded := iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + ⟨i_star, by omega⟩ steps h_destIdx h_destIdx_le f_star r_challenges + -- **CONCLUSION**: 2 * d(f_next, f_bar_next) ≥ d_{i* + steps} + let f_bar_next := UDRCodeword 𝔽q β destIdx h_destIdx_le (f := f_next) (h_within_radius := h_next_close) + ¬ pair_UDRClose 𝔽q β destIdx h_destIdx_le f_i_star_folded f_bar_next := by + -- Definitions for clarity + let d_next := BBF_CodeDistance 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) destIdx + let S_next := sDomain 𝔽q β h_ℓ_add_R_rate destIdx + let C_cur := BBF_Code 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ⟨i_star, by omega⟩ + let C_next := BBF_Code 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) destIdx + let f_bar_next := UDRCodeword 𝔽q β destIdx h_destIdx_le + (f := f_next) (h_within_radius := h_next_close) + let f_i_star_folded := iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + ⟨i_star, by omega⟩ steps h_destIdx h_destIdx_le f_star r_challenges + have h_f_bar_next_mem_C_next : f_bar_next ∈ C_next := by -- due to definition + unfold f_bar_next UDRCodeword + apply UDRCodeword_mem_BBF_Code (i := destIdx) (h_i := h_destIdx_le) (f := f_next) (h_within_radius := h_next_close) + have h_d_next_ne_0 : d_next ≠ 0 := by + unfold d_next + simp [BBF_CodeDistance_eq (L := L) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := destIdx) (h_i := h_destIdx_le)] + let d_fw := fiberwiseDistance 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := ⟨i_star, by omega⟩) + steps h_destIdx h_destIdx_le (f := f_star) + -- Split into Case 1 (Close) and Case 2 (Far) + by_cases h_fw_close : fiberwiseClose 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := ⟨i_star, by omega⟩) (steps := steps) (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) (f := f_star) + -- Case 1: Fiberwise Close (d < d_next / 2) + · let h_fw_dist_lt := h_fw_close -- This gives 2 * fiber_dist < d_next + -- Define f_bar_star (the unique decoded codeword for f_star) to be the **fiberwise**-close codeword to f_star + obtain ⟨f_bar_star, ⟨h_f_bar_star_mem, h_f_bar_star_min_card, h_f_bar_star_eq_UDRCodeword⟩, h_unique⟩ := exists_unique_fiberwiseClosestCodeword_within_UDR 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := ⟨i_star, by omega⟩) (steps := steps) (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) (f := f_star) (h_fw_close := h_fw_close) + have h_fw_dist_f_g_eq : #(fiberwiseDisagreementSet 𝔽q β ⟨i_star, by omega⟩ steps h_destIdx h_destIdx_le f_star f_bar_star) = d_fw := by + unfold d_fw + rw [h_f_bar_star_min_card]; rfl + let folded_f_bar_star := iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ⟨i_star, by omega⟩ + steps h_destIdx h_destIdx_le f_bar_star (r_challenges := r_challenges) + have h_folded_f_bar_star_mem_C_next : folded_f_bar_star ∈ C_next := by + unfold folded_f_bar_star + apply iterated_fold_preserves_BBF_Code_membership 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := ⟨i_star, by omega⟩) (steps := steps) (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) (f := ⟨f_bar_star, h_f_bar_star_mem⟩) (r_challenges := r_challenges) + -- We prove two inequalities (1) and (2) as per the proof sketch. + -- **Step (1): Distance between the two codewords in C_next** + -- First, show that `folded_f_bar_star` ≠ `f_bar_next`. + -- This follows because `f_star` is NON-COMPLIANT. + have h_codewords_neq : f_bar_next ≠ folded_f_bar_star := by + by_contra h_eq + -- If they were equal, `isCompliant` would be true (satisfying all 3 conditions). + apply h_not_compliant + use h_fw_dist_lt -- Condition 1: f_star is close + use h_next_close -- Condition 2: f_next is close + -- Condition 3: folded decoding equals next decoding + simp only + rw [←h_f_bar_star_eq_UDRCodeword] + -- ⊢ iterated_fold ⟨i*, ⋯⟩ steps ⋯ f_bar_star r_challenges = UDRCodeword 𝔽q β ⟨i* + steps, ⋯⟩ f_next h_next_close + exact id (Eq.symm h_eq) + -- Since they are distinct codewords, their distance is at least `d_next`. + have h_ineq_1 : Δ₀(f_bar_next, folded_f_bar_star) ≥ d_next := by + apply Code.pairDist_ge_code_mindist_of_ne (C := (C_next : Set _)) + (u := f_bar_next) (v := folded_f_bar_star) + · exact h_f_bar_next_mem_C_next + · exact h_folded_f_bar_star_mem_C_next + · exact h_codewords_neq + -- **Step (2): Distance between folded functions** + -- We know |Δ_fiber(f*, f_bar*)| < d_next / 2 (from fiberwise close hypothesis). + have h_fiber_dist_lt_half : + 2 * (fiberwiseDisagreementSet 𝔽q β (i := ⟨i_star, by omega⟩) steps h_destIdx h_destIdx_le f_star f_bar_star).card < d_next := by + rw [Nat.two_mul_lt_iff_le_half_of_sub_one (h_b_pos := by omega)] + -- ⊢ #(fiberwiseDisagreementSet 𝔽q β i_star steps h_destIdx h_destIdx_le f_star f_bar_star) ≤ (d_next - 1) / 2 + rw [h_fw_dist_f_g_eq] + rw [←Nat.two_mul_lt_iff_le_half_of_sub_one (h_b_pos := by omega)] + unfold d_fw + unfold fiberwiseClose at h_fw_close + norm_cast at h_fw_close + -- Lemma 4.18 (Geometric): d(fold(f), fold(g)) ≤ |Δ_fiber(f, g)| + have h_ineq_2 : 2 * Δ₀(f_i_star_folded, folded_f_bar_star) < d_next := by + calc + 2 * Δ₀(iterated_fold 𝔽q β ⟨i_star, by omega⟩ steps h_destIdx h_destIdx_le f_star (r_challenges := r_challenges), folded_f_bar_star) + _ ≤ 2 * (fiberwiseDisagreementSet 𝔽q β (i := ⟨i_star, by omega⟩) steps h_destIdx h_destIdx_le f_star f_bar_star).card := by + -- Hamming distance is card(disagreementSet) + -- disagreementSet ⊆ fiberwiseDisagreementSet (Lemma 4.18 Helper) + apply Nat.mul_le_mul_left + let res := disagreement_fold_subset_fiberwiseDisagreement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := i_star) (steps := steps) (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) (f := f_star) (g := f_bar_star) (r_challenges := r_challenges) + simp only at res + apply Finset.card_le_card + exact res + _ < d_next := h_fiber_dist_lt_half + -- **Final Step: Reverse Triangle Inequality** + -- d(A, C) ≥ d(B, C) - d(A, B) + -- We want 2 * d(f_next, f_bar_next) ≥ d_next + have h_triangle : Δ₀(f_bar_next, folded_f_bar_star) ≤ Δ₀(f_bar_next, f_i_star_folded) + Δ₀(f_i_star_folded, folded_f_bar_star) := + hammingDist_triangle f_bar_next f_i_star_folded folded_f_bar_star + have h_final_bound : 2 * d_next ≤ 2 * Δ₀(f_bar_next, f_i_star_folded) + 2 * Δ₀(f_i_star_folded, folded_f_bar_star) := by + have h_trans : d_next ≤ Δ₀(f_bar_next, folded_f_bar_star) := h_ineq_1 + have h_mul : 2 * d_next ≤ 2 * Δ₀(f_bar_next, folded_f_bar_star) := Nat.mul_le_mul_left 2 h_trans + linarith [h_triangle, h_mul] + -- We have 2*d_next ≤ 2*d(Target) + (something < d_next) + -- This implies 2*d(Target) > d_next + -- Or in integer arithmetic: 2*d(Target) ≥ d_next + rw [hammingDist_comm] at h_final_bound -- align directions + unfold pair_UDRClose + simp only [not_lt, ge_iff_le] + apply le_of_not_gt + intro h_contra + -- If 2 * d(target) < d_next: + -- Sum < d_next + d_next = 2*d_next. Contradiction. + linarith [h_ineq_2, h_final_bound, h_contra] + -- **Case 2: Fiberwise Far (d ≥ d_next / 2)** + · -- In this case, the definition of `foldingBadEvent` (Case 2 branch) simplifies. + -- The bad event is defined as: UDRClose(f_next). + unfold foldingBadEvent at h_no_bad_event + simp only [h_fw_close, ↓reduceDIte] at h_no_bad_event + -- h_no_bad_event : ¬ UDRClose ... + -- This means f_next is NOT close to the code C_next. + -- Definition of not UDRClose: 2 * dist(f_next, C_next) ≥ d_next + unfold UDRClose at h_no_bad_event + simp only [not_lt] at h_no_bad_event + -- ↑(BBF_CodeDistance 𝔽q β destIdx) + have h_no_bad_event_alt : (d_next : ℕ∞) ≤ 2 * Δ₀(f_i_star_folded, f_bar_next):= by + calc + d_next ≤ 2 * Δ₀(f_i_star_folded, (C_next : Set (S_next → L))) := by + exact h_no_bad_event + _ ≤ 2 * Δ₀(f_i_star_folded, f_bar_next) := by + rw [ENat.mul_le_mul_left_iff] + · simp only [Code.distFromCode_le_dist_to_mem (C := (C_next : Set (S_next → L))) (u := + f_i_star_folded) (v := f_bar_next) (hv := h_f_bar_next_mem_C_next)] + · simp only [ne_eq, OfNat.ofNat_ne_zero, not_false_eq_true] + · simp only [ne_eq, ENat.ofNat_ne_top, not_false_eq_true] + unfold pair_UDRClose + simp only [not_lt, ge_iff_le] + norm_cast at h_no_bad_event_alt + +section QueryPhaseSoundnessStatements + +variable [hdiv : Fact (ϑ ∣ ℓ)] +variable [SampleableType L] +open QueryPhase + +/-- Number of oracle blocks at the end of the protocol. -/ +abbrev nBlocks : ℕ := toOutCodewordsCount ℓ ϑ (Fin.last ℓ) + +/-- A block index is *bad* if the corresponding folding-compliance check fails. -/ +def badBlockProp + (stmtIn : FinalSumcheckStatementOut (L := L) (ℓ := ℓ)) + (oStmtIn : ∀ j, OracleStatement 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (Fin.last ℓ) j) : + Fin (nBlocks (ϑ := ϑ) (ℓ := ℓ)) → Prop := fun j => + have h_ϑ_le_ℓ : ϑ ≤ ℓ := by apply Nat.le_of_dvd (by exact Nat.pos_of_neZero ℓ) (hdiv.out) + if hj : j.val + 1 < nBlocks then + let curDomainIdx : Fin r := ⟨j.val * ϑ, by + apply lt_r_of_le_ℓ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (x := j.val * ϑ) + have h := oracle_index_add_steps_le_ℓ (ℓ := ℓ) (ϑ := ϑ) (i := Fin.last ℓ) (j := j) + exact (Nat.le_add_right _ _).trans h⟩ + let destIdx : Fin r := ⟨j.val * ϑ + ϑ, by + apply lt_r_of_le_ℓ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (x := j.val * ϑ + ϑ) + exact oracle_index_add_steps_le_ℓ (ℓ := ℓ) (ϑ := ϑ) (i := Fin.last ℓ) (j := j)⟩ + ¬ isCompliant 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := curDomainIdx) (steps := ϑ) (destIdx := destIdx) + (h_destIdx := by rfl) (h_destIdx_le := by + exact oracle_index_add_steps_le_ℓ (ℓ := ℓ) (ϑ := ϑ) (i := Fin.last ℓ) (j := j)) + (f_i := oStmtIn j) + (f_i_plus_steps := + getNextOracle 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ϑ := ϑ) + (i := Fin.last ℓ) (oStmt := oStmtIn) (j := j) (hj := by + simp only [nBlocks] at hj ⊢ + exact hj) (destDomainIdx := destIdx) (h_destDomainIdx := by rfl)) + (challenges := + getFoldingChallenges (r := r) (𝓡 := 𝓡) (ϑ := ϑ) (i := Fin.last ℓ) + stmtIn.challenges (k := j.val * ϑ) (h := by + exact oracle_index_add_steps_le_ℓ (ℓ := ℓ) (ϑ := ϑ) (i := Fin.last ℓ) (j := j))) + else + let j_last := getLastOraclePositionIndex ℓ ϑ (Fin.last ℓ) + let k := j_last.val * ϑ + have h_k : k = ℓ - ϑ := by + dsimp [j_last, k] + simp only [getLastOraclePositionIndex_last, Nat.sub_mul, Nat.div_mul_cancel (hdiv.out), + one_mul] + have hk_add : k + ϑ = ℓ := by + simp only [h_k] at h_k ⊢ + exact Nat.sub_add_cancel (by omega) + have hk_le : k ≤ ℓ := by omega + ¬ isCompliant 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := ⟨k, by + apply lt_r_of_le_ℓ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (x := k); omega + ⟩) (steps := ϑ) (destIdx := ⟨k + ϑ, by + apply lt_r_of_le_ℓ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (x := k + ϑ); omega + ⟩) + (h_destIdx := by rfl) + (h_destIdx_le := by + -- k + ϑ = ℓ, so the bound holds + simp only [hk_add, le_refl]) + (f_i := oStmtIn j_last) + (f_i_plus_steps := fun _ => stmtIn.final_constant) + (challenges := + getFoldingChallenges (r := r) (𝓡 := 𝓡) (ϑ := ϑ) (i := Fin.last ℓ) + stmtIn.challenges (k := k) (h := by + simp only [hk_add, Fin.val_last, le_refl])) + +open Classical in +/-- Finset of bad blocks. -/ +def badBlockSet + (stmtIn : FinalSumcheckStatementOut (L := L) (ℓ := ℓ)) + (oStmtIn : ∀ j, OracleStatement 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (Fin.last ℓ) j) : + Finset (Fin (nBlocks (ℓ := ℓ) (ϑ := ϑ))) := + Finset.filter (badBlockProp 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (stmtIn := stmtIn) (oStmtIn := oStmtIn)) Finset.univ + +open Classical in +noncomputable def highestBadBlock + (stmtIn : FinalSumcheckStatementOut (L := L) (ℓ := ℓ)) + (oStmtIn : ∀ j, OracleStatement 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (Fin.last ℓ) j) + (h_exists : ∃ j : Fin (nBlocks (ℓ := ℓ) (ϑ := ϑ)), + badBlockProp 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (stmtIn := stmtIn) (oStmtIn := oStmtIn) j) : + Fin (nBlocks (ℓ := ℓ) (ϑ := ϑ)) := + (badBlockSet 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (stmtIn := stmtIn) (oStmtIn := oStmtIn)).max' (by + rcases h_exists with ⟨j, hj⟩ + refine ⟨j, ?_⟩ + exact (Finset.mem_filter.mpr ⟨by simp, hj⟩)) + +lemma highestBadBlock_is_bad + (stmtIn : FinalSumcheckStatementOut (L := L) (ℓ := ℓ)) + (oStmtIn : ∀ j, OracleStatement 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (Fin.last ℓ) j) + (h_exists : ∃ j : Fin (nBlocks (ℓ := ℓ) (ϑ := ϑ)), + badBlockProp 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (stmtIn := stmtIn) (oStmtIn := oStmtIn) j) : + badBlockProp 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (stmtIn := stmtIn) (oStmtIn := oStmtIn) + (highestBadBlock 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (stmtIn := stmtIn) (oStmtIn := oStmtIn) h_exists) := by + classical + have hmem : + highestBadBlock 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (stmtIn := stmtIn) (oStmtIn := oStmtIn) h_exists + ∈ badBlockSet 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (stmtIn := stmtIn) (oStmtIn := oStmtIn) := by + -- max' is always a member of the set + simpa [highestBadBlock] using + (Finset.max'_mem + (badBlockSet 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (stmtIn := stmtIn) (oStmtIn := oStmtIn)) + (by + rcases h_exists with ⟨j, hj⟩ + refine ⟨j, ?_⟩ + exact (Finset.mem_filter.mpr ⟨by simp, hj⟩))) + have hmem' := Finset.mem_filter.mp hmem + exact hmem'.2 + +lemma not_badBlock_of_lt_highest + (stmtIn : FinalSumcheckStatementOut (L := L) (ℓ := ℓ)) + (oStmtIn : ∀ j, OracleStatement 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (Fin.last ℓ) j) + (h_exists : ∃ j : Fin (nBlocks (ℓ := ℓ) (ϑ := ϑ)), + badBlockProp 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (stmtIn := stmtIn) (oStmtIn := oStmtIn) j) + {j : Fin (nBlocks (ℓ := ℓ) (ϑ := ϑ))} + (hlt : highestBadBlock 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (stmtIn := stmtIn) (oStmtIn := oStmtIn) h_exists < j) : + ¬ badBlockProp 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (stmtIn := stmtIn) (oStmtIn := oStmtIn) j := by + classical + intro hj_bad + have hj_mem : + j ∈ badBlockSet 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (stmtIn := stmtIn) (oStmtIn := oStmtIn) := by + exact (Finset.mem_filter.mpr ⟨by simp, hj_bad⟩) + have h_nonempty : + (badBlockSet 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (stmtIn := stmtIn) (oStmtIn := oStmtIn)).Nonempty := by + rcases h_exists with ⟨j', hj'⟩ + refine ⟨j', ?_⟩ + exact (Finset.mem_filter.mpr ⟨by simp, hj'⟩) + have hle : j ≤ highestBadBlock 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (stmtIn := stmtIn) (oStmtIn := oStmtIn) h_exists := + by + -- le_max' takes the membership proof; Nonempty is inferred from max' + simpa [highestBadBlock] using + (Finset.le_max' (badBlockSet 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (stmtIn := stmtIn) (oStmtIn := oStmtIn)) j hj_mem) + exact not_lt_of_ge hle hlt + +/-- Congruence lemma for `UDRClose`: transport along a `Fin r` equality. +Given two `Fin r` indices with the same value and `HEq` functions, `UDRClose` transfers. -/ +lemma UDRClose_of_fin_eq {i j : Fin r} (hij : i = j) + {hi : ↑i ≤ ℓ} {hj : ↑j ≤ ℓ} + {f : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i} + {g : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) j} + (hfg : HEq f g) (h : UDRClose 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i hi f) : + UDRClose 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) j hj g := by + subst hij; exact eq_of_heq hfg ▸ h + +/-- If block `j` is not bad (i.e. it is compliant), then the oracle `oStmtIn j` is UDR-close +at its domain position `j.val * ϑ`. This extracts `fiberwiseClose` from `isCompliant` +(the negation of `badBlockProp`) and converts it to `UDRClose` via `UDRClose_of_fiberwiseClose`. -/ +lemma goodBlock_implies_UDRClose + (stmtIn : FinalSumcheckStatementOut (L := L) (ℓ := ℓ)) + (oStmtIn : ∀ j, OracleStatement 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (Fin.last ℓ) j) + (j : Fin (nBlocks (ℓ := ℓ) (ϑ := ϑ))) + (h_good : ¬ badBlockProp 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) stmtIn oStmtIn j) + {destIdx : Fin r} + (h_idx : (⟨j.val * ϑ, lt_r_of_le_ℓ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + ((Nat.le_add_right _ _).trans + (oracle_index_add_steps_le_ℓ (ℓ := ℓ) (ϑ := ϑ) + (i := Fin.last ℓ) (j := j)))⟩ : Fin r) = destIdx) + (h_le : destIdx.val ≤ ℓ) : + UDRClose 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + destIdx h_le (fun y => (oStmtIn j) (cast (by rw [h_idx]) y)) := by + subst h_idx; simp only [cast_eq] + -- Unfold badBlockProp: it's `¬isCompliant` in both branches. + simp only [badBlockProp] at h_good + by_cases h_last : j.val + 1 < nBlocks (ℓ := ℓ) (ϑ := ϑ) + · -- Intermediate block: badBlockProp = ¬isCompliant + simp only [h_last, ↓reduceDIte, not_not] at h_good + obtain ⟨h_fw, _, _⟩ := h_good + exact UDRClose_of_fiberwiseClose 𝔽q β _ ϑ (by rfl) + (oracle_index_add_steps_le_ℓ (ℓ := ℓ) (ϑ := ϑ) (i := Fin.last ℓ) (j := j)) + (oStmtIn j) h_fw + · -- Final block: need getLastOraclePositionIndex = j + simp only [h_last, ↓reduceDIte, not_not] at h_good + have h_j_eq : getLastOraclePositionIndex ℓ ϑ (Fin.last ℓ) = j := by + apply Fin.ext + simp only [getLastOraclePositionIndex, toOutCodewordsCount_last] + have h_ge : nBlocks (ℓ := ℓ) (ϑ := ϑ) ≤ j.val + 1 := Nat.le_of_not_gt h_last + simp only [nBlocks, toOutCodewordsCount_last] at h_ge + have h_lt : j.val < nBlocks (ℓ := ℓ) (ϑ := ϑ) := j.isLt + simp only [nBlocks, toOutCodewordsCount_last] at h_lt + omega + subst h_j_eq + obtain ⟨h_fw, _, _⟩ := h_good + exact UDRClose_of_fiberwiseClose 𝔽q β _ ϑ (by rfl) + (oracle_index_add_steps_le_ℓ (ℓ := ℓ) (ϑ := ϑ) + (i := Fin.last ℓ) (j := getLastOraclePositionIndex ℓ ϑ (Fin.last ℓ))) + (oStmtIn (getLastOraclePositionIndex ℓ ϑ (Fin.last ℓ))) h_fw + +open Classical in +lemma prob_uniform_suffix_mem + (destIdx : Fin r) (h_destIdx_le : destIdx ≤ ℓ) + (D : Finset (sDomain 𝔽q β h_ℓ_add_R_rate destIdx)) : + Pr_{ let v ←$ᵖ (sDomain 𝔽q β h_ℓ_add_R_rate 0) }[ + extractSuffixFromChallenge 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (v := v) (destIdx := destIdx) (h_destIdx_le := h_destIdx_le) ∈ D + ] = (D.card : ENNReal) / + Fintype.card (sDomain 𝔽q β h_ℓ_add_R_rate destIdx) := by + classical + -- Setup + let S0 := sDomain 𝔽q β h_ℓ_add_R_rate 0 + let Sdest := sDomain 𝔽q β h_ℓ_add_R_rate destIdx + let steps : ℕ := destIdx.val + have h_destIdx : destIdx.val = (0 : Fin r).val + steps := by simp [steps] + let suffix : S0 → Sdest := + extractSuffixFromChallenge 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (destIdx := destIdx) (h_destIdx_le := h_destIdx_le) + -- Express probability via cardinalities + rw [prob_uniform_eq_card_filter_div_card] + -- Define the preimage set + let preimage : Finset S0 := Finset.univ.filter (fun v => suffix v ∈ D) + -- Each fiber over y has size 2^steps + let fiberSet : Sdest → Finset S0 := fun y => + (Set.image (qMap_total_fiber 𝔽q β (i := (0 : Fin r)) (steps := steps) + h_destIdx h_destIdx_le (y := y)) (Set.univ : Set (Fin (2 ^ steps)))).toFinset + have h_fiber_card : ∀ y : Sdest, (fiberSet y).card = 2 ^ steps := by + intro y + have h := + card_qMap_total_fiber 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := (0 : Fin r)) (steps := steps) (h_destIdx := h_destIdx) + (h_destIdx_le := h_destIdx_le) (y := y) + -- Convert Fintype.card of the set to Finset.card + have h_card : + (fiberSet y).card = + Fintype.card + (Set.image (qMap_total_fiber 𝔽q β (i := (0 : Fin r)) (steps := steps) + h_destIdx h_destIdx_le (y := y)) (Set.univ : Set (Fin (2 ^ steps)))) := by + classical + simpa [fiberSet] using + (Set.toFinset_card + (s := Set.image (qMap_total_fiber 𝔽q β (i := (0 : Fin r)) (steps := steps) + h_destIdx h_destIdx_le (y := y)) (Set.univ : Set (Fin (2 ^ steps))))) + calc + (fiberSet y).card = + Fintype.card + (Set.image (qMap_total_fiber 𝔽q β (i := (0 : Fin r)) (steps := steps) + h_destIdx h_destIdx_le (y := y)) (Set.univ : Set (Fin (2 ^ steps)))) := h_card + _ = 2 ^ steps := h + -- Preimage equals union of fibers over D + have h_preimage_eq : + preimage = D.biUnion fiberSet := by + ext v + constructor + · intro hv + have hv' : suffix v ∈ D := by + simp only [preimage] at hv ⊢ + exact (Finset.mem_filter.mp hv).2 + -- v is in the fiber of its suffix + have hv_fiber : v ∈ fiberSet (suffix v) := by + -- Use the fiber index corresponding to v + let k := + pointToIterateQuotientIndex 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := (0 : Fin r)) (steps := steps) h_destIdx h_destIdx_le (x := v) + have hk : + qMap_total_fiber 𝔽q β (i := (0 : Fin r)) (steps := steps) + h_destIdx h_destIdx_le (y := suffix v) k = v := by + -- suffix v is exactly the iterated quotient of v + have h_eq : + suffix v = + iteratedQuotientMap 𝔽q β h_ℓ_add_R_rate (i := (0 : Fin r)) + (destIdx := destIdx) (k := steps) (h_destIdx := h_destIdx) + (h_destIdx_le := h_destIdx_le) (x := v) := by + simp [suffix, extractSuffixFromChallenge, steps] + -- Use the characterization of fibers + exact (is_fiber_iff_generates_quotient_point 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := (0 : Fin r)) (steps := steps) (h_destIdx := h_destIdx) + (h_destIdx_le := h_destIdx_le) (x := v) (y := suffix v)).1 h_eq + -- Show membership in the fiber set + have : v ∈ Set.image (qMap_total_fiber 𝔽q β (i := (0 : Fin r)) (steps := steps) + h_destIdx h_destIdx_le (y := suffix v)) (Set.univ : Set (Fin (2 ^ steps))) := by + refine ⟨k, by simp, hk⟩ + simpa [fiberSet] using this + -- Put together + refine Finset.mem_biUnion.mpr ?_ + exact ⟨suffix v, hv', hv_fiber⟩ + · intro hv + rcases Finset.mem_biUnion.mp hv with ⟨y, hyD, hv_fiber⟩ + -- From v ∈ fiberSet y, deduce suffix v = y + have hv_fiber' : + v ∈ Set.image (qMap_total_fiber 𝔽q β (i := (0 : Fin r)) (steps := steps) + h_destIdx h_destIdx_le (y := y)) (Set.univ : Set (Fin (2 ^ steps))) := by + simpa [fiberSet] using hv_fiber + rcases hv_fiber' with ⟨k, hk_mem, hk_eq⟩ + have h_eq : + y = + iteratedQuotientMap 𝔽q β h_ℓ_add_R_rate (i := (0 : Fin r)) + (destIdx := destIdx) (k := steps) (h_destIdx := h_destIdx) + (h_destIdx_le := h_destIdx_le) (x := v) := by + -- v is in the fiber of y, so y is the iterated quotient of v + apply generates_quotient_point_if_is_fiber_of_y 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := (0 : Fin r)) (steps := steps) + (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) + (x := v) (y := y) + refine ⟨k, ?_⟩ + exact hk_eq.symm + have : suffix v = y := by + -- Rewrite suffix v as iteratedQuotientMap + simpa [suffix, extractSuffixFromChallenge, steps] using h_eq.symm + -- Conclude v ∈ preimage + apply Finset.mem_filter.mpr + constructor + · simp only [mem_univ] + · -- suffix v ∈ D + simpa [this] using hyD + -- Cardinality of the preimage + have h_preimage_card : preimage.card = D.card * 2 ^ steps := by + -- Use disjoint union of fibers + have h_disjoint : + ∀ y₁ ∈ D, ∀ y₂ ∈ D, y₁ ≠ y₂ → + Disjoint (fiberSet y₁) (fiberSet y₂) := by + intro y₁ hy₁ y₂ hy₂ hy_ne + -- Apply fiber disjointness lemma + have h := + qMap_total_fiber_disjoint 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := (0 : Fin r)) (steps := steps) (h_destIdx := h_destIdx) + (h_destIdx_le := h_destIdx_le) (y₁ := y₁) (y₂ := y₂) hy_ne + simp only [fiberSet] at h ⊢ + exact h + -- Now compute the card via biUnion + calc + preimage.card + = (D.biUnion fiberSet).card := by simp only [h_preimage_eq] + _ = ∑ y ∈ D, (fiberSet y).card := by + exact Finset.card_biUnion (s := D) (t := fiberSet) (h := h_disjoint) + _ = ∑ y ∈ D, 2 ^ steps := by + refine Finset.sum_congr rfl ?_ + intro y hy + simp only [h_fiber_card] + _ = D.card * 2 ^ steps := by + simp only [sum_const, smul_eq_mul] + -- Cardinality of the source domain + have h_card_S0 : Fintype.card S0 = Fintype.card Sdest * 2 ^ steps := by + -- Use sDomain_card and the fact |𝔽q| = 2 + have h0 : + Fintype.card S0 = (Fintype.card 𝔽q) ^ (ℓ + 𝓡 - (0 : Fin r)) := by + simpa using (sDomain_card 𝔽q β h_ℓ_add_R_rate (i := (0 : Fin r)) + (h_i := Sdomain_bound (by omega))) + have hdest : + Fintype.card Sdest = (Fintype.card 𝔽q) ^ (ℓ + 𝓡 - destIdx) := by + simpa using (sDomain_card 𝔽q β h_ℓ_add_R_rate (i := destIdx) + (h_i := Sdomain_bound (by omega))) + -- Rewrite and use pow_add + have h_add : (ℓ + 𝓡) = (ℓ + 𝓡 - destIdx.val) + destIdx.val := by + have h_le : destIdx.val ≤ ℓ + 𝓡 := by omega + exact (Nat.sub_add_cancel h_le).symm + -- Convert to the desired form + -- We use hF₂.out to rewrite |𝔽q| = 2 + have hFq : Fintype.card 𝔽q = 2 := hF₂.out + calc + Fintype.card S0 + = (Fintype.card 𝔽q) ^ (ℓ + 𝓡) := by + simpa using h0 + _ = (Fintype.card 𝔽q) ^ ((ℓ + 𝓡 - destIdx.val) + destIdx.val) := by + exact congrArg (HPow.hPow (Fintype.card 𝔽q)) h_add + _ = (Fintype.card 𝔽q) ^ (ℓ + 𝓡 - destIdx.val) * + (Fintype.card 𝔽q) ^ destIdx.val := by + simp [pow_add] + _ = Fintype.card Sdest * 2 ^ steps := by + -- rewrite with hdest and |𝔽q| = 2 + simp only [hFq, hdest, steps] + -- Finish the probability computation + have h_card_pos : (2 ^ steps : ENNReal) ≠ 0 := by + exact_mod_cast (pow_ne_zero steps (by decide : (2 : ℕ) ≠ 0)) + have h_card_fin : (2 ^ steps : ENNReal) ≠ ⊤ := by + simp + -- Rewrite in terms of cards + have h_prob : + (preimage.card : ENNReal) / Fintype.card S0 + = (D.card : ENNReal) / Fintype.card Sdest := by + calc + (preimage.card : ENNReal) / Fintype.card S0 + = ((D.card * 2 ^ steps : ℕ) : ENNReal) / + (Fintype.card Sdest * 2 ^ steps : ℕ) := by + simp [h_preimage_card, h_card_S0, preimage, S0, Sdest] + _ = (D.card : ENNReal) / Fintype.card Sdest := by + -- Cancel the factor 2^steps + -- (a*b)/(c*b) = a/c + simpa [mul_comm, mul_left_comm, mul_assoc] using + (ENNReal.mul_div_mul_left (a := (D.card : ENNReal)) + (b := (Fintype.card Sdest : ENNReal)) (c := (2 ^ steps : ENNReal)) + h_card_pos h_card_fin) + simpa [preimage] using h_prob + +open Classical in +/-- **Lemma 4.25** (Query rejection from disagreement suffix). + +If the verifier's query point `v` has its suffix in the disagreement set between +`fold(f^{(j*\cdot\vartheta)})` and `\bar f^{(j*\cdot\vartheta+\vartheta)}`, then the +single-repetition logical check rejects. + +**Hypotheses (following BinaryBasefold.md, Lemma 4.25):** +- `h_no_bad_event`: The bad event at block `j*` didn't occur. +- `h_good_after`: All blocks after `j*` are compliant (maximality of `j*`). +- `h_no_bad_global`: No bad events occur at any block (for the inductive step). + +**Proof sketch (per spec):** +- **Base case (i = j*\cdot\vartheta):** `V` computes the fold inline. + Since the suffix is in the disagreement set, the folded value differs from the codeword. +- **Inductive step (i > j*\cdot\vartheta):** Disagreement propagates using no-bad-events + and compliance of subsequent blocks. +- **Final step:** The final check fails. -/ +theorem lemma_4_25_reject_if_suffix_in_disagreement + (stmtIn : FinalSumcheckStatementOut (L := L) (ℓ := ℓ)) + (oStmtIn : ∀ j, OracleStatement 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (Fin.last ℓ) j) + -- Block index j* ∈ {0, ..., ℓ/ϑ - 1} (the highest non-compliant block) + (j_star : Fin (nBlocks (ℓ := ℓ) (ϑ := ϑ))) + {destIdx : Fin r} (h_destIdx : destIdx.val = j_star.val * ϑ + ϑ) (h_destIdx_le : destIdx ≤ ℓ) + (f_next : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) destIdx) + (r_challenges : Fin ϑ → L) + (h_r_challenges : + r_challenges = + getFoldingChallenges (r := r) (𝓡 := 𝓡) (ϑ := ϑ) (i := Fin.last ℓ) + stmtIn.challenges (k := j_star.val * ϑ) (h := by + exact oracle_index_add_steps_le_ℓ (ℓ := ℓ) (ϑ := ϑ) + (i := Fin.last ℓ) (j := j_star))) + (h_f_next_shape : + (∃ hj : j_star.val + 1 < nBlocks (ℓ := ℓ) (ϑ := ϑ), + f_next = + getNextOracle 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ϑ := ϑ) + (i := Fin.last ℓ) (oStmt := oStmtIn) (j := j_star) (hj := by + simpa only [toOutCodewordsCount_last, nBlocks] using hj) + (destDomainIdx := destIdx) (h_destDomainIdx := by + simpa only using h_destIdx)) + ∨ (f_next = fun _ => stmtIn.final_constant)) + (h_no_bad_event : ¬ foldingBadEvent 𝔽q β (i := ⟨j_star.val * ϑ, by omega⟩) ϑ + h_destIdx h_destIdx_le (oStmtIn j_star) r_challenges) + (h_next_close : UDRClose 𝔽q β destIdx h_destIdx_le f_next) + -- All blocks after j* are compliant (consequence of maximality of j*) + (h_good_after : ∀ j : Fin (nBlocks (ℓ := ℓ) (ϑ := ϑ)), j_star < j → + ¬ badBlockProp 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) stmtIn oStmtIn j) + -- No bad events globally (for the inductive step at subsequent blocks) + (h_no_bad_global : ¬ blockBadEventExistsProp 𝔽q β (stmtIdx := Fin.last ℓ) + (oracleIdx := OracleFrontierIndex.mkFromStmtIdx (Fin.last ℓ)) + (oStmt := oStmtIn) (challenges := stmtIn.challenges)) + (v : sDomain 𝔽q β h_ℓ_add_R_rate 0) : + let f_star := oStmtIn j_star + let folded_f := iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + ⟨j_star.val * ϑ, by omega⟩ ϑ h_destIdx h_destIdx_le f_star (r_challenges := r_challenges) + let f_bar_next := UDRCodeword 𝔽q β destIdx h_destIdx_le + (f := f_next) (h_within_radius := h_next_close) + let v_suffix := + iteratedQuotientMap 𝔽q β h_ℓ_add_R_rate (i := (0 : Fin r)) (destIdx := destIdx) + (k := destIdx.val) + (h_destIdx := by simp only [Fin.coe_ofNat_eq_mod, zero_mod, zero_add]) + (h_destIdx_le := h_destIdx_le) (x := v) + v_suffix ∈ disagreementSet 𝔽q β (i := destIdx) (destIdx := destIdx) (h_destIdx := rfl) (f := folded_f) (g := f_bar_next) → + ¬ logical_checkSingleRepetition 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + oStmtIn v stmtIn stmtIn.final_constant := by + stop + classical + -- Proof per BinaryBasefold.md, Lemma 4.25. + -- We show: assuming all step conditions pass, the fold value at the last step + -- disagrees with `final_constant`, contradicting the final step condition. + -- Introduce the let bindings and hypotheses + intro f_star folded_f f_bar_next v_suffix h_disagree h_accept + -- Step 1: Extract the final step condition from h_accept. + -- At k = ℓ/ϑ (the terminal index), the step condition says: + -- fold(oStmt(ℓ/ϑ-1), ...)(v_ℓ) = final_constant + have h_div_pos : ℓ / ϑ > 0 := + Nat.div_pos (Nat.le_of_dvd (Nat.pos_of_neZero ℓ) hdiv.out) (Nat.pos_of_neZero ϑ) + have h_final := h_accept (⟨ℓ / ϑ, by omega⟩ : Fin (ℓ / ϑ + 1)) + unfold logical_stepCondition at h_final + split_ifs at h_final with h_absurd + · exact absurd h_absurd (lt_irrefl _) + -- Step 2: Key inductive invariant — disagreement propagates from j* to the final step. + -- For each block k from j_star to ℓ/ϑ - 1 (inclusive), conditioned on all guard checks + -- in h_accept passing, the fold value computed by the verifier at step k disagrees with + -- the closest codeword at the next level. + -- + -- At k = ℓ/ϑ - 1 (the last block), this gives: + -- fold(oStmt(ℓ/ϑ-1), ...)(v_ℓ) ≠ final_constant + -- + -- The induction proceeds as follows (per the spec): + -- Base case (k = j*): The fold value = iterated_fold(oStmt(j*), ...)(v_suffix). + -- Since v_suffix ∈ Δ(fold(oStmt(j*)), f̄), the fold value ≠ f̄(v_suffix). + -- Inductive step (k > j*): The guard check c_{k*ϑ} = oStmt(k)(v_{k*ϑ}) passes (from h_accept). + -- Combined with c_{k*ϑ} ≠ f̄^{(k)}(v_{k*ϑ}), we get oStmt(k) ≠ f̄^{(k)} at v_{k*ϑ}. + -- By ¬E_k (from h_no_bad_global), disagreement propagates through folding. + -- By compliance (from h_good_after), fold(f̄^{(k)}) = f̄^{(k+1)}. + -- So fold(oStmt(k))(v_{(k+1)*ϑ}) ≠ f̄^{(k+1)}(v_{(k+1)*ϑ}). + have h_fold_ne_const : + logical_computeFoldedValue 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + ⟨ℓ / ϑ - 1, by omega⟩ v stmtIn + (logical_queryFiberPoints 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + oStmtIn ⟨ℓ / ϑ - 1, by omega⟩ v) ≠ + stmtIn.final_constant := by + -- Base disagreement at block j*: computed fold value differs from decoded next codeword. + have h_base_ne : + logical_computeFoldedValue 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + ⟨j_star.val, by simpa [nBlocks, toOutCodewordsCount_last] using j_star.isLt⟩ v stmtIn + (logical_queryFiberPoints 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + oStmtIn ⟨j_star.val, by simpa [nBlocks, toOutCodewordsCount_last] using j_star.isLt⟩ v) + ≠ f_bar_next v_suffix := by + have h_eval_eq_fold : + logical_computeFoldedValue 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + ⟨j_star.val, by simpa [nBlocks, toOutCodewordsCount_last] using j_star.isLt⟩ v stmtIn + (logical_queryFiberPoints 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + oStmtIn ⟨j_star.val, by simpa [nBlocks, toOutCodewordsCount_last] using j_star.isLt⟩ v) + = folded_f v_suffix := by + have h_eval_eq_fold_raw := + logical_computeFoldedValue_eq_iterated_fold 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (oStmt := oStmtIn) + (k := ⟨j_star.val, by + simpa [nBlocks, toOutCodewordsCount_last] using j_star.isLt⟩) + (v := v) (stmt := stmtIn) + -- Most of the remaining work here is index-transport: + -- identifying `getChallengeSuffix` with `v_suffix`, and aligning the + -- challenge slice with `r_challenges`. + simpa [folded_f, f_star, v_suffix, h_r_challenges, h_destIdx] + using h_eval_eq_fold_raw + have h_disagree_val : folded_f v_suffix ≠ f_bar_next v_suffix := by + simpa [disagreementSet] using h_disagree + intro h_eq + exact h_disagree_val (by simpa [h_eval_eq_fold] using h_eq) + by_cases h_more : j_star.val + 1 < nBlocks (ℓ := ℓ) (ϑ := ϑ) + · -- Main propagation case (j* is not the last block). + have h_propagates_to_last : + logical_computeFoldedValue 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + ⟨ℓ / ϑ - 1, by omega⟩ v stmtIn + (logical_queryFiberPoints 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + oStmtIn ⟨ℓ / ϑ - 1, by omega⟩ v) ≠ + stmtIn.final_constant := by + -- Main induction from `j*+1` to the last block: + -- 1. At each block `j`, acceptance gives the guard equality + -- `c_{j*ϑ} = f^(j)(overlap_j)`. + -- 2. Combined with the current mismatch hypothesis + -- `c_{j*ϑ} ≠ f̄^(j)(overlap_j)`, we get + -- `f^(j)(overlap_j) ≠ f̄^(j)(overlap_j)`. + -- 3. This implies the next suffix lies in the projected fiberwise disagreement set. + -- 4. `¬ E_j` (from `h_no_bad_global`) propagates disagreement through folding. + -- 5. Compliance of all `j > j*` (`h_good_after`) identifies + -- `fold(f̄^(j)) = f̄^(j+1)`. + -- 6. Therefore mismatch persists up to `j = ℓ/ϑ - 1`, yielding final rejection. + -- + -- Remaining work is mostly dependent index transport between: + -- - logical loop indices `k : Fin (ℓ/ϑ)`, + -- - oracle block indices `j : Fin (nBlocks)`, + -- - suffix points `S^(j*ϑ)` / `S^((j+1)ϑ)`, + -- plus rewriting `getNextOracle`/`UDRCodeword` casts. + -- + -- This is intentionally isolated as the only remaining technical gap in this branch. + sorry + exact h_propagates_to_last + · -- Terminal case: j* is the last block, so base disagreement is already + -- the final-step disagreement. + have hj_lt_div : j_star.val < ℓ / ϑ := by + simpa [nBlocks, toOutCodewordsCount_last] using j_star.isLt + have hj_not_lt_succ : ¬ j_star.val + 1 < ℓ / ϑ := by + simpa [nBlocks, toOutCodewordsCount_last] using h_more + have h_k_last_eq_jstar : + (⟨ℓ / ϑ - 1, by omega⟩ : Fin (ℓ / ϑ)) = + ⟨j_star.val, by simpa [nBlocks, toOutCodewordsCount_last] using j_star.isLt⟩ := by + apply Fin.eq_of_val_eq + have h_ge : ℓ / ϑ ≤ j_star.val + 1 := Nat.le_of_not_gt hj_not_lt_succ + have h1 : ℓ / ϑ - 1 ≤ j_star.val := by omega + have h2 : j_star.val ≤ ℓ / ϑ - 1 := by omega + exact Nat.le_antisymm h1 h2 + have h_f_next_const : f_next = fun _ => stmtIn.final_constant := by + rcases h_f_next_shape with h_shape | h_const + · exact False.elim (h_more h_shape.1) + · exact h_const + have h_f_bar_next_const : f_bar_next = fun _ => stmtIn.final_constant := by + subst f_next + dsimp [f_bar_next] + simpa using + (UDRCodeword_constFunc_eq_self (𝔽q := 𝔽q) (β := β) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := destIdx) (h_i := h_destIdx_le) (c := stmtIn.final_constant)) + have h_base_ne_const : + logical_computeFoldedValue 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + ⟨j_star.val, by simpa [nBlocks, toOutCodewordsCount_last] using j_star.isLt⟩ v stmtIn + (logical_queryFiberPoints 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + oStmtIn ⟨j_star.val, by simpa [nBlocks, toOutCodewordsCount_last] using j_star.isLt⟩ v) + ≠ stmtIn.final_constant := by + intro h_eq_const + exact h_base_ne (by simpa [h_f_bar_next_const] using h_eq_const) + simpa [h_k_last_eq_jstar] using h_base_ne_const + -- Step 3: Contradiction. + exact h_fold_ne_const h_final + +open Classical in +/-- **Proposition 4.23** (Query-phase soundness, assuming no bad events). + +If any oracle is non-compliant and no bad folding event occurs, then a single +repetition of the proximity check accepts with probability at most +`(1/2) + 1/(2 * 2^𝓡)`. -/ +theorem prop_4_23_singleRepetition_proximityCheck_bound + {h_le : ϑ ≤ ℓ} + (stmtIn : FinalSumcheckStatementOut (L := L) (ℓ := ℓ)) + (oStmtIn : ∀ j, OracleStatement 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (Fin.last ℓ) j) + (h_not_consistent : ¬ finalSumcheckStepOracleConsistencyProp 𝔽q β + (h_le := h_le) (stmtOut := stmtIn) (oStmtOut := oStmtIn)) + (h_no_bad : ¬ blockBadEventExistsProp 𝔽q β (stmtIdx := Fin.last ℓ) + (oracleIdx := OracleFrontierIndex.mkFromStmtIdx (Fin.last ℓ)) + (oStmt := oStmtIn) (challenges := stmtIn.challenges)) : + Pr_{ let v ←$ᵖ (sDomain 𝔽q β h_ℓ_add_R_rate 0) }[ + logical_checkSingleRepetition 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + oStmtIn v stmtIn stmtIn.final_constant + ] ≤ ((1/2 : ℝ≥0) + (1 : ℝ≥0) / (2 * 2^𝓡)) := by + classical + -- Extract a concrete bad block from `h_not_consistent`. + have h_exists_badBlock : + ∃ j : Fin (nBlocks (ℓ := ℓ) (ϑ := ϑ)), + badBlockProp 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (stmtIn := stmtIn) (oStmtIn := oStmtIn) j := by + -- Define final-step compliance as in `badBlockProp`'s last branch. + let j_last := getLastOraclePositionIndex ℓ ϑ (Fin.last ℓ) + let k := j_last.val * ϑ + have h_k : k = ℓ - ϑ := by + dsimp [j_last, k] + simp only [getLastOraclePositionIndex_last, Nat.sub_mul, Nat.div_mul_cancel (hdiv.out), + one_mul] + have hk_add : k + ϑ = ℓ := by + simpa [h_k] using (Nat.sub_add_cancel (by omega)) + let final_compliance : Prop := + isCompliant 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := ⟨k, by + apply lt_r_of_le_ℓ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (x := k); omega⟩) + (steps := ϑ) (destIdx := ⟨k + ϑ, by + apply lt_r_of_le_ℓ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (x := k + ϑ); omega⟩) + (h_destIdx := by rfl) (h_destIdx_le := by simp only [hk_add, le_refl]) + (f_i := oStmtIn j_last) + (f_i_plus_steps := fun _ => stmtIn.final_constant) + (challenges := + getFoldingChallenges (r := r) (𝓡 := 𝓡) (ϑ := ϑ) (i := Fin.last ℓ) + stmtIn.challenges (k := k) (h := by simp only [hk_add, Fin.val_last, le_refl])) + have h_not_and : + ¬ (oracleFoldingConsistencyProp 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := Fin.last ℓ) (challenges := stmtIn.challenges) (oStmt := oStmtIn) ∧ + final_compliance) := by + simpa only [not_and, finalSumcheckStepOracleConsistencyProp, Fin.val_last] using + h_not_consistent + by_cases h_final_ok : final_compliance + · -- Final block compliant: then oracleFoldingConsistencyProp must fail. + have h_oracle_bad : + ¬ oracleFoldingConsistencyProp 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := Fin.last ℓ) (challenges := stmtIn.challenges) (oStmt := oStmtIn) := by + intro h_oracle_ok + exact h_not_and ⟨h_oracle_ok, h_final_ok⟩ + have h_oracle_bad' : + ∃ (j : Fin (nBlocks (ℓ := ℓ) (ϑ := ϑ))) + (hj : j.val + 1 < nBlocks (ℓ := ℓ) (ϑ := ϑ)), + ¬ isCompliant 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := ⟨j.val * ϑ, by + apply lt_r_of_le_ℓ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + exact (Nat.le_add_right _ _).trans + (oracle_index_add_steps_le_ℓ (ℓ := ℓ) (ϑ := ϑ) + (i := Fin.last ℓ) (j := j)) + ⟩) + (steps := ϑ) + (destIdx := ⟨j.val * ϑ + ϑ, by + apply lt_r_of_le_ℓ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + exact oracle_index_add_steps_le_ℓ (ℓ := ℓ) (ϑ := ϑ) + (i := Fin.last ℓ) (j := j) + ⟩) + (h_destIdx := by rfl) + (h_destIdx_le := by + exact oracle_index_add_steps_le_ℓ (ℓ := ℓ) (ϑ := ϑ) + (i := Fin.last ℓ) (j := j)) + (f_i := oStmtIn j) + (f_i_plus_steps := + getNextOracle 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ϑ := ϑ) + (i := Fin.last ℓ) (oStmt := oStmtIn) (j := j) (hj := by + simpa [nBlocks] using hj) + (destDomainIdx := ⟨j.val * ϑ + ϑ, by + apply lt_r_of_le_ℓ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + exact oracle_index_add_steps_le_ℓ (ℓ := ℓ) (ϑ := ϑ) + (i := Fin.last ℓ) (j := j)⟩) + (h_destDomainIdx := by rfl)) + (challenges := + getFoldingChallenges (r := r) (𝓡 := 𝓡) (ϑ := ϑ) (i := Fin.last ℓ) + stmtIn.challenges (k := j.val * ϑ) (h := by + exact oracle_index_add_steps_le_ℓ (ℓ := ℓ) (ϑ := ϑ) + (i := Fin.last ℓ) (j := j))) := by + -- Unfold oracleFoldingConsistencyProp and push the negation inside. + have h_not_forall : + ¬ (∀ (j : Fin (nBlocks (ℓ := ℓ) (ϑ := ϑ))) + (hj : j.val + 1 < nBlocks (ℓ := ℓ) (ϑ := ϑ)), + isCompliant 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := ⟨j.val * ϑ, by + apply lt_r_of_le_ℓ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + exact (Nat.le_add_right _ _).trans + (oracle_index_add_steps_le_ℓ (ℓ := ℓ) (ϑ := ϑ) + (i := Fin.last ℓ) (j := j)) + ⟩) + (steps := ϑ) + (destIdx := ⟨j.val * ϑ + ϑ, by + apply lt_r_of_le_ℓ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + exact oracle_index_add_steps_le_ℓ (ℓ := ℓ) (ϑ := ϑ) + (i := Fin.last ℓ) (j := j) + ⟩) + (h_destIdx := by rfl) + (h_destIdx_le := by + exact oracle_index_add_steps_le_ℓ (ℓ := ℓ) (ϑ := ϑ) + (i := Fin.last ℓ) (j := j)) + (f_i := oStmtIn j) + (f_i_plus_steps := + getNextOracle 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ϑ := ϑ) + (i := Fin.last ℓ) (oStmt := oStmtIn) (j := j) (hj := by + simpa [nBlocks] using hj) + (destDomainIdx := ⟨j.val * ϑ + ϑ, by + apply lt_r_of_le_ℓ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + exact oracle_index_add_steps_le_ℓ (ℓ := ℓ) (ϑ := ϑ) + (i := Fin.last ℓ) (j := j)⟩) + (h_destDomainIdx := by rfl)) + (challenges := + getFoldingChallenges (r := r) (𝓡 := 𝓡) (ϑ := ϑ) (i := Fin.last ℓ) + stmtIn.challenges (k := j.val * ϑ) (h := by + exact oracle_index_add_steps_le_ℓ (ℓ := ℓ) (ϑ := ϑ) + (i := Fin.last ℓ) (j := j)))) := by + simpa [oracleFoldingConsistencyProp, nBlocks] using h_oracle_bad + classical + push_neg at h_not_forall + exact h_not_forall + rcases h_oracle_bad' with ⟨j, hj, hbad⟩ + refine ⟨j, ?_⟩ + have hj' : j.val + 1 < ℓ / ϑ := by + simpa [nBlocks, toOutCodewordsCount_last] using hj + simpa [badBlockProp, hj', nBlocks, toOutCodewordsCount_last] using hbad + · -- Final block non-compliant: take the last block. + refine ⟨j_last, ?_⟩ + have h_no_succ : + ¬ j_last.val + 1 < nBlocks (ℓ := ℓ) (ϑ := ϑ) := by + have h_div_pos : 0 < ℓ / ϑ := + Nat.div_pos (Nat.le_of_dvd (Nat.pos_of_neZero ℓ) hdiv.out) (Nat.pos_of_neZero ϑ) + have h_div_pos' : 1 ≤ ℓ / ϑ := Nat.succ_le_iff.mpr h_div_pos + have h_eq : j_last.val + 1 = nBlocks (ℓ := ℓ) (ϑ := ϑ) := by + simp [j_last, nBlocks, getLastOraclePositionIndex_last, + toOutCodewordsCount_last, Nat.sub_add_cancel h_div_pos'] + exact not_lt_of_ge (by simp only [toOutCodewordsCount_last, h_eq, le_refl]) + have h_no_succ' : ¬ j_last.val + 1 < ℓ / ϑ := by + simpa [nBlocks, toOutCodewordsCount_last] using h_no_succ + simpa [badBlockProp, h_no_succ', final_compliance, nBlocks, toOutCodewordsCount_last] + using h_final_ok + -- Pick the highest bad block. + let j_star := + highestBadBlock 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (stmtIn := stmtIn) (oStmtIn := oStmtIn) h_exists_badBlock + have h_j_star_bad : + badBlockProp 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (stmtIn := stmtIn) (oStmtIn := oStmtIn) j_star := by + simpa using + (highestBadBlock_is_bad 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (stmtIn := stmtIn) (oStmtIn := oStmtIn) h_exists_badBlock) + have h_good_of_lt {j : Fin (nBlocks (ℓ := ℓ) (ϑ := ϑ))} (hlt : j_star < j) : + ¬ badBlockProp 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (stmtIn := stmtIn) (oStmtIn := oStmtIn) j := + not_badBlock_of_lt_highest 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (stmtIn := stmtIn) (oStmtIn := oStmtIn) h_exists_badBlock hlt + -- Indices for the chosen bad block. + set i_star : Fin ℓ := ⟨j_star.val * ϑ, by + simp only [(oracle_block_k_bound (ℓ := ℓ) (ϑ := ϑ) (i := Fin.last ℓ) (j := j_star))]⟩ + set destIdx : Fin r := ⟨j_star.val * ϑ + ϑ, by + apply lt_r_of_le_ℓ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + simp only [(oracle_index_add_steps_le_ℓ (ℓ := ℓ) (ϑ := ϑ) (i := Fin.last ℓ) (j := j_star))]⟩ + have h_destIdx : destIdx.val = i_star.val + ϑ := by + simp [i_star, destIdx] + have h_destIdx_le : destIdx ≤ ℓ := by + simpa [destIdx] using + (oracle_index_add_steps_le_ℓ (ℓ := ℓ) (ϑ := ϑ) (i := Fin.last ℓ) (j := j_star)) + let f_star := oStmtIn j_star + let r_challenges : Fin ϑ → L := + getFoldingChallenges (r := r) (𝓡 := 𝓡) (ϑ := ϑ) (i := Fin.last ℓ) + stmtIn.challenges (k := j_star.val * ϑ) + (h := by + exact oracle_index_add_steps_le_ℓ (ℓ := ℓ) (ϑ := ϑ) (i := Fin.last ℓ) (j := j_star)) + -- No bad event at the chosen block. + have h_no_bad_event : + ¬ foldingBadEvent 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := ⟨i_star, by omega⟩) ϑ (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) + (f_i := f_star) (r_challenges := r_challenges) := by + intro h_bad + apply h_no_bad + refine ⟨j_star, ?_⟩ + have h_branch : + (oraclePositionToDomainIndex (positionIdx := j_star)).val + ϑ ≤ (Fin.last ℓ).val := by + simp only [Fin.val_last, + (oracle_index_add_steps_le_ℓ (ℓ := ℓ) (ϑ := ϑ) (i := Fin.last ℓ) (j := j_star))] + simpa [foldingBadEventAtBlock, h_branch, r_challenges, i_star, destIdx] using h_bad + -- Choose `f_next` and extract compliance/UDR-close facts. + have h_choose : + ∃ f_next : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) destIdx, + (¬ isCompliant 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := ⟨i_star, by omega⟩) ϑ (destIdx := destIdx) + (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) + (f_i := f_star) (f_i_plus_steps := f_next) (challenges := r_challenges)) ∧ + UDRClose 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + destIdx h_destIdx_le f_next ∧ + ((∃ hj : j_star.val + 1 < nBlocks (ℓ := ℓ) (ϑ := ϑ), + f_next = + getNextOracle 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ϑ := ϑ) + (i := Fin.last ℓ) (oStmt := oStmtIn) (j := j_star) (hj := by + simpa [nBlocks, toOutCodewordsCount_last] using hj) + (destDomainIdx := destIdx) (h_destDomainIdx := by + simpa using h_destIdx)) + ∨ (f_next = fun _ => stmtIn.final_constant)) := by + by_cases h_last : j_star.val + 1 < nBlocks (ℓ := ℓ) (ϑ := ϑ) + · -- Intermediate bad block. + let f_next : + OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) destIdx := + getNextOracle 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ϑ := ϑ) + (i := Fin.last ℓ) (oStmt := oStmtIn) (j := j_star) (hj := by + simpa [nBlocks] using h_last) + (destDomainIdx := destIdx) (h_destDomainIdx := by rfl) + have h_not_compliant : + ¬ isCompliant 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := ⟨i_star, by omega⟩) ϑ (destIdx := destIdx) + (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) + (f_i := f_star) (f_i_plus_steps := f_next) (challenges := r_challenges) := by + have h_last' : j_star.val + 1 < ℓ / ϑ := by + simpa [nBlocks, toOutCodewordsCount_last] using h_last + simpa [badBlockProp, h_last', nBlocks, toOutCodewordsCount_last, i_star, destIdx] + using h_j_star_bad + let j_next : Fin (nBlocks (ℓ := ℓ) (ϑ := ϑ)) := ⟨j_star.val + 1, h_last⟩ + have h_next_good : + ¬ badBlockProp 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (stmtIn := stmtIn) (oStmtIn := oStmtIn) j_next := by + have hlt : j_star < j_next := by + exact Fin.lt_iff_val_lt_val.mpr (by simp [j_next]) + exact h_good_of_lt (j := j_next) hlt + have h_next_close : + UDRClose 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + destIdx h_destIdx_le f_next := by + have h_j_next_mul_lt_r : ↑j_next * ϑ < r := by + have : ↑j_next * ϑ = destIdx.val := by simp only [j_next, destIdx]; ring + omega + have h_idx : (⟨↑j_next * ϑ, h_j_next_mul_lt_r⟩ : Fin r) = destIdx := by + apply Fin.ext; simp only [j_next, destIdx]; ring + have h_udr := goodBlock_implies_UDRClose 𝔽q β stmtIn oStmtIn j_next h_next_good + (h_idx := h_idx) (h_le := h_destIdx_le) + exact h_udr + exact ⟨f_next, h_not_compliant, h_next_close, Or.inl ⟨h_last, rfl⟩⟩ + · -- Final bad block: `f_next` is the constant oracle. + let f_next : + OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) destIdx := + fun _ => stmtIn.final_constant + have h_not_compliant : + ¬ isCompliant 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := ⟨i_star, by omega⟩) ϑ (destIdx := destIdx) + (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) + (f_i := f_star) (f_i_plus_steps := f_next) (challenges := r_challenges) := by + -- Reduce `badBlockProp` to its final-block branch, then rewrite `j_last` to `j_star`. + have h_j_star_last : j_star = getLastOraclePositionIndex ℓ ϑ (Fin.last ℓ) := by + apply Fin.ext + have h_val : j_star.val = (nBlocks (ℓ := ℓ) (ϑ := ϑ)) - 1 := by + have h_lt : j_star.val < (nBlocks (ℓ := ℓ) (ϑ := ϑ)) := j_star.isLt + have h_ge : (nBlocks (ℓ := ℓ) (ϑ := ϑ)) ≤ j_star.val + 1 := by + exact Nat.le_of_not_gt h_last + omega + simp [getLastOraclePositionIndex, nBlocks, h_val] + have h_no_succ' : ¬ j_star.val + 1 < ℓ / ϑ := by + simp only [nBlocks, toOutCodewordsCount_last] at h_last + exact h_last + let j_last : Fin (nBlocks (ℓ := ℓ) (ϑ := ϑ)) := + getLastOraclePositionIndex ℓ ϑ (Fin.last ℓ) + have h_j_star_bad' : + ¬ isCompliant 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := ⟨j_last.val * ϑ, by + apply lt_r_of_le_ℓ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + exact (Nat.le_add_right _ _).trans + (oracle_index_add_steps_le_ℓ (ℓ := ℓ) (ϑ := ϑ) + (i := Fin.last ℓ) (j := j_last))⟩) + (steps := ϑ) + (destIdx := ⟨j_last.val * ϑ + ϑ, by + apply lt_r_of_le_ℓ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + exact oracle_index_add_steps_le_ℓ (ℓ := ℓ) (ϑ := ϑ) + (i := Fin.last ℓ) (j := j_last)⟩) + (h_destIdx := by rfl) + (h_destIdx_le := by + exact oracle_index_add_steps_le_ℓ (ℓ := ℓ) (ϑ := ϑ) + (i := Fin.last ℓ) (j := j_last)) + (f_i := oStmtIn j_last) + (f_i_plus_steps := fun _ => stmtIn.final_constant) + (challenges := + getFoldingChallenges (r := r) (𝓡 := 𝓡) (ϑ := ϑ) (i := Fin.last ℓ) + stmtIn.challenges + (k := j_last.val * ϑ) (h := by + exact oracle_index_add_steps_le_ℓ (ℓ := ℓ) (ϑ := ϑ) + (i := Fin.last ℓ) (j := j_last))) := by + simp only [badBlockProp, h_no_succ', nBlocks, toOutCodewordsCount_last] at h_j_star_bad + exact h_j_star_bad + have h_j_last_to_star : j_last = j_star := by + simp only [j_last] at h_j_star_last ⊢ + exact h_j_star_last.symm + have h_j_star_bad'' : ¬ isCompliant 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := ⟨j_star.val * ϑ, by + apply lt_r_of_le_ℓ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + exact (Nat.le_add_right _ _).trans + (oracle_index_add_steps_le_ℓ (ℓ := ℓ) (ϑ := ϑ) + (i := Fin.last ℓ) (j := j_star))⟩) + (steps := ϑ) + (destIdx := ⟨j_star.val * ϑ + ϑ, by + apply lt_r_of_le_ℓ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + exact oracle_index_add_steps_le_ℓ (ℓ := ℓ) (ϑ := ϑ) + (i := Fin.last ℓ) (j := j_star)⟩) + (h_destIdx := by rfl) + (h_destIdx_le := by + exact oracle_index_add_steps_le_ℓ (ℓ := ℓ) (ϑ := ϑ) + (i := Fin.last ℓ) (j := j_star)) + (f_i := oStmtIn j_star) + (f_i_plus_steps := fun _ => stmtIn.final_constant) + (challenges := + getFoldingChallenges (r := r) (𝓡 := 𝓡) (ϑ := ϑ) (i := Fin.last ℓ) + stmtIn.challenges (k := j_star.val * ϑ) (h := by + exact oracle_index_add_steps_le_ℓ (ℓ := ℓ) (ϑ := ϑ) + (i := Fin.last ℓ) (j := j_star))) := by + have h := h_j_star_bad' + rw [h_j_last_to_star] at h + exact h + simp only [i_star, destIdx, f_star, f_next, r_challenges] at h_j_star_bad'' ⊢ + exact h_j_star_bad'' + have h_next_close : + UDRClose 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + destIdx h_destIdx_le f_next := by + exact + (constFunc_UDRClose (𝔽q := 𝔽q) (β := β) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := destIdx) + (h_i := h_destIdx_le) (c := stmtIn.final_constant)) + exact ⟨f_next, h_not_compliant, h_next_close, Or.inr rfl⟩ + rcases h_choose with ⟨f_next, h_not_compliant, h_next_close, h_f_next_shape⟩ + -- Apply Lemma 4.24: folded function is far from the decoded next codeword. + have h_not_pair : + let f_i_star_folded := + iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + ⟨i_star, by omega⟩ ϑ h_destIdx h_destIdx_le f_star (r_challenges := r_challenges) + let f_bar_next := UDRCodeword 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + destIdx h_destIdx_le (f := f_next) (h_within_radius := h_next_close) + ¬ pair_UDRClose 𝔽q β destIdx h_destIdx_le f_i_star_folded f_bar_next := by + exact + lemma_4_24_dist_folded_ge_of_last_noncompliant (𝔽q := 𝔽q) (β := β) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i_star := i_star) (steps := ϑ) + (destIdx := destIdx) (h_destIdx := h_destIdx) (h_destIdx_le := h_destIdx_le) + (f_star := f_star) (f_next := f_next) (r_challenges := r_challenges) + h_not_compliant h_no_bad_event h_next_close + -- Disagreement set between folded oracle and decoded next codeword. + let folded_f := iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + ⟨i_star, by omega⟩ ϑ h_destIdx h_destIdx_le f_star (r_challenges := r_challenges) + let f_bar_next := UDRCodeword 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + destIdx h_destIdx_le (f := f_next) (h_within_radius := h_next_close) + let D := + disagreementSet 𝔽q β (i := destIdx) (destIdx := destIdx) (h_destIdx := rfl) (f := folded_f) (g := f_bar_next) + -- From `¬ pair_UDRClose`, derive a lower bound on |D|. + have h_dist_ge : + BBF_CodeDistance 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) destIdx ≤ + 2 * D.card := by + have h' : + BBF_CodeDistance 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) destIdx ≤ + 2 * Δ₀(folded_f, f_bar_next) := by + have h'' := not_lt.mp h_not_pair + -- f_bar_next = UDRCodeword ... = Classical.choose ..., so types should match + exact h'' + simp only [D, disagreementSet, hammingDist] at h' ⊢ + exact h' + -- Acceptance implies the suffix is NOT in the disagreement set. + have h_accept_subset : + ∀ v : sDomain 𝔽q β h_ℓ_add_R_rate 0, + logical_checkSingleRepetition 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + oStmtIn v stmtIn stmtIn.final_constant → + extractSuffixFromChallenge 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (v := v) (destIdx := destIdx) (h_destIdx_le := h_destIdx_le) ∉ D := by + intro v h_accept h_mem + have h_reject := + lemma_4_25_reject_if_suffix_in_disagreement (𝔽q := 𝔽q) (β := β) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (stmtIn := stmtIn) (oStmtIn := oStmtIn) + (j_star := j_star) (destIdx := destIdx) (h_destIdx := by simp [destIdx]) + (h_destIdx_le := h_destIdx_le) (f_next := f_next) + (r_challenges := r_challenges) (h_r_challenges := by rfl) + (h_f_next_shape := h_f_next_shape) + (h_no_bad_event := by + simp only [i_star, destIdx] at h_no_bad_event; exact h_no_bad_event) + (h_next_close := h_next_close) + (h_good_after := fun j hlt => h_good_of_lt hlt) + (h_no_bad_global := h_no_bad) (v := v) + exact h_reject (by + simp only [UDRCodeword, SetLike.mem_coe, uniqueDecodingRadius_eq_floor_div_2, and_imp, D, + folded_f, f_bar_next] at h_mem ⊢ + exact h_mem) h_accept + -- Probability bound via monotonicity. + have h_prob_accept_le : + Pr_{ let v ←$ᵖ (sDomain 𝔽q β h_ℓ_add_R_rate 0) }[ + logical_checkSingleRepetition 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + oStmtIn v stmtIn stmtIn.final_constant + ] ≤ + Pr_{ let v ←$ᵖ (sDomain 𝔽q β h_ℓ_add_R_rate 0) }[ + extractSuffixFromChallenge 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (v := v) (destIdx := destIdx) (h_destIdx_le := h_destIdx_le) ∉ D + ] := by + apply prob_mono + exact h_accept_subset + -- Evaluate the suffix probability for the complement set. + have h_prob_suffix_not : + Pr_{ let v ←$ᵖ (sDomain 𝔽q β h_ℓ_add_R_rate 0) }[ + extractSuffixFromChallenge 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (v := v) (destIdx := destIdx) (h_destIdx_le := h_destIdx_le) ∉ D + ] = + ((Dᶜ).card : ENNReal) / + Fintype.card (sDomain 𝔽q β h_ℓ_add_R_rate destIdx) := by + have h := + prob_uniform_suffix_mem (𝔽q := 𝔽q) (β := β) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (destIdx := destIdx) (h_destIdx_le := h_destIdx_le) (D := Dᶜ) + simp only [Finset.mem_compl, D] at h ⊢ + exact h + -- Bound the complement probability using the distance bound. + have h_prob_bound : + ((Dᶜ).card : ENNReal) / + Fintype.card (sDomain 𝔽q β h_ℓ_add_R_rate destIdx) + ≤ ((1/2 : ℝ≥0) + (1 : ℝ≥0) / (2 * 2^𝓡)) := by + -- Set up notation. + let Sdest := sDomain 𝔽q β h_ℓ_add_R_rate destIdx + let n : ℕ := Fintype.card Sdest + have h_card_Sdest : + n = 2 ^ (ℓ + 𝓡 - destIdx.val) := by + have h := (sDomain_card 𝔽q β h_ℓ_add_R_rate (i := destIdx) + (h_i := Sdomain_bound (by omega))) + simp only [n, hF₂.out] at h ⊢ + exact h + have h_exp : + ℓ + 𝓡 - destIdx.val = 𝓡 + (ℓ - destIdx.val) := by + have h_le : destIdx.val ≤ ℓ := by exact h_destIdx_le + calc + ℓ + 𝓡 - destIdx.val = 𝓡 + ℓ - destIdx.val := by omega + _ = 𝓡 + (ℓ - destIdx.val) := by + exact Nat.add_sub_assoc h_destIdx_le 𝓡 + have h_n_div : + n / 2 ^ 𝓡 = 2 ^ (ℓ - destIdx.val) := by + have h_pos : 0 < 2 ^ 𝓡 := by + exact pow_pos (by decide : 0 < (2 : ℕ)) _ + calc + n / 2 ^ 𝓡 + = (2 ^ (ℓ + 𝓡 - destIdx.val)) / 2 ^ 𝓡 := by simp [h_card_Sdest] + _ = (2 ^ (𝓡 + (ℓ - destIdx.val))) / 2 ^ 𝓡 := by + simp [h_exp] + _ = (2 ^ 𝓡 * 2 ^ (ℓ - destIdx.val)) / 2 ^ 𝓡 := by + simp only [pow_add, ne_eq, Nat.pow_eq_zero, OfNat.ofNat_ne_zero, false_and, + not_false_eq_true, mul_div_cancel_left₀] + _ = 2 ^ (ℓ - destIdx.val) := by + have h := Nat.mul_div_left (2 ^ (ℓ - destIdx.val)) h_pos + simp only [Nat.mul_comm] at h ⊢ + exact h + have h_d_next : + BBF_CodeDistance 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) destIdx = + n - n / 2 ^ 𝓡 + 1 := by + have h_d := + BBF_CodeDistance_eq (𝔽q := 𝔽q) (β := β) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := destIdx) (h_i := h_destIdx_le) + rw [h_n_div, h_card_Sdest] + exact h_d + have h_Dcomp_nat : + 2 * (Dᶜ).card ≤ n + n / 2 ^ 𝓡 := by + have h_card_compl : + (Dᶜ).card = n - D.card := by + have h := Finset.card_compl (s := D) + simp only [Sdest, n] at h ⊢ + exact h + have h1 : + 2 * (Dᶜ).card = 2 * n - 2 * D.card := by + simp only [h_card_compl, Nat.mul_sub_left_distrib] + have h2 : + 2 * n - 2 * D.card ≤ 2 * n - + BBF_CodeDistance 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) destIdx := by + exact Nat.sub_le_sub_left h_dist_ge _ + have h3 : + 2 * n - + BBF_CodeDistance 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) destIdx ≤ + n + n / 2 ^ 𝓡 := by + simp [h_d_next]; omega + exact le_trans (by + simp only [h1] at h2 ⊢ + exact h2) h3 + have h_n_pos : (n : ENNReal) ≠ 0 := by + -- exact_mod_cast (pow_ne_zero (ℓ + 𝓡 - destIdx.val) (by decide : (2 : ℕ) ≠ 0)) + simp [h_card_Sdest] + have h_n_fin : (n : ENNReal) ≠ ⊤ := by simp + have h_Dcomp_ennreal : + (2 * (Dᶜ).card : ENNReal) ≤ + (n : ENNReal) + ((n / 2 ^ 𝓡 : ℕ) : ENNReal) := by + exact_mod_cast h_Dcomp_nat + have h_div_cast : + ((n / 2 ^ 𝓡 : ℕ) : ENNReal) = + (n : ENNReal) / (2 ^ 𝓡 : ENNReal) := by + have h_dvd : (2 ^ 𝓡) ∣ n := by + refine ⟨2 ^ (ℓ - destIdx.val), ?_⟩ + calc + n = 2 ^ (ℓ + 𝓡 - destIdx.val) := h_card_Sdest + _ = 2 ^ (𝓡 + (ℓ - destIdx.val)) := by simp [h_exp] + _ = 2 ^ 𝓡 * 2 ^ (ℓ - destIdx.val) := by + simp only [pow_add] + have h_pos : (2 ^ 𝓡 : ENNReal) ≠ 0 := by + exact_mod_cast (pow_ne_zero 𝓡 (by decide : (2 : ℕ) ≠ 0)) + have h_pos_nn : (2 ^ 𝓡 : NNReal) ≠ 0 := by + exact_mod_cast (pow_ne_zero 𝓡 (by decide : (2 : ℕ) ≠ 0)) + have h_div_nn : ((n / 2 ^ 𝓡 : ℕ) : NNReal) = (n : NNReal) / (2 ^ 𝓡 : NNReal) := by + have h := Nat.cast_div (K := NNReal) h_dvd (by + simp only [cast_pow, cast_ofNat, ne_eq, h_pos_nn, not_false_eq_true]) + simp only [cast_pow, cast_ofNat] at h + exact h + simpa [ENNReal.coe_div h_pos_nn] using (congr_arg (ENNReal.ofNNReal) h_div_nn) + have h_Dcomp_ennreal' : + (2 * (Dᶜ).card : ENNReal) ≤ + (n : ENNReal) + (n : ENNReal) / (2 ^ 𝓡 : ENNReal) := by + simpa [h_div_cast] using h_Dcomp_ennreal + have h_step : + ((Dᶜ).card : ENNReal) ≤ + ((2 : ENNReal)⁻¹ * ((n : ENNReal) + (n : ENNReal) / (2 ^ 𝓡 : ENNReal))) := by + rw [← ENNReal.mul_le_iff_le_inv (by simp) (by simp)] + simpa [mul_comm] using h_Dcomp_ennreal' + apply (ENNReal.div_le_iff h_n_pos h_n_fin).2 + have h_rhs : + ((2 : ENNReal)⁻¹ * ((n : ENNReal) + (n : ENNReal) / (2 ^ 𝓡 : ENNReal))) = + ((1/2 : ℝ≥0) + (1 : ℝ≥0) / (2 * 2 ^ 𝓡)) * (n : ENNReal) := by + have h_inv : (2 * 2 ^ 𝓡 : ENNReal)⁻¹ = 2⁻¹ * (2 ^ 𝓡 : ENNReal)⁻¹ := by + apply ENNReal.mul_inv (Or.inl (by simp)) (Or.inl (by simp)) + simp [mul_add, add_mul, h_inv, div_eq_mul_inv, mul_assoc, mul_left_comm, mul_comm] + simpa [h_rhs] using h_step + have h_prob_suffix_not' : + Pr_{ let v ←$ᵖ (sDomain 𝔽q β h_ℓ_add_R_rate 0) }[ + extractSuffixFromChallenge 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (v := v) (destIdx := destIdx) (h_destIdx_le := h_destIdx_le) ∉ D + ] ≤ ((1/2 : ℝ≥0) + (1 : ℝ≥0) / (2 * 2^𝓡)) := by + rw [h_prob_suffix_not] + exact h_prob_bound + exact le_trans h_prob_accept_le h_prob_suffix_not' + +end QueryPhaseSoundnessStatements + +end +end Binius.BinaryBasefold + +#min_imports diff --git a/ArkLib/ProofSystem/Binius/BinaryBasefold/Spec.lean b/ArkLib/ProofSystem/Binius/BinaryBasefold/Spec.lean index 13d5e1ea3..306b04994 100644 --- a/ArkLib/ProofSystem/Binius/BinaryBasefold/Spec.lean +++ b/ArkLib/ProofSystem/Binius/BinaryBasefold/Spec.lean @@ -292,9 +292,12 @@ def fullPSpec := (pSpecCoreInteraction 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ /-! ## Oracle Interface instances for Messages-/ -instance : ∀ j, OracleInterface ((pSpecFold (L:=L)).Message j) -- this cover .Message only - | ⟨0, h⟩ => by exact OracleInterface.instDefault -- h_i(X) polynomial - | ⟨1, _⟩ => by exact OracleInterface.instDefault -- challenge r'_i +instance instOracleInterfaceMessagePSpecFold : + ∀ j, OracleInterface ((pSpecFold (L:=L)).Message j) := + fun _ => OracleInterface.instDefault + +instance : ∀ j, OracleInterface ((pSpecFold (L:=L)).Challenge j) := + fun _ => OracleInterface.instDefault instance : ∀ j, OracleInterface ((pSpecFold (L := L)).Challenge j) := ProtocolSpec.challengeOracleInterface @@ -452,6 +455,11 @@ instance : ∀ i, SampleableType ((pSpecQuery 𝔽q β γ_repetitions instance : ∀ j, SampleableType ((fullPSpec 𝔽q β γ_repetitions (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate)).Challenge j) := instSampleableTypeChallengeAppend +instance : SampleableType (Fin γ_repetitions → ↥(sDomain 𝔽q β h_ℓ_add_R_rate 0)) := by + let res := instSDomain 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := 0) (h_i := by + apply Nat.lt_add_of_pos_right_of_le; simp only [Fin.coe_ofNat_eq_mod, Nat.zero_mod, zero_le]) + exact instSampleableTypeFinFunc + /-! ## Additional OracleInterface and Fintype instances -/ /-- OracleInterface instance for the matrix-indexed message type family using instDefault. -/ @@ -510,7 +518,7 @@ The response types are the polynomial and field element themselves, both finite and inhabited. -/ instance : ([(pSpecFold (L:=L)).Message]ₒ).Fintype := by sorry -instance instOracleStatementFiniteRange {i : Fin (ℓ + 1)} : +instance instOracleStatementFintype {i : Fin (ℓ + 1)} : [OracleStatement 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i]ₒ.Fintype := by sorry instance instFintypePSpecFinalSumcheck_AllChallenges: ∀ i, Fintype ((pSpecFinalSumcheckStep (L:=L)).Challenge i) := by sorry @@ -518,6 +526,8 @@ instance instFintypePSpecFinalSumcheck_AllChallenges: ∀ i, Fintype ((pSpecFina instance instFintypePSpecFinalSumcheckStepChallenge : [pSpecFinalSumcheckStep (L := L).Challenge]ₒ.Fintype := by sorry +instance : Fintype (Fin γ_repetitions → ↥(sDomain 𝔽q β h_ℓ_add_R_rate 0)) := by + infer_instance instance instInhabitedPSpecFinalSumcheckStepChallenge : [(pSpecFinalSumcheckStep (L:=L)).Challenge]ₒ.Inhabited := by sorry diff --git a/ArkLib/ProofSystem/Binius/BinaryBasefold/Steps.lean b/ArkLib/ProofSystem/Binius/BinaryBasefold/Steps.lean index 12b65409c..c739a62b4 100644 --- a/ArkLib/ProofSystem/Binius/BinaryBasefold/Steps.lean +++ b/ArkLib/ProofSystem/Binius/BinaryBasefold/Steps.lean @@ -7,10 +7,8 @@ import ArkLib.ProofSystem.Binius.BinaryBasefold.ReductionLogic import ArkLib.ToVCVio.Oracle import ArkLib.ToVCVio.Simulation import ArkLib.OracleReduction.Completeness +import ArkLib.ProofSystem.Binius.BinaryBasefold.Soundness -set_option maxHeartbeats 400000 -- Increase if needed -set_option profiler true -set_option profiler.threshold 20 -- Show anything taking over 10ms namespace Binius.BinaryBasefold.CoreInteraction /-! ## Binary Basefold single steps @@ -33,7 +31,7 @@ namespace Binius.BinaryBasefold.CoreInteraction noncomputable section open OracleSpec OracleComp ProtocolSpec Finset AdditiveNTT Polynomial MvPolynomial open Binius.BinaryBasefold -open scoped NNReal +open scoped NNReal ProbabilityTheory variable {r : ℕ} [NeZero r] variable {L : Type} [Field L] [Fintype L] [DecidableEq L] [CharP L 2] @@ -50,9 +48,10 @@ variable [hdiv : Fact (ϑ ∣ ℓ)] section SingleIteratedSteps variable {Context : Type} {mp : SumcheckMultiplierParam L ℓ Context} -- Sumcheck context + section FoldStep -/-- The prover for the `i`-th round of Binary Foldfold. -/ +/-! The prover for the `i`-th round of Binary Foldfold. -/ noncomputable def foldOracleProver (i : Fin ℓ) : OracleProver (oSpec := []ₒ) -- current round @@ -64,11 +63,8 @@ noncomputable def foldOracleProver (i : Fin ℓ) : (OStmtOut := OracleStatement 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i.castSucc) (WitOut := Witness (L := L) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ := ℓ) i.succ) (pSpec := pSpecFold (L := L)) where - PrvState := foldPrvState 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i - input := fun ⟨⟨stmt, oStmt⟩, wit⟩ => (stmt, oStmt, wit) - sendMessage -- There are either 2 or 3 messages in the pSpec depending on commitment rounds | ⟨0, _⟩ => fun ⟨stmt, oStmt, wit⟩ => do -- USE THE SHARED KERNEL (Guarantees match with foldStepLogic) @@ -77,13 +73,11 @@ noncomputable def foldOracleProver (i : Fin ℓ) : -- Return message and update state pure ⟨h_i, (stmt, oStmt, wit, h_i)⟩ | ⟨1, _⟩ => by contradiction - receiveChallenge | ⟨0, h⟩ => nomatch h -- i.e. contradiction | ⟨1, _⟩ => fun ⟨stmt, oStmt, wit, h_i⟩ => do pure (fun r_i' => (stmt, oStmt, wit, h_i, r_i')) -- | ⟨2, h⟩ => nomatch h -- no challenge after third message - -- output : PrvState → StmtOut × (∀i, OracleStatement i) × WitOut output := fun finalPrvState => let (stmt, oStmt, wit, h_i, r_i') := finalPrvState @@ -92,8 +86,8 @@ noncomputable def foldOracleProver (i : Fin ℓ) : pure ((foldStepLogic 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) (mp := mp) i).proverOut stmt wit oStmt t) +/-! The oracle verifier for the `i`-th round of Binary Foldfold. -/ open Classical in -/-- The oracle verifier for the `i`-th round of Binary Foldfold. -/ def foldOracleVerifier (i : Fin ℓ) : OracleVerifier (oSpec := []ₒ) @@ -106,7 +100,6 @@ def foldOracleVerifier (i : Fin ℓ) : (OStmtOut := OracleStatement 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i.castSucc) (pSpec := pSpecFold (L := L)) where - -- The core verification logic. Takes the input statement `stmtIn` and the transcript, and -- performs an oracle computation that outputs a new statement verify := fun stmtIn pSpecChallenges => do @@ -117,14 +110,13 @@ def foldOracleVerifier (i : Fin ℓ) : (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) (mp := mp) i) guard (logic.verifierCheck stmtIn t) pure (logic.verifierOut stmtIn t) - -- Reuse embed and hEq from foldStepLogic to ensure consistency embed := (foldStepLogic 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) (mp := mp) i).embed hEq := (foldStepLogic 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) (mp := mp) i).hEq -/-- The oracle reduction that is the `i`-th round of Binary Foldfold. -/ +/-! The oracle reduction that is the `i`-th round of Binary Foldfold. -/ noncomputable def foldOracleReduction (i : Fin ℓ) : OracleReduction (oSpec := []ₒ) (StmtIn := Statement (L := L) Context i.castSucc) @@ -144,7 +136,7 @@ variable {R : Type} [CommSemiring R] [DecidableEq R] [SampleableType R] {n : ℕ} {deg : ℕ} {m : ℕ} {D : Fin m ↪ R} variable {σ : Type} {init : ProbComp σ} {impl : QueryImpl []ₒ (StateT σ ProbComp)} -/-- Simplifies membership in a conditional singleton set. +/-! Simplifies membership in a conditional singleton set. `x ∈ (if c then {a} else {b})` is equivalent to `x = (if c then a else b)`. -/ lemma mem_ite_singleton {α : Type*} {c : Prop} [Decidable c] {a b x : α} : @@ -172,6 +164,7 @@ always succeeds (with probability 1) and produces valid outputs. - Apply the logic properties to complete the proof -/ open Classical in +omit [DecidableEq 𝔽q] in theorem foldOracleReduction_perfectCompleteness (hInit : NeverFail init) (i : Fin ℓ) : OracleReduction.perfectCompleteness @@ -184,6 +177,7 @@ theorem foldOracleReduction_perfectCompleteness (hInit : NeverFail init) (i : Fi (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) (mp := mp) i) (init := init) (impl := impl) := by + classical -- Step 1: Unroll the 2-message reduction to convert from probability to logic -- **NOTE**: this requires `ProtocolSpec.challengeOracleInterface` to avoid conflict rw [OracleReduction.unroll_2_message_reduction_perfectCompleteness (oSpec := []ₒ) @@ -200,7 +194,6 @@ theorem foldOracleReduction_perfectCompleteness (hInit : NeverFail init) (i : Fi let step := (foldStepLogic 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) (mp := mp) i) let strongly_complete : step.IsStronglyComplete := foldStep_is_logic_complete (L := L) 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) (mp := mp) (i := i) - -- Step 4: Split into safety and correctness goals refine ⟨?_, ?_⟩ -- GOAL 1: SAFETY - Prove the verifier never crashes ([⊥|...] = 0) @@ -218,7 +211,6 @@ theorem foldOracleReduction_perfectCompleteness (hInit : NeverFail init) (i : Fi simp only [ChallengeIdx, Challenge, Fin.isValue, Matrix.cons_val_one, Matrix.cons_val_zero, liftComp_eq_liftM, OptionT.probFailure_lift, HasEvalPMF.probFailure_eq_zero] rw [true_and] - intro r_i' h_r_i'_mem_query_1_support conv => enter [1]; @@ -228,7 +220,6 @@ theorem foldOracleReduction_perfectCompleteness (hInit : NeverFail init) (i : Fi Fin.succ_one_eq_two, Message, Fin.succ_zero_eq_one, Fin.castSucc_one, liftComp_eq_liftM, OptionT.probFailure_lift, HasEvalPMF.probFailure_eq_zero] rw [true_and] - intro h_receive_challenge_fn h_receive_challenge_fn_mem_support conv => enter [1]; @@ -264,7 +255,6 @@ theorem foldOracleReduction_perfectCompleteness (hInit : NeverFail init) (i : Fi conv at h_vStmtOut_mem_support => erw [simulateQ_bind] -- turn the simulated oracle query into OracleInterface.answer form - rw [OptionT.simulateQ_simOracle2_liftM_query_T2] change vStmtOut ∈ (Bind.bind (m := (OracleComp []ₒ)) _ _).support erw [_root_.bind_pure_simulateQ_comp] simp only [Matrix.cons_val_zero, guard_eq] @@ -328,9 +318,7 @@ theorem foldOracleReduction_perfectCompleteness (hInit : NeverFail init) (i : Fi Fin.reduceLast, MessageIdx, Message, exists_eq_left] at hx_mem_support -- Step 2b: Extract the challenge r1 and the trace equations obtain ⟨r1, ⟨_h_r1_mem_challenge_support, h_trace_support⟩⟩ := hx_mem_support - rcases h_trace_support with ⟨prvOut_eq, h_verOut_mem_support⟩ - -- Step 2c: Simplify the verifier computation conv at h_verOut_mem_support => erw [simulateQ_bind] @@ -363,7 +351,6 @@ theorem foldOracleReduction_perfectCompleteness (hInit : NeverFail init) (i : Fi exact hj_ne hj | 1 => exact r1 ) - have h_V_check_is_true : V_check := h_V_check simp only [h_V_check_is_true, ↓reduceIte, Fin.isValue, pure_bind] at h_verOut_mem_support erw [simulateQ_pure, liftM_pure] at h_verOut_mem_support @@ -372,9 +359,7 @@ theorem foldOracleReduction_perfectCompleteness (hInit : NeverFail init) (i : Fi rcases h_verOut_mem_support with ⟨verStmtOut_eq, verOStmtOut_eq⟩ dsimp only [foldStepLogic, foldProverComputeMsg, step, getFoldProverFinalOutput] at prvOut_eq rw [Prod.mk.injEq, Prod.mk.injEq] at prvOut_eq - obtain ⟨⟨prvStmtOut_eq, prvOStmtOut_eq⟩, prvWitOut_eq⟩ := prvOut_eq - constructor · rw [prvWitOut_eq, verStmtOut_eq, verOStmtOut_eq]; exact h_rel @@ -386,26 +371,36 @@ theorem foldOracleReduction_perfectCompleteness (hInit : NeverFail init) (i : Fi open scoped NNReal open Classical in -/-- Definition of the per-round RBR KS error for Binary FoldFold. -This combines the Sumcheck error (1/|L|) and the LDT Bad Event probability. +/-! Definition of the per-round RBR KS error for Binary FoldFold. +This combines the Sumcheck error (2/|L|) and the LDT Bad Event probability. For round i : rbrKnowledgeError(i) = err_SC + err_BE where -- err_SC = 1/|L| (Schwartz-Zippel for degree 1) -- err_BE = (if ϑ ∣ (i + 1) then ϑ * |S^(i+1)| / |L| else 0) - where k = i / ϑ and |S^(j)| is the size of the j-th domain +- err_SC = 2/|L| (Schwartz-Zippel for degree 1) +- err_BE = |S^(last_oracle_domain_index_of_i + ϑ)| / |L| -/ def foldKnowledgeError (i : Fin ℓ) (_ : (pSpecFold (L := L)).ChallengeIdx) : ℝ≥0 := - let err_SC := (1 : ℝ≥0) / (Fintype.card L) - -- bad event of `fⱼ` exists RIGHT AFTER the V's challenge of sumcheck round `j+ϑ-1`, - let err_BE := if hi : ϑ ∣ (i.val + 1) then - -- HERE: we view `i` as `j+ϑ-1`, error rate is `ϑ * |S^(j+ϑ)| / |L| = ϑ * |S^(i+1)| / |L|` - ϑ * (Fintype.card ((sDomain 𝔽q β h_ℓ_add_R_rate) - ⟨i.val + 1, by -- ⊢ ↑i + 1 < r + let err_SC := (2 : ℝ≥0) / (Fintype.card L) + -- Distributed fold-error budget: one incremental bad-event charge per fold round. + let err_BE := + let lastDomainIdx := getLastOracleDomainIndex ℓ ϑ i.castSucc + (Fintype.card ((sDomain 𝔽q β h_ℓ_add_R_rate) + ⟨lastDomainIdx.val + ϑ, by + have h_le := getLastOracleDomainIndex_add_ϑ_le ℓ ϑ i.castSucc omega⟩) : ℝ≥0) / (Fintype.card L) - else 0 err_SC + err_BE -/-- The round-by-round extractor for a single round. -Since f^(0) is always available, we can invoke the extractMLP function directly. -/ +/-! WitMid type for fold step: Witness i.succ at final round, Witness i.castSucc otherwise. +This allows the extractor to work with the actual output witness type at the final round. -/ +def foldWitMid (i : Fin ℓ) : Fin (2 + 1) → Type := + fun m => match m with + | ⟨0, _⟩ => Witness (L := L) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i.castSucc + | ⟨1, _⟩ => Witness (L := L) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i.castSucc + | ⟨2, _⟩ => Witness (L := L) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i.succ + +/-! The round-by-round extractor for a single round. +Since f^(0) is always available, we can invoke the extractMLP function directly. + +Key design: WitMid at the final round (m=2) is Witness i.succ, matching WitOut. +This allows extractOut to be identity and simplifies toFun_full proofs. -/ noncomputable def foldRbrExtractor (i : Fin ℓ) : Extractor.RoundByRound []ₒ (StmtIn := (Statement (L := L) Context i.castSucc) × (∀ j, @@ -413,80 +408,78 @@ noncomputable def foldRbrExtractor (i : Fin ℓ) : (WitIn := Witness (L := L) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i.castSucc) (WitOut := Witness (L := L) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i.succ) (pSpec := pSpecFold (L := L)) - (WitMid := fun _messageIdx => Witness (L := L) 𝔽q β - (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i.castSucc) where + (WitMid := foldWitMid 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i) where eqIn := rfl - extractMid := fun _ _ _ witMidSucc => witMidSucc - extractOut := fun ⟨stmtIn, oStmtIn⟩ fullTranscript witOut => by - exact { - t := witOut.t, - H := - projectToMidSumcheckPoly (L := L) (ℓ := ℓ) - (t := witOut.t) (m := mp.multpoly stmtIn.ctx) + extractMid := fun m ⟨stmtIn, _oStmtIn⟩ _tr witMidSucc => + match m with + | ⟨0, _⟩ => witMidSucc -- WitMid 1 → WitMid 0, both are Witness i.castSucc + | ⟨1, _⟩ => + -- WitMid 2 → WitMid 1, i.e., Witness i.succ → Witness i.castSucc + -- Extract backward using the transcript + { + t := witMidSucc.t, + H := projectToMidSumcheckPoly (L := L) (ℓ := ℓ) + (t := witMidSucc.t) (m := mp.multpoly stmtIn.ctx) (i := i.castSucc) (challenges := stmtIn.challenges), - f := getMidCodewords 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) witOut.t - (challenges := stmtIn.challenges) - } + f := getMidCodewords 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) witMidSucc.t + (challenges := stmtIn.challenges) + } + -- extractOut is now identity since WitMid (Fin.last 2) = WitOut = Witness i.succ + extractOut := fun _stmtIn _fullTranscript witOut => witOut -/-- This follows the KState of sum-check -/ +/-! This follows the KState of sum-check -/ def foldKStateProp {i : Fin ℓ} (m : Fin (2 + 1)) - (tr : Transcript m (pSpecFold (L := L))) (stmt : Statement (L := L) Context i.castSucc) - (witMid : Witness (L := L) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i.castSucc) - (oStmt : ∀ j, OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ i.castSucc j) : + (tr : Transcript m (pSpecFold (L := L))) (stmtMid : Statement (L := L) Context i.castSucc) + (witMid : foldWitMid 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i m) + (oStmtMid : ∀ j, OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ i.castSucc j) : Prop := -- Ground-truth polynomial from witness - let h_star : ↥L⦃≤ 2⦄[X] := getSumcheckRoundPoly ℓ 𝓑 (i := i) (h := witMid.H) - -- Checks available after message 1 (P -> V : hᵢ(X)) - let get_Hᵢ := fun (m: Fin (2 + 1)) (tr: Transcript m pSpecFold) (hm: 1 ≤ m.val) => - let ⟨msgsUpTo, _⟩ := Transcript.equivMessagesChallenges (k := m) - (pSpec := pSpecFold (L := L)) tr - let i_msg1 : ((pSpecFold (L := L)).take m m.is_le).MessageIdx := - ⟨⟨0, Nat.lt_of_succ_le hm⟩, by simp [pSpecFold]; rfl⟩ - let h_i : L⦃≤ 2⦄[X] := msgsUpTo i_msg1 - h_i - let get_rᵢ' := fun (m: Fin (2 + 1)) (tr: Transcript m pSpecFold) (hm: 2 ≤ m.val) => - let ⟨msgsUpTo, chalsUpTo⟩ := Transcript.equivMessagesChallenges (k := m) - (pSpec := pSpecFold (L := L)) tr - let i_msg1 : ((pSpecFold (L := L)).take m m.is_le).MessageIdx := - ⟨⟨0, Nat.lt_of_succ_le (Nat.le_trans (by decide) hm)⟩, by simp; rfl⟩ - let h_i : L⦃≤ 2⦄[X] := msgsUpTo i_msg1 - let i_msg2 : ((pSpecFold (L := L)).take m m.is_le).ChallengeIdx := - ⟨⟨1, Nat.lt_of_succ_le hm⟩, by simp only [Nat.reduceAdd]; rfl⟩ - let r_i' : L := chalsUpTo i_msg2 - r_i' match m with - | ⟨0, _⟩ => -- equiv s relIn + | ⟨0, _⟩ => -- Same as relIn (roundRelation at i.castSucc) masterKStateProp (mp := mp) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (stmtIdx := i.castSucc) (oracleIdx := OracleFrontierIndex.mkFromStmtIdx i.castSucc) - (stmt := stmt) (wit := witMid) (oStmt := oStmt) - (localChecks := sumcheckConsistencyProp (𝓑 := 𝓑) stmt.sumcheck_target witMid.H) - | ⟨1, h1⟩ => -- P sends hᵢ(X) + (stmt := stmtMid) (wit := witMid) (oStmt := oStmtMid) + (localChecks := sumcheckConsistencyProp (𝓑 := 𝓑) stmtMid.sumcheck_target witMid.H) + | ⟨1, _⟩ => -- After P sends hᵢ(X), before V sends r_i' + let h_star : ↥L⦃≤ 2⦄[X] := getSumcheckRoundPoly ℓ 𝓑 (i := i) (h := witMid.H) + let h_i : ↥L⦃≤ 2⦄[X] := tr.messages ⟨0, rfl⟩ masterKStateProp (mp := mp) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (stmtIdx := i.castSucc) (oracleIdx := OracleFrontierIndex.mkFromStmtIdx i.castSucc) - (stmt := stmt) (wit := witMid) (oStmt := oStmt) + (stmt := stmtMid) (wit := witMid) (oStmt := oStmtMid) (localChecks := - let h_i := get_Hᵢ (m := ⟨1, h1⟩) (tr := tr) (hm := by simp only [le_refl]) - let explicitVCheck := h_i.val.eval (𝓑 0) + h_i.val.eval (𝓑 1) = stmt.sumcheck_target + -- Verifier's explicit check: h_i(0) + h_i(1) = sumcheck_target + let explicitVCheck := h_i.val.eval (𝓑 0) + h_i.val.eval (𝓑 1) = stmtMid.sumcheck_target + -- Honest prover check: h_i matches ground truth let localizedRoundPolyCheck := h_i = h_star explicitVCheck ∧ localizedRoundPolyCheck ) - | ⟨2, h2⟩ => -- implied by (relOut + V's check) - -- The bad-folding-event of `fᵢ` is also introduced internaly by `masterKStateProp` + | ⟨2, _⟩ => -- After V sends r_i': use OUTPUT state (consistent with foldStepRelOut) + let h_i : ↥L⦃≤ 2⦄[X] := tr.messages ⟨0, rfl⟩ + let r_i' : L := tr.challenges ⟨1, rfl⟩ + -- Forward-compute the output statement using transcript-derived values + let newSumcheckTarget : L := h_i.val.eval r_i' + let stmtOut : Statement (L := L) Context i.succ := { + -- same as in Verifier's output & getFoldProverFinalOutput + ctx := stmtMid.ctx, + sumcheck_target := newSumcheckTarget, + challenges := Fin.snoc stmtMid.challenges r_i' + } + let oStmtOut := oStmtMid + let witOut := witMid + -- Use OUTPUT state: stmtIdx advances to i.succ, oracleIdx stays at i.castSucc (no new oracle) masterKStateProp (mp := mp) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) - (stmtIdx := i.castSucc) (oracleIdx := OracleFrontierIndex.mkFromStmtIdx i.castSucc) - (stmt := stmt) (wit := witMid) (oStmt := oStmt) + (stmtIdx := i.succ) (oracleIdx := OracleFrontierIndex.mkFromStmtIdxCastSuccOfSucc i) + (stmt := stmtOut) (wit := witOut) (oStmt := oStmtOut) (localChecks := - let h_i := get_Hᵢ (m := ⟨2, h2⟩) (tr := tr) (hm := by simp only [Nat.one_le_ofNat]) - let r_i' := get_rᵢ' (m := ⟨2, h2⟩) (tr := tr) (hm := by simp only [le_refl]) - let localizedRoundPolyCheck := h_i = h_star - let nextSumcheckTargetCheck := -- this presents sumcheck of next round (sᵢ = s^*ᵢ) - h_i.val.eval r_i' = h_star.val.eval r_i' - localizedRoundPolyCheck ∧ nextSumcheckTargetCheck - ) -- this holds the constraint for witOut in relOut + let explicitVCheck := + h_i.val.eval (𝓑 0) + h_i.val.eval (𝓑 1) = stmtMid.sumcheck_target + explicitVCheck ∧ + -- we also keep the output-state sumcheck consistency + sumcheckConsistencyProp (𝓑 := 𝓑) stmtOut.sumcheck_target witOut.H) -- Note: this fold step couldn't carry bad-event errors, because we don't have oracles yet. -/-- Knowledge state function (KState) for single round -/ +/-! Knowledge state function (KState) for single round -/ def foldKnowledgeStateFunction (i : Fin ℓ) : (foldOracleVerifier 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) (mp := mp) i).KnowledgeStateFunction init impl @@ -496,27 +489,908 @@ def foldKnowledgeStateFunction (i : Fin ℓ) : (𝓑 := 𝓑) i) (extractor := foldRbrExtractor (mp:=mp) (𝓡 := 𝓡) (ϑ := ϑ) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i) where - toFun := fun m ⟨stmt, oStmt⟩ tr witMid => + toFun := fun m ⟨stmtMid, oStmtMid⟩ tr witMid => foldKStateProp (mp:=mp) 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) - (i := i) (m := m) (tr := tr) (stmt := stmt) (witMid := witMid) (oStmt := oStmt) + (i := i) (m := m) (tr := tr) (stmtMid := stmtMid) (witMid := witMid) (oStmtMid := oStmtMid) toFun_empty := fun _ _ => by rfl - toFun_next := fun m hDir stmtIn tr msg witMid => by - obtain ⟨stmt, oStmt⟩ := stmtIn - fin_cases m - · exact fun ⟨_, h⟩ => by sorry - · simp at hDir - toFun_full := fun ⟨stmtLast, oStmtLast⟩ tr witOut h_relOut => by - simp at h_relOut - rcases h_relOut with ⟨stmtOut, ⟨oStmtOut, h_conj⟩⟩ - have h_simulateQ := h_conj.1 - have h_foldStepRelOut := h_conj.2 - set witLast := (foldRbrExtractor (mp:=mp) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i).extractOut - ⟨stmtLast, oStmtLast⟩ tr witOut - simp only [Fin.reduceLast, Fin.isValue] - have h_oStmt : oStmtLast = oStmtOut := by sorry - sorry + toFun_next := fun m hDir ⟨stmtMid, oStmtMid⟩ tr msg witMid => by + -- For pSpecFold, the only P_to_V message is at index 0 + -- So m = 0, m.succ = 1, m.castSucc = 0 + have h_m_eq_0 : m = 0 := by + cases m using Fin.cases with + | zero => rfl + | succ m' => simp only [ne_eq, reduceCtorEq, not_false_eq_true, Matrix.cons_val_succ, + Matrix.cons_val_fin_one, Direction.not_V_to_P_eq_P_to_V] at hDir + subst h_m_eq_0 + intro h_kState_round1 + unfold foldKStateProp at h_kState_round1 ⊢ + simp only [Fin.isValue, Fin.succ_zero_eq_one, Nat.reduceAdd, Fin.mk_one, + Fin.coe_ofNat_eq_mod, Nat.reduceMod] at h_kState_round1 + simp only [Fin.castSucc_zero] + -- At round 1: bad ∨ (localChecks ∧ structural ∧ initial ∧ oracleFoldingConsistency) + -- At round 0: bad ∨ (sumcheckConsistency ∧ structural ∧ initial ∧ oracleFoldingConsistency) + cases h_kState_round1 with + | inl h_bad => + exact Or.inl h_bad + | inr h_good => + have h_explicit := h_good.1.1 + have h_localized := h_good.1.2 + have h_struct : witnessStructuralInvariant 𝔽q β (mp := mp) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) stmtMid witMid := h_good.2.1 + have h_init : firstOracleWitnessConsistencyProp 𝔽q β witMid.t + (getFirstOracle 𝔽q β oStmtMid) := h_good.2.2.1 + have h_fold := h_good.2.2.2 + have h_sumcheck : sumcheckConsistencyProp (𝓑 := 𝓑) stmtMid.sumcheck_target witMid.H := by + simp_rw [h_localized] at h_explicit + rw [h_explicit.symm] + exact getSumcheckRoundPoly_sum_eq (L := L) (ℓ := ℓ) (𝓑 := 𝓑) (i := i) (h := witMid.H) + exact Or.inr ⟨h_sumcheck, h_struct, h_init, h_fold⟩ + toFun_full := fun ⟨stmtIn, oStmtIn⟩ tr witOut probEvent_relOut_gt_0 => by + -- h_relOut: ∃ stmtOut oStmtOut, verifier outputs (stmtOut, oStmtOut) with prob > 0 + -- and ((stmtOut, oStmtOut), witOut) ∈ foldStepRelOut + simp only [StateT.run'_eq, gt_iff_lt, probEvent_pos_iff, Prod.exists] at probEvent_relOut_gt_0 + rcases probEvent_relOut_gt_0 with ⟨stmtOut, oStmtOut, h_output_mem_V_run_support, h_relOut⟩ + have h_output_mem_V_run_support' : + some (stmtOut, oStmtOut) ∈ + (do + let s ← init + Prod.fst <$> + (simulateQ impl + (Verifier.run (stmtIn, oStmtIn) tr + (foldOracleVerifier 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (𝓑 := 𝓑) (mp := mp) i).toVerifier)).run s).support := by + exact (OptionT.mem_support_iff + (mx := OptionT.mk (do + let s ← init + Prod.fst <$> + (simulateQ impl + (Verifier.run (stmtIn, oStmtIn) tr + (foldOracleVerifier 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (𝓑 := 𝓑) (mp := mp) i).toVerifier)).run s)) + (x := (stmtOut, oStmtOut))).1 h_output_mem_V_run_support + simp only [support_bind, Set.mem_iUnion, exists_prop] at h_output_mem_V_run_support' + rcases h_output_mem_V_run_support' with ⟨s, hs_init, h_output_mem_V_run_support⟩ + conv at h_output_mem_V_run_support => + simp only [Verifier.run, OracleVerifier.toVerifier] + -- Now unfold the foldOracleVerifier's `verify()` method + simp only [foldOracleVerifier] + -- dsimp only [StateT.run] + -- simp only [simulateQ_bind, simulateQ_query, simulateQ_pure] + -- oracle query unfolding + simp only [support_bind, Set.mem_iUnion] + dsimp only [StateT.run] + -- enter [1, i_1, 2, 1, x] + simp only [simulateQ_bind] + unfold OracleInterface.answer + --------------------------------------- + -- Now simplify the `guard` and `ite` of StateT.map generated from it + simp only [MessageIdx, Fin.isValue, Matrix.cons_val_zero, simulateQ_pure, Message, guard_eq, + pure_bind, Function.comp_apply, simulateQ_map, simulateQ_ite, + OptionT.simulateQ_failure', bind_map_left] + simp only [MessageIdx, Message, Fin.isValue, Matrix.cons_val_zero, Matrix.cons_val_one, + bind_pure_comp, simulateQ_map, simulateQ_ite, simulateQ_pure, OptionT.simulateQ_failure', + bind_map_left, Function.comp_apply] + simp only [support_ite] + simp only [Fin.isValue, Set.mem_ite_empty_right, Set.mem_singleton_iff, Prod.mk.injEq, + exists_and_left, exists_eq', exists_eq_right, exists_and_right] + erw [simulateQ_bind] + enter [1, x, 1, 2, 1, 2]; + erw [simulateQ_bind] + erw [OptionT.simulateQ_simOracle2_liftM_query_T2] + simp only [Fin.isValue, FullTranscript.mk1_eq_snoc, pure_bind, OptionT.simulateQ_map] + conv at h_output_mem_V_run_support => + simp only [Fin.isValue, FullTranscript.mk1_eq_snoc, Function.comp_apply] + erw [support_bind] at h_output_mem_V_run_support + let step := (foldStepLogic 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) (mp := mp) i) + set V_check := step.verifierCheck stmtIn + (FullTranscript.mk2 (msg0 := _) (msg1 := _)) with h_V_check_def + by_cases h_V_check : V_check + · simp only [Fin.isValue, Matrix.cons_val_zero, id_eq, h_V_check, ↓reduceIte, OptionT.run_pure, + simulateQ_pure, Function.comp_apply, Set.mem_iUnion, exists_prop, Prod.exists, + exists_and_right] at h_output_mem_V_run_support + erw [simulateQ_bind] at h_output_mem_V_run_support + simp only [simulateQ_pure, Fin.isValue, Function.comp_apply, + pure_bind] at h_output_mem_V_run_support + erw [support_pure] at h_output_mem_V_run_support + simp only [Fin.isValue, Set.mem_singleton_iff, Prod.mk.injEq, exists_eq_right, + exists_eq_left] at h_output_mem_V_run_support + erw [support_pure] at h_output_mem_V_run_support + simp only [Fin.isValue, Set.mem_singleton_iff, Option.some.injEq, + Prod.mk.injEq] at h_output_mem_V_run_support + -- simp only [support_map, Set.mem_image, exists_prop] at h_output_mem_V_run_support + rcases h_output_mem_V_run_support with ⟨h_stmtOut_eq, h_oStmtOut_eq⟩ + simp only [Fin.reduceLast, Fin.isValue] -- simp the `match` + dsimp only [foldStepRelOut, foldStepRelOutProp, masterKStateProp] at h_relOut + simp only [Fin.val_succ, Set.mem_setOf_eq] at h_relOut + dsimp only [foldKStateProp] + set h_i : ↥L⦃≤ 2⦄[X] := tr.messages ⟨⟨0, by simp only [Nat.reduceAdd, + Fin.reduceLast, Fin.coe_ofNat_eq_mod, Nat.mod_succ, Nat.ofNat_pos]⟩, rfl⟩ with h_i_def + set r_i' : L := tr.challenges ⟨⟨1, by simp only [Nat.reduceAdd, Fin.reduceLast, + Fin.coe_ofNat_eq_mod, Nat.mod_succ, Nat.one_lt_ofNat]⟩, rfl⟩ with h_i_def + set extractedWitLast : Witness (L := L) 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i.succ := + (foldRbrExtractor 𝔽q β i).extractOut (stmtIn, oStmtIn) tr witOut + have h_oStmtOut_eq_oStmtIn : oStmtOut = oStmtIn := by + rw [h_oStmtOut_eq] + funext j + -- ⊢ OracleVerifier.mkVerifierOStmtOut (foldStepLogic 𝔽q β i).embed ⋯ oStmtIn tr j + -- = oStmtIn j + simp only [foldStepLogic, Prod.mk.eta, Fin.isValue, MessageIdx, Fin.is_lt, ↓reduceDIte, + Fin.eta, Fin.zero_eta, Fin.mk_one, Function.Embedding.coeFn_mk, Sum.inl.injEq, + OracleVerifier.mkVerifierOStmtOut_inl, cast_eq] + have h_stmtOut_challenges_eq : + ((Fin.snoc stmtIn.challenges r_i') : Fin (↑i + 1) → L) = stmtOut.challenges := by + -- use the h_stmtOut_eq to prove this + rw [h_stmtOut_eq] + unfold foldStepLogic foldVerifierStmtOut + simp only [Fin.val_succ, Fin.isValue, Fin.snoc_inj, true_and] + rfl + rw [h_oStmtOut_eq_oStmtIn] at h_relOut + have h_stmtOut_sumcheck_target_eq : + stmtOut.sumcheck_target = (Polynomial.eval r_i' ↑h_i) := by + rw [h_stmtOut_eq]; rfl + dsimp only [masterKStateProp] + rw [h_stmtOut_sumcheck_target_eq] at h_relOut + have h_explicit : h_i.val.eval (𝓑 0) + h_i.val.eval (𝓑 1) = stmtIn.sumcheck_target := by + simpa [h_i_def] using h_V_check + cases h_relOut with + | inl h_bad => + have h_bad' : incrementalBadEventExistsProp 𝔽q β i.succ + (OracleFrontierIndex.mkFromStmtIdxCastSuccOfSucc i) oStmtIn + (Fin.snoc stmtIn.challenges r_i') := by + simpa [h_stmtOut_challenges_eq] using h_bad + exact Or.inl h_bad' + | inr h_good => + refine Or.inr ?_ + refine ⟨?_, ?_, ?_, ?_⟩ + · exact ⟨h_explicit, h_good.1⟩ + · simpa [h_stmtOut_eq] using h_good.2.1 + · simpa [h_stmtOut_eq] using h_good.2.2.1 + · have h_res := h_good.2.2.2 + simp only [h_stmtOut_eq] at ⊢ h_res + exact h_res + · simp only [Fin.isValue, h_V_check, ↓reduceIte, OptionT.run_failure, simulateQ_pure, + Set.mem_iUnion, exists_prop, Prod.exists] at h_output_mem_V_run_support + erw [simulateQ_bind] at h_output_mem_V_run_support + simp only [simulateQ_pure, Fin.isValue, Function.comp_apply, + pure_bind] at h_output_mem_V_run_support + erw [support_pure] at h_output_mem_V_run_support + simp only [Set.mem_singleton_iff, Prod.mk.injEq, ↓existsAndEq, and_true, exists_eq_left, + ] at h_output_mem_V_run_support + erw [support_pure] at h_output_mem_V_run_support + simp only [Set.mem_singleton_iff, reduceCtorEq] at h_output_mem_V_run_support + +/-! This follows the KState of sum-check -/ +def foldKStateProps {i : Fin ℓ} (m : Fin (2 + 1)) + (tr : Transcript m (pSpecFold (L := L))) (stmtMid : Statement (L := L) Context i.castSucc) + (witMid : foldWitMid 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i m) + (oStmtMid : ∀ j, OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ i.castSucc j) : + Prop := + -- Ground-truth polynomial from witness + match m with + | ⟨0, _⟩ => -- Same as relIn (roundRelation at i.castSucc) + masterKStateProp (mp := mp) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (stmtIdx := i.castSucc) (oracleIdx := OracleFrontierIndex.mkFromStmtIdx i.castSucc) + (stmt := stmtMid) (wit := witMid) (oStmt := oStmtMid) + (localChecks := sumcheckConsistencyProp (𝓑 := 𝓑) stmtMid.sumcheck_target witMid.H) + | ⟨1, _⟩ => -- After P sends hᵢ(X), before V sends r_i' + let h_star : ↥L⦃≤ 2⦄[X] := getSumcheckRoundPoly ℓ 𝓑 (i := i) (h := witMid.H) + let h_i : ↥L⦃≤ 2⦄[X] := tr.messages ⟨0, rfl⟩ + masterKStateProp (mp := mp) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (stmtIdx := i.castSucc) (oracleIdx := OracleFrontierIndex.mkFromStmtIdx i.castSucc) + (stmt := stmtMid) (wit := witMid) (oStmt := oStmtMid) + (localChecks := + -- Verifier's explicit check: h_i(0) + h_i(1) = sumcheck_target + let explicitVCheck := h_i.val.eval (𝓑 0) + h_i.val.eval (𝓑 1) = stmtMid.sumcheck_target + -- Honest prover check: h_i matches ground truth + let localizedRoundPolyCheck := h_i = h_star + explicitVCheck ∧ localizedRoundPolyCheck + ) + | ⟨2, _⟩ => -- After V sends r_i': use OUTPUT state (consistent with foldStepRelOut) + let h_i : ↥L⦃≤ 2⦄[X] := tr.messages ⟨0, rfl⟩ + let r_i' : L := tr.challenges ⟨1, rfl⟩ + -- Forward-compute the output statement using transcript-derived values + let newSumcheckTarget : L := h_i.val.eval r_i' + let stmtOut : Statement (L := L) Context i.succ := { -- same as in getFoldProverFinalOutput + ctx := stmtMid.ctx, + sumcheck_target := newSumcheckTarget, + challenges := Fin.snoc stmtMid.challenges r_i' + } + let oStmtOut := oStmtMid + let witOut := witMid + -- Use OUTPUT state: stmtIdx advances to i.succ, oracleIdx stays at i.castSucc (no new oracle) + masterKStateProp (mp := mp) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (stmtIdx := i.succ) (oracleIdx := OracleFrontierIndex.mkFromStmtIdxCastSuccOfSucc i) + (stmt := stmtOut) (wit := witOut) (oStmt := oStmtOut) + (localChecks := + -- we reduce the sumcheck consistency check here + sumcheckConsistencyProp (𝓑 := 𝓑) stmtOut.sumcheck_target witOut.H) -/-- RBR knowledge soundness for a single round oracle verifier -/ +/- +The fold-step extraction failure event implies either: +1. a sumcheck bad event at the sampled challenge, or +2. an incremental folding bad event at the current oracle frontier. + +More precisely: +- **Sumcheck bad**: `h_i ≠ h_star ∧ h_i.eval r_i' = h_star.eval r_i'`, + where `h_star = getSumcheckRoundPoly ℓ 𝓑 i witIn.H`. +- **Folding bad**: an incremental bad-event witness exists at frontier `i.castSucc` + using challenges extended by `r_i'`. + +Proof plan for `foldStep_rbrExtractionFailureEvent_imply_sumcheck_or_badEvent`: + +Goal shape: + Doom-escape at challenge round `⟨1, rfl⟩` gives an existential `witMid` with + `¬kSF@castSucc` and `kSF@succ`; we must derive: + `badSumcheckEventProp r_i' h_i h_star(witIn) ∨ incrementalFoldingBadEvent`. + +Plan: +1. Unfold the doom-escape witness: + Expand `rbrExtractionFailureEvent`, `foldKnowledgeStateFunction`, and `foldKStateProp` + at rounds `m=1` and `m=2`, obtaining the two KState facts carried by `witMid`. + +2. Isolate the KState core: + From `masterKStateProp`, separate local checks from the core disjunction + `incrementalBadEventExistsProp ∨ oracleWitnessConsistency`. + +3. Split by the incremental bad event: + Case A: `incrementalFoldingBadEvent` holds; finish by `Or.inr`. + Case B: `¬ incrementalFoldingBadEvent`; show this forces the KState-2 core to use + `oracleWitnessConsistency` (good branch). + +4. Overlap-cancellation for bad events: + In Case B, any bad event witnessed at round 2 must already be present at round 1. + Old events are preserved backward to round 1 (same oracle frontier / challenge prefix), + contradicting `¬kSF@round1`. Hence no bad-event branch remains. + +5. Fix the round polynomial on the good branch: + Use the good branch (`oracleWitnessConsistency`, plus local checks) to identify the + witness-derived round polynomial and compare it with `h_i`. + Then combine with `¬kSF@round1` to obtain: + `h_i ≠ h_star` and `h_i(r_i') = h_star(r_i')`. + +6. Conclude sumcheck bad: + Package Step 5 as `badSumcheckEventProp r_i' h_i h_star(witIn)` and finish by `Or.inl`. + +Expected helper lemmas: +- backward preservation of incremental bad events from round-2 to round-1 view; +- extraction of localized round-poly equalities from fold-step local checks. +-/ +omit [SampleableType L] [DecidableEq 𝔽q] in +lemma firstOracleWitnessConsistency_unique (i : Fin ℓ) + (oStmt : ∀ j, OracleStatement 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i.castSucc j) + {t₁ t₂ : MultilinearPoly L ℓ} + (h₁ : firstOracleWitnessConsistencyProp 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + t₁ (getFirstOracle 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) oStmt)) + (h₂ : firstOracleWitnessConsistencyProp 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + t₂ (getFirstOracle 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) oStmt)) : + t₁ = t₂ := by + classical + have h₁_some : + extractMLP 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) 0 + (getFirstOracle 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) oStmt) = some t₁ := + (extractMLP_eq_some_iff_pair_UDRClose 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (f := getFirstOracle 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) oStmt) (tpoly := t₁)).2 h₁ + have h₂_some : + extractMLP 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) 0 + (getFirstOracle 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) oStmt) = some t₂ := + (extractMLP_eq_some_iff_pair_UDRClose 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (f := getFirstOracle 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) oStmt) (tpoly := t₂)).2 h₂ + rw [h₁_some] at h₂_some + simpa using h₂_some + +/-! Extract the round-`i` witness (before the verifier challenge) from a fold-step output +witness. -/ +@[reducible] +def foldStepWitBeforeFromWitMid (i : Fin ℓ) + (stmtOStmtIn : (Statement (L := L) Context i.castSucc) × (∀ j, + OracleStatement 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i.castSucc j)) + (h_i : (pSpecFold (L := L)).Message ⟨0, rfl⟩) (r_i' : L) + (witMid : Witness (L := L) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i.succ) : + Witness (L := L) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i.castSucc := + (foldRbrExtractor.{0} (mp := mp) 𝔽q β i).extractMid + (m := 1) stmtOStmtIn (FullTranscript.mk2 h_i r_i') witMid + +/-! Canonical fold-step round polynomial extracted from a specific `witMid`. -/ +@[reducible] +def foldStepHStarFromWitMid (i : Fin ℓ) + (stmtOStmtIn : (Statement (L := L) Context i.castSucc) × (∀ j, + OracleStatement 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i.castSucc j)) + (h_i : (pSpecFold (L := L)).Message ⟨0, rfl⟩) (r_i' : L) + (witMid : Witness (L := L) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i.succ) : + L⦃≤ 2⦄[X] := + let witBefore := foldStepWitBeforeFromWitMid + (mp := mp) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i stmtOStmtIn h_i r_i' witMid + getSumcheckRoundPoly ℓ 𝓑 (i := i) (h := witBefore.H) + +/-! At the same fold-step output state, `witnessStructuralInvariant` +and `firstOracleWitnessConsistencyProp` determine a unique witness. +Consequently, any witness-dependent extracted `h_star` is canonical. -/ +omit [SampleableType L] [DecidableEq 𝔽q] in +lemma foldStep_oracleWitnessConsistency_unique_witMid (i : Fin ℓ) + (stmtOut : Statement (L := L) Context i.succ) + (oStmt : ∀ j, OracleStatement 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i.castSucc j) + {witMid₁ witMid₂ : Witness (L := L) 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i.succ} + (h_struct₁ : witnessStructuralInvariant 𝔽q β (mp := mp) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) stmtOut witMid₁) + (h_struct₂ : witnessStructuralInvariant 𝔽q β (mp := mp) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) stmtOut witMid₂) + (h_init₁ : firstOracleWitnessConsistencyProp 𝔽q β witMid₁.t + (getFirstOracle 𝔽q β oStmt)) + (h_init₂ : firstOracleWitnessConsistencyProp 𝔽q β witMid₂.t + (getFirstOracle 𝔽q β oStmt)) : + witMid₁ = witMid₂ := by + classical + have h_t : witMid₁.t = witMid₂.t := by + exact firstOracleWitnessConsistency_unique 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (ϑ := ϑ) (i := i) (oStmt := oStmt) (h₁ := h_init₁) (h₂ := h_init₂) + have h_H : witMid₁.H = witMid₂.H := by + calc + witMid₁.H = projectToMidSumcheckPoly (L := L) (ℓ := ℓ) (t := witMid₁.t) + (m := mp.multpoly stmtOut.ctx) (i := i.succ) + (challenges := stmtOut.challenges) := h_struct₁.1 + _ = projectToMidSumcheckPoly (L := L) (ℓ := ℓ) (t := witMid₂.t) + (m := mp.multpoly stmtOut.ctx) (i := i.succ) + (challenges := stmtOut.challenges) := by simp only [Fin.val_succ, h_t] + _ = witMid₂.H := h_struct₂.1.symm + have h_f : witMid₁.f = witMid₂.f := by + calc + witMid₁.f = getMidCodewords 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := i.succ) (t := witMid₁.t) (challenges := stmtOut.challenges) := h_struct₁.2 + _ = getMidCodewords 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := i.succ) (t := witMid₂.t) + (challenges := stmtOut.challenges) := by simp only [Fin.val_succ, h_t] + _ = witMid₂.f := h_struct₂.2.symm + cases witMid₁ + cases witMid₂ + simp only [Fin.val_succ, Witness.mk.injEq] at h_t h_H h_f ⊢ + exact ⟨h_t, h_H, h_f⟩ + +omit [SampleableType L] [DecidableEq 𝔽q] in +lemma foldStepHStarFromWitMid_eq_of_oracleWitnessConsistency (i : Fin ℓ) + (stmtOStmtIn : (Statement (L := L) Context i.castSucc) × (∀ j, + OracleStatement 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i.castSucc j)) + (h_i : (pSpecFold (L := L)).Message ⟨0, rfl⟩) (r_i' : L) + {witMid₁ witMid₂ : Witness (L := L) 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i.succ} + (h_struct₁ : witnessStructuralInvariant 𝔽q β (mp := mp) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + { + sumcheck_target := h_i.val.eval r_i', + challenges := Fin.snoc stmtOStmtIn.1.challenges r_i', + ctx := stmtOStmtIn.1.ctx + } witMid₁) + (h_struct₂ : witnessStructuralInvariant 𝔽q β (mp := mp) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + { + sumcheck_target := h_i.val.eval r_i', + challenges := Fin.snoc stmtOStmtIn.1.challenges r_i', + ctx := stmtOStmtIn.1.ctx + } witMid₂) + (h_init₁ : firstOracleWitnessConsistencyProp 𝔽q β witMid₁.t + (getFirstOracle 𝔽q β stmtOStmtIn.2)) + (h_init₂ : firstOracleWitnessConsistencyProp 𝔽q β witMid₂.t + (getFirstOracle 𝔽q β stmtOStmtIn.2)) : + foldStepHStarFromWitMid (mp := mp) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (𝓑 := 𝓑) i stmtOStmtIn h_i r_i' witMid₁ = + foldStepHStarFromWitMid (mp := mp) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (𝓑 := 𝓑) i stmtOStmtIn h_i r_i' witMid₂ := by + have h_wit_eq : + witMid₁ = witMid₂ := foldStep_oracleWitnessConsistency_unique_witMid + 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (mp := mp) (ϑ := ϑ) + (i := i) + (stmtOut := { + sumcheck_target := h_i.val.eval r_i', + challenges := Fin.snoc stmtOStmtIn.1.challenges r_i', + ctx := stmtOStmtIn.1.ctx + }) + (oStmt := stmtOStmtIn.2) h_struct₁ h_struct₂ h_init₁ h_init₂ + subst h_wit_eq + rfl + +/-! Fresh incremental bad-event for the **latest oracle block** at the fold-step: +`¬ E_before ∧ E_after`, where `E_*` is `incrementalFoldingBadEvent` evaluated +before/after appending `r_i'`. -/ +@[reducible] +def foldStepFreshDoomPreservationEvent (i : Fin ℓ) + (stmtOStmtIn : (Statement (L := L) Context i.castSucc) × (∀ j, + OracleStatement 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i.castSucc j)) + (r_i' : L) : Prop := + let stmtIdxBefore : Fin (ℓ + 1) := i.castSucc + let challengesBefore : Fin stmtIdxBefore → L := stmtOStmtIn.1.challenges + let j := getLastOraclePositionIndex ℓ ϑ i.castSucc + let curOracleDomainIdx : Fin r := ⟨oraclePositionToDomainIndex (positionIdx := j), by omega⟩ + let kBefore : ℕ := min ϑ (stmtIdxBefore.val - curOracleDomainIdx.val) + -- NOTE: actually `kBefore` is always less than `ϑ`, so `kBefore + 1 ≤ ϑ` + have h_j_val : j.val = i.val / ϑ := by + have h_i_lt_ℓ : i.val < ℓ := i.isLt + have h_i_cast_lt_ℓ : i.val < ℓ := by simp only [h_i_lt_ℓ] + dsimp only [j, getLastOraclePositionIndex] + unfold toOutCodewordsCount + simp only [Fin.val_castSucc, h_i_lt_ℓ, ↓reduceIte, add_tsub_cancel_right] + have h_cur_eq : curOracleDomainIdx.val = (i.val / ϑ) * ϑ := by + dsimp only [curOracleDomainIdx, oraclePositionToDomainIndex] + simp only [h_j_val] + have h_diff_lt : stmtIdxBefore.val - curOracleDomainIdx.val < ϑ := by + have h_div_mod : (i.val / ϑ) * ϑ + i.val % ϑ = i.val := by + simpa [Nat.mul_comm] using (Nat.div_add_mod i.val ϑ) + have h_cur_le : curOracleDomainIdx.val ≤ stmtIdxBefore.val := by + dsimp only [stmtIdxBefore] + calc + curOracleDomainIdx.val = (i.val / ϑ) * ϑ := h_cur_eq + _ ≤ i.val := Nat.div_mul_le_self i.val ϑ + have h_sum : curOracleDomainIdx.val + i.val % ϑ = stmtIdxBefore.val := by + dsimp only [stmtIdxBefore] + calc + curOracleDomainIdx.val + i.val % ϑ = (i.val / ϑ) * ϑ + i.val % ϑ := by + simp only [h_cur_eq] + _ = i.val := h_div_mod + have h_diff_eq : stmtIdxBefore.val - curOracleDomainIdx.val = i.val % ϑ := by omega + rw [h_diff_eq] + exact Nat.mod_lt i.val (Nat.pos_of_neZero ϑ) + have h_kBefore_lt : kBefore < ϑ := by + exact lt_of_le_of_lt + (Nat.min_le_right ϑ (stmtIdxBefore.val - curOracleDomainIdx.val)) h_diff_lt + let destIdx : Fin r := ⟨curOracleDomainIdx.val + ϑ, by + have h1 := oracle_index_add_steps_le_ℓ ℓ ϑ (i := i.castSucc) (j := j) + have h2 : ℓ + 𝓡 < r := h_ℓ_add_R_rate + have _ : 𝓡 > 0 := Nat.pos_of_neZero 𝓡 + dsimp only [oraclePositionToDomainIndex, curOracleDomainIdx] + omega + ⟩ + let r_prefix : Fin kBefore → L := fun cId => challengesBefore + ⟨curOracleDomainIdx.val + cId.val, by + have h_k_le_stmt : kBefore ≤ stmtIdxBefore.val - curOracleDomainIdx.val := + Nat.min_le_right ϑ (stmtIdxBefore.val - curOracleDomainIdx.val) + have h_cId_lt_k : cId.val < kBefore := cId.isLt + omega + ⟩ + let E_before := + Binius.BinaryBasefold.incrementalFoldingBadEvent 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (block_start_idx := curOracleDomainIdx) + (midIdx := ⟨curOracleDomainIdx.val + kBefore, by + apply lt_r_of_le_ℓ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + have h_k_le : kBefore ≤ ϑ := Nat.min_le_left ϑ (stmtIdxBefore.val - curOracleDomainIdx.val) + have h_add_le : curOracleDomainIdx.val + ϑ ≤ ℓ := + oracle_index_add_steps_le_ℓ ℓ ϑ (i := i.castSucc) (j := j) + omega + ⟩) + (destIdx := destIdx) (k := kBefore) + (h_k_le := Nat.min_le_left ϑ (stmtIdxBefore.val - curOracleDomainIdx.val)) + (h_midIdx := by simp only) + (h_destIdx := rfl) + (h_destIdx_le := by + simp only [(oracle_index_add_steps_le_ℓ ℓ ϑ (i := i.castSucc) (j := j)), j, destIdx, + curOracleDomainIdx]) + (f_block_start := stmtOStmtIn.2 j) + (r_challenges := r_prefix) + let E_after := + Binius.BinaryBasefold.incrementalFoldingBadEvent 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (block_start_idx := curOracleDomainIdx) + (midIdx := ⟨curOracleDomainIdx.val + (kBefore + 1), by + apply lt_r_of_le_ℓ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + have h_k_le : kBefore + 1 ≤ ϑ := Nat.succ_le_of_lt h_kBefore_lt + have h_add_le : curOracleDomainIdx.val + ϑ ≤ ℓ := + oracle_index_add_steps_le_ℓ ℓ ϑ (i := i.castSucc) (j := j) + omega + ⟩) + (destIdx := destIdx) (k := kBefore + 1) + (h_k_le := Nat.succ_le_of_lt h_kBefore_lt) + (h_midIdx := by simp only) + (h_destIdx := rfl) + (h_destIdx_le := by + simp only [(oracle_index_add_steps_le_ℓ ℓ ϑ (i := i.castSucc) (j := j)), j, destIdx, + curOracleDomainIdx]) + (f_block_start := stmtOStmtIn.2 j) + (r_challenges := Fin.snoc r_prefix r_i') + ¬ E_before ∧ E_after + +/-! Oracle-witness consistency for a candidate fold-step output witness. -/ +@[reducible] +def foldStepWitMidOracleConsistency (i : Fin ℓ) + (stmtOStmtIn : (Statement (L := L) Context i.castSucc) × (∀ j, + OracleStatement 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i.castSucc j)) + (h_i : (pSpecFold (L := L)).Message ⟨0, rfl⟩) (r_i' : L) + (witMid : Witness (L := L) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i.succ) : Prop := + let stmt : Statement (L := L) Context i.succ := { + sumcheck_target := h_i.val.eval r_i', + challenges := Fin.snoc stmtOStmtIn.1.challenges r_i', + ctx := stmtOStmtIn.1.ctx + } + let structural := witnessStructuralInvariant 𝔽q β (mp := mp) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) stmt witMid + let initial := firstOracleWitnessConsistencyProp 𝔽q β witMid.t (getFirstOracle 𝔽q β stmtOStmtIn.2) + structural ∧ initial + +/-! Proof sketch: +let `j` be the **oracle position index** of the last oracle at oracle frontier `i`. +Note that `k = i - j * ϑ < ϑ`, since if `k = ϑ`, + then `i` must be an oracle domain, therefore `k = 0`, contradiction. +We have: + h_bad_after = `|__|__|...|__|__|j*ϑ|====i===(i+1)| ↔ exists_bad_until_j OR incBad(j -> i+1)` + h_not_fresh = `¬(¬incBad(j -> i) ∧ incBad(j -> i+1)) ↔ incBad(j -> i) ∨ (¬incBad(j -> i+1))` +Goal: h_bad_before = `|__|__|...|__|__|j*ϑ|====i| = exists_bad_until_j OR incBad(j -> i)` +-------- +We rcases on h_not_fresh: + If `incBad(j -> i)` holds, then h_bad_before = true, Q.E.D. + else we have `¬incBad(j -> i+1)`, + which implies `exists_bad_until_j` to be true from `h_bad_after` + => `h_bad_before = true` by definition +-/ +lemma incrementalBadEventExistsProp_fold_step_backward (i : Fin ℓ) + (stmtOStmtIn : (Statement (L := L) Context i.castSucc) × (∀ j, + OracleStatement 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i.castSucc j)) + (r_i' : L) + (h_bad_after : incrementalBadEventExistsProp 𝔽q β i.succ + (OracleFrontierIndex.mkFromStmtIdxCastSuccOfSucc i) stmtOStmtIn.2 + (Fin.snoc stmtOStmtIn.1.challenges r_i')) + (h_not_fresh : ¬ foldStepFreshDoomPreservationEvent 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ := ℓ) i stmtOStmtIn r_i') : + incrementalBadEventExistsProp 𝔽q β i.castSucc + (OracleFrontierIndex.mkFromStmtIdx i.castSucc) stmtOStmtIn.2 + stmtOStmtIn.1.challenges := by + sorry + +lemma foldStep_rbrExtractionFailureEvent_imply_sumcheck_or_badEvent (i : Fin ℓ) + (stmtOStmtIn : (Statement (L := L) Context i.castSucc) × (∀ j, + OracleStatement 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i.castSucc j)) + (h_i : (pSpecFold (L := L)).Message ⟨0, rfl⟩) (r_i' : L) + (doomEscape : rbrExtractionFailureEvent + (kSF := foldKnowledgeStateFunction (mp := mp) (𝓑 := 𝓑) (init := init) + (impl := impl) (σ := σ) 𝔽q β i) + (extractor := foldRbrExtractor (mp := mp) 𝔽q β i) (i := ⟨1, rfl⟩) (stmtIn := stmtOStmtIn) + (transcript := FullTranscript.mk1 h_i) (challenge := r_i')) : + let incrementalFoldingBadEvent := + foldStepFreshDoomPreservationEvent 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ := ℓ) i stmtOStmtIn r_i' + incrementalFoldingBadEvent ∨ ( + ¬incrementalFoldingBadEvent ∧ + (∃ witMid : Witness (L := L) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i.succ, + (foldStepWitMidOracleConsistency (mp := mp) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ϑ := ϑ) + (i := i) stmtOStmtIn h_i r_i' witMid) + ∧ (badSumcheckEventProp r_i' h_i + (foldStepHStarFromWitMid (mp := mp) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (𝓑 := 𝓑) i stmtOStmtIn h_i r_i' witMid)) + ) + ) := by + classical + let incrementalFoldingBadEvent : Prop := + foldStepFreshDoomPreservationEvent 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ := ℓ) i stmtOStmtIn r_i' + unfold rbrExtractionFailureEvent at doomEscape + rcases doomEscape with ⟨witMid, h_kState_before_false, h_kState_after_true⟩ + simp only [foldKnowledgeStateFunction] at h_kState_before_false h_kState_after_true + unfold foldKStateProp at h_kState_before_false h_kState_after_true + simp only [Fin.isValue, Fin.castSucc_one, Fin.succ_one_eq_two, Nat.reduceAdd, + Transcript.concat] at h_kState_before_false h_kState_after_true + by_cases h_bad : incrementalFoldingBadEvent + · left + exact h_bad + · right + refine ⟨h_bad, ?_⟩ + -- Under ¬ fresh bad-event, the m=2 KState cannot be on the bad branch. + have h_after_good_exists : ∃ h_after_good, h_kState_after_true = Or.inr h_after_good := by + cases h_kState_after_true with + | inl h_bad_after => + exfalso + have h_bad_before : incrementalBadEventExistsProp 𝔽q β i.castSucc + (OracleFrontierIndex.mkFromStmtIdx i.castSucc) stmtOStmtIn.2 + stmtOStmtIn.1.challenges := + incrementalBadEventExistsProp_fold_step_backward 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ϑ := ϑ) + i stmtOStmtIn r_i' h_bad_after h_bad + exact h_kState_before_false (Or.inl h_bad_before) + | inr h_after_good => + exact ⟨h_after_good, rfl⟩ + rcases h_after_good_exists with ⟨h_after_good, rfl⟩ + have h_explicit_after : + h_i.val.eval (𝓑 0) + h_i.val.eval (𝓑 1) = stmtOStmtIn.1.sumcheck_target := by + simpa using h_after_good.1.1 + have h_sumcheck_after : + sumcheckConsistencyProp (𝓑 := 𝓑) (Polynomial.eval r_i' h_i.val) witMid.H := by + simpa using h_after_good.1.2 + have h_consistency : foldStepWitMidOracleConsistency 𝔽q β i stmtOStmtIn h_i r_i' witMid := + ⟨h_after_good.2.1, h_after_good.2.2.1⟩ + have h_left_from_consistency : + badSumcheckEventProp r_i' h_i + (foldStepHStarFromWitMid (mp := mp) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (𝓑 := 𝓑) i stmtOStmtIn h_i r_i' witMid) := by + have h_wit_struct_after : + witMid.H = projectToMidSumcheckPoly (L := L) (ℓ := ℓ) (t := witMid.t) + (m := mp.multpoly stmtOStmtIn.1.ctx) (i := i.succ) + (challenges := Fin.snoc stmtOStmtIn.1.challenges r_i') := by + exact h_consistency.1.1 + let H_before : L⦃≤ 2⦄[X Fin (ℓ - i.castSucc)] := + projectToMidSumcheckPoly (L := L) (ℓ := ℓ) (t := witMid.t) + (m := mp.multpoly stmtOStmtIn.1.ctx) (i := i.castSucc) + (challenges := stmtOStmtIn.1.challenges) + let h_star_extracted : L⦃≤ 2⦄[X] := getSumcheckRoundPoly ℓ 𝓑 (i := i) (h := H_before) + have h_eval_eq_extracted : + Polynomial.eval r_i' h_i.val = Polynomial.eval r_i' h_star_extracted.val := by + unfold sumcheckConsistencyProp at h_sumcheck_after + rw [h_wit_struct_after] at h_sumcheck_after + rw [projectToMidSumcheckPoly_succ (L := L) (ℓ := ℓ) (t := witMid.t) + (m := mp.multpoly stmtOStmtIn.1.ctx) (i := i) + (challenges := stmtOStmtIn.1.challenges) (r_i' := r_i')] at h_sumcheck_after + have h_sum_eq := + projectToNextSumcheckPoly_sum_eq (L := L) (𝓑 := 𝓑) (ℓ := ℓ) + (i := i) (Hᵢ := H_before) (rᵢ := r_i') + have h_sum_eq' : + Polynomial.eval r_i' h_star_extracted.val = + ∑ x ∈ (univ.map 𝓑) ^ᶠ (ℓ - i.succ), + (projectToNextSumcheckPoly (L := L) (ℓ := ℓ) (i := i) + (Hᵢ := H_before) (rᵢ := r_i')).val.eval x := by + simpa [h_star_extracted] using h_sum_eq + calc + Polynomial.eval r_i' h_i.val + = ∑ x ∈ (univ.map 𝓑) ^ᶠ (ℓ - i.succ), + (projectToNextSumcheckPoly (L := L) (ℓ := ℓ) (i := i) + (Hᵢ := H_before) (rᵢ := r_i')).val.eval x := h_sumcheck_after + _ = Polynomial.eval r_i' h_star_extracted.val := by + symm + exact h_sum_eq' + have h_hi_ne_extracted : h_i ≠ h_star_extracted := by + intro h_eq + apply h_kState_before_false + right + refine ⟨?_, ?_, ?_, ?_⟩ + · constructor + · exact h_explicit_after + · simpa [h_star_extracted, H_before, foldRbrExtractor, Fin.isValue] using h_eq + · unfold witnessStructuralInvariant + simp only [Fin.val_castSucc, foldRbrExtractor, Fin.zero_eta, Fin.isValue, + Fin.succ_zero_eq_one, Fin.mk_one, Fin.succ_one_eq_two, + Fin.coe_ofNat_eq_mod, Nat.reduceMod, and_self] + · exact h_consistency.2 + · have h_folding_after := h_after_good.2.2.2 + unfold oracleFoldingConsistencyProp at h_folding_after ⊢ + simpa [OracleFrontierIndex.val_mkFromStmtIdx, + OracleFrontierIndex.val_mkFromStmtIdxCastSuccOfSucc] using h_folding_after + simpa [foldStepHStarFromWitMid, foldStepWitBeforeFromWitMid, h_star_extracted, + H_before, foldRbrExtractor, Fin.isValue] using + (show badSumcheckEventProp r_i' h_i h_star_extracted from + ⟨h_hi_ne_extracted, h_eval_eq_extracted⟩) + exact ⟨witMid, h_consistency, h_left_from_consistency⟩ + +#check prop_4_20_2_incremental_bad_event_probability +/-! Per-transcript bound: for the first prover message `msg0`, the probability (over the verifier + challenge `y`) that extraction fails is at most `foldKnowledgeError`. Stated for + `P (FullTranscript.mk1 msg0)` so it matches the goal after `tsum_uniform_Pr_eq_Pr` in the main + soundness proof. + **Proof strategy:** + 1. **Implication**: Show that extraction failure `P(tr, y)` implies either + - a SINGLE sumcheck “bad” event + - or an incremental folding bad event (bad oracle / consistency failure) + 2. **Monotonicity**: Conclude `Pr[P] ≤ Pr[SZ ∨ BE]` via `prob_mono`. + 3. **Union bound**: Apply `Pr_or_le` to get `Pr[SZ ∨ BE] ≤ Pr[SZ] + Pr[BE]`. + 4. **Schwartz–Zippel**: Bound `Pr[SZ]` by `1/|L|` using univariate degree-1 + agreement (lemmas from Instances.lean) + 5. **Bad event**: Bound `Pr[BE]` using the incremental folding bad-event probability + (`prop_4_20_2_incremental_bad_event_probability`). + 6. **Combine**: Add the two bounds and match the RHS to `foldKnowledgeError`. -/ +lemma foldStep_doom_escape_probability_bound (i : Fin ℓ) + (stmtOStmtIn : (Statement (L := L) Context i.castSucc) × (∀ j, + OracleStatement 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i.castSucc j)) + (h_i : (pSpecFold (L := L)).Message ⟨0, rfl⟩) : + Pr_{ let y ← $ᵖ L }[ + rbrExtractionFailureEvent + (kSF := foldKnowledgeStateFunction (mp := mp) (𝓑 := 𝓑) + (init := init) (impl := impl) (σ := σ) 𝔽q β i) + (extractor := foldRbrExtractor (mp := mp) 𝔽q β i) ⟨1, rfl⟩ + stmtOStmtIn (FullTranscript.mk1 h_i) y ] ≤ + foldKnowledgeError 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i ⟨1, by rfl⟩ := by + classical + let doomEvent := fun y : L => + rbrExtractionFailureEvent + (kSF := foldKnowledgeStateFunction (mp := mp) (𝓑 := 𝓑) + (init := init) (impl := impl) (σ := σ) 𝔽q β i) + (extractor := foldRbrExtractor (mp := mp) 𝔽q β i) ⟨1, rfl⟩ + stmtOStmtIn (FullTranscript.mk1 h_i) y + let sumcheckBadEvent : L → Prop := fun y => + let incrementalFoldingBadEvent := + foldStepFreshDoomPreservationEvent 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ := ℓ) i stmtOStmtIn y + (¬incrementalFoldingBadEvent ∧ + (∃ witMid : Witness (L := L) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i.succ, + (foldStepWitMidOracleConsistency (mp := mp) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ϑ := ϑ) + (i := i) stmtOStmtIn h_i y witMid) + ∧ (badSumcheckEventProp y h_i + (foldStepHStarFromWitMid (mp := mp) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (𝓑 := 𝓑) i stmtOStmtIn h_i y witMid)) + )) + let incrementalBadFoldEvent := fun y : L => + foldStepFreshDoomPreservationEvent 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ := ℓ) i stmtOStmtIn y + let incrementalBadFoldEvent_or_sumcheckBadEvent := fun y : L => + (incrementalBadFoldEvent y) ∨ (sumcheckBadEvent y) + have h_prob_mono := prob_mono (D := $ᵖ L) + (f := doomEvent) (g := incrementalBadFoldEvent_or_sumcheckBadEvent) + (h_imp := by + intro y h_doomEscape + have h_imp := (foldStep_rbrExtractionFailureEvent_imply_sumcheck_or_badEvent + (mp := mp) (𝓑 := 𝓑) (init := init) (impl := impl) 𝔽q β + (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := i) (stmtOStmtIn := stmtOStmtIn) (h_i := h_i) + (r_i' := y) (doomEscape := h_doomEscape)) + dsimp only [incrementalBadFoldEvent_or_sumcheckBadEvent, sumcheckBadEvent, + incrementalBadFoldEvent] + by_cases h_bad : foldStepFreshDoomPreservationEvent 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ := ℓ) i stmtOStmtIn y + · exact Or.inl h_bad + · simpa [h_bad] using h_imp + ) + refine le_trans h_prob_mono ?_ + dsimp only [incrementalBadFoldEvent_or_sumcheckBadEvent, foldKnowledgeError] + apply le_trans ( + Pr_or_le ($ᵖ L) (f := incrementalBadFoldEvent) (g := sumcheckBadEvent) + ) + conv_rhs => simp only [ENNReal.coe_add]; rw [add_comm] + apply add_le_add + · dsimp only [incrementalBadFoldEvent, foldStepFreshDoomPreservationEvent] + let stmtIdxBefore : Fin (ℓ + 1) := i.castSucc + let challengesBefore : Fin stmtIdxBefore → L := stmtOStmtIn.1.challenges + let j := getLastOraclePositionIndex ℓ ϑ i.castSucc + let curOracleDomainIdx : Fin r := ⟨oraclePositionToDomainIndex (positionIdx := j), by omega⟩ + let kBefore : ℕ := min ϑ (stmtIdxBefore.val - curOracleDomainIdx.val) + have h_j_val : j.val = i.val / ϑ := by + have h_i_lt_ℓ : i.val < ℓ := i.isLt + have h_i_cast_lt_ℓ : i.val < ℓ := by + simp only [h_i_lt_ℓ] + dsimp only [j, getLastOraclePositionIndex] + unfold toOutCodewordsCount + simp only [Fin.val_castSucc, h_i_lt_ℓ, ↓reduceIte, add_tsub_cancel_right] + have h_cur_eq : curOracleDomainIdx.val = (i.val / ϑ) * ϑ := by + dsimp only [curOracleDomainIdx, oraclePositionToDomainIndex] + simp only [h_j_val] + have h_diff_lt : stmtIdxBefore.val - curOracleDomainIdx.val < ϑ := by + have h_div_mod : (i.val / ϑ) * ϑ + i.val % ϑ = i.val := by + simpa [Nat.mul_comm] using (Nat.div_add_mod i.val ϑ) + have h_cur_le : curOracleDomainIdx.val ≤ stmtIdxBefore.val := by + dsimp only [stmtIdxBefore] + calc + curOracleDomainIdx.val = (i.val / ϑ) * ϑ := h_cur_eq + _ ≤ i.val := Nat.div_mul_le_self i.val ϑ + have h_sum : curOracleDomainIdx.val + i.val % ϑ = stmtIdxBefore.val := by + dsimp only [stmtIdxBefore] + calc + curOracleDomainIdx.val + i.val % ϑ = (i.val / ϑ) * ϑ + i.val % ϑ := by + simp only [h_cur_eq] + _ = i.val := h_div_mod + have h_diff_eq : stmtIdxBefore.val - curOracleDomainIdx.val = i.val % ϑ := by omega + rw [h_diff_eq] + exact Nat.mod_lt i.val (Nat.pos_of_neZero ϑ) + have h_kBefore_lt : kBefore < ϑ := by + exact lt_of_le_of_lt + (Nat.min_le_right ϑ (stmtIdxBefore.val - curOracleDomainIdx.val)) h_diff_lt + let destIdx : Fin r := ⟨curOracleDomainIdx.val + ϑ, by + have h1 := oracle_index_add_steps_le_ℓ ℓ ϑ (i := i.castSucc) (j := j) + have h2 : ℓ + 𝓡 < r := h_ℓ_add_R_rate + have _ : 𝓡 > 0 := Nat.pos_of_neZero 𝓡 + dsimp only [oraclePositionToDomainIndex, curOracleDomainIdx] + omega + ⟩ + let r_prefix : Fin kBefore → L := fun cId => challengesBefore + ⟨curOracleDomainIdx.val + cId.val, by + have h_k_le_stmt : kBefore ≤ stmtIdxBefore.val - curOracleDomainIdx.val := + Nat.min_le_right ϑ (stmtIdxBefore.val - curOracleDomainIdx.val) + have h_cId_lt_k : cId.val < kBefore := cId.isLt + omega + ⟩ + have h_res := prop_4_20_2_incremental_bad_event_probability 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (block_start_idx := curOracleDomainIdx) + (midIdx_i := ⟨curOracleDomainIdx.val + kBefore, by + apply lt_r_of_le_ℓ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + have h_k_le : kBefore ≤ ϑ := Nat.min_le_left ϑ (stmtIdxBefore.val - curOracleDomainIdx.val) + have h_add_le : curOracleDomainIdx.val + ϑ ≤ ℓ := + oracle_index_add_steps_le_ℓ ℓ ϑ (i := i.castSucc) (j := j) + omega + ⟩) + (midIdx_i_succ := ⟨curOracleDomainIdx.val + kBefore + 1, by + apply lt_r_of_le_ℓ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + have h_k_le : kBefore + 1 ≤ ϑ := Nat.succ_le_of_lt h_kBefore_lt + have h_add_le : curOracleDomainIdx.val + ϑ ≤ ℓ := + oracle_index_add_steps_le_ℓ ℓ ϑ (i := i.castSucc) (j := j) + omega + ⟩) + (destIdx := destIdx) (k := kBefore) + (h_k_lt := h_kBefore_lt) + (h_midIdx_i := by simp only) + (h_midIdx_i_succ := by simp only) + (h_destIdx := rfl) + (h_destIdx_le := oracle_index_add_steps_le_ℓ ℓ ϑ (i := i.castSucc) (j := j)) + (f_block_start := stmtOStmtIn.2 j) + (r_prefix := r_prefix) + dsimp only [destIdx, curOracleDomainIdx, j, kBefore, r_prefix, stmtIdxBefore, challengesBefore] + at h_res + have h_cur_le_stmt : curOracleDomainIdx.val ≤ stmtIdxBefore.val := by + dsimp only [stmtIdxBefore] + calc + curOracleDomainIdx.val = (i.val / ϑ) * ϑ := h_cur_eq + _ ≤ i.val := Nat.div_mul_le_self i.val ϑ + have h_kBefore_eq : kBefore = stmtIdxBefore.val - curOracleDomainIdx.val := by + dsimp only [kBefore] + exact Nat.min_eq_right (Nat.le_of_lt h_diff_lt) + have h_kAfter_eq : min ϑ (i.succ.val - curOracleDomainIdx.val) = kBefore + 1 := by + have h_cur_le_i : curOracleDomainIdx.val ≤ i.val := by + simpa [stmtIdxBefore] using h_cur_le_stmt + have h_sub_succ : i.val + 1 - curOracleDomainIdx.val + = (i.val - curOracleDomainIdx.val) + 1 := by + simpa [Nat.succ_eq_add_one] using (Nat.succ_sub h_cur_le_i) + have h_kBefore_eq' : kBefore = i.val - curOracleDomainIdx.val := by + simpa [stmtIdxBefore] using h_kBefore_eq + simp only [Fin.val_succ] + rw [h_sub_succ, ← h_kBefore_eq'] + exact Nat.min_eq_right (Nat.succ_le_of_lt h_kBefore_lt) + have h_snoc_eq : + ∀ r_new : L, + (fun cId : Fin (kBefore + 1) => + if h : curOracleDomainIdx.val + cId.val < stmtIdxBefore.val then + challengesBefore ⟨curOracleDomainIdx.val + cId.val, h⟩ + else + r_new) = Fin.snoc r_prefix r_new := by + intro r_new + funext cId + by_cases h_lt : cId.val < kBefore + · have h_guard : curOracleDomainIdx.val + cId.val < stmtIdxBefore.val := by + omega + simp [Fin.snoc, r_prefix, h_lt, h_guard] + · have h_guard_false : ¬ curOracleDomainIdx.val + cId.val < stmtIdxBefore.val := by + omega + simp [Fin.snoc, h_lt, h_guard_false] + conv_rhs => simp only [ne_eq, Nat.cast_eq_zero, Fintype.card_ne_zero, not_false_eq_true, + ENNReal.coe_div, ENNReal.coe_natCast] + exact h_res + · dsimp only [sumcheckBadEvent] + -- Strategy: ignore the `foldStepFreshDoomPreservationEvent`, plus `oracleWitnessConsistency` + -- guarantees uniqueness of witMid, then we can transform this to prove the bound via + -- `probability_bound_badSumcheckEventProp` + let compatPred : MultilinearPoly L ℓ → Prop := fun t => + firstOracleWitnessConsistencyProp 𝔽q β t (getFirstOracle 𝔽q β stmtOStmtIn.2) + by_cases hCompat : ∃ t : MultilinearPoly L ℓ, compatPred t + · rcases hCompat with ⟨t_fixed, h_t_fixed_compat⟩ + let H_fixed : L⦃≤ 2⦄[X Fin (ℓ - i.castSucc)] := + projectToMidSumcheckPoly (L := L) (ℓ := ℓ) (t := t_fixed) + (m := mp.multpoly stmtOStmtIn.1.ctx) + (i := i.castSucc) (challenges := stmtOStmtIn.1.challenges) + let h_star_fixed : L⦃≤ 2⦄[X] := getSumcheckRoundPoly ℓ 𝓑 (i := i) (h := H_fixed) + have h_prob_mono_sum := prob_mono (D := $ᵖ L) + (f := fun y => sumcheckBadEvent y) + (g := fun y => badSumcheckEventProp y h_i h_star_fixed) + (h_imp := by + intro y h_sum + rcases h_sum with ⟨_h_not_fresh, witMid, h_cons, h_bad⟩ + have h_t_eq : witMid.t = t_fixed := + firstOracleWitnessConsistency_unique 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ϑ := ϑ) (i := i) + (oStmt := stmtOStmtIn.2) (h₁ := h_cons.2) (h₂ := h_t_fixed_compat) + simpa [h_star_fixed, H_fixed, foldStepHStarFromWitMid, + foldStepWitBeforeFromWitMid, foldRbrExtractor, Fin.isValue, h_t_eq] + using h_bad) + refine le_trans h_prob_mono_sum ?_ + have h_sz := probability_bound_badSumcheckEventProp (h_i := h_i) (h_star := h_star_fixed) + conv_rhs => + rw [ENNReal.coe_div (hr := by + simp only [ne_eq, Nat.cast_eq_zero, Fintype.card_ne_zero, not_false_eq_true])] + simp only [ENNReal.coe_ofNat, ENNReal.coe_natCast] + exact h_sz + · have h_prob_mono_false := prob_mono (D := $ᵖ L) + (f := fun y => sumcheckBadEvent y) + (g := fun _ => False) + (h_imp := by + intro y h_sum + rcases h_sum with ⟨_h_not_fresh, witMid, h_cons, _h_bad⟩ + exact (hCompat ⟨witMid.t, h_cons.2⟩).elim) + refine le_trans h_prob_mono_false ?_ + simp only [PMF.monad_pure_eq_pure, PMF.monad_bind_eq_bind, PMF.bind_const, PMF.pure_apply, + eq_iff_iff, iff_false, not_true_eq_false, ↓reduceIte, _root_.zero_le] + +/-! RBR knowledge soundness for a single round oracle verifier -/ +open Classical in theorem foldOracleVerifier_rbrKnowledgeSoundness (i : Fin ℓ) : (foldOracleVerifier 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) (mp := mp) i).rbrKnowledgeSoundness init impl @@ -525,12 +1399,82 @@ theorem foldOracleVerifier_rbrKnowledgeSoundness (i : Fin ℓ) : (relOut := foldStepRelOut (mp := mp) 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) i) (foldKnowledgeError 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i) := by - use fun _ => Witness (L := L) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i.castSucc - use foldRbrExtractor (mp:=mp) (𝓡 := 𝓡) (ϑ := ϑ) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i - use foldKnowledgeStateFunction (mp:=mp) (𝓡 := 𝓡) (ϑ := ϑ) - (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) 𝔽q β i - intro stmtIn witIn prover j - sorry + apply OracleReduction.unroll_rbrKnowledgeSoundness (kSF := foldKnowledgeStateFunction + (mp:=mp) (𝓡 := 𝓡) (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) 𝔽q β i) + intro stmtOStmtIn witIn prover j initState + let P := rbrExtractionFailureEvent + (foldKnowledgeStateFunction (mp := mp) (𝓑 := 𝓑) (init := init) (impl := impl) (σ := σ) 𝔽q β i) + (foldRbrExtractor (mp := mp) 𝔽q β i) + j + stmtOStmtIn + rw [OracleReduction.probEvent_soundness_goal_unroll_log' (pSpec := pSpecFold + (L := L)) (P := P) (impl := impl) (prover := prover) (i := j) (stmt := stmtOStmtIn) + (wit := witIn) (s := initState)] + have h_j_eq_1 : j = ⟨1, rfl⟩ := by + match j with + | ⟨0, h0⟩ => nomatch h0 + | ⟨1, _⟩ => rfl + subst h_j_eq_1 + conv_lhs => simp only [Fin.isValue, Fin.castSucc_one]; + rw [OracleReduction.soundness_unroll_runToRound_1_P_to_V_pSpec_2 + (pSpec := pSpecFold (L := L)) (prover := prover) (hDir0 := rfl)] + simp only [Fin.isValue, Challenge, Matrix.cons_val_one, Matrix.cons_val_zero, ChallengeIdx, + QueryImpl.addLift_def, QueryImpl.liftTarget_self, Message, Fin.succ_zero_eq_one, Nat.reduceAdd, + Fin.coe_ofNat_eq_mod, Nat.reduceMod, FullTranscript.mk1_eq_snoc, bind_pure_comp, + liftComp_eq_liftM, bind_map_left, simulateQ_bind, simulateQ_map, StateT.run'_eq, + StateT.run_bind, StateT.run_map, map_bind, Functor.map_map] + rw [probEvent_bind_eq_tsum] + apply OracleReduction.ENNReal.tsum_mul_le_of_le_of_sum_le_one + · -- Bound the conditional probability for each transcript + intro x + -- rw [OracleComp.probEvent_map] + simp only [Fin.isValue, probEvent_map] + let q : OracleQuery [(pSpecFold (L := L)).Challenge]ₒ _ := query ⟨⟨1, by rfl⟩, ()⟩ + erw [OracleReduction.probEvent_StateT_run_ignore_state + (comp := simulateQ (impl.addLift challengeQueryImpl) (liftM (query q.input))) + (s := x.2) + (P := fun a => P (FullTranscript.mk1 x.1.1) (q.cont a))] + rw [probEvent_eq_tsum_ite] + erw [simulateQ_query] + simp only [ChallengeIdx, Challenge, Fin.isValue, Nat.reduceAdd, Fin.castSucc_one, + Fin.coe_ofNat_eq_mod, Nat.reduceMod, monadLift_self, + QueryImpl.addLift_def, QueryImpl.liftTarget_self, StateT.run'_eq, StateT.run_map, + Functor.map_map, ge_iff_le] + have h_L_inhabited : Inhabited L := ⟨0⟩ + conv_lhs => + enter [1, x_1, 2, 1, 2] + rw [addLift_challengeQueryImpl_input_run_eq_liftM_run (impl := impl) (q := q) (s := x.2)] + erw [StateT.run_monadLift, monadLift_self, liftComp_id] + rw [bind_pure_comp] + conv => + enter [1, 1, x_1, 2] + rw [Functor.map_map] + rw [← probEvent_eq_eq_probOutput] + rw [probEvent_map] + rw [OracleQuery.cont_apply] + dsimp only [MonadLift.monadLift] + rw [OracleQuery.cont_apply] + dsimp only [q] + simp_rw [OracleQuery.input_query, OracleQuery.snd_query] + conv_lhs => change (∑' (x_1 : L), _) + simp only [Function.comp_id] + conv => + enter [1, 1, x_1, 2] + rw [probEvent_eq_eq_probOutput] + change Pr[=x_1 | $ᵗ L] + rw [OracleReduction.probOutput_uniformOfFintype_eq_Pr (L := _) (x := x_1)] + rw [OracleReduction.tsum_uniform_Pr_eq_Pr + (L := L) (P := fun x_1 => P (FullTranscript.mk1 x.1.1) (q.2 x_1))] + -- Now the goal is in do-notation form, which is exactly what Pr_ notation expands to + -- Make this explicit using change + change Pr_{ let y ← $ᵖ L }[ P (FullTranscript.mk1 x.1.1) y ] ≤ + foldKnowledgeError 𝔽q β i ⟨1, by rfl⟩ + -- Apply the per-transcript bound + exact foldStep_doom_escape_probability_bound 𝔽q β (i := i) + (stmtOStmtIn := stmtOStmtIn) (h_i := x.1.1) (init := init) (impl := impl) (mp := mp) + (𝓑 := 𝓑) (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + · -- Prove: ∑' x, [=x|transcript computation] ≤ 1 + apply tsum_probOutput_le_one end FoldStep section CommitStep @@ -548,14 +1492,14 @@ def getCommitProverFinalOutput (i : Fin ℓ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i 0) : (↥(sDomain 𝔽q β h_ℓ_add_R_rate ⟨↑i + 1, by omega⟩) → L) × commitPrvState (Context := Context) 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i 1 := - let (stmt, oStmtIn, wit) := inputPrvState - let fᵢ_succ := wit.f + let (stmtIn, oStmtIn, witIn) := inputPrvState + let fᵢ_succ := witIn.f let oStmtOut := snoc_oracle 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (oStmtIn := oStmtIn) (newOracleFn := fᵢ_succ) (h_destIdx := by rfl) -- The only thing the prover does is to sends f_{i+1} as an oracle - (fᵢ_succ, (stmt, oStmtOut, wit)) + (fᵢ_succ, (stmtIn, oStmtOut, witIn)) -/-- The prover for the `i`-th round of Binary commitmentfold. -/ +/-! The prover for the `i`-th round of Binary commitmentfold. -/ noncomputable def commitOracleProver (i : Fin ℓ) : OracleProver (oSpec := []ₒ) -- current round @@ -568,24 +1512,19 @@ noncomputable def commitOracleProver (i : Fin ℓ) : (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i.succ) (WitOut := Witness (L := L) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ := ℓ) i.succ) (pSpec := pSpecCommit 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i) where - PrvState := commitPrvState 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i - input := fun ⟨⟨stmt, oStmt⟩, wit⟩ => (stmt, oStmt, wit) - sendMessage -- There are either 2 or 3 messages in the pSpec depending on commitment rounds | ⟨0, _⟩ => fun inputPrvState => by let res := getCommitProverFinalOutput 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i inputPrvState exact pure res - receiveChallenge | ⟨0, h⟩ => nomatch h -- i.e. contradiction - output := fun ⟨stmt, oStmt, wit⟩ => by exact pure ⟨⟨stmt, oStmt⟩, wit⟩ -/-- The oracle verifier for the `i`-th round of Binary commitmentfold. -/ +/-! The oracle verifier for the `i`-th round of Binary commitmentfold. -/ noncomputable def commitOracleVerifier (i : Fin ℓ) (hCR : isCommitmentRound ℓ ϑ i) : OracleVerifier (oSpec := []ₒ) @@ -598,18 +1537,16 @@ noncomputable def commitOracleVerifier (i : Fin ℓ) (hCR : isCommitmentRound (OStmtOut := OracleStatement 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i.succ) (pSpec := pSpecCommit 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i) where - -- The core verification logic. Takes the input statement `stmtIn` and the transcript, and -- performs an oracle computation that outputs a new statement verify := fun stmtIn _pSpecChallenges => do pure stmtIn - embed := (commitStepLogic (mp := mp) 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) i hCR).embed hEq := (commitStepLogic (mp := mp) 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) i hCR).hEq -/-- The oracle reduction that is the `i`-th round of Binary commitmentfold. -/ +/-! The oracle reduction that is the `i`-th round of Binary commitmentfold. -/ noncomputable def commitOracleReduction (i : Fin ℓ) (hCR : isCommitmentRound ℓ ϑ i) : OracleReduction (oSpec := []ₒ) (StmtIn := Statement (L := L) Context i.succ) @@ -630,7 +1567,7 @@ variable {R : Type} [CommSemiring R] [DecidableEq R] [SampleableType R] variable {σ : Type} {init : ProbComp σ} {impl : QueryImpl []ₒ (StateT σ ProbComp)} -/-- +/-! Perfect completeness for the commit step oracle reduction. This theorem proves that the honest prover-verifier interaction for the commit step @@ -674,7 +1611,6 @@ theorem commitOracleReduction_perfectCompleteness (hInit : NeverFail init) (i : (mp := mp) i (hCR := hCR)) let strongly_complete : step.IsStronglyComplete := commitStep_is_logic_complete (L := L) 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) (mp := mp) (i := i) (hCR := hCR) - -- Step 4: Split into safety and correctness goals refine ⟨?_, ?_⟩ -- GOAL 1: SAFETY - Prove the verifier never crashes ([⊥|...] = 0) @@ -694,7 +1630,6 @@ theorem commitOracleReduction_perfectCompleteness (hInit : NeverFail init) (i : simp only [ChallengeIdx, Challenge, Fin.isValue, Matrix.cons_val_one, Matrix.cons_val_zero, liftComp_eq_liftM, OptionT.probFailure_lift, HasEvalPMF.probFailure_eq_zero] rw [true_and] - intro r_i' h_r_i'_mem_query_1_support conv => enter [2]; @@ -763,9 +1698,8 @@ theorem commitOracleReduction_perfectCompleteness (hInit : NeverFail init) (i : simp only -- Step 2a: Simplify the support membership to extract the challenge simp only [ - support_bind, support_pure, support_liftComp, - Set.mem_iUnion, Set.mem_singleton_iff, - exists_eq_left, exists_prop, Prod.exists + support_bind, support_pure, + Set.mem_iUnion, Set.mem_singleton_iff, exists_prop, Prod.exists ] at hx_mem_support conv at hx_mem_support => erw [OptionT.support_mk, support_pure] @@ -778,7 +1712,7 @@ theorem commitOracleReduction_perfectCompleteness (hInit : NeverFail init) (i : exists_eq_right_right', liftM_pure, support_pure, exists_eq_left] dsimp only [monadLift, MonadLift.monadLift] simp only [Fin.isValue, Message, Matrix.cons_val_zero, Fin.succ_zero_eq_one, ChallengeIdx, - Challenge, Fin.reduceLast, liftComp_eq_liftM, MessageIdx] at hx_mem_support + Challenge, Fin.reduceLast, liftComp_eq_liftM] at hx_mem_support obtain ⟨newOracleFn, lastPrvState, h_prvFinalState_eq, ⟨h_prvOut_mem_support, h_verOut_mem_support⟩⟩ := hx_mem_support conv at h_prvFinalState_eq => @@ -808,7 +1742,6 @@ theorem commitOracleReduction_perfectCompleteness (hInit : NeverFail init) (i : erw [support_pure] simp only [Set.mem_singleton_iff, Option.some.injEq, Prod.mk.injEq] -- pure equalities now - -- Step 2e: Apply the logic completeness lemma obtain ⟨h_V_check, h_rel, h_agree⟩ := strongly_complete (stmtIn := stmtIn) (witIn := witIn) (h_relIn := h_relIn) @@ -818,22 +1751,18 @@ theorem commitOracleReduction_perfectCompleteness (hInit : NeverFail init) (i : case zero => simp at hj case succ j' => exact j'.elim0 ) - obtain ⟨newOracleFn_eq, lastPrvState_eq⟩ := h_prvFinalState_eq obtain ⟨⟨prvStmtOut_eq, prvOStmtOut_eq⟩, prvWitOut_eq⟩ := h_prvOut_mem_support obtain ⟨verStmtOut_eq, verOStmtOut_eq⟩ := h_verOut_mem_support - -- Step 2f: Simplify the verifier check -- simp only [commitStepLogic] at h_V_check -- unfold FullTranscript.mk1 at h_V_check simp only [Fin.isValue] at h_V_check - rw [ -- lastPrvState_eq, prvStmtOut_eq, prvOStmtOut_eq, prvWitOut_eq, verStmtOut_eq, verOStmtOut_eq, ] - constructor · rw [newOracleFn_eq] exact h_rel @@ -851,7 +1780,7 @@ def commitKnowledgeError {i : Fin ℓ} simp only [ne_eq, reduceCtorEq, not_false_eq_true, Matrix.cons_val_fin_one, Direction.not_P_to_V_eq_V_to_P] at hj -- not a V challenge -/-- The round-by-round extractor for a single round. +/-! The round-by-round extractor for a single round. Since f^(0) is always available, we can invoke the extractMLP function directly. -/ noncomputable def commitRbrExtractor (i : Fin ℓ) : Extractor.RoundByRound []ₒ @@ -866,29 +1795,32 @@ noncomputable def commitRbrExtractor (i : Fin ℓ) : extractMid := fun _ _ _ witMidSucc => witMidSucc extractOut := fun _ _ witOut => witOut -/-- Note : stmtIn and witMid already advances to state `(i+1)` from the fold step, +/-! Note : stmtIn and witMid already advances to state `(i+1)` from the fold step, while oStmtIn is not. -/ def commitKStateProp (i : Fin ℓ) (m : Fin (1 + 1)) (stmtIn : Statement (L := L) Context i.succ) (witMid : Witness (L := L) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i.succ) (oStmtIn : (i_1 : Fin (toOutCodewordsCount ℓ ϑ i.castSucc)) → OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ i.castSucc i_1) + (tr : Transcript m (pSpecCommit 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i)) : Prop := match m with | ⟨0, _⟩ => -- same as relIn masterKStateProp (mp := mp) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) -- (𝓑 := 𝓑) (stmtIdx := i.succ) (oracleIdx := OracleFrontierIndex.mkFromStmtIdxCastSuccOfSucc i) (stmt := stmtIn) (wit := witMid) (oStmt := oStmtIn) - (localChecks := True) - | ⟨1, _⟩ => -- implied by relOut - let ⟨_, stmtOut, oStmtOut, witOut⟩ := getCommitProverFinalOutput 𝔽q β (ϑ := ϑ) - (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i ⟨stmtIn, oStmtIn, witMid⟩ + (localChecks := sumcheckConsistencyProp (𝓑 := 𝓑) stmtIn.sumcheck_target witMid.H) + | ⟨1, _⟩ => -- implied by relOut: use transcript message as oracle (what verifier sees) + -- The verifier sees tr.messages ⟨0, rfl⟩ as the new oracle, not witMid.f + let newOracle := tr.messages ⟨0, rfl⟩ + let oStmtOut := snoc_oracle 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (oStmtIn := oStmtIn) (newOracleFn := newOracle) (h_destIdx := by rfl) masterKStateProp (mp := mp) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) -- (𝓑 := 𝓑) (stmtIdx := i.succ) (oracleIdx := OracleFrontierIndex.mkFromStmtIdx i.succ) - (stmt := stmtOut) (wit := witOut) (oStmt := oStmtOut) - (localChecks := True) + (stmt := stmtIn) (wit := witMid) (oStmt := oStmtOut) + (localChecks := sumcheckConsistencyProp (𝓑 := 𝓑) stmtIn.sumcheck_target witMid.H) -/-- Knowledge state function (KState) for single round -/ +/-! Knowledge state function (KState) for single round -/ def commitKState (i : Fin ℓ) (hCR : isCommitmentRound ℓ ϑ i) : (commitOracleVerifier 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) (mp := mp) i hCR).KnowledgeStateFunction init impl @@ -898,18 +1830,128 @@ def commitKState (i : Fin ℓ) (hCR : isCommitmentRound ℓ ϑ i) : (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) i.succ) (extractor := commitRbrExtractor 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i) where toFun := fun m ⟨stmtIn, oStmtIn⟩ tr witMid => - commitKStateProp 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) -- (𝓑 := 𝓑) - (i := i) (m := m) (stmtIn := stmtIn) (witMid := witMid) (oStmtIn := oStmtIn) (mp:=mp) - toFun_empty := fun stmtIn witMid => by sorry + commitKStateProp 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) + (i := i) (m := m) (stmtIn := stmtIn) (witMid := witMid) (oStmtIn := oStmtIn) + (tr := tr) (mp:=mp) + toFun_empty := fun ⟨stmtIn, oStmtIn⟩ witMid => by + -- commitKStateProp 0 = foldStepRelOutProp i (same masterKStateProp) + rw [cast_eq] + simp only [foldStepRelOut, foldStepRelOutProp, Set.mem_setOf_eq, commitKStateProp] toFun_next := fun m hDir (stmtIn, oStmtIn) tr msg witMid => by + -- For pSpecCommit, the only P_to_V message is at index 0 + -- So m = 0, m.succ = 1, m.castSucc = 0 + have h_m_eq_0 : m = 0 := by + cases m using Fin.cases with + | zero => rfl + | succ m' => omega + subst h_m_eq_0 + intro h_kState_round1 + unfold commitKStateProp masterKStateProp at h_kState_round1 ⊢ + simp only [Fin.isValue, Fin.succ_zero_eq_one, Nat.reduceAdd, Fin.mk_one, + Fin.coe_ofNat_eq_mod, Nat.reduceMod] at h_kState_round1 + simp only [Fin.castSucc_zero] + -- Round-1 state is bad ∨ good under Option B. + cases h_kState_round1 with + | inl hBad => + left + have hBad_cast := + incrementalBadEventExistsProp_commit_step_backward 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i hCR oStmtIn + _ _ hBad + simpa using hBad_cast + | inr hGood => + have h_sumcheck : sumcheckConsistencyProp (𝓑 := 𝓑) stmtIn.sumcheck_target witMid.H := hGood.1 + have h_struct : witnessStructuralInvariant 𝔽q β (mp := mp) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) stmtIn witMid := hGood.2.1 + have h_init : firstOracleWitnessConsistencyProp 𝔽q β witMid.t + (getFirstOracle 𝔽q β + (snoc_oracle 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (h_destIdx := rfl) + oStmtIn + (msg : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (domainIdx := ⟨i.val + 1, by omega⟩)))) := hGood.2.2.1 + have h_fold : oracleFoldingConsistencyProp 𝔽q β (i := i.succ) stmtIn.challenges + (snoc_oracle 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (h_destIdx := rfl) + oStmtIn + (msg : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (domainIdx := ⟨i.val + 1, by omega⟩))) := hGood.2.2.2 + have h_init_cast : firstOracleWitnessConsistencyProp 𝔽q β witMid.t + (getFirstOracle 𝔽q β oStmtIn) := by + have h_pos : 0 < toOutCodewordsCount ℓ ϑ i.castSucc := by + exact Nat.pos_of_neZero (toOutCodewordsCount ℓ ϑ i.castSucc) + simpa [commitRbrExtractor, getFirstOracle, snoc_oracle, h_pos] using h_init + have h_fold_cast : + oracleFoldingConsistencyProp 𝔽q β (i := i.castSucc) (Fin.init stmtIn.challenges) + oStmtIn := by + exact oracleFoldingConsistencyProp_commit_step_backward 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i hCR _ oStmtIn _ h_fold + right + exact ⟨h_sumcheck, h_struct, h_init_cast, h_fold_cast⟩ + toFun_full := fun ⟨stmtIn, oStmtIn⟩ tr witOut probEvent_relOut_gt_0 => by + -- probEvent_relOut_gt_0: the relOut is satisified under oracle verifier's execution + -- Now we simp the probEvent_relOut_gt_0 to extract equalities for stmtOut, oStmtOut as + -- deterministic computations (oracle verifier execution) of stmtIn, oStmtIn + simp only [StateT.run'_eq, gt_iff_lt, probEvent_pos_iff, Prod.exists] at probEvent_relOut_gt_0 + rcases probEvent_relOut_gt_0 with ⟨stmtOut, oStmtOut, h_output_mem_V_run_support, h_relOut⟩ + have h_output_mem_V_run_support' : + some (stmtOut, oStmtOut) ∈ + (do + let s ← init + Prod.fst <$> + (simulateQ impl + (Verifier.run (stmtIn, oStmtIn) tr + (commitOracleVerifier 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (𝓑 := 𝓑) (mp := mp) i hCR).toVerifier)).run s).support := by + exact (OptionT.mem_support_iff + (mx := OptionT.mk (do + let s ← init + Prod.fst <$> + (simulateQ impl + (Verifier.run (stmtIn, oStmtIn) tr + (commitOracleVerifier 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (𝓑 := 𝓑) (mp := mp) i hCR).toVerifier)).run s)) + (x := (stmtOut, oStmtOut))).1 h_output_mem_V_run_support + simp only [support_bind, Set.mem_iUnion, exists_prop] at h_output_mem_V_run_support' + rcases h_output_mem_V_run_support' with ⟨s, hs_init, h_output_mem_V_run_support⟩ + conv at h_output_mem_V_run_support => + simp only [Verifier.run, OracleVerifier.toVerifier] + simp only [commitOracleVerifier] + simp only [support_bind, Set.mem_iUnion] + dsimp only [StateT.run] + simp only [simulateQ_pure, pure_bind, Function.comp_apply] + dsimp only [ProbComp] + simp only [MessageIdx, support_pure, Set.mem_singleton_iff, Prod.mk.injEq, exists_eq_right, + exists_and_right] + --- + erw [simulateQ_bind] + erw [simulateQ_pure] + simp only [pure_bind, support_map, Set.mem_image, Prod.exists, exists_and_right, + exists_eq_right] + erw [simulateQ_pure, support_pure] + simp only [Set.mem_singleton_iff, Prod.mk.injEq, Option.some.injEq, exists_eq_right] + rcases h_output_mem_V_run_support with ⟨h_stmtOut_eq, h_oStmtOut_eq⟩ simp only [Nat.reduceAdd] - intro kState_next - sorry - toFun_full := fun (stmtIn, oStmtIn) tr witOut=> by - sorry - -omit [CharP L 2] [SampleableType L] in -/-- RBR knowledge soundness for a single round oracle verifier -/ + -- h_relOut : ((stmtOut, oStmtOut), witOut) ∈ roundRelation 𝔽q β i.succ + simp only [roundRelation, roundRelationProp, Set.mem_setOf_eq] at h_relOut + set extractedWitIn : Witness (L := L) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i.succ := + (commitRbrExtractor 𝔽q β i).extractOut (stmtIn, oStmtIn) tr witOut + -- extractedWitIn = witOut by definition of commitRbrExtractor + -- ⊢ commitKStateProp 𝔽q β i (Fin.last 1) stmtIn extractedWitIn oStmtIn tr + unfold commitKStateProp + simp only [Fin.reduceLast, Fin.isValue, Fin.val_succ, h_stmtOut_eq] at h_relOut ⊢ + -- Key: goal's oStmt = snoc_oracle oStmtIn (tr.messages ⟨0, rfl⟩) = oStmtOut + let msgIdx0 : (pSpecCommit 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i).MessageIdx := ⟨0, rfl⟩ + have h_oStmt_eq : snoc_oracle 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (destIdx := ⟨i.val + 1, by omega⟩) (h_destIdx := by rfl) (oStmtIn := oStmtIn) + (newOracleFn := tr.messages msgIdx0) = oStmtOut := by + simpa [h_oStmtOut_eq] using + (snoc_oracle_eq_mkVerifierOStmtOut_commitStep 𝔽q β (mp := mp) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) i hCR oStmtIn + (tr.messages msgIdx0) tr rfl) + rw [h_oStmt_eq] + exact h_relOut + +/-! RBR knowledge soundness for a single round oracle verifier -/ +omit [SampleableType L] in theorem commitOracleVerifier_rbrKnowledgeSoundness (i : Fin ℓ) (hCR : isCommitmentRound ℓ ϑ i) : (commitOracleVerifier 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) @@ -922,8 +1964,11 @@ theorem commitOracleVerifier_rbrKnowledgeSoundness (i : Fin ℓ) use fun _ => Witness (L := L) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i.succ use commitRbrExtractor 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i use commitKState (mp:=mp) 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) i hCR - intro stmtIn witIn prover j - exact absurd j.2 (by simp [pSpecCommit]) + intro stmtIn witIn prover ⟨j, hj⟩ + cases j using Fin.cases with + | zero => simp only [ne_eq, reduceCtorEq, not_false_eq_true, Fin.isValue, Matrix.cons_val_fin_one, + Direction.not_P_to_V_eq_V_to_P] at hj + | succ j' => exact Fin.elim0 j' end CommitStep @@ -935,7 +1980,7 @@ def relayPrvState (i : Fin ℓ) : Fin (0 + 1) → Type := fun (∀ j, OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ i.castSucc j) × Witness (L := L) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i.succ -/-- The prover for the `i`-th round of Binary relayfold. -/ +/-! The prover for the `i`-th round of Binary relayfold. -/ noncomputable def relayOracleProver (i : Fin ℓ) (hNCR : ¬ isCommitmentRound ℓ ϑ i) : OracleProver (oSpec := []ₒ) -- current round @@ -966,9 +2011,9 @@ def relayOracleVerifier_embed (i : Fin ℓ) (hNCR : ¬ isCommitmentRound ℓ ϑ Fin (toOutCodewordsCount ℓ ϑ i.castSucc) ⊕ pSpecRelay.MessageIdx := fun j => Sum.inl ⟨j.val, by rw [h_oracle_size_eq_relay i hNCR]; omega⟩ -/-- The oracle verifier for the `i`-th round of Binary relayfold. -/ +/-! The oracle verifier for the `i`-th round of Binary relayfold. -/ noncomputable def relayOracleVerifier (i : Fin ℓ) (hNCR : ¬ isCommitmentRound ℓ ϑ i) : - OracleVerifier + OracleVerifier.{0, 0} (oSpec := []ₒ) (StmtIn := Statement (L := L) Context i.succ) (OStmtIn := OracleStatement 𝔽q β (ϑ := ϑ) @@ -987,7 +2032,7 @@ noncomputable def relayOracleVerifier (i : Fin ℓ) (hNCR : ¬ isCommitmentRound hEq := fun oracleIdx => by simp only [MessageIdx, Function.Embedding.coeFn_mk, relayOracleVerifier_embed] -/-- The oracle reduction that is the `i`-th round of Binary relayfold. -/ +/-! The oracle reduction that is the `i`-th round of Binary relayfold. -/ noncomputable def relayOracleReduction (i : Fin ℓ) (hNCR : ¬ isCommitmentRound ℓ ϑ i) : OracleReduction (oSpec := []ₒ) (StmtIn := Statement (L := L) Context i.succ) @@ -1144,7 +2189,7 @@ def relayKnowledgeError (m : pSpecRelay.ChallengeIdx) : ℝ≥0 := match m with | ⟨j, _⟩ => j.elim0 -/-- The round-by-round extractor for a single round. +/-! The round-by-round extractor for a single round. Since f^(0) is always available, we can invoke the extractMLP function directly. -/ noncomputable def relayRbrExtractor (i : Fin ℓ) : Extractor.RoundByRound []ₒ @@ -1164,13 +2209,38 @@ def relayKStateProp (i : Fin ℓ) (hNCR : ¬ isCommitmentRound ℓ ϑ i) (witMid : Witness (L := L) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i.succ) (oStmtIn : (∀ j, OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ i.castSucc j)) : Prop := + -- Relay step inherits sumcheckConsistency from foldStepRelOut (relIn) and preserves it + let sumCheckConsistency: Prop := sumcheckConsistencyProp (𝓑 := 𝓑) stmtIn.sumcheck_target witMid.H masterKStateProp (mp := mp) (ϑ := ϑ) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) -- (𝓑 := 𝓑) (stmtIdx := i.succ) (oracleIdx := OracleFrontierIndex.mkFromStmtIdx i.succ) (stmt := stmtIn) (wit := witMid) (oStmt := mapOStmtOutRelayStep 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i hNCR oStmtIn) - (localChecks := True) + (localChecks := sumCheckConsistency) + +/-! The relay step oracle transformation equals mkVerifierOStmtOut. +This shows that mapOStmtOutRelayStep is exactly what the verifier produces. -/ +lemma mapOStmtOut_eq_mkVerifierOStmtOut_relayStep + (i : Fin ℓ) (hNCR : ¬ isCommitmentRound ℓ ϑ i) + (oStmtIn : ∀ j, OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ i.castSucc j) + (transcript : FullTranscript pSpecRelay) : + let v := relayOracleVerifier (Context := Context) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i hNCR + mapOStmtOutRelayStep 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i hNCR oStmtIn = + OracleVerifier.mkVerifierOStmtOut v.embed v.hEq oStmtIn transcript := by + intro v + funext j + simp only [mapOStmtOutRelayStep, OracleVerifier.mkVerifierOStmtOut, relayOracleVerifier, v] + sorry -/-- Knowledge state function (KState) for single round -/ +lemma getFirstOracle_mapOStmtOutRelayStep_eq (i : Fin ℓ) + (hNCR : ¬ isCommitmentRound ℓ ϑ i) + (oStmtIn : ∀ j, OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ i.castSucc j) : + getFirstOracle 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (mapOStmtOutRelayStep 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i hNCR oStmtIn) = + getFirstOracle 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) oStmtIn := by + funext y + simp only [getFirstOracle, mapOStmtOutRelayStep] + +/-! Knowledge state function (KState) for single round -/ def relayKnowledgeStateFunction (i : Fin ℓ) (hNCR : ¬ isCommitmentRound ℓ ϑ i) : (relayOracleVerifier 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i hNCR).KnowledgeStateFunction init impl @@ -1183,18 +2253,117 @@ def relayKnowledgeStateFunction (i : Fin ℓ) (hNCR : ¬ isCommitmentRound ℓ relayKStateProp 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (mp:=mp) -- (𝓑 := 𝓑) i hNCR stmtIn witMid oStmtIn toFun_empty := fun ⟨stmtIn, oStmtIn⟩ witIn => by + rw [cast_eq] simp only [foldStepRelOut, foldStepRelOutProp, Set.mem_setOf_eq, relayKStateProp] unfold masterKStateProp - simp only [Fin.val_succ, true_and] - have hRight := oracleWitnessConsistency_relay_preserved (mp := mp) 𝔽q β i -- (𝓑 := 𝓑) - hNCR stmtIn witIn oStmtIn - -- rw [hRight] - sorry - toFun_next := fun m hDir (stmtIn, oStmtIn) tr msg witMid => by exact fun a ↦ a - toFun_full := fun (stmtIn, oStmtIn) tr witOut=> by sorry - -omit [SampleableType L] in -/-- RBR knowledge soundness for a single round oracle verifier -/ + simp only [Fin.val_succ] + constructor <;> intro h + · -- Forward: castSuccOfSucc/original oStmt -> mkFromStmtIdx/mapped oStmt + cases h with + | inl hBad => + left + exact (incrementalBadEventExistsProp_relay_preserved 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i hNCR oStmtIn stmtIn.challenges).1 hBad + | inr hGood => + right + refine ⟨hGood.1, hGood.2.1, ?_, ?_⟩ + · simpa [getFirstOracle_mapOStmtOutRelayStep_eq (i := i) (hNCR := hNCR) + (oStmtIn := oStmtIn)] using hGood.2.2.1 + · have hFold' : + oracleFoldingConsistencyProp 𝔽q β (i := i.castSucc) + (Fin.init stmtIn.challenges) oStmtIn := by + simpa using hGood.2.2.2 + have hFold_map := + (oracleFoldingConsistencyProp_relay_preserved 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i hNCR stmtIn.challenges oStmtIn).1 hFold' + simpa using hFold_map + · -- Backward: mkFromStmtIdx/mapped oStmt -> castSuccOfSucc/original oStmt + cases h with + | inl hBad => + left + exact (incrementalBadEventExistsProp_relay_preserved 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i hNCR oStmtIn stmtIn.challenges).2 hBad + | inr hGood => + right + refine ⟨hGood.1, hGood.2.1, ?_, ?_⟩ + · simpa [getFirstOracle_mapOStmtOutRelayStep_eq (i := i) (hNCR := hNCR) + (oStmtIn := oStmtIn)] using hGood.2.2.1 + · have hFold' : + oracleFoldingConsistencyProp 𝔽q β (i := i.succ) + stmtIn.challenges + (mapOStmtOutRelayStep 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i hNCR oStmtIn) := by + simpa using hGood.2.2.2 + have hFold_cast := + (oracleFoldingConsistencyProp_relay_preserved 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i hNCR stmtIn.challenges oStmtIn).2 hFold' + simpa using hFold_cast + toFun_next := fun m hDir (stmtIn, oStmtIn) tr msg witMid => Fin.elim0 m + toFun_full := by + intro stmtOStmtIn tr witOut probEvent_relOut_gt_0 + rcases stmtOStmtIn with ⟨stmtIn, oStmtIn⟩ + -- h_relOut: ∃ stmtOut oStmtOut, verifier outputs (stmtOut, oStmtOut) with prob > 0 + -- and ((stmtOut, oStmtOut), witOut) ∈ foldStepRelOut + simp only [StateT.run'_eq, gt_iff_lt, probEvent_pos_iff, Prod.exists] at probEvent_relOut_gt_0 + rcases probEvent_relOut_gt_0 with ⟨stmtOut, oStmtOut, h_output_mem_V_run_support, h_relOut⟩ + have h_output_mem_V_run_support' : + some (stmtOut, oStmtOut) ∈ + (do + let s ← init + Prod.fst <$> + (simulateQ impl + (Verifier.run (stmtIn, oStmtIn) tr + (relayOracleVerifier 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + i hNCR).toVerifier)).run s).support := by + exact (OptionT.mem_support_iff + (mx := OptionT.mk (do + let s ← init + Prod.fst <$> + (simulateQ impl + (Verifier.run (stmtIn, oStmtIn) tr + (relayOracleVerifier 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + i hNCR).toVerifier)).run s)) + (x := (stmtOut, oStmtOut))).1 h_output_mem_V_run_support + simp only [support_bind, Set.mem_iUnion, exists_prop] at h_output_mem_V_run_support' + rcases h_output_mem_V_run_support' with ⟨s, hs_init, h_output_mem_V_run_support⟩ + conv at h_output_mem_V_run_support => + simp only [Verifier.run, OracleVerifier.toVerifier] + -- Now unfold the foldOracleVerifier's `verify()` method + simp only [relayOracleVerifier] + simp only [support_bind, Set.mem_iUnion] + dsimp only [StateT.run] + simp only [simulateQ_pure, pure_bind, Function.comp_apply] + dsimp only [ProbComp] -- unfold ProbComp back to OracleComp + simp only [MessageIdx, support_pure, Set.mem_singleton_iff, Prod.mk.injEq, exists_eq_right, + exists_and_right] + --- + erw [simulateQ_bind] + erw [simulateQ_pure, support_pure] + simp only [Set.mem_singleton_iff, Option.some.injEq, Prod.mk.injEq] + rcases h_output_mem_V_run_support with ⟨h_stmtOut_eq, h_oStmtOut_eq⟩ + simp only [Nat.reduceAdd] + let v := relayOracleVerifier (Context := Context) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i hNCR + -- Now h_relOut : ((stmtIn, oStmtOut), witOut) ∈ roundRelation 𝔽q β i.succ + -- where oStmtOut = OracleVerifier.mkVerifierOStmtOut ... + simp only [roundRelation, roundRelationProp, Set.mem_setOf_eq] at h_relOut + unfold masterKStateProp at h_relOut + -- The goal is relayKStateProp, which expands to masterKStateProp with sumcheckConsistency + simp only [relayKStateProp] + unfold masterKStateProp + -- relayRbrExtractor.extractOut is identity + rw [h_stmtOut_eq] at h_relOut + -- Rewrite verifier-produced oracle statement to the relay map and conclude directly. + have h_oStmt_eq_map : oStmtOut = + mapOStmtOutRelayStep 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i hNCR oStmtIn := by + calc + oStmtOut = OracleVerifier.mkVerifierOStmtOut v.embed v.hEq oStmtIn tr := h_oStmtOut_eq + _ = mapOStmtOutRelayStep 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) i hNCR oStmtIn := by + symm + simpa [v] using mapOStmtOut_eq_mkVerifierOStmtOut_relayStep + (Context := Context) (i := i) (hNCR := hNCR) (oStmtIn := oStmtIn) (transcript := tr) + rw [h_oStmt_eq_map] at h_relOut + simpa [relayRbrExtractor] using h_relOut + +/-! RBR knowledge soundness for a single round oracle verifier -/ theorem relayOracleVerifier_rbrKnowledgeSoundness (i : Fin ℓ) (hNCR : ¬ isCommitmentRound ℓ ϑ i) : (relayOracleVerifier 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) @@ -1209,7 +2378,7 @@ theorem relayOracleVerifier_rbrKnowledgeSoundness (i : Fin ℓ) use relayKnowledgeStateFunction (mp:=mp) 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) i hNCR intro stmtIn witIn prover j - exact j.val.elim0 + exact Fin.elim0 j end RelayStep @@ -1234,7 +2403,7 @@ The step consists of : -/ open Classical in -/-- The prover for the final sumcheck step -/ +/-! The prover for the final sumcheck step -/ noncomputable def finalSumcheckProver : OracleProver (oSpec := []ₒ) @@ -1253,16 +2422,13 @@ noncomputable def finalSumcheckProver : (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ (Fin.last ℓ) j) × Witness (L := L) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (Fin.last ℓ) × L input := fun ⟨⟨stmt, oStmt⟩, wit⟩ => (stmt, oStmt, wit) - sendMessage | ⟨0, _⟩ => fun ⟨stmtIn, oStmtIn, witIn⟩ => do -- Compute the message using the honest transcript from logic let c : L := witIn.f ⟨0, by simp only [zero_mem]⟩ -- f^(ℓ)(0, ..., 0) pure ⟨c, (stmtIn, oStmtIn, witIn, c)⟩ - receiveChallenge | ⟨0, h⟩ => nomatch h -- No challenges in this step - output := fun ⟨stmtIn, oStmtIn, witIn, c⟩ => do -- Construct the transcript from the message and challenges (no challenges in this step) let t := FullTranscript.mk1 (pSpec := pSpecFinalSumcheckStep (L := L)) c @@ -1270,8 +2436,8 @@ noncomputable def finalSumcheckProver : pure ((finalSumcheckStepLogic 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑)).proverOut stmtIn witIn oStmtIn t) +/-! The verifier for the final sumcheck step -/ open Classical in -/-- The verifier for the final sumcheck step -/ noncomputable def finalSumcheckVerifier : OracleVerifier (oSpec := []ₒ) @@ -1292,13 +2458,12 @@ noncomputable def finalSumcheckVerifier : -- Use guard for verifier check (fails if check doesn't pass) guard (logic.verifierCheck stmtIn t) pure (logic.verifierOut stmtIn t) - embed := (finalSumcheckStepLogic 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑)).embed hEq := (finalSumcheckStepLogic 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑)).hEq -/-- The oracle reduction for the final sumcheck step -/ +/-! The oracle reduction for the final sumcheck step -/ noncomputable def finalSumcheckOracleReduction : OracleReduction (oSpec := []ₒ) @@ -1312,7 +2477,7 @@ noncomputable def finalSumcheckOracleReduction : prover := finalSumcheckProver 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) verifier := finalSumcheckVerifier 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) -/-- Perfect completeness for the final sumcheck step -/ +/-! Perfect completeness for the final sumcheck step -/ theorem finalSumcheckOracleReduction_perfectCompleteness {σ : Type} (init : ProbComp σ) (hInit : NeverFail init) (impl : QueryImpl []ₒ (StateT σ ProbComp)) : @@ -1327,18 +2492,15 @@ theorem finalSumcheckOracleReduction_perfectCompleteness {σ : Type} rw [OracleReduction.unroll_1_message_reduction_perfectCompleteness_P_to_V (hInit := hInit) (hDir0 := by rfl) (hImplSupp := by simp only [Set.fmap_eq_image, IsEmpty.forall_iff, implies_true])] - intro stmtIn oStmtIn witIn h_relIn -- Step 2: Convert probability 1 to universal quantification over support rw [probEvent_eq_one_iff] -- Step 3: Unfold protocol definitions dsimp only [finalSumcheckOracleReduction, finalSumcheckProver, finalSumcheckVerifier, OracleVerifier.toVerifier, FullTranscript.mk1] - let step := (finalSumcheckStepLogic 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑)) let strongly_complete : step.IsStronglyComplete := finalSumcheckStep_is_logic_complete (L := L) 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) - -- Step 4: Split into safety and correctness goals refine ⟨?_, ?_⟩ -- GOAL 1: SAFETY - Prove the verifier never crashes ([⊥|...] = 0) @@ -1415,7 +2577,6 @@ theorem finalSumcheckOracleReduction_perfectCompleteness {σ : Type} exfalso exact hj_ne hj ) - have h_V_check_is_true : V_check := h_V_check simp only [h_V_check_is_true, ↓reduceIte, support_pure, Set.mem_singleton_iff, Fin.isValue, Fin.val_last, exists_eq_left, OptionT.support_OptionT_pure_run] at h_vStmtOut_mem_support @@ -1443,7 +2604,7 @@ theorem finalSumcheckOracleReduction_perfectCompleteness {σ : Type} dsimp only [monadLift, MonadLift.monadLift] simp only [Fin.isValue, Challenge, ChallengeIdx, liftComp_eq_liftM, liftM_pure, liftComp_pure, support_pure, Set.mem_singleton_iff, - Fin.reduceLast, MessageIdx, Message] at hx_mem_support + MessageIdx, Message] at hx_mem_support -- Step 2b: Extract the challenge r1 and the trace equations rcases hx_mem_support with ⟨prvWitOut, h_prvOut_mem_support, h_verOut_mem_support⟩ conv at h_prvOut_mem_support => @@ -1462,11 +2623,9 @@ theorem finalSumcheckOracleReduction_perfectCompleteness {σ : Type} Set.mem_iUnion, exists_prop] rw [simulateQ_ite]; erw [simulateQ_pure] simp only [OptionT.simulateQ_failure'] - set V_check := step.verifierCheck stmtIn (FullTranscript.mk1 (msg0 := _))with h_V_check_def - -- Step 2e: Apply the logic completeness lemma obtain ⟨h_V_check, h_rel, h_agree⟩ := strongly_complete (stmtIn := stmtIn) (witIn := witIn) (h_relIn := h_relIn) (challenges := @@ -1479,7 +2638,6 @@ theorem finalSumcheckOracleReduction_perfectCompleteness {σ : Type} exfalso exact hj_ne hj ) - have h_V_check_is_true : V_check := h_V_check simp only [h_V_check_is_true, ↓reduceIte, Fin.isValue] at h_verOut_mem_support erw [support_bind, support_pure] at h_verOut_mem_support @@ -1487,9 +2645,7 @@ theorem finalSumcheckOracleReduction_perfectCompleteness {σ : Type} OptionT.support_OptionT_pure_run, exists_eq_left, Option.some.injEq, Prod.mk.injEq] at h_verOut_mem_support rcases h_verOut_mem_support with ⟨verStmtOut_eq, verOStmtOut_eq⟩ - obtain ⟨prvStmtOut_eq, prvOStmtOut_eq⟩ := h_prvOut_mem_support - constructor · rw [verStmtOut_eq, verOStmtOut_eq]; exact h_rel @@ -1498,18 +2654,380 @@ theorem finalSumcheckOracleReduction_perfectCompleteness {σ : Type} · rw [verOStmtOut_eq, prvOStmtOut_eq]; exact h_agree.2 -/-- RBR knowledge error for the final sumcheck step -/ +/-! RBR knowledge error for the final sumcheck step -/ def finalSumcheckKnowledgeError (m : pSpecFinalSumcheckStep (L := L).ChallengeIdx) : ℝ≥0 := match m with | ⟨0, h0⟩ => nomatch h0 +omit [SampleableType L] in +/-! When final-sumcheck oracle consistency holds, extractMLP must succeed. + +This connects the proximity-based `finalSumcheckStepOracleConsistencyProp` to the decoder: +- That prop implies oracle folding consistency and final compliance (last oracle → constant) +- Folding consistency implies the first oracle is within unique decoding radius +- Berlekamp-Welch decoder succeeds when within UDR, returning `some` -/ +lemma extractMLP_some_of_oracleFoldingConsistency + (stmtOut : FinalSumcheckStatementOut (L := L) (ℓ := ℓ)) + (oStmt : ∀ j, OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ (Fin.last ℓ) j) + (h_oracle_consistency : finalSumcheckStepOracleConsistencyProp 𝔽q β + (h_le := by apply Nat.le_of_dvd (by exact Nat.pos_of_neZero ℓ) (hdiv.out)) + (stmtOut := stmtOut) (oStmtOut := oStmt)) : + -- extractMLP is used in `finalSumcheckRbrExtractor` + ∃ tpoly, extractMLP 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := 0) (f := getFirstOracle 𝔽q β oStmt) = some tpoly := by + -- Proof strategy: the first oracle must be fiberwise-close due to isCompliant + -- constraint, hence it's UDR-close, Q.E.D + have h_le : ϑ ≤ ℓ := by apply Nat.le_of_dvd (by exact Nat.pos_of_neZero ℓ) (hdiv.out) + have h_ϑ_pos : ϑ > 0 := by exact Nat.pos_of_neZero ϑ + dsimp only [finalSumcheckStepOracleConsistencyProp] at h_oracle_consistency + rcases h_oracle_consistency with ⟨h_oracle_cons, h_final_cons⟩ + let j0 : Fin (toOutCodewordsCount ℓ ϑ (Fin.last ℓ)) := ⟨0, by + exact Nat.pos_of_neZero (toOutCodewordsCount ℓ ϑ (Fin.last ℓ)) + ⟩ + by_cases h_ℓ_eq_ϑ : ℓ = ϑ + · -- We reason on h_final_cons + have h_div : ℓ / ϑ = 1 := by + rw [h_ℓ_eq_ϑ]; rw [Nat.div_self (n := ϑ) (H := by omega)] + have h_getLastOraclePositionIndex_last : getLastOraclePositionIndex ℓ ϑ (Fin.last ℓ) = 0 := by + dsimp only [getLastOraclePositionIndex] + simp only [toOutCodewordsCount_last, Fin.mk_eq_zero, h_div] + let jLast : Fin (toOutCodewordsCount ℓ ϑ (Fin.last ℓ)) := + getLastOraclePositionIndex ℓ ϑ (Fin.last ℓ) + have h_jLast_eq_zero : jLast = 0 := by + simpa [jLast] using h_getLastOraclePositionIndex_last + let zeroIdxLast : Fin r := ⟨↑jLast * ϑ, by + simpa [h_jLast_eq_zero] using (Nat.pos_of_neZero r)⟩ + let destIdxLast : Fin r := ⟨↑jLast * ϑ + ϑ, by + have h_ℓ_lt_r : ℓ < r := by omega + simpa [h_jLast_eq_zero, h_ℓ_eq_ϑ] using h_ℓ_lt_r⟩ + let challengesLast : Fin ϑ → L := fun cId => + stmtOut.challenges ⟨↑jLast * ϑ + ↑cId, by + simp only [h_jLast_eq_zero, Fin.coe_ofNat_eq_mod, toOutCodewordsCount_last, h_ℓ_eq_ϑ, + Nat.zero_mod, zero_mul, zero_add, Fin.val_last, cId.isLt]⟩ + have h_zeroIdxLast : zeroIdxLast.val = 0 := by + simp [zeroIdxLast, h_jLast_eq_zero] + have h_zeroIdxLast_eq : zeroIdxLast = 0 := Fin.eq_of_val_eq h_zeroIdxLast + have h_destIdxLast : destIdxLast = 0 + ϑ := by + simp [destIdxLast, h_jLast_eq_zero] + have h_destIdxLast_le : destIdxLast ≤ ℓ := by + simp only [h_jLast_eq_zero, Fin.coe_ofNat_eq_mod, toOutCodewordsCount_last, h_ℓ_eq_ϑ, + Nat.zero_mod, zero_mul, zero_add, le_refl, destIdxLast] + have h_compl0 : + isCompliant 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := zeroIdxLast) + (steps := ϑ) + (destIdx := destIdxLast) + (h_destIdx := by simpa [h_zeroIdxLast_eq] using h_destIdxLast) + (h_destIdx_le := h_destIdxLast_le) + (f_i := oStmt jLast) + (f_i_plus_steps := fun _ => stmtOut.final_constant) + (challenges := challengesLast) := by + simpa [jLast, zeroIdxLast, destIdxLast, challengesLast, h_ℓ_eq_ϑ] using h_final_cons + rcases (extractMLP_some_of_isCompliant_at_zero 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (steps := ϑ) + (zero_Idx := zeroIdxLast) + (h_zero_Idx := h_zeroIdxLast) + (destIdx := destIdxLast) + (h_destIdx := h_destIdxLast) + (h_destIdx_le := h_destIdxLast_le) + (f_i := oStmt jLast) + (f_next := fun _ => stmtOut.final_constant) + (challenges := challengesLast) + (h_compl := h_compl0)) with + ⟨tpoly, h_extract⟩ + refine ⟨tpoly, ?_⟩ + convert h_extract using 1 + congr 1 + funext x + dsimp [getFirstOracle] + refine OracleStatement.oracle_eval_congr (oStmtIn := oStmt) + (h_j := h_jLast_eq_zero.symm) (h_x := ?_) + simp only [Fin.coe_ofNat_eq_mod, cast_cast] + · -- We reason on h_oracle_cons + dsimp only [oracleFoldingConsistencyProp] at h_oracle_cons + have h_lt : ϑ < ℓ := by omega + have h_div_gt_1 : ℓ / ϑ > 1 := by + have h_res := (Nat.div_lt_div_right (a := ϑ) (b := ϑ) (c := ℓ) (ha := by omega) + (by simp only [dvd_refl]) (by exact hdiv.out)).mpr h_lt + rw [Nat.div_self (n := ϑ) (H := by omega)] at h_res + exact h_res + have h_j0_next_lt : ↑j0 + 1 < toOutCodewordsCount ℓ ϑ (Fin.last ℓ) := by + simpa [j0, toOutCodewordsCount_last] using h_div_gt_1 + let zeroIdx0 : Fin r := ⟨↑j0 * ϑ, by + simpa [j0] using (Nat.pos_of_neZero r)⟩ + let destIdx0 : Fin r := ⟨↑j0 * ϑ + ϑ, by + have h_ℓ_lt_r : ℓ < r := by omega + have h_ϑ_lt_r : ϑ < r := lt_of_le_of_lt h_le h_ℓ_lt_r + simpa [j0] using h_ϑ_lt_r⟩ + have h_zeroIdx0 : zeroIdx0.val = 0 := by + simp [zeroIdx0, j0] + have h_destIdx0 : destIdx0 = 0 + ϑ := by + simp [destIdx0, j0] + have h_destIdx0_le : destIdx0 ≤ ℓ := by + simpa [destIdx0, j0] using h_le + have h_k_next_le_last : ↑j0 * ϑ + ϑ ≤ Fin.last ℓ := by + exact oracle_block_k_next_le_i (ℓ := ℓ) (ϑ := ϑ) + (i := Fin.last ℓ) (j := j0) (hj := h_j0_next_lt) + let fNext0 : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) destIdx0 := + getNextOracle 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := Fin.last ℓ) + oStmt j0 h_j0_next_lt (destDomainIdx := destIdx0) (h_destDomainIdx := by simp only [destIdx0]) + let challenges0 : Fin ϑ → L := + getFoldingChallenges (r := r) (𝓡 := 𝓡) (ϑ := ϑ) (i := Fin.last ℓ) + (challenges := stmtOut.challenges) (k := ↑j0 * ϑ) (h := h_k_next_le_last) + have h_isCompliant_f₀ : + isCompliant 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := zeroIdx0) (steps := ϑ) + (destIdx := destIdx0) + (h_destIdx := by simpa [h_zeroIdx0] using h_destIdx0) + (h_destIdx_le := h_destIdx0_le) + (f_i := oStmt ⟨↑j0, by exact j0.isLt⟩) + (f_i_plus_steps := fNext0) + (challenges := challenges0) := by + simpa [zeroIdx0, destIdx0, fNext0, challenges0] using h_oracle_cons j0 h_j0_next_lt + rcases (extractMLP_some_of_isCompliant_at_zero 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (steps := ϑ) + (zero_Idx := zeroIdx0) + (h_zero_Idx := h_zeroIdx0) + (destIdx := destIdx0) + (h_destIdx := h_destIdx0) + (h_destIdx_le := h_destIdx0_le) + (f_i := oStmt ⟨↑j0, by exact j0.isLt⟩) + (f_next := fNext0) + (challenges := challenges0) + (h_compl := h_isCompliant_f₀)) with + ⟨tpoly, h_extract⟩ + refine ⟨tpoly, ?_⟩ + simpa [getFirstOracle, j0] using h_extract + +/-! When oracle folding consistency holds from first oracle through the final constant, +the extracted polynomial's evaluation at challenges equals the final constant. + +This is the key lemma connecting extraction to the final sumcheck verification: +- `oracleFoldingConsistencyProp` ensures all intermediate foldings are correct +- `h_finalFolding` (isCompliant to final constant) ensures the last step is correct +- Together, they imply the extracted `tpoly` satisfies `tpoly.eval(challenges) = c` -/ +lemma extracted_t_poly_eval_eq_final_constant + (stmtOut : FinalSumcheckStatementOut (L := L) (ℓ := ℓ)) + (oStmtOut : ∀ j, OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ (Fin.last ℓ) j) + (tpoly : MultilinearPoly L ℓ) + (h_extractMLP : extractMLP 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := 0) (f := getFirstOracle 𝔽q β oStmtOut) = some tpoly) + (h_finalSumcheckStepOracleConsistency : finalSumcheckStepOracleConsistencyProp 𝔽q β + (h_le := by apply Nat.le_of_dvd (by exact Nat.pos_of_neZero ℓ) (hdiv.out)) + (stmtOut := stmtOut) (oStmtOut := oStmtOut)) : + stmtOut.final_constant = tpoly.val.eval stmtOut.challenges := by + -- Proof strategy: + -- 1. We can see that tpoly satisifes firstOracleWitnessConsistencyProp + -- 2. From h_finalSumcheckStepOracleConsistency, we can inductively prove that + -- UDR-decoded(f_j) = iterated_fold (UDR-decoded(f_0), challenges_{0->j*ϑ}) + -- 3. We have UDR-decoded(f_0) = encoded (tpoly's evaluations) + -- 4. We have UDR-decoded(f_{ℓ/ϑ}) = fun x => stmtOut.final_constant + -- 5. Therefore, tpoly.val.eval stmtOut.challenges = stmtOut.final_constant + -- Somehow similar to the strict version `iterated_fold_to_const_strict` + classical + rcases (by + simpa [finalSumcheckStepOracleConsistencyProp] using h_finalSumcheckStepOracleConsistency + ) with ⟨h_oracle_cons, h_final_cons⟩ + let P₀ : L⦃< 2^ℓ⦄[X] := + polynomialFromNovelCoeffsF₂ 𝔽q β ℓ (by omega) + (fun ω => tpoly.val.eval (bitsOfIndex ω)) + let f₀ : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (domainIdx := (0 : Fin r)) := + polyToOracleFunc 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (domainIdx := 0) (P := P₀) + have h_pair : + pair_UDRClose 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := (0 : Fin r)) (h_i := by simp) + (f := getFirstOracle 𝔽q β oStmtOut) (g := f₀) := by + simpa [f₀] using + (extractMLP_eq_some_iff_pair_UDRClose 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (f := getFirstOracle 𝔽q β oStmtOut) (tpoly := tpoly)).1 h_extractMLP + let C₀ : Set ((sDomain 𝔽q β h_ℓ_add_R_rate (0 : Fin r)) → L) := + (BBF_Code 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := (0 : Fin r))) + have h_f0_mem : f₀ ∈ C₀ := by + dsimp [C₀, f₀] + change polyToOracleFunc 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (domainIdx := (0 : Fin r)) (P := P₀) ∈ + BBF_Code 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := (0 : Fin r)) + simpa [getBBF_Codeword_of_poly] using + (getBBF_Codeword_of_poly 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := (0 : Fin r)) (h_i := by simp) (P := P₀)).property + have h_close_first : + UDRClose 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := (0 : Fin r)) (h_i := by simp) (f := getFirstOracle 𝔽q β oStmtOut) := by + unfold UDRClose + calc + 2 * Δ₀(getFirstOracle 𝔽q β oStmtOut, C₀) ≤ + 2 * Δ₀(getFirstOracle 𝔽q β oStmtOut, f₀) := by + rw [ENat.mul_le_mul_left_iff (ha := by + simp only [ne_eq, OfNat.ofNat_ne_zero, not_false_eq_true]) + (h_top := by simp only [ne_eq, ENat.ofNat_ne_top, not_false_eq_true])] + exact Code.distFromCode_le_dist_to_mem (C := C₀) + (u := getFirstOracle 𝔽q β oStmtOut) (v := f₀) h_f0_mem + _ < BBF_CodeDistance 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := (0 : Fin r)) := by + norm_cast + have h_neZero_C₀ : NeZero ‖C₀‖₀ := by + have h_dist_ne_zero : + BBF_CodeDistance 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := (0 : Fin r)) ≠ 0 := by + rw [BBF_CodeDistance_eq 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := (0 : Fin r)) (h_i := by simp)] + omega + dsimp [C₀] + exact ⟨by simpa [BBF_CodeDistance] using h_dist_ne_zero⟩ + letI : NeZero ‖C₀‖₀ := h_neZero_C₀ + have h_f0_close_to_first : + Δ₀(getFirstOracle 𝔽q β oStmtOut, f₀) ≤ Code.uniqueDecodingRadius C₀ := by + exact (Code.UDRClose_iff_two_mul_proximity_lt_d_UDR (C := C₀)).2 + (by simpa [pair_UDRClose, C₀] using h_pair) + have h_dec0_eq_f0 : + UDRCodeword 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := (0 : Fin r)) (h_i := by simp) (f := getFirstOracle 𝔽q β oStmtOut) + (h_within_radius := h_close_first) = f₀ := by + symm + exact Code.eq_of_le_uniqueDecodingRadius (C := C₀) + (u := getFirstOracle 𝔽q β oStmtOut) + (v := f₀) + (w := UDRCodeword 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := (0 : Fin r)) (h_i := by simp) (f := getFirstOracle 𝔽q β oStmtOut) + (h_within_radius := h_close_first)) + (hv := h_f0_mem) + (hw := by + simpa [C₀] using UDRCodeword_mem_BBF_Code 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := (0 : Fin r)) (h_i := by simp) (f := getFirstOracle 𝔽q β oStmtOut) + (h_within_radius := h_close_first)) + (huv := h_f0_close_to_first) + (huw := by + simpa [C₀] using + dist_to_UDRCodeword_le_uniqueDecodingRadius 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := (0 : Fin r)) (h_i := by simp) + (f := getFirstOracle 𝔽q β oStmtOut) (h_within_radius := h_close_first)) + have h_oracle_cons' := h_oracle_cons + dsimp only [oracleFoldingConsistencyProp] at h_oracle_cons' + have h_final_cons_all := h_final_cons + rcases h_final_cons with ⟨h_fw_last, h_close_const, h_fold_last⟩ + have h_last_const := congr_fun h_fold_last 0 + simp only at h_last_const + -- The last decoded oracle equals the constant oracle fun _ => stmtOut.final_constant. + -- We apply the same unique-decoding argument as for the first oracle, but now at the + -- last oracle index with code C_last and center u := oStmtOut jLast. + let jLast : Fin (toOutCodewordsCount ℓ ϑ (Fin.last ℓ)) := + getLastOraclePositionIndex ℓ ϑ (Fin.last ℓ) + let lastDomainIdx : Fin r := + ⟨jLast.val * ϑ, by + apply lt_r_of_le_ℓ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (x := jLast.val * ϑ) + exact oracle_index_le_ℓ (ℓ := ℓ) (ϑ := ϑ) (i := Fin.last ℓ) (j := jLast)⟩ + let k := lastDomainIdx.val + have h_k: k = ℓ - ϑ := by + dsimp only [k, lastDomainIdx, jLast] + rw [getLastOraclePositionIndex_last, Nat.sub_mul, Nat.one_mul, Nat.div_mul_cancel (hdiv.out)] + have h_ϑ_le_ℓ : ϑ ≤ ℓ := by apply Nat.le_of_dvd (by exact Nat.pos_of_neZero ℓ) (hdiv.out) + let C_last : Set ((sDomain 𝔽q β h_ℓ_add_R_rate lastDomainIdx) → L) := + BBF_Code 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := lastDomainIdx) + let finalDomainIdx : Fin r := ⟨ℓ, by exact Nat.lt_of_add_right_lt h_ℓ_add_R_rate⟩ + -- final virtual oracle's evaluation domain + let C_final : Set ((sDomain 𝔽q β h_ℓ_add_R_rate finalDomainIdx) → L) := + BBF_Code 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := finalDomainIdx) + -- Constant codeword is in C_final + have h_const_mem : (fun _ => stmtOut.final_constant) ∈ C_final := by + dsimp [C_final] + exact constFunc_mem_BBFCode 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := finalDomainIdx) + (h_i := by exact Nat.le_refl finalDomainIdx.val) + stmtOut.final_constant + let f_last : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) lastDomainIdx := + getLastOracle (h_destIdx := by rfl) (oracleFrontierIdx := Fin.last ℓ) 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (oStmt := oStmtOut) + let f_final_virtual : OracleFunction 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) finalDomainIdx := + fun _ => stmtOut.final_constant + let preFinalChallenges : (Fin k) → L := fun cId => stmtOut.challenges ⟨cId, by + simp only [Fin.val_last]; omega⟩ + let finalChallenges : Fin ϑ → L := fun cId => stmtOut.challenges ⟨k + cId, by + rw [h_k] + have h_le : ϑ ≤ ℓ := by apply Nat.le_of_dvd (by exact Nat.pos_of_neZero ℓ) (hdiv.out) + have h_cId : cId.val < ϑ := cId.isLt + have h_last : (Fin.last ℓ).val = ℓ := rfl + simp only [Fin.val_last, gt_iff_lt] + -- ⊢ ℓ - ϑ + ↑cId < ℓ + omega + ⟩ + -- **f_last = iterated_fold (f_0, ...)** + let f_f₀_folded_to_last := iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := 0) (steps := k) (destIdx := lastDomainIdx) (h_destIdx := by + dsimp only [k]; simp only [Fin.coe_ofNat_eq_mod, Nat.zero_mod, zero_add]) + (h_destIdx_le := by omega) (f := f₀) (r_challenges := preFinalChallenges) + have h_f_last_eq_iterated_fold_f₀ : + f_last = f_f₀_folded_to_last := by + -- From `isCompliant`, quite straightforward. + sorry + -- **f_final_virtual = iterated_fold (f_last, ...)** + let f_last_folded_to_final := iterated_fold 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := lastDomainIdx) (steps := ϑ) (destIdx := finalDomainIdx) (h_destIdx := by + change finalDomainIdx.val = k + ϑ; rw [h_k]; dsimp only [finalDomainIdx]; omega + ) + (h_destIdx_le := by + dsimp only [finalDomainIdx]; omega + ) (f := f_last) + (r_challenges := finalChallenges) + have h_f_final_virtual_eq : + f_final_virtual = f_last_folded_to_final := by + -- From `isCompliant`, quite straightforward. + sorry + -- **=> f_final_virtual = iterated_fold (f_0, ...)** + -- Now we construct the nested `iterated_fold` form + dsimp only [f_final_virtual, f_last_folded_to_final] at h_f_final_virtual_eq + rw [h_f_last_eq_iterated_fold_f₀] at h_f_final_virtual_eq + dsimp only [f_f₀_folded_to_last] at h_f_final_virtual_eq + -- h_f_final_virtual_eq : (fun x ↦ stmtOut.final_constant) = + -- iterated_fold 𝔽q β lastDomainIdx ϑ ⋯ ⋯ + -- (iterated_fold 𝔽q β 0 k ⋯ ⋯ f₀ preFinalChallenges) finalChallenges + rw [iterated_fold_transitivity 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (h_destIdx := by + rw [h_k]; dsimp only [finalDomainIdx]; + simp only [Fin.coe_ofNat_eq_mod, Nat.zero_mod, zero_add]; omega + ) + ] at h_f_final_virtual_eq + have h_congr_steps := iterated_fold_congr_steps_index 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := 0) (steps := k + ϑ) (destIdx := finalDomainIdx) + (h_destIdx := by + rw [h_k]; dsimp only [finalDomainIdx]; + simp only [Fin.coe_ofNat_eq_mod, Nat.zero_mod, zero_add]; omega) + (h_destIdx_le := by dsimp only [finalDomainIdx]; omega) + (h_steps_eq_steps' := by rw [h_k]; omega) + (f := f₀) (r_challenges := Fin.append preFinalChallenges finalChallenges) (steps' := ℓ) + have h_congr_steps_fn := funext (h := h_congr_steps) + rw [h_congr_steps_fn] at h_f_final_virtual_eq + -- Hint: study the proof strategy of `finalSumcheckStep_verifierCheck_passed`, + -- `iterated_fold_to_const_strict`, `iterated_fold_to_level_ℓ_is_constant` + rw [iterated_fold_to_level_ℓ_eval 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (destIdx := finalDomainIdx) (h_destIdx := by dsimp only [finalDomainIdx]) (t := tpoly)] + at h_f_final_virtual_eq + have h_res := congr_fun (h := h_f_final_virtual_eq) (a := 0) + rw [h_res] + have h_concat_challenges_eq : (fun (cId : Fin ℓ) => + (Fin.append preFinalChallenges finalChallenges) ⟨cId, by + rw [h_k]; rw [Nat.sub_add_cancel (n := ℓ) (m := ϑ) (h := by omega)]; simp only [cId.isLt]⟩) + = (fun (cId : Fin ℓ) => (stmtOut.challenges cId)) := by + funext cId + dsimp only [preFinalChallenges, finalChallenges] + by_cases h : cId.val < k + · -- Case 1: cId < k_steps, so it's from the first part + simp only [Fin.val_last] + dsimp only [Fin.append, Fin.addCases] + -- dsimp only [preFinalChallenges] + simp only [h, ↓reduceDIte, Fin.castLT_mk, Fin.eta] + · -- Case 2: cId >= k_steps, so it's from the second part + simp only [Fin.val_last] + dsimp only [Fin.append, Fin.addCases] + simp only [h, ↓reduceDIte, Fin.cast_mk, Fin.subNat_mk, Fin.natAdd_mk, eq_rec_constant] + congr 1; apply Fin.eq_of_val_eq; simp only; rw [add_comm]; omega + rw [h_concat_challenges_eq]; rfl + def FinalSumcheckWit := fun (m : Fin (1 + 1)) => match m with | ⟨0, _⟩ => Witness (L := L) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (Fin.last ℓ) | ⟨1, _⟩ => Unit -/-- The round-by-round extractor for the final sumcheck step -/ +/-! The round-by-round extractor for the final sumcheck step -/ noncomputable def finalSumcheckRbrExtractor : Extractor.RoundByRound []ₒ (StmtIn := (Statement (L := L) (SumcheckBaseContext L ℓ) (Fin.last ℓ)) × @@ -1522,78 +3040,260 @@ noncomputable def finalSumcheckRbrExtractor : extractMid := fun m ⟨stmtMid, oStmtMid⟩ trSucc witMidSucc => by have hm : m = 0 := by omega subst hm + have _ : witMidSucc = () := by rfl -- witMidSucc is of type Unit -- Decode t from the first oracle f^(0) let f0 := getFirstOracle 𝔽q β oStmtMid let polyOpt := extractMLP 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := ⟨0, by exact Nat.pos_of_neZero ℓ⟩) (f := f0) + let H_constant : L⦃≤ 2⦄[X Fin (ℓ - ↑(Fin.last ℓ))] := ⟨MvPolynomial.C stmtMid.sumcheck_target, + by + simp only [Fin.val_last, mem_restrictDegree, MvPolynomial.mem_support_iff, + MvPolynomial.coeff_C, ne_eq, ite_eq_right_iff, Classical.not_imp, and_imp, forall_eq', + Finsupp.coe_zero, Pi.zero_apply, zero_le, implies_true]⟩ match polyOpt with - | none => -- NOTE, In proofs of toFun_next, this case would be eliminated - exact dummyLastWitness (L := L) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + | none => + -- Extraction failed - use constant H to satisfy sumcheckConsistencyProp trivially + exact { + t := ⟨0, by apply zero_mem⟩, + H := H_constant, + f := fun _ => 0 + } | some tpoly => -- Build H_ℓ from t and challenges r' exact { t := tpoly, - H := projectToMidSumcheckPoly (L := L) (ℓ := ℓ) (t := tpoly) - (m := BBF_SumcheckMultiplierParam.multpoly stmtMid.ctx) - (i := Fin.last ℓ) (challenges := stmtMid.challenges), + -- projectToMidSumcheckPoly (L := L) (ℓ := ℓ) (t := tpoly) + -- (m := BBF_SumcheckMultiplierParam.multpoly stmtMid.ctx) + -- (i := Fin.last ℓ) (challenges := stmtMid.challenges), + H := H_constant, f := getMidCodewords 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) tpoly stmtMid.challenges } extractOut := fun ⟨stmtIn, oStmtIn⟩ tr witOut => () def finalSumcheckKStateProp {m : Fin (1 + 1)} (tr : Transcript m (pSpecFinalSumcheckStep (L := L))) - (stmt : Statement (L := L) (SumcheckBaseContext L ℓ) (Fin.last ℓ)) + (stmtIn : Statement (L := L) (SumcheckBaseContext L ℓ) (Fin.last ℓ)) (witMid : FinalSumcheckWit (L := L) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ := ℓ) m) - (oStmt : ∀ j, OracleStatement 𝔽q β + (oStmtIn : ∀ j, OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ (Fin.last ℓ) j) : Prop := match m with | ⟨0, _⟩ => -- same as relIn masterKStateProp 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) -- (𝓑 := 𝓑) (mp := BBF_SumcheckMultiplierParam) (stmtIdx := Fin.last ℓ) (oracleIdx := OracleFrontierIndex.mkFromStmtIdx (Fin.last ℓ)) - (stmt := stmt) (wit := witMid) (oStmt := oStmt) (localChecks := True) + (stmt := stmtIn) (wit := witMid) (oStmt := oStmtIn) + (localChecks := sumcheckConsistencyProp (𝓑 := 𝓑) stmtIn.sumcheck_target witMid.H) | ⟨1, _⟩ => -- implied by relOut + local checks via extractOut proofs - let tr_so_far := (pSpecFinalSumcheckStep (L := L)).take 1 (by omega) - let i_msg0 : tr_so_far.MessageIdx := ⟨⟨0, by omega⟩, rfl⟩ - let c : L := (ProtocolSpec.Transcript.equivMessagesChallenges (k := 1) - (pSpec := pSpecFinalSumcheckStep (L := L)) tr).1 i_msg0 + let c : L := tr.messages ⟨0, rfl⟩ let stmtOut : FinalSumcheckStatementOut (L := L) (ℓ := ℓ) := { - ctx := stmt.ctx, - sumcheck_target := stmt.sumcheck_target, - challenges := stmt.challenges, + ctx := stmtIn.ctx, + sumcheck_target := stmtIn.sumcheck_target, + challenges := stmtIn.challenges, final_constant := c } - let sumcheckFinalCheck : Prop := stmt.sumcheck_target = eqTilde r stmt.challenges * c - let finalFoldingProp := finalFoldingStateProp 𝔽q β (ϑ := ϑ) + let sumcheckFinalCheck : Prop := stmtIn.sumcheck_target + = eqTilde (stmtIn.ctx.t_eval_point) stmtIn.challenges * c + let finalFoldingProp := finalSumcheckStepFoldingStateProp 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (h_le := by apply Nat.le_of_dvd; · exact Nat.pos_of_neZero ℓ - · exact hdiv.out) (input := ⟨stmtOut, oStmt⟩) + · exact hdiv.out) (input := ⟨stmtOut, oStmtIn⟩) sumcheckFinalCheck ∧ finalFoldingProp -- local checks ∧ (oracleConsitency ∨ badEventExists) -/-- The knowledge state function for the final sumcheck step -/ +/-! The knowledge state function for the final sumcheck step -/ noncomputable def finalSumcheckKnowledgeStateFunction {σ : Type} (init : ProbComp σ) (impl : QueryImpl []ₒ (StateT σ ProbComp)) : - (finalSumcheckVerifier 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑)).KnowledgeStateFunction - init impl + (finalSumcheckVerifier 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (𝓑 := 𝓑)).KnowledgeStateFunction init impl (relIn := roundRelation 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) (mp := BBF_SumcheckMultiplierParam) (Fin.last ℓ) ) (relOut := finalSumcheckRelOut 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ) (extractor := finalSumcheckRbrExtractor 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) where - toFun := fun m ⟨stmt, oStmt⟩ tr witMid => - finalSumcheckKStateProp 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) - (tr := tr) (stmt := stmt) (witMid := witMid) (oStmt := oStmt) -- (𝓑 := 𝓑) - toFun_empty := fun stmt witMid => by simp only; sorry - toFun_next := fun m hDir stmt tr msg witMid h => by - -- Either bad events exist, or (oracleFoldingConsistency is true so - -- the extractor can construct a satisfying witness) - sorry - toFun_full := fun stmt tr witOut h => by - sorry + toFun := fun m ⟨stmtIn, oStmtIn⟩ tr witMid => + finalSumcheckKStateProp 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) + (tr := tr) (stmtIn := stmtIn) (witMid := witMid) (oStmtIn := oStmtIn) + toFun_empty := fun ⟨stmtIn, oStmtIn⟩ witMid => by + rw [cast_eq]; rfl + toFun_next := fun m hDir (stmtIn, oStmtIn) tr msg witMid => by + -- toFun_next is impacted by how we build extractMid + -- For pSpecCommit, the only P_to_V message is at index 0 + -- So m = 0, m.succ = 1, m.castSucc = 0 + have h_m_eq_0 : m = 0 := by + cases m using Fin.cases with + | zero => rfl + | succ m' => omega + subst h_m_eq_0 + simp only [Fin.isValue, Fin.succ_zero_eq_one, Fin.castSucc_zero] + -- declare c and stmtOut as in KState (m=1), as well as in honest verifier + -- For the final sumcheck step, there is a single P→V message carrying the final constant, + -- so we can read it directly from `msg` without reconstructing a truncated transcript. + let c : L := msg + let stmtOut : FinalSumcheckStatementOut (L := L) (ℓ := ℓ) := { + ctx := stmtIn.ctx, + sumcheck_target := stmtIn.sumcheck_target, + challenges := stmtIn.challenges, + final_constant := c + } + intro h_kState_round1 + unfold finalSumcheckKStateProp finalSumcheckStepFoldingStateProp + masterKStateProp at h_kState_round1 ⊢ + simp only [Fin.isValue, Nat.reduceAdd, Fin.mk_one, + Fin.coe_ofNat_eq_mod, Nat.reduceMod] at h_kState_round1 + -- At m=1 we have local final-check and (oracle-consistency ∨ block-bad-event). + -- At m=0 the target is Option-B masterKState: + -- incremental-bad-event ∨ (local ∧ structural ∧ initial ∧ oracleFoldingConsistency). + obtain ⟨h_V_check, h_core⟩ := h_kState_round1 + -- Case split on the m=1 final-folding state: consistency or block bad-event. + cases h_core with + | inl hConsistent => + -- When we have finalSumcheckStepOracleConsistencyProp, extractMLP must succeed. + have ⟨tpoly, h_extractMLP⟩ := extractMLP_some_of_oracleFoldingConsistency 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (oStmt := oStmtIn) (h_oracle_consistency := hConsistent) + refine Or.inr ?_ + refine ⟨?_, ?_, ?_, ?_⟩ + · -- local check at m=0 + unfold finalSumcheckRbrExtractor sumcheckConsistencyProp + simp only [Fin.val_last, Fin.mk_zero', h_extractMLP, Fin.coe_ofNat_eq_mod] + simp only [MvPolynomial.eval_C, sum_const, Fintype.card_piFinset, card_map, card_univ, + Fintype.card_fin, prod_const, tsub_self, Fintype.card_eq_zero, pow_zero, one_smul] + · -- witnessStructuralInvariant + unfold finalSumcheckRbrExtractor witnessStructuralInvariant + simp only [Fin.val_last, Fin.mk_zero', h_extractMLP, Fin.coe_ofNat_eq_mod, and_true] + refine SetLike.coe_eq_coe.mp ?_ + rw [projectToMidSumcheckPoly_at_last_eq] + have h_sumcheck_target_eq : stmtIn.sumcheck_target = + (MvPolynomial.eval stmtIn.challenges + (BBF_SumcheckMultiplierParam.multpoly stmtIn.ctx).val) * + (MvPolynomial.eval stmtIn.challenges tpoly.val) := by + rw [h_V_check] + congr 1 + change c = tpoly.val.eval stmtIn.challenges + exact extracted_t_poly_eval_eq_final_constant 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (oStmtOut := oStmtIn) (stmtOut := stmtOut) + (tpoly := tpoly) + (h_extractMLP := h_extractMLP) (h_finalSumcheckStepOracleConsistency := hConsistent) + simp only [h_sumcheck_target_eq, Fin.val_last, Fin.coe_ofNat_eq_mod, MvPolynomial.C_mul] + · -- firstOracleWitnessConsistencyProp + dsimp only [finalSumcheckRbrExtractor, firstOracleWitnessConsistencyProp] + simp only [Fin.mk_zero', h_extractMLP, Fin.coe_ofNat_eq_mod, Fin.val_last, + OracleFrontierIndex.val_mkFromStmtIdx] + exact (extractMLP_eq_some_iff_pair_UDRClose 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (f := getFirstOracle 𝔽q β oStmtIn) (tpoly := tpoly)).mp h_extractMLP + · exact hConsistent.1 + | inr hBad => + -- Hybrid plan: map terminal block bad-event to incremental bad-event at m=0. + exact Or.inl ( + (badEventExistsProp_iff_incrementalBadEventExistsProp_last 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ϑ := ϑ) + (oStmt := oStmtIn) (challenges := stmtIn.challenges)).1 hBad + ) + toFun_full := fun ⟨stmtIn, oStmtIn⟩ tr witOut probEvent_relOut_gt_0 => by + -- Same pattern as relay: verifier output (stmtOut, oStmtOut) + h_relOut ⇒ commitKStateProp 1 + simp only [StateT.run'_eq, gt_iff_lt, probEvent_pos_iff, Prod.exists] at probEvent_relOut_gt_0 + rcases probEvent_relOut_gt_0 with ⟨stmtOut, oStmtOut, h_output_mem_V_run_support, h_relOut⟩ + have h_output_mem_V_run_support' : + some (stmtOut, oStmtOut) ∈ + (do + let s ← init + Prod.fst <$> + (simulateQ impl + (Verifier.run (stmtIn, oStmtIn) tr + (finalSumcheckVerifier 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (𝓑 := 𝓑)).toVerifier)).run s).support := by + exact (OptionT.mem_support_iff + (mx := OptionT.mk (do + let s ← init + Prod.fst <$> + (simulateQ impl + (Verifier.run (stmtIn, oStmtIn) tr + (finalSumcheckVerifier 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (𝓑 := 𝓑)).toVerifier)).run s)) + (x := (stmtOut, oStmtOut))).1 h_output_mem_V_run_support + simp only [support_bind, Set.mem_iUnion, exists_prop] at h_output_mem_V_run_support' + rcases h_output_mem_V_run_support' with ⟨s, hs_init, h_output_mem_V_run_support⟩ + conv at h_output_mem_V_run_support => -- same as fold step + simp only [Verifier.run, OracleVerifier.toVerifier] + -- Now unfold the foldOracleVerifier's `verify()` method + simp only [finalSumcheckVerifier] + -- dsimp only [StateT.run] + -- simp only [simulateQ_bind, simulateQ_query, simulateQ_pure] + -- oracle query unfolding + simp only [support_bind, Set.mem_iUnion] + dsimp only [StateT.run] + -- enter [1, i_1, 2, 1, x] + simp only [simulateQ_bind] + --------------------------------------- + -- Now simplify the `guard` and `ite` of StateT.map generated from it + simp only [MessageIdx, Fin.isValue, Matrix.cons_val_zero, simulateQ_pure, Message, guard_eq, + pure_bind, Function.comp_apply, simulateQ_map, simulateQ_ite, + OptionT.simulateQ_failure', bind_map_left] + simp only [MessageIdx, Message, Fin.isValue, Matrix.cons_val_zero, Matrix.cons_val_one, + bind_pure_comp, simulateQ_map, simulateQ_ite, simulateQ_pure, OptionT.simulateQ_failure', + bind_map_left, Function.comp_apply] + simp only [support_ite] + simp only [Fin.isValue, Set.mem_ite_empty_right, Set.mem_singleton_iff, Prod.mk.injEq, + exists_and_left, exists_eq', exists_eq_right, exists_and_right] + simp only [Fin.isValue, id_eq, FullTranscript.mk1_eq_snoc, support_map, Set.mem_image, + Prod.exists, exists_and_right, exists_eq_right] + erw [simulateQ_bind] + enter [1, x, 1, 1, 1, 2]; + erw [simulateQ_bind] + erw [OptionT.simulateQ_simOracle2_liftM_query_T2] + simp only [Fin.isValue, FullTranscript.mk1_eq_snoc, pure_bind, OptionT.simulateQ_map] + conv at h_output_mem_V_run_support => + simp only [Fin.isValue, FullTranscript.mk1_eq_snoc, Function.comp_apply] + erw [support_bind] at h_output_mem_V_run_support + let step := (finalSumcheckStepLogic 𝔽q β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑)) + set V_check := step.verifierCheck stmtIn + (FullTranscript.mk1 (msg0 := _)) with h_V_check_def + by_cases h_V_check : V_check + · simp only [Fin.isValue, h_V_check, ↓reduceIte, OptionT.run_pure, simulateQ_pure, + Set.mem_iUnion, exists_prop, Prod.exists] at h_output_mem_V_run_support + erw [simulateQ_bind] at h_output_mem_V_run_support + simp only [simulateQ_pure, Fin.isValue, Function.comp_apply, + pure_bind] at h_output_mem_V_run_support + erw [support_pure] at h_output_mem_V_run_support + simp only [Set.mem_singleton_iff, Prod.mk.injEq, ↓existsAndEq, and_true, exists_eq_left, + ] at h_output_mem_V_run_support + erw [support_pure] at h_output_mem_V_run_support + simp only [Fin.isValue, Set.mem_singleton_iff, Prod.mk.injEq, Option.some.injEq, + exists_eq_right] at h_output_mem_V_run_support + rcases h_output_mem_V_run_support with ⟨h_stmtOut_eq, h_oStmtOut_eq⟩ + simp only [Fin.reduceLast, Fin.isValue] + -- h_relOut : ((stmtOut, oStmtOut), witOut) ∈ roundRelation 𝔽q β i.succ + simp only [finalSumcheckRelOut, finalSumcheckRelOutProp, Set.mem_setOf_eq] at h_relOut + -- Goal: commitKStateProp 1 stmtIn oStmtIn tr witOut + unfold finalSumcheckKStateProp + -- Unfold the sendMessage, receiveChallenge, output logic of prover + dsimp only + -- stmtOut = stmtIn; need oStmtOut = snoc_oracle oStmtIn witOut.f so goal matches h_relOut + simp only [h_stmtOut_eq] at h_relOut ⊢ + have h_oStmtOut_eq_oStmtIn : oStmtOut = oStmtIn := by rw [h_oStmtOut_eq]; rfl + -- c equals tr.messages ⟨0, rfl⟩ + constructor + · -- First conjunct: sumcheck_target = eqTilde r challenges * c + exact h_V_check + · -- Second conjunct: + -- finalSumcheckStepFoldingStateProp ({ toStatement := stmtIn, final_constant := c }, oStmtIn) + rw [h_oStmtOut_eq_oStmtIn] at h_relOut + exact h_relOut + · simp only [Fin.isValue, h_V_check, ↓reduceIte, OptionT.run_failure, simulateQ_pure, + Set.mem_iUnion, exists_prop, Prod.exists] at h_output_mem_V_run_support + erw [simulateQ_bind] at h_output_mem_V_run_support + simp only [simulateQ_pure, Fin.isValue, Function.comp_apply, + pure_bind] at h_output_mem_V_run_support + erw [support_pure] at h_output_mem_V_run_support + simp only [Set.mem_singleton_iff, Prod.mk.injEq, ↓existsAndEq, and_true, exists_eq_left, + simulateQ_pure] at h_output_mem_V_run_support + erw [support_pure] at h_output_mem_V_run_support + simp only [Set.mem_singleton_iff, Prod.mk.injEq, reduceCtorEq, false_and, + exists_false] at h_output_mem_V_run_support -- False omit [Fintype L] [CharP L 2] in -/-- Round-by-round knowledge soundness for the final sumcheck step -/ +/-! Round-by-round knowledge soundness for the final sumcheck step -/ theorem finalSumcheckOracleVerifier_rbrKnowledgeSoundness [Fintype L] {σ : Type} + [CharP L 2] (init : ProbComp σ) (impl : QueryImpl []ₒ (StateT σ ProbComp)) : (finalSumcheckVerifier 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑)).rbrKnowledgeSoundness init impl @@ -1605,8 +3305,13 @@ theorem finalSumcheckOracleVerifier_rbrKnowledgeSoundness [Fintype L] {σ : Type use finalSumcheckRbrExtractor 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) use finalSumcheckKnowledgeStateFunction 𝔽q β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) init impl - intro stmtIn witIn prover j - exact absurd j.2 (by simp [pSpecFinalSumcheckStep]) + intro stmtIn witIn prover ⟨j, hj⟩ + -- pSpecFinalSumcheckStep has 1 message (ChallengeIdx = Fin 1); same pattern as commit + cases j using Fin.cases with + | zero => simp only [pSpecFinalSumcheckStep, ne_eq, reduceCtorEq, not_false_eq_true, Fin.isValue, + Matrix.cons_val_fin_one, Direction.not_P_to_V_eq_V_to_P] at hj + -- bound for challenge index 0 (P→V only, no V challenge) + | succ j' => exact Fin.elim0 j' end FinalSumcheckStep end SingleIteratedSteps diff --git a/ArkLib/ProofSystem/Binius/FRIBinius/CoreInteractionPhase.lean b/ArkLib/ProofSystem/Binius/FRIBinius/CoreInteractionPhase.lean index b8afee769..b7572f5c0 100644 --- a/ArkLib/ProofSystem/Binius/FRIBinius/CoreInteractionPhase.lean +++ b/ArkLib/ProofSystem/Binius/FRIBinius/CoreInteractionPhase.lean @@ -5,6 +5,7 @@ Authors: Chung Thai Nguyen, Quang Dao -/ import ArkLib.ProofSystem.Binius.BinaryBasefold.CoreInteractionPhase +import ArkLib.ProofSystem.Binius.BinaryBasefold.ReductionLogic import ArkLib.ProofSystem.Binius.FRIBinius.Prelude /-! @@ -27,6 +28,19 @@ This phase combines sumcheck and FRI folding using shared challenges r'ᵢ: and decomposes `e =: Σ_{u ∈ {0,1}^κ} β_u ⊗ e_u`. `V` requires `s_{ℓ'} ?= (Σ_{u ∈ {0,1}^κ} eqTilde(u_0, ..., u_{κ-1},` `r''_0, ..., r''_{κ-1}) * e_u) * c`. + +## Oracle reduction composition + +Inside this file, `coreInteractionOracleReduction` is exactly the composition of: +1. `LiftContext(sumcheckFoldOracleReduction)` (the lifted Binary + Basefold sumcheck-fold reduction), then +2. `finalSumcheckOracleReduction`. + +`LiftContext` here is only the bridge from batching-output shape to Binary Basefold sumcheck-fold +input shape. Concretely, it maps +`SumcheckWitness (t', H)` to `BinaryBasefold.Witness (t, H, f₀)`, where +`f₀ := getMidCodewords t challenges`, and keeps the output witness unchanged (`toFunB` is +identity on `innerWitOut`). -/ namespace Binius.FRIBinius.CoreInteractionPhase @@ -65,10 +79,8 @@ def sumcheckFoldStmtLens : OracleStatement.Lens (InnerOStmtIn := BinaryBasefold.OracleStatement K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ 0) (InnerOStmtOut := BinaryBasefold.OracleStatement K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ (Fin.last ℓ')) where - -- Stmt and OStmt are same as in outer context, only witness changes toFunA := fun ⟨outerStmtIn, outerOStmtIn⟩ => ⟨outerStmtIn, outerOStmtIn⟩ - toFunB := fun ⟨_, _⟩ ⟨innerStmtOut, innerOStmtOut⟩ => ⟨innerStmtOut, innerOStmtOut⟩ /-- Oracle context lens for sumcheck fold lifting -/ @@ -93,10 +105,10 @@ def sumcheckFoldCtxLens : OracleContext.Lens toFunA := fun ⟨⟨outerStmtIn, outerOStmtIn⟩, outerWitIn⟩ => by let t : L⦃≤ 1⦄[X Fin ℓ'] := outerWitIn.t' let H : L⦃≤ 2⦄[X Fin (ℓ' - 0)] := outerWitIn.H - let P₀ : L⦃< 2^ℓ'⦄[X] := polynomialFromNovelCoeffsF₂ K β ℓ' (by omega) - (fun ω => t.val.eval (bitsOfIndex ω)) let f₀ : (sDomain K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) - ⟨0, by omega⟩ → L := fun x => P₀.val.eval x.val + ⟨0, by omega⟩ → L := + BinaryBasefold.getMidCodewords K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := (0 : Fin (ℓ' + 1))) (t := t) (challenges := outerStmtIn.challenges) exact { t := t, H := H, f := f₀ } toFunB := fun ⟨⟨outerStmtIn, outerOStmtIn⟩, outerWitIn⟩ ⟨⟨innerStmtOut, innerOStmtOut⟩, innerWitOut⟩ => innerWitOut @@ -123,7 +135,7 @@ def sumcheckFoldExtractorLens : Extractor.Lens toFunA := fun ⟨⟨outerStmtIn, outerOStmtIn⟩, outerWitOut⟩ => outerWitOut toFunB := fun ⟨⟨outerStmtIn, outerOStmtIn⟩, outerWitOut⟩ innerWitIn => by let outerWitIn : SumcheckWitness L ℓ' 0 := { - t' := outerWitOut.t + t' := innerWitIn.t H := innerWitIn.H } exact outerWitIn @@ -151,7 +163,8 @@ variable {σ : Type} {init : ProbComp σ} {impl : QueryImpl []ₒ (StateT σ Pro -- Completeness instance for the context lens instance sumcheckFoldCtxLens_complete : - (sumcheckFoldCtxLens κ L K β ℓ ℓ' 𝓡 ϑ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) h_l).toContext.IsComplete + (sumcheckFoldCtxLens κ L K β ℓ ℓ' 𝓡 ϑ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + h_l).toContext.IsComplete (OuterStmtIn := Statement (L := L) (ℓ := ℓ') (RingSwitchingBaseContext κ L K ℓ) 0 × (∀ i, BinaryBasefold.OracleStatement K (⇑β) ϑ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) 0 i)) (OuterStmtOut := Statement (L := L) (ℓ := ℓ') (RingSwitchingBaseContext κ L K ℓ) (Fin.last ℓ') × @@ -167,9 +180,12 @@ instance sumcheckFoldCtxLens_complete : (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ:=ℓ') (Fin.last ℓ')) (InnerWitIn := Witness K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ:=ℓ') 0) (InnerWitOut := Witness K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ:=ℓ') (Fin.last ℓ')) - (outerRelIn := RingSwitching.sumcheckRoundRelation κ L K (booleanHypercubeBasis κ L K β) - ℓ ℓ' h_l (𝓑 := 𝓑) (aOStmtIn := BinaryBasefoldAbstractOStmtIn κ L K β ℓ' - 𝓡 ϑ (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) 0) + (outerRelIn := RingSwitching.strictSumcheckRoundRelation κ L K + (β := booleanHypercubeBasis κ L K β) ℓ ℓ' h_l (𝓑 := 𝓑) + (aOStmtIn := BinaryBasefoldAbstractOStmtIn + (κ := κ) (L := L) (K := K) (β := β) + (ℓ' := ℓ') (𝓡 := 𝓡) (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) 0) (outerRelOut := BinaryBasefold.strictRoundRelation (mp := RingSwitching_SumcheckMultParam κ L K (β := booleanHypercubeBasis κ L K β) ℓ ℓ' h_l) K β (ϑ:=ϑ) @@ -194,9 +210,35 @@ instance sumcheckFoldCtxLens_complete : (sumcheckFoldCtxLens κ L K β ℓ ℓ' 𝓡 ϑ h_ℓ_add_R_rate h_l).toContext originalReduction ) where proj_complete := fun stmtIn oStmtIn hRelIn => by - sorry + rcases stmtIn with ⟨stmtIn, oStmtIn'⟩ + rcases oStmtIn with ⟨t', H⟩ + rcases hRelIn with ⟨h_local, h_struct, h_strict_compat⟩ + refine ⟨?_, ?_⟩ + · simpa [sumcheckFoldStmtLens] using h_local + · refine ⟨?_, ?_⟩ + · refine ⟨?_, ?_⟩ + · simpa [sumcheckFoldStmtLens, RingSwitching.witnessStructuralInvariant, + BinaryBasefold.witnessStructuralInvariant] using h_struct + · rfl + · change strictOracleFoldingConsistencyProp K β + (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (t := t') (i := (0 : Fin (ℓ' + 1))) + (challenges := stmtIn.challenges) (oStmt := oStmtIn') + have h_strict_compat' : + strictOracleFoldingConsistencyProp K β + (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (t := t') (i := (0 : Fin (ℓ' + 1))) + (challenges := Fin.elim0) (oStmt := oStmtIn') := by + simpa [BinaryBasefoldAbstractOStmtIn, + Binius.RingSwitching.BBFSmallFieldIOPCS.bbfAbstractOStmtIn, + strictOracleFoldingConsistencyProp] using h_strict_compat + have h_challenges : stmtIn.challenges = (Fin.elim0 : Fin 0 → L) := by + funext i + exact Fin.elim0 i + simpa [h_challenges] using h_strict_compat' lift_complete := fun outerStmtIn outerWitIn innerStmtOut innerWitOut compat => by - sorry + intro _ hRelOut + simpa [sumcheckFoldStmtLens] using hRelOut omit [NeZero κ] [NeZero ℓ] in -- Perfect completeness for the lifted oracle reduction @@ -212,9 +254,9 @@ theorem sumcheckFoldOracleReduction_perfectCompleteness (hInit : NeverFail init) (WitOut := BinaryBasefold.Witness K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ:=ℓ') (Fin.last ℓ')) (pSpec := BinaryBasefold.pSpecSumcheckFold K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) - (relIn := RingSwitching.sumcheckRoundRelation κ L K (booleanHypercubeBasis κ L K β) - ℓ ℓ' h_l (𝓑 := 𝓑) (aOStmtIn := BinaryBasefoldAbstractOStmtIn κ L K β ℓ' - 𝓡 ϑ (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) 0) + (relIn := RingSwitching.strictSumcheckRoundRelation κ L K (booleanHypercubeBasis κ L K β) + ℓ ℓ' h_l (𝓑 := 𝓑) (aOStmtIn := BinaryBasefoldAbstractOStmtIn (β := β) (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) 0) (relOut := BinaryBasefold.strictRoundRelation (mp := RingSwitching_SumcheckMultParam κ L K (β := booleanHypercubeBasis κ L K β) ℓ ℓ' h_l) K β (ϑ:=ϑ) @@ -243,9 +285,11 @@ theorem sumcheckFoldOracleReduction_perfectCompleteness (hInit : NeverFail init) (InnerOStmtOut := BinaryBasefold.OracleStatement K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ (Fin.last ℓ')) (pSpec := BinaryBasefold.pSpecSumcheckFold K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) - (outerRelIn := RingSwitching.sumcheckRoundRelation κ L K (booleanHypercubeBasis κ L K β) - ℓ ℓ' h_l (𝓑 := 𝓑) (aOStmtIn := BinaryBasefoldAbstractOStmtIn κ L K β ℓ' - 𝓡 ϑ (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) 0) + (outerRelIn := RingSwitching.strictSumcheckRoundRelation κ L K (booleanHypercubeBasis κ L K β) + ℓ ℓ' h_l (𝓑 := 𝓑) (aOStmtIn := BinaryBasefoldAbstractOStmtIn + (κ := κ) (L := L) (K := K) (β := β) + (ℓ' := ℓ') (𝓡 := 𝓡) (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) 0) (outerRelOut := BinaryBasefold.strictRoundRelation (mp := RingSwitching_SumcheckMultParam κ L K (β := booleanHypercubeBasis κ L K β) ℓ ℓ' h_l) K β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) (Fin.last ℓ')) @@ -263,8 +307,14 @@ theorem sumcheckFoldOracleReduction_perfectCompleteness (hInit : NeverFail init) (h := BinaryBasefold.CoreInteraction.sumcheckFoldOracleReduction_perfectCompleteness (hInit:=hInit) K β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑)) --- Knowledge soundness instance for the extractor lens -instance sumcheckFoldExtractorLens_rbr_knowledge_soundness : +/-- Knowledge soundness instance for the extractor lens. This one is compatStmt-agnostic -/ +instance sumcheckFoldExtractorLens_rbr_knowledge_soundness + {compatStmt : + (Statement (L := L) (ℓ := ℓ') (RingSwitchingBaseContext κ L K ℓ) 0 × + (∀ i, BinaryBasefold.OracleStatement K (⇑β) ϑ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) 0 i)) → + (Statement (L := L) (ℓ := ℓ') (RingSwitchingBaseContext κ L K ℓ) (Fin.last ℓ') × + (∀ i, BinaryBasefold.OracleStatement K (⇑β) ϑ + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (Fin.last ℓ') i)) → Prop} : Extractor.Lens.IsKnowledgeSound (OuterStmtIn := Statement (L := L) (ℓ := ℓ') (RingSwitchingBaseContext κ L K ℓ) 0 × (∀ i, BinaryBasefold.OracleStatement K (⇑β) ϑ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) 0 i)) @@ -282,8 +332,10 @@ instance sumcheckFoldExtractorLens_rbr_knowledge_soundness : (InnerWitIn := Witness K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ:=ℓ') 0) (InnerWitOut := Witness K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ:=ℓ') (Fin.last ℓ')) (outerRelIn := RingSwitching.sumcheckRoundRelation κ L K (booleanHypercubeBasis κ L K β) - ℓ ℓ' h_l (𝓑 := 𝓑) (aOStmtIn := BinaryBasefoldAbstractOStmtIn κ L K β ℓ' - 𝓡 ϑ (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) 0) + ℓ ℓ' h_l (𝓑 := 𝓑) (aOStmtIn := BinaryBasefoldAbstractOStmtIn + (κ := κ) (L := L) (K := K) (β := β) + (ℓ' := ℓ') (𝓡 := 𝓡) (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) 0) (outerRelOut := BinaryBasefold.roundRelation (mp := RingSwitching_SumcheckMultParam κ L K (β := booleanHypercubeBasis κ L K β) ℓ ℓ' h_l) K β (ϑ:=ϑ) @@ -299,14 +351,54 @@ instance sumcheckFoldExtractorLens_rbr_knowledge_soundness : (β := booleanHypercubeBasis κ L K β) ℓ ℓ' h_l) K β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) (Fin.last ℓ') ) - (compatStmt := fun _ _ => True) + (compatStmt := compatStmt) (compatWit := fun _ _ => True) (lens := sumcheckFoldExtractorLens κ L K β ℓ ℓ' 𝓡 ϑ (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) where proj_knowledgeSound := by - sorry + intro outerStmtIn innerStmtOut outerWitOut _ hOuter + simpa [sumcheckFoldExtractorLens, sumcheckFoldStmtLens] using hOuter lift_knowledgeSound := by - sorry + intro outerStmtIn outerWitOut innerWitIn _ hInner + rcases outerStmtIn with ⟨stmtIn, oStmtIn⟩ + have hInner' : + BinaryBasefold.roundRelationProp + (mp := RingSwitching_SumcheckMultParam κ L K + (β := booleanHypercubeBasis κ L K β) ℓ ℓ' h_l) + K β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) + (0 : Fin (ℓ' + 1)) ((stmtIn, oStmtIn), innerWitIn) := by + simpa [BinaryBasefold.roundRelation, Set.mem_setOf_eq] using hInner + unfold BinaryBasefold.roundRelationProp BinaryBasefold.masterKStateProp at hInner' + have h_no_bad : + ¬ incrementalBadEventExistsProp K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (ϑ := ϑ) (stmtIdx := (0 : Fin (ℓ' + 1))) + (oracleIdx := OracleFrontierIndex.mkFromStmtIdx (0 : Fin (ℓ' + 1))) + (oStmt := oStmtIn) (challenges := stmtIn.challenges) := by + intro h_bad + rcases h_bad with ⟨j, hj⟩ + have hj0 : j = 0 := by + apply Fin.eq_of_val_eq + have hjlt : j.val < 1 := by + simpa [BinaryBasefold.toOutCodewordsCountOf0] using j.isLt + exact Nat.lt_one_iff.mp hjlt + subst hj0 + dsimp [BinaryBasefold.oraclePositionToDomainIndex] at hj + exact absurd hj (by + apply BinaryBasefold.incrementalFoldingBadEvent_of_k_eq_0_is_false + (𝔽q := K) (β := β) + (h_k := by + simp only [Nat.zero_mod, zero_mul, tsub_self, zero_le, inf_of_le_right]) + (h_midIdx := by simp only [Nat.zero_mod, zero_mul, tsub_self, zero_le, + inf_of_le_right, add_zero])) + rcases hInner' with h_bad | h_good + · exact (h_no_bad h_bad).elim + · have h_local := h_good.1 + have h_struct := h_good.2.1 + have h_first := h_good.2.2.1 + refine ⟨h_local, ?_, ?_⟩ + · simpa [sumcheckFoldExtractorLens, RingSwitching.witnessStructuralInvariant, + BinaryBasefold.witnessStructuralInvariant] using h_struct.1 + · simpa [BinaryBasefoldAbstractOStmtIn] using h_first -- Round-by-round knowledge soundness for the lifted oracle verifier theorem sumcheckFoldOracleVerifier_rbrKnowledgeSoundness [Fintype L] : @@ -322,8 +414,10 @@ theorem sumcheckFoldOracleVerifier_rbrKnowledgeSoundness [Fintype L] : (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ:=ℓ') (Fin.last ℓ')) (pSpec := BinaryBasefold.pSpecSumcheckFold K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) (relIn := RingSwitching.sumcheckRoundRelation κ L K (booleanHypercubeBasis κ L K β) - ℓ ℓ' h_l (𝓑 := 𝓑) (aOStmtIn := BinaryBasefoldAbstractOStmtIn κ L K β ℓ' - 𝓡 ϑ (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) 0) + ℓ ℓ' h_l (𝓑 := 𝓑) (aOStmtIn := BinaryBasefoldAbstractOStmtIn + (κ := κ) (L := L) (K := K) (β := β) + (ℓ' := ℓ') (𝓡 := 𝓡) (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) 0) (relOut := BinaryBasefold.roundRelation (mp := RingSwitching_SumcheckMultParam κ L K (β := booleanHypercubeBasis κ L K β) ℓ ℓ' h_l) K β (ϑ:=ϑ) @@ -335,8 +429,60 @@ theorem sumcheckFoldOracleVerifier_rbrKnowledgeSoundness [Fintype L] : (impl := impl) (rbrKnowledgeError := BinaryBasefold.CoreInteraction.sumcheckFoldKnowledgeError K β (ϑ := ϑ)) := by - -- apply OracleVerifier.liftContext_rbr_knowledgeSoundness - sorry + letI : Inhabited (Statement (L := L) (ℓ := ℓ') + (RingSwitchingBaseContext κ L K ℓ) (Fin.last ℓ')) := ⟨{ + ctx := { + t_eval_point := 0 + original_claim := 0 + s_hat := 0 + r_batching := 0 + } + sumcheck_target := 0 + challenges := 0 + }⟩ + letI : + ∀ i : Fin (toOutCodewordsCount ℓ' ϑ (i := Fin.last ℓ')), + Inhabited (BinaryBasefold.OracleStatement K β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ (Fin.last ℓ') i) := by + intro i + exact ⟨fun _ => 0⟩ + letI : Inhabited (BinaryBasefold.Witness K β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ := ℓ') 0) := ⟨{ + t := 0 + H := 0 + f := fun _ => 0 + }⟩ + have h_lifted := OracleVerifier.liftContext_rbr_knowledgeSoundness + (V := BinaryBasefold.CoreInteraction.sumcheckFoldOracleVerifier K β + (ϑ := ϑ) + (mp := RingSwitching_SumcheckMultParam κ L K + (β := booleanHypercubeBasis κ L K β) ℓ ℓ' h_l) + (𝓑 := 𝓑) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) + (stmtLens := sumcheckFoldStmtLens κ L K β ℓ ℓ' 𝓡 ϑ + (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) + (witLens := (sumcheckFoldExtractorLens κ L K β ℓ ℓ' 𝓡 ϑ + (h_ℓ_add_R_rate := h_ℓ_add_R_rate)).wit) + (lensKS := sumcheckFoldExtractorLens_rbr_knowledge_soundness + (κ := κ) (L := L) (K := K) (β := β) (ℓ := ℓ) (ℓ' := ℓ') (𝓡 := 𝓡) (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (h_l := h_l) (𝓑 := 𝓑) + (compatStmt := (BinaryBasefold.CoreInteraction.sumcheckFoldOracleVerifier K β + (ϑ := ϑ) + (mp := RingSwitching_SumcheckMultParam κ L K (β := booleanHypercubeBasis κ L K β) + ℓ ℓ' h_l) + (𝓑 := 𝓑) (h_ℓ_add_R_rate := h_ℓ_add_R_rate)).toVerifier.compatStatement + (sumcheckFoldStmtLens κ L K β ℓ ℓ' 𝓡 ϑ (h_ℓ_add_R_rate := h_ℓ_add_R_rate)))) + (h := by + simpa using + (BinaryBasefold.CoreInteraction.sumcheckFoldOracleVerifier_rbrKnowledgeSoundness + (L := L) K β + (ϑ := ϑ) + (mp := RingSwitching_SumcheckMultParam κ L K + (β := booleanHypercubeBasis κ L K β) ℓ ℓ' h_l) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (𝓑 := 𝓑) + (init := init) (impl := impl))) + simpa [sumcheckFoldOracleVerifier] using h_lifted end Security end SumcheckFold @@ -346,6 +492,80 @@ section FinalSumcheckStep ## Final Sumcheck Step -/ +/-! ## Pure Logic Functions (ReductionLogicStep Infrastructure) -/ + +/-- Pure verifier check for FRI final sumcheck step. -/ +@[reducible] +def finalSumcheckVerifierCheck + (stmtIn : Statement (L := L) (ℓ := ℓ') + (RingSwitchingBaseContext κ L K ℓ) (Fin.last ℓ')) + (c : L) : Prop := + let eq_tilde_eval : L := RingSwitching.compute_final_eq_value κ L K + (β := booleanHypercubeBasis κ L K β) ℓ ℓ' h_l + stmtIn.ctx.t_eval_point stmtIn.challenges stmtIn.ctx.r_batching + stmtIn.sumcheck_target = eq_tilde_eval * c + +/-- Pure verifier output for FRI final sumcheck step. -/ +@[reducible] +def finalSumcheckVerifierStmtOut + (stmtIn : Statement (L := L) (ℓ := ℓ') + (RingSwitchingBaseContext κ L K ℓ) (Fin.last ℓ')) + (c : L) : BinaryBasefold.FinalSumcheckStatementOut (L := L) (ℓ := ℓ') := { + ctx := { + t_eval_point := getEvaluationPointSuffix κ L ℓ ℓ' h_l stmtIn.ctx.t_eval_point + original_claim := stmtIn.ctx.original_claim + } + sumcheck_target := stmtIn.sumcheck_target + challenges := stmtIn.challenges + final_constant := c + } + +/-- Pure prover message computation for FRI final sumcheck step. -/ +@[reducible] +def finalSumcheckProverComputeMsg + (witIn : BinaryBasefold.Witness K β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ := ℓ') (Fin.last ℓ')) : L := + witIn.f ⟨0, by simp only [zero_mem]⟩ + +/-- Pure prover output witness for FRI final sumcheck step. -/ +@[reducible] +def finalSumcheckProverWitOut : Unit := () + +/-! ## ReductionLogicStep Instance -/ + +/-- The logic instance for the FRI final sumcheck step. -/ +def finalSumcheckStepLogic : + Binius.BinaryBasefold.ReductionLogicStep + (Statement (L := L) (ℓ := ℓ') (RingSwitchingBaseContext κ L K ℓ) (Fin.last ℓ')) + (BinaryBasefold.Witness K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ := ℓ') (Fin.last ℓ')) + (BinaryBasefold.OracleStatement K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ (Fin.last ℓ')) + (BinaryBasefold.OracleStatement K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ (Fin.last ℓ')) + (BinaryBasefold.FinalSumcheckStatementOut (L := L) (ℓ := ℓ')) + Unit + (BinaryBasefold.pSpecFinalSumcheckStep (L := L)) where + completeness_relIn := fun ((stmt, oStmt), wit) => + ((stmt, oStmt), wit) ∈ BinaryBasefold.strictRoundRelation + (mp := RingSwitching_SumcheckMultParam κ L K + (β := booleanHypercubeBasis κ L K β) ℓ ℓ' h_l) K β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) (Fin.last ℓ') + completeness_relOut := fun ((stmtOut, oStmtOut), witOut) => + ((stmtOut, oStmtOut), witOut) ∈ BinaryBasefold.strictFinalSumcheckRelOut K β + (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + verifierCheck := fun stmtIn transcript => + finalSumcheckVerifierCheck κ L K β ℓ ℓ' h_l stmtIn (transcript.messages ⟨0, rfl⟩) + verifierOut := fun stmtIn transcript => + finalSumcheckVerifierStmtOut κ L K ℓ ℓ' h_l stmtIn (transcript.messages ⟨0, rfl⟩) + embed := ⟨fun j => Sum.inl j, fun a b h => by cases h; rfl⟩ + hEq := fun _ => rfl + honestProverTranscript := fun _stmtIn witIn _oStmtIn _chal => + let c : L := finalSumcheckProverComputeMsg (κ := κ) (L := L) (K := K) (β := β) + (ℓ' := ℓ') (𝓡 := 𝓡) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) witIn + FullTranscript.mk1 c + proverOut := fun stmtIn _witIn oStmtIn transcript => + let c : L := transcript.messages ⟨0, rfl⟩ + let stmtOut := finalSumcheckVerifierStmtOut κ L K ℓ ℓ' h_l stmtIn c + ((stmtOut, oStmtIn), ()) + /-- The prover for the final sumcheck step -/ noncomputable def finalSumcheckProver : OracleProver @@ -368,27 +588,17 @@ noncomputable def finalSumcheckProver : (∀ j, BinaryBasefold.OracleStatement K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ (Fin.last ℓ') j) × BinaryBasefold.Witness K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ:=ℓ') (Fin.last ℓ') × L input := fun ⟨⟨stmt, oStmt⟩, wit⟩ => (stmt, oStmt, wit) - sendMessage | ⟨0, _⟩ => fun ⟨stmtIn, oStmtIn, witIn⟩ => do - let f_ℓ : (sDomain K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ:=ℓ') ⟨ℓ', by omega⟩) → L := witIn.f - let c : L := f_ℓ ⟨0, by simp only [zero_mem]⟩ -- f_ℓ(0, ..., 0) + let c : L := finalSumcheckProverComputeMsg (κ := κ) (L := L) (K := K) (β := β) + (ℓ' := ℓ') (𝓡 := 𝓡) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) witIn pure ⟨c, (stmtIn, oStmtIn, witIn, c)⟩ - receiveChallenge | ⟨0, h⟩ => nomatch h -- No challenges in this step - output := fun ⟨stmtIn, oStmtIn, witIn, s'⟩ => do - let stmtOut : BinaryBasefold.FinalSumcheckStatementOut (L:=L) (ℓ:=ℓ') := { - ctx := { - t_eval_point := getEvaluationPointSuffix κ L ℓ ℓ' h_l stmtIn.ctx.t_eval_point, - original_claim := stmtIn.ctx.original_claim, - }, - sumcheck_target := stmtIn.sumcheck_target, - challenges := stmtIn.challenges, - final_constant := s' - } - pure (⟨stmtOut, oStmtIn⟩, ()) + let logic := finalSumcheckStepLogic κ L K β ℓ ℓ' 𝓡 ϑ h_ℓ_add_R_rate h_l (𝓑 := 𝓑) + let t := FullTranscript.mk1 (pSpec := BinaryBasefold.pSpecFinalSumcheckStep (L := L)) s' + pure (logic.proverOut stmtIn witIn oStmtIn t) /-- The verifier for the final sumcheck step -/ noncomputable def finalSumcheckVerifier : @@ -402,44 +612,17 @@ noncomputable def finalSumcheckVerifier : (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ (Fin.last ℓ')) (pSpec := BinaryBasefold.pSpecFinalSumcheckStep (L:=L)) where verify := fun stmtIn _ => do - -- Get the final constant `s'` from the prover's message let s' : L ← query (spec := [(BinaryBasefold.pSpecFinalSumcheckStep - (L:=L)).Message]ₒ) ⟨⟨0, rfl⟩, ()⟩ - - -- 8. `V` sets `e := eq̃(φ₀(r_κ), ..., φ₀(r_{ℓ-1}), φ₁(r'_0), ..., φ₁(r'_{ℓ'-1}))` and - -- decomposes `e =: Σ_{u ∈ {0,1}^κ} β_u ⊗ e_u`. - -- Then `V` computes the final eq value: `(Σ_{u ∈ {0,1}^κ} eq̃(u_0, ..., u_{κ-1},` - -- `r''_0, ..., r''_{κ-1}) ⋅ e_u)` - - let eq_tilde_eval : L := RingSwitching.compute_final_eq_value κ L K - (β := booleanHypercubeBasis κ L K β) ℓ ℓ' h_l - stmtIn.ctx.t_eval_point stmtIn.challenges stmtIn.ctx.r_batching - -- 9. `V` requires `s_{ℓ'} ?= (Σ_{u ∈ {0,1}^κ} eq̃(u_0, ..., u_{κ-1},` - -- `r''_0, ..., r''_{κ-1}) ⋅ e_u) ⋅ s'`. - unless stmtIn.sumcheck_target = eq_tilde_eval * s' do - return { -- dummy stmtOut - ctx := { - t_eval_point := 0, - original_claim := 0, - }, - sumcheck_target := 0, - challenges := 0, - final_constant := 0, - } - -- Return the final sumcheck statement with the constant - let stmtOut : BinaryBasefold.FinalSumcheckStatementOut (L:=L) (ℓ:=ℓ') := { - ctx := { - t_eval_point := getEvaluationPointSuffix κ L ℓ ℓ' h_l stmtIn.ctx.t_eval_point, - original_claim := stmtIn.ctx.original_claim, - }, - sumcheck_target := stmtIn.sumcheck_target, - challenges := stmtIn.challenges, - final_constant := s', - } - pure stmtOut - - embed := ⟨fun j => Sum.inl j, fun a b h => by cases h; rfl⟩ - hEq := fun _ => rfl + (L:=L)).Message]ₒ) ⟨⟨0, by rfl⟩, (by simpa only using ())⟩ + let t := FullTranscript.mk1 (pSpec := BinaryBasefold.pSpecFinalSumcheckStep (L := L)) s' + let logic := finalSumcheckStepLogic κ L K β ℓ ℓ' 𝓡 ϑ h_ℓ_add_R_rate h_l (𝓑 := 𝓑) + have : Decidable (logic.verifierCheck stmtIn t) := Classical.propDecidable _ + guard (logic.verifierCheck stmtIn t) + pure (logic.verifierOut stmtIn t) + embed := (finalSumcheckStepLogic κ L K β ℓ ℓ' 𝓡 ϑ h_ℓ_add_R_rate h_l + (𝓑 := 𝓑)).embed + hEq := (finalSumcheckStepLogic κ L K β ℓ ℓ' 𝓡 ϑ h_ℓ_add_R_rate h_l + (𝓑 := 𝓑)).hEq /-- The oracle reduction for the final sumcheck step -/ noncomputable def finalSumcheckOracleReduction : @@ -454,12 +637,453 @@ noncomputable def finalSumcheckOracleReduction : (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ (Fin.last ℓ')) (WitOut := Unit) (pSpec := BinaryBasefold.pSpecFinalSumcheckStep (L:=L)) where - prover := finalSumcheckProver κ L K β ℓ ℓ' 𝓡 ϑ h_ℓ_add_R_rate h_l - verifier := finalSumcheckVerifier κ L K β ℓ ℓ' 𝓡 ϑ h_ℓ_add_R_rate h_l + prover := finalSumcheckProver κ L K β ℓ ℓ' 𝓡 ϑ h_ℓ_add_R_rate h_l (𝓑 := 𝓑) + verifier := finalSumcheckVerifier κ L K β ℓ ℓ' 𝓡 ϑ h_ℓ_add_R_rate h_l (𝓑 := 𝓑) + +omit [Fintype L] [DecidableEq L] [CharP L 2] [SampleableType L] [NeZero ℓ'] in +/-- At `Fin.last ℓ'`, sumcheck consistency simplifies to a single evaluation. -/ +lemma sumcheckConsistency_at_last_simplifies + (target : L) (H : L⦃≤ 2⦄[X Fin (ℓ' - Fin.last ℓ')]) + (h_cons : BinaryBasefold.sumcheckConsistencyProp (𝓑 := 𝓑) target H) : + target = H.val.eval (fun _ => (0 : L)) := by + simp only [Fin.val_last] at H h_cons ⊢ + simp only [BinaryBasefold.sumcheckConsistencyProp] at h_cons + haveI : IsEmpty (Fin 0) := Fin.isEmpty + rw [Finset.sum_eq_single (a := fun _ => 0) + (h₀ := fun b _ hb_ne => by + exfalso + apply hb_ne + funext i + simp only [tsub_self] at i + exact i.elim0) + (h₁ := fun h_not_mem => by + exfalso + apply h_not_mem + simp only [Fintype.mem_piFinset] + intro i + simp only [tsub_self] at i + exact i.elim0)] at h_cons + exact h_cons + +omit [NeZero κ] [CharP L 2] [SampleableType L] [NeZero ℓ] in +/-- The final codeword value at `0` equals `t(challenges)`. -/ +lemma finalCodeword_zero_eq_t_eval + (stmtIn : Statement (L := L) (ℓ := ℓ') (RingSwitchingBaseContext κ L K ℓ) (Fin.last ℓ')) + (witIn : BinaryBasefold.Witness K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ := ℓ') (Fin.last ℓ')) + (h_wit_struct : BinaryBasefold.witnessStructuralInvariant K β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (mp := RingSwitching_SumcheckMultParam κ L K + (β := booleanHypercubeBasis κ L K β) ℓ ℓ' h_l) + (stmt := stmtIn) (wit := witIn)) : + witIn.f ⟨0, by simp only [zero_mem]⟩ = witIn.t.val.eval stmtIn.challenges := by + have h_f_eq_getMidCodewords_t : + witIn.f = BinaryBasefold.getMidCodewords K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := Fin.last ℓ') witIn.t stmtIn.challenges := h_wit_struct.2 + dsimp only [BinaryBasefold.getMidCodewords, Fin.coe_ofNat_eq_mod] at h_f_eq_getMidCodewords_t + rw [congr_fun h_f_eq_getMidCodewords_t ⟨0, by simp only [zero_mem]⟩] + let h_eval := BinaryBasefold.iterated_fold_to_level_ℓ_eval K β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (t := witIn.t) + (destIdx := ⟨Fin.last ℓ', by omega⟩) + (h_destIdx := by simp only [Fin.val_last]) (challenges := stmtIn.challenges) + exact congr_fun h_eval ⟨0, by simp only [Fin.val_last, zero_mem]⟩ + +omit [SampleableType L] [NeZero κ] [NeZero ℓ] in +/-- Strict helper: folding the last oracle block in the final sumcheck step yields +the constant function equal to the prover message `witIn.f(0)`. -/ +lemma iterated_fold_to_const_strict + (stmtIn : Statement (L := L) (ℓ := ℓ') + (RingSwitchingBaseContext κ L K ℓ) (Fin.last ℓ')) + (witIn : BinaryBasefold.Witness K β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ := ℓ') (Fin.last ℓ')) + (oStmtIn : ∀ j, BinaryBasefold.OracleStatement K β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ (Fin.last ℓ') j) + (h_strictOracleWitConsistency_In : BinaryBasefold.strictOracleWitnessConsistency K β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (Context := RingSwitchingBaseContext κ L K ℓ) + (mp := RingSwitching_SumcheckMultParam κ L K + (β := booleanHypercubeBasis κ L K β) ℓ ℓ' h_l) + (stmtIdx := Fin.last ℓ') + (oracleIdx := OracleFrontierIndex.mkFromStmtIdx (Fin.last ℓ')) + (stmt := stmtIn) (wit := witIn) (oStmt := oStmtIn)) : + let c : L := witIn.f ⟨0, by simp only [zero_mem]⟩ + let lastDomainIdx := getLastOracleDomainIndex ℓ' ϑ (Fin.last ℓ') + let k := lastDomainIdx.val + have h_k : k = ℓ' - ϑ := by + dsimp only [k, lastDomainIdx] + rw [getLastOraclePositionIndex_last, Nat.sub_mul, Nat.one_mul, + Nat.div_mul_cancel (hdiv.out)] + let curDomainIdx : Fin (2 ^ κ) := ⟨k, by + rw [h_k] + omega + ⟩ + have h_destIdx_eq : curDomainIdx.val = lastDomainIdx.val := rfl + let f_k : OracleFunction K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) curDomainIdx := + getLastOracle (h_destIdx := h_destIdx_eq) (oracleFrontierIdx := Fin.last ℓ') + K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (oStmt := oStmtIn) + let finalChallenges : Fin ϑ → L := fun cId => stmtIn.challenges ⟨k + cId, by + rw [h_k] + have h_le : ϑ ≤ ℓ' := by apply Nat.le_of_dvd (by exact Nat.pos_of_neZero ℓ') (hdiv.out) + have h_cId : cId.val < ϑ := cId.isLt + have h_last : (Fin.last ℓ').val = ℓ' := rfl + omega + ⟩ + let destDomainIdx : Fin (2 ^ κ) := ⟨k + ϑ, by + rw [h_k] + have h_le : ϑ ≤ ℓ' := by apply Nat.le_of_dvd (by exact Nat.pos_of_neZero ℓ') (hdiv.out) + omega + ⟩ + let folded := iterated_fold K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := curDomainIdx) (steps := ϑ) (destIdx := destDomainIdx) (h_destIdx := by rfl) + (h_destIdx_le := by + dsimp only [destDomainIdx, k, lastDomainIdx] + rw [getLastOraclePositionIndex_last, Nat.sub_mul, Nat.one_mul, + Nat.div_mul_cancel (hdiv.out)] + rw [Nat.sub_add_cancel (by + exact Nat.le_of_dvd (h := by exact Nat.pos_of_neZero ℓ') (hdiv.out))] + ) (f := f_k) + (r_challenges := finalChallenges) + ∀ y, folded y = c := by + have h_ϑ_le_ℓ' : ϑ ≤ ℓ' := by + apply Nat.le_of_dvd (by exact Nat.pos_of_neZero ℓ') (hdiv.out) + intro c lastDomainIdx k h_k curDomainIdx h_destIdx_eq f_k finalChallenges destDomainIdx folded + let P₀ : L[X]_(2 ^ ℓ') := polynomialFromNovelCoeffsF₂ K β ℓ' (by omega) + (fun ω => witIn.t.val.eval (bitsOfIndex ω)) + let f₀ := polyToOracleFunc K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (domainIdx := 0) (P := P₀) + have h_wit_struct := h_strictOracleWitConsistency_In.1 + have h_strict_oracle_folding := h_strictOracleWitConsistency_In.2 + dsimp only [Fin.val_last, OracleFrontierIndex.val_mkFromStmtIdx, + strictOracleFoldingConsistencyProp] at h_strict_oracle_folding + have h_eq : folded = fun x => c := by + dsimp only [folded, f_k] + have h_f_last_consistency := h_strict_oracle_folding + (j := (getLastOraclePositionIndex ℓ' ϑ (Fin.last ℓ'))) + have h_wit_f_eq : witIn.f = getMidCodewords K β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) witIn.t stmtIn.challenges := h_wit_struct.2 + dsimp only [Fin.val_last, getMidCodewords] at h_wit_f_eq + dsimp only [c] + conv_rhs => + rw [h_wit_f_eq] + simp only [Fin.val_last] + have h_curDomainIdx_eq : curDomainIdx = ⟨ℓ' - ϑ, by omega⟩ := by + dsimp [curDomainIdx, k, lastDomainIdx] + simp only [Fin.mk.injEq] + rw [getLastOraclePositionIndex_last, Nat.sub_mul, Nat.div_mul_cancel (hdiv.out)] + simp only [one_mul] + let res := iterated_fold_congr_source_index K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := curDomainIdx) (i' := ⟨ℓ' - ϑ, by omega⟩) (h := h_curDomainIdx_eq) (steps := ϑ) + (destIdx := destDomainIdx) + (h_destIdx := by rfl) (h_destIdx' := by simp only [destDomainIdx, h_k]) + (h_destIdx_le := by + dsimp only [destDomainIdx] + rw [h_k] + rw [Nat.sub_add_cancel (by + exact Nat.le_of_dvd (h := by exact Nat.pos_of_neZero ℓ') (hdiv.out))] + ) (f := (getLastOracle K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) h_destIdx_eq oStmtIn)) + (r_challenges := finalChallenges) + rw [res] + dsimp only [getLastOracle, finalChallenges] + rw [h_f_last_consistency] + simp only [Fin.take_eq_self] + let k_pos_idx := getLastOraclePositionIndex ℓ' ϑ (Fin.last ℓ') + let k_steps := k_pos_idx.val * ϑ + have h_k_steps_eq : k_steps = k := by + dsimp only [k_steps, k_pos_idx, k, lastDomainIdx] + have h_cast_elim := iterated_fold_congr_dest_index K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := 0) (steps := k_steps) (destIdx := curDomainIdx) (destIdx' := ⟨k_steps, by omega⟩) + (h_destIdx := by simp only [Fin.coe_ofNat_eq_mod, Nat.zero_mod, zero_add]; omega) + (h_destIdx_le := by + dsimp only [curDomainIdx] + simp only [h_k, tsub_le_iff_right, le_add_iff_nonneg_right, zero_le] + ) (h_destIdx_eq_destIdx' := by rfl) + (f := f₀) + (r_challenges := getFoldingChallenges (𝓡 := 𝓡) (r := 2 ^ κ) (Fin.last ℓ') + stmtIn.challenges 0 (by simp only [zero_add, Fin.val_last]; omega)) + have h_cast_elim2 := iterated_fold_congr_dest_index K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := 0) (steps := k_steps) (destIdx := ⟨ℓ' - ϑ, by omega⟩) (destIdx' := curDomainIdx) + (h_destIdx := by simp only [Fin.coe_ofNat_eq_mod, Nat.zero_mod, zero_add]; omega) + (h_destIdx_le := by + dsimp only [curDomainIdx] + simp only [tsub_le_iff_right, le_add_iff_nonneg_right, zero_le] + ) + (h_destIdx_eq_destIdx' := by + dsimp only [curDomainIdx] + simp only [Fin.mk.injEq]; omega + ) + (f := f₀) + (r_challenges := getFoldingChallenges (𝓡 := 𝓡) (r := 2 ^ κ) (Fin.last ℓ') + stmtIn.challenges 0 (by simp only [zero_add, Fin.val_last]; omega)) + dsimp only [k_steps, k_pos_idx, f₀, P₀] at h_cast_elim + dsimp only [k_steps, k_pos_idx, f₀, P₀] at h_cast_elim2 + conv_lhs => + simp only [←h_cast_elim] + simp only [←h_cast_elim2] + simp only [←fun_eta_expansion] + have h_transitivity := iterated_fold_transitivity K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := 0) (midIdx := ⟨ℓ' - ϑ, by omega⟩) (destIdx := destDomainIdx) + (steps₁ := k_steps) (steps₂ := ϑ) + (h_midIdx := by + simp only [Fin.coe_ofNat_eq_mod, Nat.zero_mod, h_k_steps_eq, h_k, zero_add] + ) + (h_destIdx := by + dsimp only [destDomainIdx, k_steps, k_pos_idx] + rw [h_k] + simp only [Fin.coe_ofNat_eq_mod, Nat.zero_mod, zero_add, Nat.add_right_cancel_iff] + rw [getLastOraclePositionIndex_last] + simp only + rw [Nat.sub_mul, Nat.div_mul_cancel (hdiv.out)] + simp only [one_mul] + ) + (h_destIdx_le := by + dsimp only [destDomainIdx] + rw [h_k] + rw [Nat.sub_add_cancel (by + exact Nat.le_of_dvd (h := by exact Nat.pos_of_neZero ℓ') (hdiv.out))] + ) + (f := f₀) + (r_challenges₁ := getFoldingChallenges (𝓡 := 𝓡) (r := 2 ^ κ) (Fin.last ℓ') + stmtIn.challenges 0 (by simp only [zero_add, Fin.val_last]; omega)) + (r_challenges₂ := finalChallenges) + have h_finalChallenges_eq : finalChallenges = fun cId : Fin ϑ => stmtIn.challenges + ⟨k + cId.val, by + rw [h_k] + have h_le : ϑ ≤ ℓ' := by + apply Nat.le_of_dvd (by exact Nat.pos_of_neZero ℓ') (hdiv.out) + have h_cId : cId.val < ϑ := cId.isLt + have h_last : (Fin.last ℓ').val = ℓ' := rfl + omega + ⟩ := by + rfl + rw [h_finalChallenges_eq] at h_transitivity + rw [h_transitivity] + have h_steps_eq : k_steps + ϑ = ℓ' := by + dsimp only [k_steps, k_pos_idx, h_k_steps_eq, h_k] + rw [getLastOraclePositionIndex_last] + simp only [Nat.sub_mul, Nat.one_mul, Nat.div_mul_cancel (hdiv.out)] + rw [Nat.sub_add_cancel (by + exact Nat.le_of_dvd (h := by exact Nat.pos_of_neZero ℓ') (hdiv.out))] + have h_concat_challenges_eq : + Fin.append + (getFoldingChallenges (𝓡 := 𝓡) (r := 2 ^ κ) (ϑ := k_steps) + (Fin.last ℓ') stmtIn.challenges 0 + (by simp only [zero_add, Fin.val_last]; omega)) + finalChallenges = + fun (cIdx : Fin (k_steps + ϑ)) => stmtIn.challenges ⟨cIdx, by + simp only [Fin.val_last] + omega + ⟩ := by + funext cId + dsimp only [getFoldingChallenges, finalChallenges] + by_cases h : cId.val < k_steps + · simp only [Fin.val_last] + dsimp only [Fin.append, Fin.addCases] + simp only [h, ↓reduceDIte, getFoldingChallenges, Fin.val_last, Fin.val_castLT, zero_add] + · simp only [Fin.val_last] + dsimp only [Fin.append, Fin.addCases] + simp [h, ↓reduceDIte, Fin.val_subNat, Fin.val_cast, eq_rec_constant] + congr 1 + simp only [Fin.val_last, Fin.mk.injEq] + rw [add_comm, ←h_k_steps_eq] + omega + dsimp only [finalChallenges] at h_concat_challenges_eq + simp only [h_concat_challenges_eq] + funext y + have h_cast_elim3 := iterated_fold_congr_dest_index K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := 0) (steps := k_steps + ϑ) (destIdx := destDomainIdx) + (destIdx' := ⟨Fin.last ℓ', by omega⟩) + (h_destIdx := by simp only [Fin.coe_ofNat_eq_mod, Nat.zero_mod, zero_add]; rfl) + (h_destIdx_le := by dsimp only [destDomainIdx]; omega) + (h_destIdx_eq_destIdx' := by + dsimp only [destDomainIdx] + simp only [Fin.val_last, Fin.mk.injEq] + omega + ) + (f := f₀) + (r_challenges := fun (cIdx : Fin (k_steps + ϑ)) => stmtIn.challenges ⟨cIdx, by + simp only [Fin.val_last] + omega + ⟩) + rw [h_cast_elim3] + have h_cast_elim4 := iterated_fold_congr_steps_index K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := 0) (steps := ℓ') (steps' := k_steps + ϑ) + (destIdx := ⟨Fin.last ℓ', by omega⟩) + (h_steps_eq_steps' := by simp only [h_steps_eq]) + (h_destIdx := by + dsimp only [destDomainIdx] + simp only [Fin.val_last, Fin.coe_ofNat_eq_mod, Nat.zero_mod, zero_add] + ) + (h_destIdx_le := by simp only [Fin.val_last, le_refl]) + (f := f₀) (r_challenges := stmtIn.challenges) + rw [←h_cast_elim4] + set f_last := iterated_fold K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) 0 ℓ' + (destIdx := ⟨Fin.last ℓ', by omega⟩) + (h_destIdx := by + simp only [Fin.val_last, Fin.coe_ofNat_eq_mod, Nat.zero_mod, zero_add] + ) + (h_destIdx_le := by simp only [Fin.val_last, le_refl]) (f := f₀) + (r_challenges := stmtIn.challenges) + have h_eval_eq : ∀ x, f_last x = f_last ⟨0, by simp only [zero_mem]⟩ := by + intro x + apply iterated_fold_to_level_ℓ_is_constant K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (t := witIn.t) (destIdx := ⟨Fin.last ℓ', by omega⟩) + (h_destIdx := by simp only [Fin.val_last]) (challenges := stmtIn.challenges) + (x := x) (y := 0) + rw [h_eval_eq] + rfl + rw [h_eq] + intro y + rfl + +omit [NeZero κ] [CharP L 2] [SampleableType L] [DecidableEq K] h_β₀_eq_1 [NeZero ℓ] in +/-- Honest prover message in final sumcheck equals `witIn.f(0)`. -/ +lemma finalSumcheck_honest_message_eq_f_zero + (stmtIn : Statement (L := L) (ℓ := ℓ') + (RingSwitchingBaseContext κ L K ℓ) (Fin.last ℓ')) + (witIn : BinaryBasefold.Witness K β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ := ℓ') (Fin.last ℓ')) + (oStmtIn : ∀ j, BinaryBasefold.OracleStatement K β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ (Fin.last ℓ') j) + (challenges : (BinaryBasefold.pSpecFinalSumcheckStep (L := L)).Challenges) : + let step := finalSumcheckStepLogic κ L K β ℓ ℓ' 𝓡 ϑ h_ℓ_add_R_rate h_l (𝓑 := 𝓑) + let transcript := step.honestProverTranscript stmtIn witIn oStmtIn challenges + transcript.messages ⟨0, rfl⟩ = witIn.f ⟨0, by simp only [zero_mem]⟩ := by + simp only [finalSumcheckStepLogic, finalSumcheckProverComputeMsg] + +/-- Verifier check passes in the FRI final sumcheck logic step. -/ +lemma finalSumcheckStep_verifierCheck_passed + (stmtIn : Statement (L := L) (ℓ := ℓ') + (RingSwitchingBaseContext κ L K ℓ) (Fin.last ℓ')) + (witIn : BinaryBasefold.Witness K β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ := ℓ') (Fin.last ℓ')) + (oStmtIn : ∀ j, BinaryBasefold.OracleStatement K β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ (Fin.last ℓ') j) + (challenges : (BinaryBasefold.pSpecFinalSumcheckStep (L := L)).Challenges) + (h_sumcheck_cons : BinaryBasefold.sumcheckConsistencyProp + (𝓑 := 𝓑) stmtIn.sumcheck_target witIn.H) + (h_wit_struct : BinaryBasefold.witnessStructuralInvariant K β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (mp := RingSwitching_SumcheckMultParam κ L K + (β := booleanHypercubeBasis κ L K β) ℓ ℓ' h_l) + (stmt := stmtIn) (wit := witIn)) : + let step := finalSumcheckStepLogic κ L K β ℓ ℓ' 𝓡 ϑ h_ℓ_add_R_rate h_l (𝓑 := 𝓑) + let transcript := step.honestProverTranscript stmtIn witIn oStmtIn challenges + step.verifierCheck stmtIn transcript := by + intro step transcript + have h_target_eq_H_eval : + stmtIn.sumcheck_target = witIn.H.val.eval (fun _ => (0 : L)) := + sumcheckConsistency_at_last_simplifies (L := L) (ℓ' := ℓ') (𝓑 := 𝓑) + stmtIn.sumcheck_target witIn.H h_sumcheck_cons + have h_proj_eval : + (BinaryBasefold.projectToMidSumcheckPoly (L := L) (ℓ := ℓ') (t := witIn.t) + (m := (RingSwitching_SumcheckMultParam κ L K + (β := booleanHypercubeBasis κ L K β) ℓ ℓ' h_l).multpoly stmtIn.ctx) + (i := Fin.last ℓ') (challenges := stmtIn.challenges)).val.eval (fun _ => (0 : L)) = + ((RingSwitching_SumcheckMultParam κ L K + (β := booleanHypercubeBasis κ L K β) ℓ ℓ' h_l).multpoly stmtIn.ctx).val.eval + stmtIn.challenges * witIn.t.val.eval stmtIn.challenges := by + apply BinaryBasefold.projectToMidSumcheckPoly_at_last_eval + have h_mult_eq_eq_value : + ((RingSwitching_SumcheckMultParam κ L K + (β := booleanHypercubeBasis κ L K β) ℓ ℓ' h_l).multpoly stmtIn.ctx).val.eval + stmtIn.challenges = + RingSwitching.compute_final_eq_value κ L K + (β := booleanHypercubeBasis κ L K β) ℓ ℓ' h_l + stmtIn.ctx.t_eval_point stmtIn.challenges stmtIn.ctx.r_batching := + RingSwitching.compute_A_MLE_eval_eq_final_eq_value κ L K + (β := booleanHypercubeBasis κ L K β) ℓ ℓ' h_l + stmtIn.ctx.t_eval_point stmtIn.challenges stmtIn.ctx.r_batching + have h_c_eq : witIn.f ⟨0, by simp only [zero_mem]⟩ = witIn.t.val.eval stmtIn.challenges := by + exact finalCodeword_zero_eq_t_eval (κ := κ) (L := L) (K := K) (β := β) + (ℓ := ℓ) (ℓ' := ℓ') (𝓡 := 𝓡) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (h_l := h_l) + stmtIn witIn h_wit_struct + let cmsg : L := transcript.messages ⟨0, rfl⟩ + have h_msg_eq : cmsg = witIn.f ⟨0, by simp only [zero_mem]⟩ := + finalSumcheck_honest_message_eq_f_zero (κ := κ) (L := L) (K := K) (β := β) (ℓ := ℓ) + (ℓ' := ℓ') (𝓡 := 𝓡) (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (h_l := h_l) + (𝓑 := 𝓑) stmtIn witIn oStmtIn challenges + have h_eq : stmtIn.sumcheck_target = RingSwitching.compute_final_eq_value κ L K + (β := booleanHypercubeBasis κ L K β) ℓ ℓ' h_l + stmtIn.ctx.t_eval_point stmtIn.challenges stmtIn.ctx.r_batching * + cmsg := by + calc + stmtIn.sumcheck_target + = witIn.H.val.eval (fun _ => (0 : L)) := h_target_eq_H_eval + _ = (BinaryBasefold.projectToMidSumcheckPoly (L := L) (ℓ := ℓ') (t := witIn.t) + (m := (RingSwitching_SumcheckMultParam κ L K + (β := booleanHypercubeBasis κ L K β) ℓ ℓ' h_l).multpoly stmtIn.ctx) + (i := Fin.last ℓ') (challenges := stmtIn.challenges)).val.eval (fun _ => (0 : L)) := by + rw [h_wit_struct.1] + _ = ((RingSwitching_SumcheckMultParam κ L K + (β := booleanHypercubeBasis κ L K β) ℓ ℓ' h_l).multpoly stmtIn.ctx).val.eval + stmtIn.challenges * witIn.t.val.eval stmtIn.challenges := h_proj_eval + _ = RingSwitching.compute_final_eq_value κ L K + (β := booleanHypercubeBasis κ L K β) ℓ ℓ' h_l + stmtIn.ctx.t_eval_point stmtIn.challenges stmtIn.ctx.r_batching * + witIn.t.val.eval stmtIn.challenges := by + rw [h_mult_eq_eq_value] + _ = RingSwitching.compute_final_eq_value κ L K + (β := booleanHypercubeBasis κ L K β) ℓ ℓ' h_l + stmtIn.ctx.t_eval_point stmtIn.challenges stmtIn.ctx.r_batching * + witIn.f ⟨0, by simp only [zero_mem]⟩ := by + rw [h_c_eq] + _ = RingSwitching.compute_final_eq_value κ L K + (β := booleanHypercubeBasis κ L K β) ℓ ℓ' h_l + stmtIn.ctx.t_eval_point stmtIn.challenges stmtIn.ctx.r_batching * + cmsg := by + rw [←h_msg_eq] + simpa [step, finalSumcheckStepLogic, finalSumcheckVerifierCheck, cmsg] using h_eq + +/-- Strong completeness of the FRI final sumcheck logic step. -/ +lemma finalSumcheckStep_is_logic_complete : + (finalSumcheckStepLogic κ L K β ℓ ℓ' 𝓡 ϑ h_ℓ_add_R_rate h_l + (𝓑 := 𝓑)).IsStronglyComplete := by + intro stmtIn witIn oStmtIn challenges h_relIn + let step := finalSumcheckStepLogic κ L K β ℓ ℓ' 𝓡 ϑ h_ℓ_add_R_rate h_l (𝓑 := 𝓑) + let transcript := step.honestProverTranscript stmtIn witIn oStmtIn challenges + let verifierStmtOut := step.verifierOut stmtIn transcript + let verifierOStmtOut := OracleVerifier.mkVerifierOStmtOut step.embed step.hEq + oStmtIn transcript + let proverOutput := step.proverOut stmtIn witIn oStmtIn transcript + let proverStmtOut := proverOutput.1.1 + let proverOStmtOut := proverOutput.1.2 + let proverWitOut := proverOutput.2 + simp only [finalSumcheckStepLogic, BinaryBasefold.strictRoundRelation, + BinaryBasefold.strictRoundRelationProp, Set.mem_setOf_eq] at h_relIn + obtain ⟨h_sumcheck_cons, h_strictOracleWitConsistency⟩ := h_relIn + have h_wit_struct := h_strictOracleWitConsistency.1 + let h_VCheck_passed : step.verifierCheck stmtIn transcript := + finalSumcheckStep_verifierCheck_passed (κ := κ) (L := L) (K := K) (β := β) + (ℓ := ℓ) (ℓ' := ℓ') (𝓡 := 𝓡) (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (h_l := h_l) (𝓑 := 𝓑) stmtIn witIn oStmtIn challenges h_sumcheck_cons h_wit_struct + have hStmtOut_eq : proverStmtOut = verifierStmtOut := by + change (step.proverOut stmtIn witIn oStmtIn transcript).1.1 = step.verifierOut stmtIn transcript + simp only [step, finalSumcheckStepLogic, finalSumcheckVerifierStmtOut] + have hOStmtOut_eq : proverOStmtOut = verifierOStmtOut := by rfl + have hRelOut : step.completeness_relOut ((verifierStmtOut, verifierOStmtOut), proverWitOut) := by + simp only [step, finalSumcheckStepLogic] + refine ⟨witIn.t, ?_⟩ + unfold BinaryBasefold.strictfinalSumcheckStepFoldingStateProp + dsimp only [finalSumcheckVerifierStmtOut] + constructor + · exact h_strictOracleWitConsistency.2 + · funext y + have h_const := iterated_fold_to_const_strict (κ := κ) (L := L) (K := K) (β := β) + (ℓ := ℓ) (ℓ' := ℓ') (𝓡 := 𝓡) (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (h_l := h_l) (stmtIn := stmtIn) (witIn := witIn) (oStmtIn := oStmtIn) + (h_strictOracleWitConsistency_In := h_strictOracleWitConsistency) y + have h_msg_eq : transcript.messages ⟨0, rfl⟩ = witIn.f ⟨0, by simp only [zero_mem]⟩ := + finalSumcheck_honest_message_eq_f_zero (κ := κ) (L := L) (K := K) (β := β) (ℓ := ℓ) + (ℓ' := ℓ') (𝓡 := 𝓡) (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (h_l := h_l) + (𝓑 := 𝓑) stmtIn witIn oStmtIn challenges + simpa [transcript, step, finalSumcheckStepLogic, finalSumcheckVerifierStmtOut, h_msg_eq] + using h_const + refine ⟨?_, ?_, ?_, ?_⟩ + · exact h_VCheck_passed + · exact hRelOut + · exact hStmtOut_eq + · exact hOStmtOut_eq /-- Perfect completeness for the final sumcheck step -/ theorem finalSumcheckOracleReduction_perfectCompleteness {σ : Type} - (init : ProbComp σ) + (init : ProbComp σ) (hInit : NeverFail init) (impl : QueryImpl []ₒ (StateT σ ProbComp)) : OracleReduction.perfectCompleteness (pSpec := BinaryBasefold.pSpecFinalSumcheckStep (L:=L)) @@ -468,13 +1092,174 @@ theorem finalSumcheckOracleReduction_perfectCompleteness {σ : Type} (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) (Fin.last ℓ')) (relOut := BinaryBasefold.strictFinalSumcheckRelOut K β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) - (oracleReduction := finalSumcheckOracleReduction κ L K β ℓ ℓ' 𝓡 ϑ + (oracleReduction := finalSumcheckOracleReduction κ L K β ℓ ℓ' 𝓡 ϑ (𝓑 := 𝓑) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) h_l) (init := init) (impl := impl) := by - unfold OracleReduction.perfectCompleteness - intro stmtIn witIn h_relIn - simp only - sorry + rw [OracleReduction.unroll_1_message_reduction_perfectCompleteness_P_to_V (hInit := hInit) + (hDir0 := by rfl) + (hImplSupp := by simp only [Set.fmap_eq_image, IsEmpty.forall_iff, implies_true])] + intro stmtIn oStmtIn witIn h_relIn + -- Step 2: Convert probability 1 to universal quantification over support + rw [probEvent_eq_one_iff] + -- Step 3: Unfold protocol definitions + dsimp only [finalSumcheckOracleReduction, finalSumcheckProver, finalSumcheckVerifier, + OracleVerifier.toVerifier, FullTranscript.mk1] + let step := finalSumcheckStepLogic κ L K β ℓ ℓ' 𝓡 ϑ h_ℓ_add_R_rate h_l (𝓑 := 𝓑) + let strongly_complete : step.IsStronglyComplete := finalSumcheckStep_is_logic_complete + (κ := κ) (L := L) (K := K) (β := β) (ℓ := ℓ) (ℓ' := ℓ') (𝓡 := 𝓡) (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (h_l := h_l) (𝓑 := 𝓑) + -- Step 4: Split into safety and correctness goals + refine ⟨?_, ?_⟩ + -- GOAL 1: SAFETY - Prove the verifier never crashes ([⊥|...] = 0) + · -- Peel off monadic layers to reach the core verifier logic + simp only [probFailure_bind_eq_zero_iff] + conv_lhs => + simp only [liftComp_eq_liftM, liftM_pure, probFailure_eq_zero] + rw [true_and] + intro inputState hInputState_mem_support + simp only [Fin.isValue, Message, Matrix.cons_val_zero, Fin.succ_zero_eq_one, ChallengeIdx, + Challenge, liftComp_eq_liftM, liftM_pure, support_pure, + Set.mem_singleton_iff] at hInputState_mem_support + conv_lhs => + simp only [liftM, monadLift, MonadLift.monadLift] + simp only [ChallengeIdx, Challenge, Fin.isValue, Matrix.cons_val_one, Matrix.cons_val_zero, + liftComp_eq_liftM, OptionT.probFailure_lift, HasEvalPMF.probFailure_eq_zero] + rw [true_and] + -- ⊢ ∀ x ∈ .. support, ... ∧ ... ∧ ... + intro h_prover_final_output h_prover_final_output_support + conv => + simp only [guard_eq] -- simplify the `guard` + enter [2]; + simp only [bind_pure_comp, NeverFail.probFailure_eq_zero, implies_true] + rw [and_true] + -- Pr[⊥ | (...) : OracleComp ... (Option ...)] = 0 + rw [OptionT.probFailure_liftComp_of_OracleComp_Option] -- split into two summands + conv_lhs => + enter [1] + simp only [MessageIdx, Fin.isValue, Message, Matrix.cons_val_zero, Fin.succ_zero_eq_one, + id_eq, bind_pure_comp, OptionT.run_map, HasEvalPMF.probFailure_eq_zero] + rw [zero_add] + simp only [probOutput_eq_zero_iff] + rw [OptionT.support_run_eq] + simp only [←probOutput_eq_zero_iff] + simp_all only + change Pr[= none | OptionT.run (m := (OracleComp []ₒ)) (x := (OptionT.bind _ _)) ] = 0 + rw [OptionT.probOutput_none_bind_eq_zero_iff] + conv => + enter [x] + rw [OptionT.support_run] + intro vStmtOut h_vStmtOut_mem_support + conv at h_vStmtOut_mem_support => + erw [simulateQ_bind] + -- turn the simulated oracle query into OracleInterface.answer form + rw [OptionT.simulateQ_simOracle2_liftM_query_T2] -- V queries P's message + change vStmtOut ∈ (Bind.bind (m := (OracleComp []ₒ)) _ _).support + erw [_root_.bind_pure_simulateQ_comp] + simp only [Matrix.cons_val_zero, guard_eq] + -- simp [bind_pure_comp, + -- OptionT.simulateQ_map, OptionT.simulateQ_ite, OptionT.simulateQ_pure, + -- OptionT.support_map_run, OptionT.support_ite_run, support_pure, + -- OptionT.support_failure_run, Set.mem_image, Set.mem_ite_empty_right, + -- Set.mem_singleton_iff, and_true, exists_const, Prod.mk.injEq, existsAndEq] + rw [bind_pure_comp] + dsimp only [Functor.map] + rw [OptionT.simulateQ_bind] + erw [support_bind] + rw [simulateQ_ite] + simp only [Fin.isValue, Message, Matrix.cons_val_zero, id_eq, MessageIdx, support_ite, + toPFunctor_emptySpec, Function.comp_apply, OptionT.simulateQ_pure, Set.mem_iUnion, + exists_prop] + simp only [OptionT.simulateQ_failure'] + erw [_root_.simulateQ_pure] + set V_check := step.verifierCheck stmtIn + (FullTranscript.mk1 (msg0 := _)) with h_V_check_def + obtain ⟨h_V_check, h_rel, h_agree⟩ := strongly_complete (stmtIn := stmtIn) + (witIn := witIn) (h_relIn := h_relIn) (challenges := + fun ⟨j, hj⟩ => by + match j with + | 0 => + have hj_ne : (pSpecFinalSumcheckStep (L := L)).dir 0 ≠ Direction.V_to_P := by + dsimp only [pSpecFinalSumcheckStep, Fin.isValue, Matrix.cons_val_zero] + simp only [ne_eq, reduceCtorEq, not_false_eq_true] + exfalso + exact hj_ne hj + ) + have h_V_check_is_true : V_check := h_V_check + simp only [h_V_check_is_true, ↓reduceIte, support_pure, Set.mem_singleton_iff, Fin.isValue, + Fin.val_last, exists_eq_left, OptionT.support_OptionT_pure_run] at h_vStmtOut_mem_support + rw [h_vStmtOut_mem_support] + simp only [Fin.isValue, Fin.val_last, OptionT.run_pure, probOutput_eq_zero_iff, support_pure, + Set.mem_singleton_iff, reduceCtorEq, not_false_eq_true] + · -- GOAL 2: CORRECTNESS - Prove all outputs in support satisfy the relation + intro x hx_mem_support + rcases x with ⟨⟨prvStmtOut, prvOStmtOut⟩, ⟨verStmtOut, verOStmtOut⟩, witOut⟩ + simp only + -- Step 2a: Simplify the support membership to extract the challenge + simp only [ + support_bind, support_pure, + Set.mem_iUnion, Set.mem_singleton_iff, exists_prop, Prod.exists + ] at hx_mem_support + conv at hx_mem_support => + erw [OptionT.support_mk, support_pure] + simp only [ + Set.mem_singleton_iff, Option.some.injEq, Set.setOf_eq_eq_singleton, Prod.mk.injEq, + OptionT.mem_support_iff, + OptionT.run_monadLift, support_map, Set.mem_image, exists_eq_right, Fin.succ_one_eq_two, + id_eq, guard_eq, bind_pure_comp, + toPFunctor_add, toPFunctor_emptySpec, OptionT.support_run, ↓existsAndEq, and_true, true_and, + exists_eq_right_right', liftM_pure, support_pure, exists_eq_left] + dsimp only [monadLift, MonadLift.monadLift] + simp only [Fin.isValue, Challenge, ChallengeIdx, + liftComp_eq_liftM, liftM_pure, liftComp_pure, support_pure, Set.mem_singleton_iff, + Fin.reduceLast, MessageIdx, Message] at hx_mem_support + -- Step 2b: Extract the challenge r1 and the trace equations + rcases hx_mem_support with ⟨prvWitOut, h_prvOut_mem_support, h_verOut_mem_support⟩ + conv at h_prvOut_mem_support => + dsimp only [finalSumcheckStepLogic] + simp only [Fin.val_last, Fin.isValue, Prod.mk.injEq, and_true] + -- Step 2c: Simplify the verifier computation + conv at h_verOut_mem_support => + erw [simulateQ_bind] + simp only [Set.mem_singleton_iff] + change some (verStmtOut, verOStmtOut) ∈ (liftComp _ _).support + rw [support_liftComp] + dsimp only [Functor.map] + erw [support_bind] + simp only [Fin.isValue, Fin.val_last, OptionT.simulateQ_simOracle2_liftM_query_T2, pure_bind, + OptionT.simulateQ_bind, toPFunctor_emptySpec, Function.comp_apply, OptionT.simulateQ_pure, + Set.mem_iUnion, exists_prop] + rw [simulateQ_ite]; erw [simulateQ_pure] + simp only [OptionT.simulateQ_failure'] + set V_check := step.verifierCheck stmtIn + (FullTranscript.mk1 + (msg0 := _))with h_V_check_def + -- Step 2e: Apply the logic completeness lemma + obtain ⟨h_V_check, h_rel, h_agree⟩ := strongly_complete (stmtIn := stmtIn) + (witIn := witIn) (h_relIn := h_relIn) (challenges := + fun ⟨j, hj⟩ => by + match j with + | 0 => + have hj_ne : (pSpecFinalSumcheckStep (L := L)).dir 0 ≠ Direction.V_to_P := by + dsimp only [pSpecFinalSumcheckStep, Fin.isValue, Matrix.cons_val_zero] + simp only [ne_eq, reduceCtorEq, not_false_eq_true] + exfalso + exact hj_ne hj + ) + have h_V_check_is_true : V_check := h_V_check + simp only [h_V_check_is_true, ↓reduceIte, Fin.isValue] at h_verOut_mem_support + erw [support_bind, support_pure] at h_verOut_mem_support + simp only [Set.mem_singleton_iff, Fin.isValue, Set.iUnion_iUnion_eq_left, + OptionT.support_OptionT_pure_run, exists_eq_left, Option.some.injEq, + Prod.mk.injEq] at h_verOut_mem_support + rcases h_verOut_mem_support with ⟨verStmtOut_eq, verOStmtOut_eq⟩ + obtain ⟨prvStmtOut_eq, prvOStmtOut_eq⟩ := h_prvOut_mem_support + constructor + · rw [verStmtOut_eq, verOStmtOut_eq]; + exact h_rel + · constructor + · rw [verStmtOut_eq, prvStmtOut_eq]; rfl + · rw [verOStmtOut_eq, prvOStmtOut_eq]; + exact h_agree.2 /-- RBR knowledge error for the final sumcheck step -/ def finalSumcheckKnowledgeError (m : pSpecFinalSumcheckStep (L := L).ChallengeIdx) : @@ -501,21 +1286,27 @@ noncomputable def finalSumcheckRbrExtractor : extractMid := fun m ⟨stmtMid, oStmtMid⟩ trSucc witMidSucc => by have hm : m = 0 := by omega subst hm + have _ : witMidSucc = () := by rfl -- Decode t from the first oracle f^(0) let f0 := getFirstOracle K β oStmtMid let polyOpt := extractMLP K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := ⟨0, by exact Nat.pos_of_neZero ℓ'⟩) (f := f0) + let H_constant : L⦃≤ 2⦄[X Fin (ℓ' - ↑(Fin.last ℓ'))] := ⟨MvPolynomial.C stmtMid.sumcheck_target, + by + simp only [Fin.val_last, mem_restrictDegree, MvPolynomial.mem_support_iff, + MvPolynomial.coeff_C, ne_eq, ite_eq_right_iff, Classical.not_imp, and_imp, forall_eq', + Finsupp.coe_zero, Pi.zero_apply, zero_le, implies_true]⟩ match polyOpt with - | none => -- NOTE, In proofs of toFun_next, this case would be eliminated - exact dummyLastWitness (L := L) K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + | none => + exact { + t := ⟨0, by apply zero_mem⟩, + H := H_constant, + f := fun _ => 0 + } | some tpoly => - -- Build H_ℓ from t and challenges r' exact { t := tpoly, - H := projectToMidSumcheckPoly (L := L) (ℓ := ℓ') (t := tpoly) - (m := (RingSwitching_SumcheckMultParam κ L K - (β := booleanHypercubeBasis κ L K β) ℓ ℓ' h_l).multpoly (ctx := stmtMid.ctx)) - (i := Fin.last ℓ') (challenges := stmtMid.challenges), + H := H_constant, f := getMidCodewords K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) tpoly stmtMid.challenges } extractOut := fun ⟨stmtIn, oStmtIn⟩ tr witOut => () @@ -531,27 +1322,28 @@ def finalSumcheckKStateProp {m : Fin (1 + 1)} (tr : Transcript m (pSpecFinalSumc (mp := RingSwitching_SumcheckMultParam κ L K (β := booleanHypercubeBasis κ L K β) ℓ ℓ' h_l) (stmtIdx := Fin.last ℓ') (oracleIdx := OracleFrontierIndex.mkFromStmtIdx (Fin.last ℓ')) - (stmt := stmt) (wit := witMid) (oStmt := oStmt) (localChecks := True) + (stmt := stmt) (wit := witMid) (oStmt := oStmt) + (localChecks := sumcheckConsistencyProp (𝓑 := 𝓑) stmt.sumcheck_target witMid.H) | ⟨1, _⟩ => -- implied by relOut + local checks via extractOut proofs let tr_so_far := (pSpecFinalSumcheckStep (L := L)).take 1 (by omega) let i_msg0 : tr_so_far.MessageIdx := ⟨⟨0, by omega⟩, rfl⟩ let s' : L := (ProtocolSpec.Transcript.equivMessagesChallenges (k := 1) (pSpec := pSpecFinalSumcheckStep (L := L)) tr).1 i_msg0 let stmtOut : BinaryBasefold.FinalSumcheckStatementOut (L:=L) (ℓ:=ℓ') := { - -- Dummy unused values + -- **Dummy UNUSED values** ctx := { t_eval_point := 0, original_claim := 0 }, sumcheck_target := 0, - -- Only the last two fields are used in finalNonDoomedFoldingProp + -- **ONLY the last two fields are used in finalSumcheckStepFoldingStateProp** challenges := stmt.challenges, final_constant := s' } let sumcheckFinalCheck : Prop := stmt.sumcheck_target = compute_final_eq_value κ L K (β := booleanHypercubeBasis κ L K β) ℓ ℓ' h_l stmt.ctx.t_eval_point stmt.challenges stmt.ctx.r_batching * s' - let finalFoldingProp := finalFoldingStateProp K β (ϑ := ϑ) + let finalFoldingProp := finalSumcheckStepFoldingStateProp K β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (h_le := by apply Nat.le_of_dvd; · exact Nat.pos_of_neZero ℓ' @@ -561,40 +1353,238 @@ def finalSumcheckKStateProp {m : Fin (1 + 1)} (tr : Transcript m (pSpecFinalSumc /-- The knowledge state function for the final sumcheck step -/ noncomputable def finalSumcheckKnowledgeStateFunction {σ : Type} (init : ProbComp σ) (impl : QueryImpl []ₒ (StateT σ ProbComp)) : - (finalSumcheckVerifier κ L K β ℓ ℓ' 𝓡 ϑ h_ℓ_add_R_rate h_l).KnowledgeStateFunction init impl + (finalSumcheckVerifier κ L K β ℓ ℓ' 𝓡 ϑ h_ℓ_add_R_rate h_l + (𝓑 := 𝓑)).KnowledgeStateFunction init impl (relIn := roundRelation K β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) (mp := RingSwitching_SumcheckMultParam κ L K (β := booleanHypercubeBasis κ L K β) ℓ ℓ' h_l) (Fin.last ℓ')) (relOut := BinaryBasefold.finalSumcheckRelOut K β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) - (extractor := finalSumcheckRbrExtractor κ L K β ℓ ℓ' 𝓡 ϑ h_ℓ_add_R_rate h_l) + (extractor := finalSumcheckRbrExtractor κ L K β ℓ ℓ' 𝓡 ϑ h_ℓ_add_R_rate) where toFun := fun m ⟨stmt, oStmt⟩ tr witMid => finalSumcheckKStateProp κ L K β ℓ ℓ' 𝓡 ϑ h_ℓ_add_R_rate h_l (tr := tr) (stmt := stmt) (witMid := witMid) (oStmt := oStmt) toFun_empty := fun stmt witMid => by - -- sumcheck consistency is trivial since there is no variables left - sorry - toFun_next := fun m hDir stmt tr msg witMid h => by - -- Either bad events exist, or (oracleFoldingConsistency is true so - -- the extractor can construct a satisfying witness) - sorry - toFun_full := fun stmt tr witOut h => by - sorry + rw [cast_eq] + rfl + toFun_next := fun m hDir (stmtIn, oStmtIn) tr msg witMid => by + have h_m_eq_0 : m = 0 := by + cases m using Fin.cases with + | zero => rfl + | succ m' => omega + subst h_m_eq_0 + simp only [Fin.isValue, Fin.succ_zero_eq_one, Fin.castSucc_zero] + -- In the single-message final sumcheck step, the new message `msg` *is* the final constant. + -- We use it directly rather than reconstructing a truncated transcript. + let s' : L := msg + let stmtOut : BinaryBasefold.FinalSumcheckStatementOut (L:=L) (ℓ:=ℓ') := { + ctx := { + t_eval_point := 0, + original_claim := 0 + }, + sumcheck_target := 0, + challenges := stmtIn.challenges, + final_constant := s' + } + intro h_kState_round1 + unfold finalSumcheckKStateProp BinaryBasefold.finalSumcheckStepFoldingStateProp + BinaryBasefold.masterKStateProp at h_kState_round1 ⊢ + simp only [Fin.isValue] at h_kState_round1 + obtain ⟨h_sumcheckFinalCheck, h_core⟩ := h_kState_round1 + -- Option-B shape at m=0: + -- incremental bad-event ∨ (local ∧ structural ∧ initial ∧ oracleFoldingConsistency). + cases h_core with + | inl hConsistent => + have ⟨tpoly, h_extractMLP⟩ := + BinaryBasefold.CoreInteraction.extractMLP_some_of_oracleFoldingConsistency K β + (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) stmtOut oStmtIn hConsistent + refine Or.inr ?_ + refine ⟨?_, ?_, ?_, ?_⟩ + · -- local sumcheck consistency at m=0 + unfold finalSumcheckRbrExtractor sumcheckConsistencyProp + simp only [Fin.val_last, Fin.mk_zero', Fin.coe_ofNat_eq_mod] + split + · simp only [MvPolynomial.eval_C, sum_const, Fintype.card_piFinset, card_map, card_univ, + Fintype.card_fin, prod_const, tsub_self, Fintype.card_eq_zero, pow_zero, one_smul] + · simp only [MvPolynomial.eval_C, sum_const, Fintype.card_piFinset, card_map, card_univ, + Fintype.card_fin, prod_const, tsub_self, Fintype.card_eq_zero, pow_zero, one_smul] + · -- witnessStructuralInvariant for extracted witness + unfold finalSumcheckRbrExtractor BinaryBasefold.witnessStructuralInvariant + simp only [Fin.val_last, Fin.mk_zero', h_extractMLP, Fin.coe_ofNat_eq_mod, and_true] + refine SetLike.coe_eq_coe.mp ?_ + rw [projectToMidSumcheckPoly_at_last_eq] + have h_s'_eq : s' = tpoly.val.eval stmtIn.challenges := by + exact BinaryBasefold.CoreInteraction.extracted_t_poly_eval_eq_final_constant K β + (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (oStmtOut := oStmtIn) (stmtOut := stmtOut) + (tpoly := tpoly) (h_extractMLP := h_extractMLP) + (h_finalSumcheckStepOracleConsistency := hConsistent) + have h_mult_eq : (MvPolynomial.eval stmtIn.challenges + ((RingSwitching_SumcheckMultParam κ L K + (β := booleanHypercubeBasis κ L K β) ℓ ℓ' h_l).multpoly stmtIn.ctx).val) = + compute_final_eq_value κ L K (β := booleanHypercubeBasis κ L K β) ℓ ℓ' h_l + stmtIn.ctx.t_eval_point stmtIn.challenges stmtIn.ctx.r_batching := + compute_A_MLE_eval_eq_final_eq_value κ L K (β := booleanHypercubeBasis κ L K β) + ℓ ℓ' h_l stmtIn.ctx.t_eval_point stmtIn.challenges stmtIn.ctx.r_batching + have h_sumcheck_target_eq : stmtIn.sumcheck_target = + (MvPolynomial.eval stmtIn.challenges + ((RingSwitching_SumcheckMultParam κ L K + (β := booleanHypercubeBasis κ L K β) ℓ ℓ' h_l).multpoly stmtIn.ctx).val) * + (MvPolynomial.eval stmtIn.challenges tpoly.val) := by + calc + stmtIn.sumcheck_target + = compute_final_eq_value κ L K (β := booleanHypercubeBasis κ L K β) ℓ ℓ' h_l + stmtIn.ctx.t_eval_point stmtIn.challenges stmtIn.ctx.r_batching * s' := + h_sumcheckFinalCheck + _ = compute_final_eq_value κ L K (β := booleanHypercubeBasis κ L K β) ℓ ℓ' h_l + stmtIn.ctx.t_eval_point stmtIn.challenges stmtIn.ctx.r_batching * + (MvPolynomial.eval stmtIn.challenges tpoly.val) := by + rw [h_s'_eq] + _ = (MvPolynomial.eval stmtIn.challenges + ((RingSwitching_SumcheckMultParam κ L K + (β := booleanHypercubeBasis κ L K β) ℓ ℓ' h_l).multpoly stmtIn.ctx).val) * + (MvPolynomial.eval stmtIn.challenges tpoly.val) := by + rw [h_mult_eq] + simp only [h_sumcheck_target_eq, Fin.val_last, Fin.coe_ofNat_eq_mod, MvPolynomial.C_mul] + · -- initial compatibility via first-oracle consistency + dsimp only [finalSumcheckRbrExtractor, BinaryBasefold.firstOracleWitnessConsistencyProp] + simp only [Fin.mk_zero', h_extractMLP, Fin.coe_ofNat_eq_mod, Fin.val_last, + OracleFrontierIndex.val_mkFromStmtIdx] + exact (extractMLP_eq_some_iff_pair_UDRClose K β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (f := getFirstOracle K β oStmtIn) (tpoly := tpoly)).mp h_extractMLP + · exact hConsistent.1 + | inr hBad => + -- Convert terminal block bad-event to incremental bad-event. + exact Or.inl ( + (BinaryBasefold.badEventExistsProp_iff_incrementalBadEventExistsProp_last K β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ϑ := ϑ) + (oStmt := oStmtIn) (challenges := stmtIn.challenges)).1 hBad + ) + toFun_full := fun ⟨stmtIn, oStmtIn⟩ tr witOut probEvent_relOut_gt_0 => by + -- Same pattern as relay: verifier output (stmtOut, oStmtOut) + h_relOut ⇒ commitKStateProp 1 + simp only [StateT.run'_eq, gt_iff_lt, probEvent_pos_iff, Prod.exists] at probEvent_relOut_gt_0 + rcases probEvent_relOut_gt_0 with ⟨stmtOut, oStmtOut, h_output_mem_V_run_support, h_relOut⟩ + have h_output_mem_V_run_support' : + some (stmtOut, oStmtOut) ∈ + (do + let s ← init + Prod.fst <$> + (simulateQ impl + (Verifier.run (stmtIn, oStmtIn) tr + (finalSumcheckVerifier κ L K β ℓ ℓ' 𝓡 ϑ h_ℓ_add_R_rate h_l + (𝓑 := 𝓑)).toVerifier)).run s).support := by + exact (OptionT.mem_support_iff + (mx := OptionT.mk (do + let s ← init + Prod.fst <$> + (simulateQ impl + (Verifier.run (stmtIn, oStmtIn) tr + (finalSumcheckVerifier κ L K β ℓ ℓ' 𝓡 ϑ h_ℓ_add_R_rate h_l + (𝓑 := 𝓑)).toVerifier)).run s)) + (x := (stmtOut, oStmtOut))).1 h_output_mem_V_run_support + simp only [support_bind, Set.mem_iUnion, exists_prop] at h_output_mem_V_run_support' + rcases h_output_mem_V_run_support' with ⟨s, hs_init, h_output_mem_V_run_support⟩ + conv at h_output_mem_V_run_support => -- same as fold step + simp only [Verifier.run, OracleVerifier.toVerifier] + -- Now unfold the foldOracleVerifier's `verify()` method + simp only [finalSumcheckVerifier] + -- dsimp only [StateT.run] + -- simp only [simulateQ_bind, simulateQ_query, simulateQ_pure] + -- oracle query unfolding + simp only [support_bind, Set.mem_iUnion] + dsimp only [StateT.run] + -- enter [1, i_1, 2, 1, x] + simp only [simulateQ_bind] + --------------------------------------- + -- Now simplify the `guard` and `ite` of StateT.map generated from it + simp only [MessageIdx, Fin.isValue, Matrix.cons_val_zero, simulateQ_pure, Message, guard_eq, + pure_bind, Function.comp_apply, simulateQ_map, simulateQ_ite, + OptionT.simulateQ_failure', bind_map_left] + simp only [MessageIdx, Message, Fin.isValue, Matrix.cons_val_zero, Matrix.cons_val_one, + bind_pure_comp, simulateQ_map, simulateQ_ite, simulateQ_pure, OptionT.simulateQ_failure', + bind_map_left, Function.comp_apply] + simp only [support_ite] + simp only [Fin.isValue, Set.mem_ite_empty_right, Set.mem_singleton_iff, Prod.mk.injEq, + exists_and_left, exists_eq', exists_eq_right, exists_and_right] + simp only [Fin.isValue, id_eq, FullTranscript.mk1_eq_snoc, support_map, Set.mem_image, + Prod.exists, exists_and_right, exists_eq_right] + erw [simulateQ_bind] + enter [1, x, 1, 1, 1, 2]; + erw [simulateQ_bind] + erw [OptionT.simulateQ_simOracle2_liftM_query_T2] + simp only [Fin.isValue, FullTranscript.mk1_eq_snoc, pure_bind, OptionT.simulateQ_map] + conv at h_output_mem_V_run_support => + simp only [Fin.isValue, FullTranscript.mk1_eq_snoc, Function.comp_apply] + erw [support_bind] at h_output_mem_V_run_support + let step := (finalSumcheckStepLogic κ L K β ℓ ℓ' 𝓡 ϑ h_ℓ_add_R_rate h_l (𝓑 := 𝓑)) + set V_check := step.verifierCheck stmtIn + (FullTranscript.mk1 (msg0 := _)) with h_V_check_def + by_cases h_V_check : V_check + · + simp only [Fin.isValue, h_V_check, ↓reduceIte, OptionT.run_pure, simulateQ_pure, + Set.mem_iUnion, exists_prop, Prod.exists] at h_output_mem_V_run_support + erw [simulateQ_bind] at h_output_mem_V_run_support + simp only [simulateQ_pure, Fin.isValue, Function.comp_apply, + pure_bind] at h_output_mem_V_run_support + erw [support_pure] at h_output_mem_V_run_support + simp only [Set.mem_singleton_iff, Prod.mk.injEq, ↓existsAndEq, and_true, exists_eq_left, + simulateQ_pure] at h_output_mem_V_run_support + erw [support_pure] at h_output_mem_V_run_support + simp only [Fin.isValue, Set.mem_singleton_iff, Prod.mk.injEq, Option.some.injEq, + exists_eq_right] at h_output_mem_V_run_support + rcases h_output_mem_V_run_support with ⟨h_stmtOut_eq, h_oStmtOut_eq⟩ + simp only [Fin.reduceLast, Fin.isValue] + -- h_relOut : ((stmtOut, oStmtOut), witOut) ∈ roundRelation 𝔽q β i.succ + simp only [finalSumcheckRelOut, finalSumcheckRelOutProp, Set.mem_setOf_eq] at h_relOut + -- Goal: commitKStateProp 1 stmtIn oStmtIn tr witOut + unfold finalSumcheckKStateProp + -- Unfold the sendMessage, receiveChallenge, output logic of prover + dsimp only + -- stmtOut = stmtIn; need oStmtOut = snoc_oracle oStmtIn witOut.f so goal matches h_relOut + simp only [h_stmtOut_eq] at h_relOut ⊢ + have h_oStmtOut_eq_oStmtIn : oStmtOut = oStmtIn := by rw [h_oStmtOut_eq]; rfl + -- c equals tr.messages ⟨0, rfl⟩ + constructor + · -- First conjunct: sumcheck_target = eqTilde r challenges * c + exact h_V_check + · -- Second conjunct: finalSumcheckStepFoldingStateProp + -- ({ toStatement := stmtIn, final_constant := c }, oStmtIn) + rw [h_oStmtOut_eq_oStmtIn] at h_relOut + exact h_relOut + · + simp only [Fin.isValue, h_V_check, ↓reduceIte, OptionT.run_failure, simulateQ_pure, + Set.mem_iUnion, exists_prop, Prod.exists] at h_output_mem_V_run_support + erw [simulateQ_bind] at h_output_mem_V_run_support + simp only [simulateQ_pure, Fin.isValue, Function.comp_apply, + pure_bind] at h_output_mem_V_run_support + erw [support_pure] at h_output_mem_V_run_support + simp only [Set.mem_singleton_iff, Prod.mk.injEq, ↓existsAndEq, and_true, exists_eq_left, + simulateQ_pure] at h_output_mem_V_run_support + erw [support_pure] at h_output_mem_V_run_support + simp only [Set.mem_singleton_iff, Prod.mk.injEq, reduceCtorEq, false_and, + exists_false] at h_output_mem_V_run_support -- False /-- Round-by-round knowledge soundness for the final sumcheck step -/ theorem finalSumcheckOracleVerifier_rbrKnowledgeSoundness [Fintype L] {σ : Type} (init : ProbComp σ) (impl : QueryImpl []ₒ (StateT σ ProbComp)) : - (finalSumcheckVerifier κ L K β ℓ ℓ' 𝓡 ϑ h_ℓ_add_R_rate h_l).rbrKnowledgeSoundness init impl + (finalSumcheckVerifier κ L K β ℓ ℓ' 𝓡 ϑ h_ℓ_add_R_rate h_l + (𝓑 := 𝓑)).rbrKnowledgeSoundness init impl (relIn := roundRelation K β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) (mp := RingSwitching_SumcheckMultParam κ L K (β := booleanHypercubeBasis κ L K β) ℓ ℓ' h_l) (Fin.last ℓ')) (relOut := BinaryBasefold.finalSumcheckRelOut K β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) (rbrKnowledgeError := finalSumcheckKnowledgeError L) := by use FinalSumcheckWit κ (L := L) K β ℓ' 𝓡 (h_ℓ_add_R_rate := h_ℓ_add_R_rate) - use finalSumcheckRbrExtractor κ L K β ℓ ℓ' 𝓡 ϑ h_ℓ_add_R_rate h_l + use finalSumcheckRbrExtractor κ L K β ℓ ℓ' 𝓡 ϑ h_ℓ_add_R_rate use finalSumcheckKnowledgeStateFunction κ L K β ℓ ℓ' 𝓡 ϑ h_ℓ_add_R_rate h_l init impl intro stmtIn witIn prover j - sorry + rcases j with ⟨j, hj⟩ + cases j using Fin.cases with + | zero => + simp only [pSpecFinalSumcheckStep, ne_eq, reduceCtorEq, not_false_eq_true, Fin.isValue, + Matrix.cons_val_fin_one, Direction.not_P_to_V_eq_V_to_P] at hj + | succ j' => + exact Fin.elim0 j' end FinalSumcheckStep @@ -614,9 +1604,10 @@ def coreInteractionOracleVerifier := (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ (Fin.last ℓ')) (pSpec₁ := BinaryBasefold.pSpecSumcheckFold K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) (pSpec₂ := pSpecFinalSumcheckStep (L:=L)) - (V₁ := sumcheckFoldOracleVerifier κ (L := L) (K := K) (β := β) (ℓ := ℓ) (ℓ' := ℓ') (h_l := h_l) + (V₁ := sumcheckFoldOracleVerifier κ (L := L) (K := K) (β := β) (ℓ := ℓ) (ℓ' := ℓ') + (h_l := h_l) (𝓑 := 𝓑) (𝓡 := 𝓡) (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) - (V₂ := finalSumcheckVerifier κ L K β ℓ ℓ' 𝓡 ϑ h_ℓ_add_R_rate h_l) + (V₂ := finalSumcheckVerifier κ L K β ℓ ℓ' 𝓡 ϑ h_ℓ_add_R_rate h_l (𝓑 := 𝓑)) /-- The final oracle reduction that composes sumcheckFold with finalSumcheckStep -/ @[reducible] @@ -637,7 +1628,7 @@ def coreInteractionOracleReduction := (pSpec₂ := BinaryBasefold.pSpecFinalSumcheckStep (L:=L)) (R₁ := sumcheckFoldOracleReduction κ L K β ℓ ℓ' 𝓡 ϑ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) h_l (𝓑 := 𝓑)) - (R₂ := finalSumcheckOracleReduction κ L K β ℓ ℓ' 𝓡 ϑ h_ℓ_add_R_rate h_l) + (R₂ := finalSumcheckOracleReduction κ L K β ℓ ℓ' 𝓡 ϑ h_ℓ_add_R_rate h_l (𝓑 := 𝓑)) variable {σ : Type} {init : ProbComp σ} {impl : QueryImpl []ₒ (StateT σ ProbComp)} @@ -649,24 +1640,30 @@ theorem coreInteractionOracleReduction_perfectCompleteness (hInit : NeverFail in (OStmtIn := BinaryBasefold.OracleStatement K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ 0) (OStmtOut := BinaryBasefold.OracleStatement K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ (Fin.last ℓ')) - (relIn := RingSwitching.sumcheckRoundRelation κ (L := L) (K := K) + (relIn := RingSwitching.strictSumcheckRoundRelation κ (L := L) (K := K) (β := booleanHypercubeBasis κ L K β) ℓ ℓ' h_l (𝓑 := 𝓑) - (aOStmtIn := BinaryBasefoldAbstractOStmtIn κ L K β ℓ' - 𝓡 ϑ (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) 0) + (aOStmtIn := BinaryBasefoldAbstractOStmtIn + (κ := κ) (L := L) (K := K) (β := β) + (ℓ' := ℓ') (𝓡 := 𝓡) (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) 0) (relOut := BinaryBasefold.strictFinalSumcheckRelOut K β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) (oracleReduction := coreInteractionOracleReduction κ L K β ℓ ℓ' 𝓡 ϑ (h_ℓ_add_R_rate := h_ℓ_add_R_rate) h_l (𝓑 := 𝓑)) (init := init) (impl := impl) := by - unfold coreInteractionOracleReduction pSpecCoreInteraction + unfold coreInteractionOracleReduction Binius.BinaryBasefold.pSpecCoreInteraction apply OracleReduction.append_perfectCompleteness + (rel₂ := (strictRoundRelation K (β := β) (i := Fin.last ℓ'))) + (Wit₁ := (SumcheckWitness L ℓ' 0)) + (Wit₂ := (Witness K (β := β) (i := Fin.last ℓ'))) + (Wit₃ := Unit) · -- Perfect completeness of sumcheckFoldOracleReduction - exact sumcheckFoldOracleReduction_perfectCompleteness (hInit:=hInit) κ L K β ℓ ℓ' 𝓡 ϑ - (h_ℓ_add_R_rate := h_ℓ_add_R_rate) h_l (𝓑 := 𝓑) (init := init) (impl := impl) + exact sumcheckFoldOracleReduction_perfectCompleteness κ L K β ℓ ℓ' 𝓡 ϑ + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) h_l (𝓑 := 𝓑) (init := init) (hInit:=hInit) (impl := impl) · -- Perfect completeness of finalSumcheckOracleReduction exact finalSumcheckOracleReduction_perfectCompleteness κ L K β ℓ ℓ' 𝓡 ϑ - (h_ℓ_add_R_rate := h_ℓ_add_R_rate) h_l (𝓑 := 𝓑) init impl + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) h_l (𝓑 := 𝓑) init (hInit := hInit) impl def coreInteractionOracleRbrKnowledgeError (j : (BinaryBasefold.pSpecCoreInteraction K β (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate)).ChallengeIdx) : ℝ≥0 := @@ -678,15 +1675,18 @@ def coreInteractionOracleRbrKnowledgeError (j : (BinaryBasefold.pSpecCoreInterac /-- Round-by-round knowledge soundness for the core interaction oracle verifier -/ theorem coreInteractionOracleVerifier_rbrKnowledgeSoundness : - (coreInteractionOracleVerifier κ (L := L) (K := K) (β := β) (ℓ := ℓ) (ℓ' := ℓ') (h_l := h_l) - (𝓡 := 𝓡) (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑)).rbrKnowledgeSoundness init impl + (coreInteractionOracleVerifier κ (L := L) (K := K) (β := β) (ℓ := ℓ) (ℓ' := ℓ') + (h_l := h_l) (𝓡 := 𝓡) (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (𝓑 := 𝓑)).rbrKnowledgeSoundness init impl (OStmtIn := BinaryBasefold.OracleStatement K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ 0) (OStmtOut := BinaryBasefold.OracleStatement K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ (Fin.last ℓ')) (pSpec := BinaryBasefold.pSpecCoreInteraction K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) (relIn := RingSwitching.sumcheckRoundRelation κ L K (booleanHypercubeBasis κ L K β) - ℓ ℓ' h_l (𝓑 := 𝓑) (aOStmtIn := BinaryBasefoldAbstractOStmtIn κ L K β ℓ' - 𝓡 ϑ (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) 0) + ℓ ℓ' h_l (𝓑 := 𝓑) (aOStmtIn := BinaryBasefoldAbstractOStmtIn + (κ := κ) (L := L) (K := K) (β := β) + (ℓ' := ℓ') (𝓡 := 𝓡) (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) 0) (relOut := BinaryBasefold.finalSumcheckRelOut K β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) (rbrKnowledgeError := coreInteractionOracleRbrKnowledgeError κ L K β ℓ' 𝓡 ϑ (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) := by @@ -698,16 +1698,22 @@ theorem coreInteractionOracleVerifier_rbrKnowledgeSoundness : (OStmt₃ := BinaryBasefold.OracleStatement K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ (Fin.last ℓ')) (init := init) (impl:=impl) + (Wit₁ := (SumcheckWitness L ℓ' 0)) + (Wit₂ := (Witness K (β := β) (i := Fin.last ℓ'))) + (Wit₃ := Unit) (rel₁ := RingSwitching.sumcheckRoundRelation κ L K (booleanHypercubeBasis κ L K β) - ℓ ℓ' h_l (𝓑 := 𝓑) (aOStmtIn := BinaryBasefoldAbstractOStmtIn κ L K β ℓ' - 𝓡 ϑ (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) 0) + ℓ ℓ' h_l (𝓑 := 𝓑) (aOStmtIn := BinaryBasefoldAbstractOStmtIn + (κ := κ) (L := L) (K := K) (β := β) + (ℓ' := ℓ') (𝓡 := 𝓡) (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) 0) (rel₂ := BinaryBasefold.roundRelation (mp := RingSwitching_SumcheckMultParam κ L K (β := booleanHypercubeBasis κ L K β) ℓ ℓ' h_l) K β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑:=𝓑) (Fin.last ℓ')) (rel₃ := finalSumcheckRelOut K β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) - (V₁ := sumcheckFoldOracleVerifier κ (L := L) (K := K) (β := β) (ℓ := ℓ) (ℓ' := ℓ') (h_l := h_l) + (V₁ := sumcheckFoldOracleVerifier κ (L := L) (K := K) (β := β) (ℓ := ℓ) (ℓ' := ℓ') + (h_l := h_l) (𝓑 := 𝓑) (𝓡 := 𝓡) (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) - (V₂ := finalSumcheckVerifier κ L K β ℓ ℓ' 𝓡 ϑ h_ℓ_add_R_rate h_l) + (V₂ := finalSumcheckVerifier κ L K β ℓ ℓ' 𝓡 ϑ h_ℓ_add_R_rate h_l (𝓑 := 𝓑)) (Oₛ₃:=by exact fun i ↦ by exact OracleInterface.instDefault) (rbrKnowledgeError₁ := BinaryBasefold.CoreInteraction.sumcheckFoldKnowledgeError K β (ϑ := ϑ)) diff --git a/ArkLib/ProofSystem/Binius/FRIBinius/General.lean b/ArkLib/ProofSystem/Binius/FRIBinius/General.lean index 27cb79828..f2a38458b 100644 --- a/ArkLib/ProofSystem/Binius/FRIBinius/General.lean +++ b/ArkLib/ProofSystem/Binius/FRIBinius/General.lean @@ -78,10 +78,12 @@ def batchingCoreVerifier := OracleVerifier.append (oSpec:=[]ₒ) (V₁:= RingSwitching.BatchingPhase.batchingOracleVerifier κ (L := L) (K := K) (𝓑 := 𝓑) (β:=booleanHypercubeBasis κ L K β) ℓ ℓ' h_l - (aOStmtIn := BinaryBasefoldAbstractOStmtIn κ L K β ℓ' 𝓡 ϑ h_ℓ_add_R_rate)) + (aOStmtIn := (BinaryBasefoldAbstractOStmtIn (β := β) (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate)))) (pSpec₁ := RingSwitching.pSpecBatching κ L K) (pSpec₂:=BinaryBasefold.pSpecCoreInteraction K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) - (OStmt₁ := (BinaryBasefoldAbstractOStmtIn κ L K β ℓ' 𝓡 ϑ h_ℓ_add_R_rate).OStmtIn) + (OStmt₁ := (BinaryBasefoldAbstractOStmtIn (β := β) + (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate)).OStmtIn) (OStmt₂ := BinaryBasefold.OracleStatement K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ 0) (OStmt₃ := BinaryBasefold.OracleStatement K β @@ -94,10 +96,12 @@ def batchingCoreReduction := OracleReduction.append (oSpec:=[]ₒ) (R₁ := RingSwitching.BatchingPhase.batchingOracleReduction κ (L := L) (K := K) (𝓑 := 𝓑) (β:=booleanHypercubeBasis κ L K β) ℓ ℓ' h_l - (aOStmtIn := BinaryBasefoldAbstractOStmtIn κ L K β ℓ' 𝓡 ϑ h_ℓ_add_R_rate)) + (aOStmtIn := (BinaryBasefoldAbstractOStmtIn (β := β) + (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate)))) (pSpec₁ := RingSwitching.pSpecBatching κ L K) (pSpec₂:=BinaryBasefold.pSpecCoreInteraction K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) - (OStmt₁ := (BinaryBasefoldAbstractOStmtIn κ L K β ℓ' 𝓡 ϑ h_ℓ_add_R_rate).OStmtIn) + (OStmt₁ := (BinaryBasefoldAbstractOStmtIn (β := β) + (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate)).OStmtIn) (OStmt₂ := BinaryBasefold.OracleStatement K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ 0) (OStmt₃ := BinaryBasefold.OracleStatement K β @@ -110,7 +114,8 @@ def batchingCoreReduction := noncomputable def fullOracleVerifier : OracleVerifier (oSpec:=[]ₒ) (StmtIn := BatchingStmtIn (L := L) (ℓ:=ℓ)) - (OStmtIn := (BinaryBasefoldAbstractOStmtIn κ L K β ℓ' 𝓡 ϑ h_ℓ_add_R_rate).OStmtIn) + (OStmtIn := (BinaryBasefoldAbstractOStmtIn (β := β) + (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate)).OStmtIn) (StmtOut := Bool) (OStmtOut := fun _ : Empty => Unit) (pSpec := fullPspec κ L K β ℓ' 𝓡 ϑ γ_repetitions (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) := @@ -118,7 +123,8 @@ noncomputable def fullOracleVerifier : (Stmt₁ := BatchingStmtIn (L := L) (ℓ:=ℓ)) (Stmt₂ := BinaryBasefold.FinalSumcheckStatementOut (L:=L) (ℓ:=ℓ')) (Stmt₃ := Bool) - (OStmt₁ := (BinaryBasefoldAbstractOStmtIn κ L K β ℓ' 𝓡 ϑ h_ℓ_add_R_rate).OStmtIn) + (OStmt₁ := (BinaryBasefoldAbstractOStmtIn (β := β) + (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate)).OStmtIn) (OStmt₂ := BinaryBasefold.OracleStatement K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ (Fin.last ℓ')) (OStmt₃ := fun _ : Empty => Unit) @@ -134,7 +140,8 @@ noncomputable def fullOracleVerifier : noncomputable def fullOracleReduction : OracleReduction (oSpec:=[]ₒ) (StmtIn := BatchingStmtIn (L := L) (ℓ:=ℓ)) - (OStmtIn := (BinaryBasefoldAbstractOStmtIn κ L K β ℓ' 𝓡 ϑ h_ℓ_add_R_rate).OStmtIn) + (OStmtIn := (BinaryBasefoldAbstractOStmtIn (β := β) + (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate)).OStmtIn) (StmtOut := Bool) (OStmtOut := fun _ : Empty => Unit) (WitIn := BatchingWitIn L K ℓ ℓ') @@ -147,7 +154,8 @@ noncomputable def fullOracleReduction : (Wit₁ := BatchingWitIn L K ℓ ℓ') (Wit₂ := Unit) (Wit₃ := Unit) - (OStmt₁ := (BinaryBasefoldAbstractOStmtIn κ L K β ℓ' 𝓡 ϑ h_ℓ_add_R_rate).OStmtIn) + (OStmt₁ := (BinaryBasefoldAbstractOStmtIn (β := β) + (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate)).OStmtIn) (OStmt₂ := BinaryBasefold.OracleStatement K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ (Fin.last ℓ')) (OStmt₃ := fun _ : Empty => Unit) @@ -163,7 +171,8 @@ noncomputable def fullOracleReduction : noncomputable def fullOracleProof : OracleProof []ₒ (Statement := BatchingStmtIn (L := L) (ℓ:=ℓ)) - (OStatement := (BinaryBasefoldAbstractOStmtIn κ L K β ℓ' 𝓡 ϑ h_ℓ_add_R_rate).OStmtIn) + (OStatement := (BinaryBasefoldAbstractOStmtIn (β := β) + (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate)).OStmtIn) (Witness := BatchingWitIn L K ℓ ℓ') (pSpec:= fullPspec κ L K β ℓ' 𝓡 ϑ γ_repetitions (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) := fullOracleReduction κ L K β ℓ ℓ' 𝓡 ϑ γ_repetitions (h_ℓ_add_R_rate := h_ℓ_add_R_rate) h_l (𝓑:=𝓑) @@ -179,8 +188,9 @@ theorem fullOracleReduction_perfectCompleteness (hInit : NeverFail init) : OracleReduction.perfectCompleteness (oracleReduction := fullOracleReduction κ L K β ℓ ℓ' 𝓡 ϑ γ_repetitions (h_ℓ_add_R_rate := h_ℓ_add_R_rate) h_l (𝓑:=𝓑)) - (relIn := BatchingPhase.batchingInputRelation κ L K (β:=booleanHypercubeBasis κ L K β) - ℓ ℓ' h_l (BinaryBasefoldAbstractOStmtIn κ L K β ℓ' 𝓡 ϑ h_ℓ_add_R_rate)) + (relIn := BatchingPhase.strictBatchingInputRelation κ L K (β:=booleanHypercubeBasis κ L K β) + ℓ ℓ' h_l (BinaryBasefoldAbstractOStmtIn (β := β) + (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate))) (relOut := acceptRejectOracleRel) (init := init) (impl := impl) := @@ -188,32 +198,37 @@ theorem fullOracleReduction_perfectCompleteness (hInit : NeverFail init) : (R₁ := batchingCoreReduction κ L K β ℓ ℓ' 𝓡 ϑ h_ℓ_add_R_rate h_l (𝓑 := 𝓑)) (R₂ := QueryPhase.queryOracleReduction K β γ_repetitions (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ϑ:=ϑ)) - (OStmt₁ := (BinaryBasefoldAbstractOStmtIn κ L K β ℓ' 𝓡 ϑ h_ℓ_add_R_rate).OStmtIn) + (OStmt₁ := (BinaryBasefoldAbstractOStmtIn (β := β) + (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate)).OStmtIn) (OStmt₂ := BinaryBasefold.OracleStatement K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ (Fin.last ℓ')) (OStmt₃ := fun _ : Empty => Unit) - (Oₛ₁:= (BinaryBasefoldAbstractOStmtIn κ L K β ℓ' 𝓡 ϑ h_ℓ_add_R_rate).Oₛᵢ) + (Oₛ₁:= (BinaryBasefoldAbstractOStmtIn (β := β) + (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate)).Oₛᵢ) (Oₛ₂:=Binius.BinaryBasefold.instOracleStatementBinaryBasefold K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ϑ := ϑ) (i := Fin.last ℓ')) (Oₛ₃:=by exact fun i ↦ by exact OracleInterface.instDefault) (pSpec₁ := batchingCorePspec κ L K β ℓ' 𝓡 ϑ h_ℓ_add_R_rate) (pSpec₂ := BinaryBasefold.pSpecQuery K β γ_repetitions (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) - (rel₁ := BatchingPhase.batchingInputRelation κ L K (β:=booleanHypercubeBasis κ L K β) - ℓ ℓ' h_l (BinaryBasefoldAbstractOStmtIn κ L K β ℓ' 𝓡 ϑ h_ℓ_add_R_rate)) + (rel₁ := BatchingPhase.strictBatchingInputRelation κ L K (β:=booleanHypercubeBasis κ L K β) + ℓ ℓ' h_l (BinaryBasefoldAbstractOStmtIn (β := β) + (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate))) (rel₂ := BinaryBasefold.strictFinalSumcheckRelOut K β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) (rel₃ := acceptRejectOracleRel) (h₁ := by apply OracleReduction.append_perfectCompleteness - (rel₁ := BatchingPhase.batchingInputRelation κ L K (β:=booleanHypercubeBasis κ L K β) - ℓ ℓ' h_l (BinaryBasefoldAbstractOStmtIn κ L K β ℓ' 𝓡 ϑ h_ℓ_add_R_rate)) - (rel₂ := RingSwitching.sumcheckRoundRelation κ L K (booleanHypercubeBasis κ L K β) - ℓ ℓ' h_l (𝓑 := 𝓑) (aOStmtIn := BinaryBasefoldAbstractOStmtIn κ L K β ℓ' - 𝓡 ϑ (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) 0) + (rel₁ := BatchingPhase.strictBatchingInputRelation κ L K (β:=booleanHypercubeBasis κ L K β) + ℓ ℓ' h_l (BinaryBasefoldAbstractOStmtIn (β := β) + (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate))) + (rel₂ := RingSwitching.strictSumcheckRoundRelation κ L K (booleanHypercubeBasis κ L K β) + ℓ ℓ' h_l (𝓑 := 𝓑) (aOStmtIn := BinaryBasefoldAbstractOStmtIn (β := β) + (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) 0) (rel₃ := BinaryBasefold.strictFinalSumcheckRelOut K β (ϑ:=ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) · apply BatchingPhase.batchingReduction_perfectCompleteness (hInit:=hInit) κ L K (β:=booleanHypercubeBasis κ L K β) ℓ ℓ' h_l (𝓑 := 𝓑) - (BinaryBasefoldAbstractOStmtIn κ L K β ℓ' 𝓡 ϑ h_ℓ_add_R_rate) + (BinaryBasefoldAbstractOStmtIn (β := β) + (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) · apply CoreInteractionPhase.coreInteractionOracleReduction_perfectCompleteness κ (L := L) (K := K) (β := β) (ℓ := ℓ) (ℓ' := ℓ') (h_l := h_l) (𝓡 := 𝓡) (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) (hInit:=hInit) @@ -221,7 +236,104 @@ theorem fullOracleReduction_perfectCompleteness (hInit : NeverFail init) : (h₂ := QueryPhase.queryOracleProof_perfectCompleteness K β γ_repetitions (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ϑ:=ϑ) (hInit:=hInit) init impl) --- TODO: state RBR KS +open scoped NNReal + +/-- Combined RBR knowledge error for batching + core interaction. -/ +def batchingCoreRbrKnowledgeError + (i : (batchingCorePspec κ L K β ℓ' 𝓡 ϑ h_ℓ_add_R_rate).ChallengeIdx) : ℝ≥0 := + Sum.elim + (f := fun _ => RingSwitching.BatchingPhase.batchingRBRKnowledgeError (κ := κ) (L := L)) + (g := FRIBinius.CoreInteractionPhase.coreInteractionOracleRbrKnowledgeError + (κ := κ) (L := L) (K := K) (β := β) (ℓ' := ℓ') (𝓡 := 𝓡) (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) + (ChallengeIdx.sumEquiv.symm i) + +/-- Combined RBR knowledge error for full FRI-Binius. -/ +def fullRbrKnowledgeError + (i : (fullPspec κ L K β ℓ' 𝓡 ϑ γ_repetitions h_ℓ_add_R_rate).ChallengeIdx) : ℝ≥0 := + Sum.elim + (f := batchingCoreRbrKnowledgeError κ L K β ℓ' 𝓡 ϑ h_ℓ_add_R_rate) + (g := QueryPhase.queryRbrKnowledgeError K β γ_repetitions + (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) + (ChallengeIdx.sumEquiv.symm i) + +open FRIBinius.CoreInteractionPhase in +/-- Round-by-round knowledge soundness for the full FRI-Binius oracle verifier. -/ +theorem fullOracleVerifier_rbrKnowledgeSoundness : + (fullOracleVerifier κ L K β ℓ ℓ' 𝓡 ϑ γ_repetitions + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (h_l := h_l) (𝓑 := 𝓑)).rbrKnowledgeSoundness init impl + (relIn := BatchingPhase.batchingInputRelation κ L K (β := booleanHypercubeBasis κ L K β) + ℓ ℓ' h_l (BinaryBasefoldAbstractOStmtIn (β := β) + (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate))) + (relOut := acceptRejectOracleRel) + (rbrKnowledgeError := fullRbrKnowledgeError κ L K β ℓ' 𝓡 ϑ γ_repetitions + (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) := by + let V₁_batchingCore : OracleVerifier (oSpec := []ₒ) + (StmtIn := BatchingStmtIn (L := L) (ℓ := ℓ)) + (OStmtIn := (BinaryBasefoldAbstractOStmtIn (β := β) + (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate)).OStmtIn) + (StmtOut := BinaryBasefold.FinalSumcheckStatementOut (L := L) (ℓ := ℓ')) + (OStmtOut := BinaryBasefold.OracleStatement K β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ (Fin.last ℓ')) + (pSpec := batchingCorePspec κ L K β ℓ' 𝓡 ϑ h_ℓ_add_R_rate) := + batchingCoreVerifier κ (L := L) (K := K) (β := β) (ℓ := ℓ) (ℓ' := ℓ') + (𝓡 := 𝓡) (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (h_l := h_l) (𝓑 := 𝓑) + let res := + OracleVerifier.append_rbrKnowledgeSoundness + (init := init) (impl := impl) + (rel₁ := BatchingPhase.batchingInputRelation κ L K (β := booleanHypercubeBasis κ L K β) + ℓ ℓ' h_l (BinaryBasefoldAbstractOStmtIn (β := β) + (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate))) + (rel₂ := BinaryBasefold.finalSumcheckRelOut K β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) + (rel₃ := acceptRejectOracleRel) + (V₁ := V₁_batchingCore) + (V₂ := QueryPhase.queryOracleVerifier K β γ_repetitions + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ϑ := ϑ)) + (Oₛ₃ := by exact fun _ => OracleInterface.instDefault) + (rbrKnowledgeError₁ := batchingCoreRbrKnowledgeError κ L K β ℓ' 𝓡 ϑ h_ℓ_add_R_rate) + (rbrKnowledgeError₂ := QueryPhase.queryRbrKnowledgeError K β γ_repetitions + (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) + (h₁ := by + simpa [batchingCoreVerifier, batchingCorePspec, + BinaryBasefold.pSpecCoreInteraction, batchingCoreRbrKnowledgeError] using + (OracleVerifier.append_rbrKnowledgeSoundness + (init := init) (impl := impl) + (rel₁ := BatchingPhase.batchingInputRelation κ L K + (β := booleanHypercubeBasis κ L K β) ℓ ℓ' h_l + (BinaryBasefoldAbstractOStmtIn (β := β) + (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate))) + (rel₂ := RingSwitching.sumcheckRoundRelation κ L K (booleanHypercubeBasis κ L K β) + ℓ ℓ' h_l (𝓑 := 𝓑) (aOStmtIn := BinaryBasefoldAbstractOStmtIn (β := β) + (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) 0) + (rel₃ := BinaryBasefold.finalSumcheckRelOut K β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) + (V₁ := RingSwitching.BatchingPhase.batchingOracleVerifier κ (L := L) (K := K) + (β := booleanHypercubeBasis κ L K β) + (ℓ := ℓ) (ℓ' := ℓ') (h_l := h_l) (𝓑 := 𝓑) + (aOStmtIn := (BinaryBasefoldAbstractOStmtIn (β := β) + (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate)))) + (V₂ := coreInteractionOracleVerifier κ (L := L) (K := K) + (β := β) (ℓ := ℓ) (ℓ' := ℓ') (h_l := h_l) (𝓡 := 𝓡) (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑)) + (rbrKnowledgeError₁ := fun _ => + RingSwitching.BatchingPhase.batchingRBRKnowledgeError (κ := κ) (L := L)) + (rbrKnowledgeError₂ := + coreInteractionOracleRbrKnowledgeError + (κ := κ) (L := L) (K := K) (β := β) (ℓ' := ℓ') (𝓡 := 𝓡) (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) + (h₁ := RingSwitching.BatchingPhase.batchingOracleVerifier_rbrKnowledgeSoundness + (κ := κ) (L := L) (K := K) (β := booleanHypercubeBasis κ L K β) + (ℓ := ℓ) (ℓ' := ℓ') (h_l := h_l) (𝓑 := 𝓑) + (aOStmtIn := (BinaryBasefoldAbstractOStmtIn (β := β) + (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate)))) + (h₂ := coreInteractionOracleVerifier_rbrKnowledgeSoundness + (κ := κ) (L := L) (K := K) (β := β) (ℓ := ℓ) (ℓ' := ℓ') (h_l := h_l) + (𝓡 := 𝓡) (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑)))) + (h₂ := QueryPhase.queryOracleVerifier_rbrKnowledgeSoundness K β γ_repetitions + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ϑ := ϑ) init impl) + simpa only [ChallengeIdx, fullOracleVerifier, batchingCorePspec, + BinaryBasefold.pSpecCoreInteraction, batchingCoreVerifier, MessageIdx] using res end end Binius.FRIBinius.FullFRIBinius diff --git a/ArkLib/ProofSystem/Binius/FRIBinius/Prelude.lean b/ArkLib/ProofSystem/Binius/FRIBinius/Prelude.lean index 805b8ef07..e979376ac 100644 --- a/ArkLib/ProofSystem/Binius/FRIBinius/Prelude.lean +++ b/ArkLib/ProofSystem/Binius/FRIBinius/Prelude.lean @@ -6,6 +6,7 @@ Authors: Chung Thai Nguyen, Quang Dao import ArkLib.ProofSystem.Binius.RingSwitching.Prelude import ArkLib.ProofSystem.Binius.BinaryBasefold.Spec +import ArkLib.ProofSystem.Binius.RingSwitching.BBFSmallFieldIOPCS /-! # FRI-Binius IOPCS Prelude @@ -26,7 +27,8 @@ variable (L : Type) [Field L] [Fintype L] [DecidableEq L] [CharP L 2] variable (K : Type) [Field K] [Fintype K] [DecidableEq K] variable [h_Fq_char_prime : Fact (Nat.Prime (ringChar K))] [hF₂ : Fact (Fintype.card K = 2)] variable [Algebra K L] -variable (β : Basis (Fin (2 ^ κ)) K L) +variable (β : Basis (Fin (2 ^ κ)) K L) [hβ_lin_indep : Fact (LinearIndependent K β)] + [h_β₀_eq_1 : Fact (β 0 = 1)] variable (ℓ ℓ' 𝓡 ϑ γ_repetitions : ℕ) [NeZero ℓ] [NeZero ℓ'] [NeZero 𝓡] [NeZero ϑ] variable (h_ℓ_add_R_rate : ℓ' + 𝓡 < 2 ^ κ) variable (h_l : ℓ = ℓ' + κ) @@ -48,13 +50,8 @@ instance linearIndependentBooleanHypercubeBasis : Fact (LinearIndependent K ⇑ constructor exact β.linearIndependent -def BinaryBasefoldAbstractOStmtIn : (RingSwitching.AbstractOStmtIn L ℓ') where - ιₛᵢ := Fin (BinaryBasefold.toOutCodewordsCount ℓ' ϑ (i:=0)) - OStmtIn := BinaryBasefold.OracleStatement K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) ϑ 0 - Oₛᵢ := Binius.BinaryBasefold.instOracleStatementBinaryBasefold K β - (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ϑ := ϑ) (i := 0) - initialCompatibility := fun ⟨t, oStmt⟩ => - Binius.BinaryBasefold.firstOracleWitnessConsistencyProp K β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) - t (f₀ := Binius.BinaryBasefold.getFirstOracle K β oStmt) +def BinaryBasefoldAbstractOStmtIn : (RingSwitching.AbstractOStmtIn (L := L) (ℓ' := ℓ')) := + Binius.RingSwitching.BBFSmallFieldIOPCS.bbfAbstractOStmtIn (𝔽q := K) (β := β) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ϑ := ϑ) end Binius.FRIBinius diff --git a/ArkLib/ProofSystem/Binius/RingSwitching/BBFSmallFieldIOPCS.lean b/ArkLib/ProofSystem/Binius/RingSwitching/BBFSmallFieldIOPCS.lean new file mode 100644 index 000000000..c14a34b49 --- /dev/null +++ b/ArkLib/ProofSystem/Binius/RingSwitching/BBFSmallFieldIOPCS.lean @@ -0,0 +1,765 @@ +/- +Copyright (c) 2025 ArkLib Contributors. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Chung Thai Nguyen, Quang Dao +-/ + +import ArkLib.ProofSystem.Binius.RingSwitching.General +import ArkLib.ProofSystem.Binius.BinaryBasefold.General +import ArkLib.ProofSystem.Binius.BinaryBasefold.Soundness +import ArkLib.OracleReduction.LiftContext.OracleReduction + +/-! +# BBF Small-Field IOPCS: Ring-Switching + Binary Basefold Composition + +This module instantiates the Ring-Switching protocol with Binary Basefold as the inner +large-field MLIOPCS, producing a **small-field IOPCS** (the standard, non-optimized composition). + +## Architecture + +The composition follows the protocol layering: +1. **Ring-switching** (outer): Reduces a small-field polynomial commitment to a large-field one +2. **Binary Basefold** (inner): Serves as the `MLIOPCS L ℓ'` for the large-field evaluation + +This is the pedagogical/reference implementation that invokes Binary Basefold as a black box, +in contrast to `FRIBinius/CoreInteractionPhase.lean` which fuses the sumcheck-fold steps. + +## Main Results + +- `bbfMLIOPCS`: Binary Basefold instantiated as an `MLIOPCS L ℓ'` +- `bbf_fullOracleReduction_perfectCompleteness`: Perfect completeness of the composed protocol +- `bbf_fullOracleVerifier_rbrKnowledgeSoundness`: RBR knowledge soundness of the composed protocol + +## References + +- [DP24] Diamond, Benjamin E., and Jim Posen. "Polylogarithmic Proofs for Multilinears over Binary + Towers." Cryptology ePrint Archive (2024). +-/ + + +namespace Binius.RingSwitching.BBFSmallFieldIOPCS + +open Binius.BinaryBasefold Binius.BinaryBasefold.FullBinaryBasefold +open Binius.RingSwitching Binius.RingSwitching.FullRingSwitching +open Polynomial MvPolynomial OracleSpec OracleComp ProtocolSpec Finset AdditiveNTT Module +open scoped NNReal + +noncomputable section + +/-! ## Part 1: Binary Basefold as MLIOPCS + +We construct an `MLIOPCS L ℓ'` by wrapping Binary Basefold's full protocol. +The construction is parameterized by Binary Basefold parameters only (no Ring-switching params). +-/ + +section BinaryBasefoldMLIOPCS + +variable {r : ℕ} [NeZero r] +variable {L : Type} [Field L] [Fintype L] [DecidableEq L] [CharP L 2] + [SampleableType L] +variable (𝔽q : Type) [Field 𝔽q] [Fintype 𝔽q] [DecidableEq 𝔽q] + [h_Fq_char_prime : Fact (Nat.Prime (ringChar 𝔽q))] [hF₂ : Fact (Fintype.card 𝔽q = 2)] +variable [Algebra 𝔽q L] +variable (β : Fin r → L) [hβ_lin_indep : Fact (LinearIndependent 𝔽q β)] + [h_β₀_eq_1 : Fact (β 0 = 1)] +variable {ℓ' 𝓡 ϑ : ℕ} (γ_repetitions : ℕ) [NeZero ℓ'] [NeZero 𝓡] [NeZero ϑ] +variable {h_ℓ_add_R_rate : ℓ' + 𝓡 < r} +variable {𝓑 : Fin 2 ↪ L} +variable [h_B01 : Fact (𝓑 0 = 0 ∧ 𝓑 1 = 1)] +variable [hdiv : Fact (ϑ ∣ ℓ')] + +instance : OracleInterface Unit := OracleInterface.instDefault + +/-! ### Type Adapters + +| MLIOPCS | BinaryBasefold | +|------------------------|------------------------------------------| +| `MLPEvalStatement L ℓ'`| `Statement (SumcheckBaseContext L ℓ') 0` | +| `WitMLP L ℓ'` | `Witness 𝔽q β 0` | +| `OStmtIn` | `OracleStatement 𝔽q β ϑ 0` | + +At round 0, `sumcheck_target = original_claim` (since `∑ x, eq(r,x) * t(x) = t(r) = s`). -/ + +/-- Convert an `MLPEvalStatement L ℓ'` produced at the end of ring-switching protocol +to a `Statement (SumcheckBaseContext L ℓ') 0` that is equal to the initial statement +of the large-field Binary Basefold protocol. -/ +def reducedMLPEvalStatement_to_BBF_Statement (stmt : MLPEvalStatement (L := L) (ℓ := ℓ')) : + Statement (L := L) (SumcheckBaseContext L ℓ') (0 : Fin (ℓ' + 1)) where + sumcheck_target := stmt.original_claim + challenges := Fin.elim0 + ctx := ⟨stmt.t_eval_point, stmt.original_claim⟩ + +/-- Convert `WitMLP L ℓ'` to `Witness 𝔽q β 0`. -/ +def MLPEvalWitness_to_BBF_Witness (stmt : MLPEvalStatement (L := L) (ℓ := ℓ')) + (wit : WitMLP (K := L) (ℓ := ℓ')) : + Witness (L := L) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ := ℓ') (0 : Fin (ℓ' + 1)) where + t := wit.t + H := projectToMidSumcheckPoly (L := L) (ℓ := ℓ') (t := wit.t) + (m := BBF_SumcheckMultiplierParam.multpoly ⟨stmt.t_eval_point, stmt.original_claim⟩) + (i := (0 : Fin (ℓ' + 1))) (challenges := Fin.elim0) + f := getMidCodewords 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) wit.t Fin.elim0 + +/-! ### Large-Field Invocation Wrapper + +Ring-switching ends with a large-field MLP-evaluation statement/witness pair. +This wrapper maps that pair into Binary Basefold's round-0 input context +via `LiftContext`, reusing Binary Basefold's full reduction unchanged. -/ + +/-- Statement lens for the ring-switching large-field invocation into Binary Basefold. -/ +def largeFieldInvocationStmtLens : OracleStatement.Lens + (OuterStmtIn := MLPEvalStatement (L := L) (ℓ := ℓ')) + (OuterStmtOut := Bool) + (InnerStmtIn := Statement (L := L) (SumcheckBaseContext L ℓ') (0 : Fin (ℓ' + 1))) + (InnerStmtOut := Bool) + (OuterOStmtIn := OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (ℓ := ℓ') ϑ (0 : Fin (ℓ' + 1))) + (OuterOStmtOut := fun _ : Empty => Unit) + (InnerOStmtIn := OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (ℓ := ℓ') ϑ (0 : Fin (ℓ' + 1))) + (InnerOStmtOut := fun _ : Empty => Unit) where + toFunA := fun ⟨stmtIn, oStmtIn⟩ => + ⟨reducedMLPEvalStatement_to_BBF_Statement stmtIn, oStmtIn⟩ + toFunB := fun _ ⟨stmtOut, oStmtOut⟩ => ⟨stmtOut, oStmtOut⟩ + +/-- Context lens for the ring-switching large-field invocation into Binary Basefold. -/ +def largeFieldInvocationCtxLens : OracleContext.Lens + (OuterStmtIn := MLPEvalStatement (L := L) (ℓ := ℓ')) + (OuterStmtOut := Bool) + (InnerStmtIn := Statement (L := L) (SumcheckBaseContext L ℓ') (0 : Fin (ℓ' + 1))) + (InnerStmtOut := Bool) + (OuterOStmtIn := OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (ℓ := ℓ') ϑ (0 : Fin (ℓ' + 1))) + (OuterOStmtOut := fun _ : Empty => Unit) + (InnerOStmtIn := OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (ℓ := ℓ') ϑ (0 : Fin (ℓ' + 1))) + (InnerOStmtOut := fun _ : Empty => Unit) + (OuterWitIn := WitMLP (K := L) (ℓ := ℓ')) + (OuterWitOut := Unit) + (InnerWitIn := Witness (L := L) 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ := ℓ') (0 : Fin (ℓ' + 1))) + (InnerWitOut := Unit) where + stmt := largeFieldInvocationStmtLens 𝔽q β + wit := { + toFunA := fun ⟨⟨stmtIn, _oStmtIn⟩, witIn⟩ => + MLPEvalWitness_to_BBF_Witness 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) stmtIn witIn + toFunB := fun _ _ => () + } + +/-- Binary Basefold oracle reduction lifted to the ring-switching large-field invocation context. -/ +def largeFieldInvocationOracleReduction : + OracleReduction (oSpec := []ₒ) + (StmtIn := MLPEvalStatement (L := L) (ℓ := ℓ')) + (OStmtIn := OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ := ℓ') ϑ + (0 : Fin (ℓ' + 1))) + (StmtOut := Bool) + (OStmtOut := fun _ : Empty => Unit) + (WitIn := WitMLP (K := L) (ℓ := ℓ')) + (WitOut := Unit) + (pSpec := fullPSpec 𝔽q β γ_repetitions (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) := + (FullBinaryBasefold.fullOracleReduction 𝔽q β γ_repetitions (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) (ℓ := ℓ')).liftContext + (lens := largeFieldInvocationCtxLens 𝔽q β) + +omit [SampleableType L] in +/-- Uniqueness of the polynomial witness from first-oracle UDR-compatibility. -/ +lemma firstOracleWitnessConsistency_unique + (oStmt : ∀ j, OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ := ℓ') ϑ + (0 : Fin (ℓ' + 1)) j) + {t₁ t₂ : MultilinearPoly L ℓ'} + (h₁ : firstOracleWitnessConsistencyProp 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ := ℓ') + t₁ (getFirstOracle 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) oStmt)) + (h₂ : firstOracleWitnessConsistencyProp 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ := ℓ') + t₂ (getFirstOracle 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) oStmt)) : + t₁ = t₂ := by + have h₁_some : + extractMLP 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ := ℓ') 0 + (getFirstOracle 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) oStmt) = some t₁ := + (extractMLP_eq_some_iff_pair_UDRClose 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (ℓ := ℓ') (f := getFirstOracle 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) oStmt) + (tpoly := t₁)).2 h₁ + have h₂_some : + extractMLP 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ := ℓ') 0 + (getFirstOracle 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) oStmt) = some t₂ := + (extractMLP_eq_some_iff_pair_UDRClose 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (ℓ := ℓ') (f := getFirstOracle 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) oStmt) + (tpoly := t₂)).2 h₂ + rw [h₁_some] at h₂_some + simpa using h₂_some + +lemma map_eval_sumToIter_rename_finSum_zero + (p : MvPolynomial (Fin ℓ') L) : + (MvPolynomial.map (MvPolynomial.eval (σ := Fin 0) Fin.elim0) + ((sumToIter L (Fin ℓ') (Fin 0)) + (MvPolynomial.rename + (f := ⇑(finSumFinEquiv (m := ℓ') (n := 0)).symm) p))) = p := by + have h_sumToIter : + (sumToIter L (Fin ℓ') (Fin 0)) + (MvPolynomial.rename + (f := ⇑(finSumFinEquiv (m := ℓ') (n := 0)).symm) p) = + MvPolynomial.map (MvPolynomial.C) p := by + have h_ren_fun : + (fun i : Fin ℓ' => (finSumFinEquiv (m := ℓ') (n := 0)).symm i) = Sum.inl := by + funext i + simpa using (finSumFinEquiv_symm_apply_castAdd (m := ℓ') (n := 0) i) + have h_ren : + MvPolynomial.rename + (f := ⇑(finSumFinEquiv (m := ℓ') (n := 0)).symm) p = + MvPolynomial.rename (f := Sum.inl) p := by + simpa using congrArg (fun f => MvPolynomial.rename (f := f) p) h_ren_fun + rw [h_ren] + have h_comp := MvPolynomial.sumAlgEquiv_comp_rename_inl + (R := L) (S₁ := Fin ℓ') (S₂ := Fin 0) + have h_eval_comp := congrArg (fun f => f p) h_comp + simpa using h_eval_comp + rw [h_sumToIter] + rw [MvPolynomial.map_map] + have h_eval_comp_id : + (MvPolynomial.eval (σ := Fin 0) Fin.elim0).comp MvPolynomial.C = RingHom.id L := by + ext a + simp + rw [h_eval_comp_id] + exact MvPolynomial.map_id p + +lemma fixFirstVariablesOfMQP_zero_eq + (H : MvPolynomial (Fin ℓ') L) : + fixFirstVariablesOfMQP (L := L) (ℓ := ℓ') (v := (0 : Fin (ℓ' + 1))) H + (challenges := Fin.elim0) = H := by + simpa [fixFirstVariablesOfMQP] using + (map_eval_sumToIter_rename_finSum_zero (L := L) (p := H)) + +lemma witnessStructuralInvariant_MLPEvalWitness_to_BBF_Witness + (stmt : MLPEvalStatement (L := L) (ℓ := ℓ')) + (wit : WitMLP (K := L) (ℓ := ℓ')) : + Binius.BinaryBasefold.witnessStructuralInvariant 𝔽q β + (mp := BBF_SumcheckMultiplierParam) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (reducedMLPEvalStatement_to_BBF_Statement (L := L) (ℓ' := ℓ') stmt) + (MLPEvalWitness_to_BBF_Witness 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) stmt wit) := by + simpa [Binius.BinaryBasefold.witnessStructuralInvariant, + reducedMLPEvalStatement_to_BBF_Statement, MLPEvalWitness_to_BBF_Witness] + +/-- If `t(r) = s` for the outer MLP statement, then the mapped round-0 BBF witness +satisfies the BBF round-0 sumcheck consistency identity. -/ +lemma sumcheckConsistency_MLPEvalWitness_to_BBF_Witness_of_eval + (stmt : MLPEvalStatement (L := L) (ℓ := ℓ')) + (wit : WitMLP (K := L) (ℓ := ℓ')) + (h_eval : wit.t.val.eval stmt.t_eval_point = stmt.original_claim) : + sumcheckConsistencyProp (𝓑 := 𝓑) + (reducedMLPEvalStatement_to_BBF_Statement (L := L) (ℓ' := ℓ') stmt).sumcheck_target + (MLPEvalWitness_to_BBF_Witness 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) stmt wit).H := by + rw [sumcheckConsistencyProp] + dsimp [reducedMLPEvalStatement_to_BBF_Statement, MLPEvalWitness_to_BBF_Witness, + computeInitialSumcheckPoly, BBF_SumcheckMultiplierParam, BBF_eq_multiplier] + rw [← h_eval] + let castEmb : Fin 2 ↪ L := ⟨fun b => (b : L), by + intro a b h + fin_cases a <;> fin_cases b <;> simp at h <;> simp [h]⟩ + have h_Beq : 𝓑 = castEmb := by + ext b + fin_cases b <;> simp [castEmb, h_B01.out.1, h_B01.out.2] + subst h_Beq + have h_H0 : + projectToMidSumcheckPoly (L := L) (ℓ := ℓ') (t := wit.t) + (m := BBF_SumcheckMultiplierParam.multpoly ⟨stmt.t_eval_point, stmt.original_claim⟩) + (i := (0 : Fin (ℓ' + 1))) (challenges := Fin.elim0) = + computeInitialSumcheckPoly (ℓ := ℓ') wit.t + (BBF_SumcheckMultiplierParam.multpoly ⟨stmt.t_eval_point, stmt.original_claim⟩) := by + have h_fix0 : + fixFirstVariablesOfMQP (L := L) (ℓ := ℓ') + (v := (0 : Fin (ℓ' + 1))) + (H := (computeInitialSumcheckPoly (ℓ := ℓ') wit.t + (BBF_SumcheckMultiplierParam.multpoly + ⟨stmt.t_eval_point, stmt.original_claim⟩)).val) + (challenges := Fin.elim0) = + (computeInitialSumcheckPoly (ℓ := ℓ') wit.t + (BBF_SumcheckMultiplierParam.multpoly + ⟨stmt.t_eval_point, stmt.original_claim⟩)).val := + fixFirstVariablesOfMQP_zero_eq (L := L) + (H := (computeInitialSumcheckPoly (ℓ := ℓ') wit.t + (BBF_SumcheckMultiplierParam.multpoly + ⟨stmt.t_eval_point, stmt.original_claim⟩)).val) + apply Subtype.ext + simpa [projectToMidSumcheckPoly] using h_fix0 + let mEq : MultilinearPoly L ℓ' := BBF_eq_multiplier (L := L) stmt.t_eval_point + have h_H0' : + projectToMidSumcheckPoly (L := L) (ℓ := ℓ') (t := wit.t) + (m := mEq) (i := (0 : Fin (ℓ' + 1))) (challenges := Fin.elim0) = + computeInitialSumcheckPoly (ℓ := ℓ') wit.t mEq := by + simpa [mEq, BBF_SumcheckMultiplierParam, BBF_eq_multiplier] using h_H0 + change MvPolynomial.eval stmt.t_eval_point wit.t.val = + ∑ x ∈ Fintype.piFinset (fun _ : Fin ℓ' => Finset.map castEmb (Finset.univ : Finset (Fin 2))), + MvPolynomial.eval x + (projectToMidSumcheckPoly (L := L) (ℓ := ℓ') (t := wit.t) + (m := mEq) (i := (0 : Fin (ℓ' + 1))) (challenges := Fin.elim0)).val + rw [h_H0'] + change MvPolynomial.eval stmt.t_eval_point wit.t.val = + ∑ x ∈ Fintype.piFinset (fun _ : Fin ℓ' => Finset.map castEmb (Finset.univ : Finset (Fin 2))), + MvPolynomial.eval x (mEq.val * wit.t.val) + have h_pi : + Fintype.piFinset (fun _ : Fin ℓ' => Finset.map castEmb (Finset.univ : Finset (Fin 2))) = + (Finset.univ : Finset (Fin ℓ' → Fin 2)).image + (fun b : Fin ℓ' → Fin 2 => fun i => castEmb (b i)) := by + simpa [Finset.map_eq_image, Fintype.piFinset_univ] using + (Fintype.piFinset_image + (f := fun _ : Fin ℓ' => castEmb) + (s := fun _ : Fin ℓ' => (Finset.univ : Finset (Fin 2)))) + rw [h_pi, Finset.sum_image] + · simp only [MvPolynomial.eval_mul] + have h_sum_symm : + (∑ x : Fin ℓ' → Fin 2, + MvPolynomial.eval (fun i => castEmb (x i)) mEq.val * + MvPolynomial.eval (fun i => castEmb (x i)) wit.t.val) = + (∑ x : Fin ℓ' → Fin 2, + MvPolynomial.eval stmt.t_eval_point (MvPolynomial.eqPolynomial (fun i => castEmb (x i))) * + MvPolynomial.eval (fun i => castEmb (x i)) wit.t.val) := by + apply Finset.sum_congr rfl + intro x hx + have h_mEq : MvPolynomial.eval (fun i => castEmb (x i)) mEq.val = MvPolynomial.eval + (fun i => castEmb (x i)) (MvPolynomial.eqPolynomial stmt.t_eval_point) := by + simp only [BBF_eq_multiplier, map_prod, map_add, map_mul, map_sub, map_one, + MvPolynomial.eval_C, MvPolynomial.eval_X, mEq] + rw [h_mEq] + congr 1 + simpa using (MvPolynomial.eqPolynomial_symm + (x := fun i => castEmb (x i)) (y := stmt.t_eval_point)).symm + rw [h_sum_symm] + have h_multilinear : MvPolynomial.MLE + (fun x : Fin ℓ' → Fin 2 => MvPolynomial.eval (x : Fin ℓ' → L) wit.t.val) = wit.t.val := by + exact (MvPolynomial.is_multilinear_iff_eq_evals_zeroOne (p := wit.t.val)).mp wit.t.property + calc + MvPolynomial.eval stmt.t_eval_point wit.t.val = + MvPolynomial.eval stmt.t_eval_point + (MvPolynomial.MLE + (fun x : Fin ℓ' → Fin 2 => MvPolynomial.eval (x : Fin ℓ' → L) wit.t.val)) := by + rw [h_multilinear] + _ = ∑ x : Fin ℓ' → Fin 2, + MvPolynomial.eval stmt.t_eval_point (MvPolynomial.eqPolynomial (x : Fin ℓ' → L)) * + MvPolynomial.eval (x : Fin ℓ' → L) wit.t.val := by + unfold MvPolynomial.MLE + simp only [MvPolynomial.eval_sum, MvPolynomial.eval_mul, MvPolynomial.eval_C] + _ = ∑ x : Fin ℓ' → Fin 2, + MvPolynomial.eval stmt.t_eval_point (MvPolynomial.eqPolynomial (fun i => castEmb (x i))) * + MvPolynomial.eval (fun i => castEmb (x i)) wit.t.val := by + apply Finset.sum_congr rfl + intro x hx + rfl + · intro x hx y hy hxy + funext i + apply castEmb.injective + exact congrFun hxy i + +/-! ### AbstractOStmtIn + +Following the pattern from `FRIBinius/Prelude.lean` (`BinaryBasefoldAbstractOStmtIn`). -/ + +/-- The `AbstractOStmtIn` for Binary Basefold. + +The oracle statement type is `OracleStatement 𝔽q β ϑ 0`, representing initial committed +codewords. The compatibility relations tie the polynomial `t'` to the oracle commitments +via first-oracle witness consistency + oracle folding consistency (relaxed), +and exact equality (strict). -/ +def bbfAbstractOStmtIn : AbstractOStmtIn L ℓ' where + ιₛᵢ := Fin (toOutCodewordsCount ℓ' ϑ (0 : Fin (ℓ' + 1))) + OStmtIn := OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ := ℓ') ϑ + (0 : Fin (ℓ' + 1)) + Oₛᵢ := instOracleStatementBinaryBasefold 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ϑ := ϑ) (i := (0 : Fin (ℓ' + 1))) + -- Relaxed input compatibility at round 0 (RBR-KS style). + initialCompatibility := fun ⟨t', oStmt⟩ => + firstOracleWitnessConsistencyProp 𝔽q β (ℓ := ℓ') + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) t' + (getFirstOracle 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) oStmt) + -- Strict compatibility: exact oracle folding consistency (implies UDR-closeness). + strictInitialCompatibility := fun ⟨t', oStmt⟩ => + strictOracleFoldingConsistencyProp 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (t := t') (i := (0 : Fin (ℓ' + 1))) + (challenges := Fin.elim0) (oStmt := oStmt) + -- Strict (exact equality) implies relaxed (UDR-closeness). + strictInitialCompatibility_implies_initialCompatibility := by + intro oStmt t h_compat_strict + -- strictOracleFoldingConsistencyProp implies f₀ = getFirstOracle + have h_eq := Binius.BinaryBasefold.QueryPhase.polyToOracleFunc_eq_getFirstOracle + 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (t := t) (i := (0 : Fin (ℓ' + 1))) + (challenges := Fin.elim0) (oStmt := oStmt) h_compat_strict + -- Exact equality implies UDR-closeness (hamming distance 0) + dsimp only [firstOracleWitnessConsistencyProp] + rw [← h_eq] + dsimp only [pair_UDRClose] + have h_dist_pos : + 0 < BBF_CodeDistance 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (i := (0 : Fin r)) := by + rw [BBF_CodeDistance_eq 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (i := (0 : Fin r)) (h_i := by simp)] + omega + simp only [hammingDist_self, mul_zero, h_dist_pos] + -- Unique polynomial determination from oracle (via UDR-closeness) + initialCompatibility_unique := fun oStmt t₁ t₂ h₁ h₂ => by + exact firstOracleWitnessConsistency_unique 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ϑ := ϑ) oStmt h₁ h₂ + +instance largeFieldInvocationCtxLens_complete : + (largeFieldInvocationCtxLens 𝔽q β).toContext.IsComplete + (outerRelIn := (bbfAbstractOStmtIn 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ϑ := ϑ)).toStrictRelInput) + (innerRelIn := strictRoundRelation (mp := BBF_SumcheckMultiplierParam) 𝔽q β + (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) (0 : Fin (ℓ' + 1))) + (outerRelOut := acceptRejectOracleRel) + (innerRelOut := acceptRejectOracleRel) + (compat := Reduction.compatContext (oSpec := []ₒ) + (pSpec := fullPSpec 𝔽q β γ_repetitions (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) + (largeFieldInvocationCtxLens 𝔽q β).toContext + ((FullBinaryBasefold.fullOracleReduction 𝔽q β γ_repetitions (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) (ℓ := ℓ')).toReduction)) where + proj_complete := fun stmtIn witIn hRelIn => by + rcases stmtIn with ⟨stmtIn, oStmtIn⟩ + rcases hRelIn with ⟨h_eval, h_compat⟩ + refine ⟨?_, ?_⟩ + · exact sumcheckConsistency_MLPEvalWitness_to_BBF_Witness_of_eval 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) stmtIn witIn h_eval + · refine ⟨?_, ?_⟩ + · exact witnessStructuralInvariant_MLPEvalWitness_to_BBF_Witness 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) stmtIn witIn + · simpa [reducedMLPEvalStatement_to_BBF_Statement, strictOracleFoldingConsistencyProp] using + h_compat + lift_complete := fun outerStmtIn outerWitIn innerStmtOut innerWitOut hCompat hRelIn hRelOut => by + cases innerWitOut + simpa [largeFieldInvocationCtxLens, largeFieldInvocationStmtLens] using hRelOut + +variable {σ : Type} {init : ProbComp σ} {impl : QueryImpl []ₒ (StateT σ ProbComp)} + +theorem largeFieldInvocationOracleReduction_perfectCompleteness (hInit : NeverFail init) : + OracleReduction.perfectCompleteness + (oracleReduction := largeFieldInvocationOracleReduction 𝔽q β γ_repetitions (𝓑 := 𝓑)) + (relIn := (bbfAbstractOStmtIn 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ϑ := ϑ)).toStrictRelInput) + (relOut := acceptRejectOracleRel) + (init := init) + (impl := impl) := by + let innerReduction := FullBinaryBasefold.fullOracleReduction 𝔽q β γ_repetitions + (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) (ℓ := ℓ') + letI : (largeFieldInvocationCtxLens 𝔽q β).toContext.IsComplete + (outerRelIn := (bbfAbstractOStmtIn 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ϑ := ϑ)).toStrictRelInput) + (innerRelIn := strictRoundRelation (mp := BBF_SumcheckMultiplierParam) 𝔽q β + (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) (0 : Fin (ℓ' + 1))) + (outerRelOut := acceptRejectOracleRel) + (innerRelOut := acceptRejectOracleRel) + (compat := Reduction.compatContext (oSpec := []ₒ) + (pSpec := fullPSpec 𝔽q β γ_repetitions (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) + (largeFieldInvocationCtxLens 𝔽q β).toContext + innerReduction.toReduction) := by + infer_instance + have h_inner := FullBinaryBasefold.fullOracleReduction_perfectCompleteness + 𝔽q β γ_repetitions (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) + (init := init) (impl := impl) hInit + simpa [largeFieldInvocationOracleReduction, innerReduction] using + (OracleReduction.liftContext_perfectCompleteness + (R := innerReduction) + (lens := largeFieldInvocationCtxLens 𝔽q β) + (outerRelIn := (bbfAbstractOStmtIn 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ϑ := ϑ)).toStrictRelInput) + (innerRelIn := strictRoundRelation (mp := BBF_SumcheckMultiplierParam) 𝔽q β + (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) (0 : Fin (ℓ' + 1))) + (outerRelOut := acceptRejectOracleRel) + (innerRelOut := acceptRejectOracleRel) + (init := init) + (impl := impl) + h_inner) + +lemma MLPEvalRelation_of_round0_local_and_structural + (stmt : MLPEvalStatement (L := L) (ℓ := ℓ')) + (wit : Witness (L := L) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ := ℓ') + (0 : Fin (ℓ' + 1))) + (h_local : sumcheckConsistencyProp (𝓑 := 𝓑) + (reducedMLPEvalStatement_to_BBF_Statement (L := L) (ℓ' := ℓ') stmt).sumcheck_target wit.H) + (h_struct : Binius.BinaryBasefold.witnessStructuralInvariant 𝔽q β + (mp := BBF_SumcheckMultiplierParam) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (reducedMLPEvalStatement_to_BBF_Statement (L := L) (ℓ' := ℓ') stmt) wit) : + wit.t.val.eval stmt.t_eval_point = stmt.original_claim := by + let stmt_eval : MLPEvalStatement (L := L) (ℓ := ℓ') := { + t_eval_point := stmt.t_eval_point + original_claim := wit.t.val.eval stmt.t_eval_point + } + let wit_eval : WitMLP (K := L) (ℓ := ℓ') := { t := wit.t } + have h_eval_stmt_eval : wit_eval.t.val.eval stmt_eval.t_eval_point + = stmt_eval.original_claim := by rfl + have h_local_eval : + sumcheckConsistencyProp (𝓑 := 𝓑) + (reducedMLPEvalStatement_to_BBF_Statement (L := L) (ℓ' := ℓ') stmt_eval).sumcheck_target + (MLPEvalWitness_to_BBF_Witness 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) stmt_eval wit_eval).H := + sumcheckConsistency_MLPEvalWitness_to_BBF_Witness_of_eval 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) stmt_eval wit_eval h_eval_stmt_eval + have h_H_eq : + wit.H = (MLPEvalWitness_to_BBF_Witness 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) stmt_eval wit_eval).H := by + calc + wit.H = projectToMidSumcheckPoly (L := L) (ℓ := ℓ') (t := wit.t) + (m := BBF_SumcheckMultiplierParam.multpoly ⟨stmt.t_eval_point, stmt.original_claim⟩) + (i := (0 : Fin (ℓ' + 1))) (challenges := Fin.elim0) := by + simpa [Binius.BinaryBasefold.witnessStructuralInvariant, + reducedMLPEvalStatement_to_BBF_Statement] using h_struct.1 + _ = projectToMidSumcheckPoly (L := L) (ℓ := ℓ') (t := wit.t) + (m := BBF_SumcheckMultiplierParam.multpoly ⟨stmt_eval.t_eval_point, + stmt_eval.original_claim⟩) + (i := (0 : Fin (ℓ' + 1))) (challenges := Fin.elim0) := by + simp [stmt_eval, BBF_SumcheckMultiplierParam, BBF_eq_multiplier] + _ = (MLPEvalWitness_to_BBF_Witness 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) stmt_eval wit_eval).H := by + simp [MLPEvalWitness_to_BBF_Witness, wit_eval] + have h_sum_eq_claim : + stmt.original_claim = ∑ x ∈ (univ.map 𝓑) ^ᶠ (ℓ'), + wit.H.val.eval x := by + simpa [sumcheckConsistencyProp, reducedMLPEvalStatement_to_BBF_Statement] using h_local + have h_sum_eq_eval : + wit.t.val.eval stmt.t_eval_point = ∑ x ∈ (univ.map 𝓑) ^ᶠ (ℓ'), + wit.H.val.eval x := by + have h_sum_eq_eval' : + wit.t.val.eval stmt.t_eval_point = ∑ x ∈ (univ.map 𝓑) ^ᶠ (ℓ'), + (MLPEvalWitness_to_BBF_Witness 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) stmt_eval wit_eval).H.val.eval x := by + simpa [sumcheckConsistencyProp, reducedMLPEvalStatement_to_BBF_Statement, stmt_eval] using + h_local_eval + simpa [h_H_eq] using h_sum_eq_eval' + calc + wit.t.val.eval stmt.t_eval_point = ∑ x ∈ (univ.map 𝓑) ^ᶠ (ℓ'), wit.H.val.eval x := h_sum_eq_eval + _ = stmt.original_claim := h_sum_eq_claim.symm + +/-- Extractor lens for lifting Binary Basefold RBR-KS to the large-field invocation wrapper. -/ +def largeFieldInvocationExtractorLens : Extractor.Lens + (OuterStmtIn := MLPEvalStatement (L := L) (ℓ := ℓ') × + (∀ j, OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ := ℓ') ϑ + (0 : Fin (ℓ' + 1)) j)) + (OuterStmtOut := Bool × (∀ j : Empty, Unit)) + (InnerStmtIn := Statement (L := L) (SumcheckBaseContext L ℓ') (0 : Fin (ℓ' + 1)) × + (∀ j, OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ := ℓ') ϑ + (0 : Fin (ℓ' + 1)) j)) + (InnerStmtOut := Bool × (∀ j : Empty, Unit)) + (OuterWitIn := WitMLP (K := L) (ℓ := ℓ')) + (OuterWitOut := Unit) + (InnerWitIn := Witness (L := L) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ := ℓ') + (0 : Fin (ℓ' + 1))) + (InnerWitOut := Unit) where + stmt := largeFieldInvocationStmtLens 𝔽q β + wit := { + toFunA := fun _ => () + toFunB := fun ⟨⟨_stmtIn, _oStmtIn⟩, _outerWitOut⟩ innerWitIn => ⟨innerWitIn.t⟩ + } + +instance largeFieldInvocationExtractorLens_rbr_knowledge_soundness + {compatStmt : + (MLPEvalStatement (L := L) (ℓ := ℓ') × + (∀ i, OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ := ℓ') ϑ + (0 : Fin (ℓ' + 1)) i)) → + (Bool × (∀ i : Empty, Unit)) → Prop} : + Extractor.Lens.IsKnowledgeSound + (OuterStmtIn := MLPEvalStatement (L := L) (ℓ := ℓ') × + (∀ i, OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ := ℓ') ϑ + (0 : Fin (ℓ' + 1)) i)) + (OuterStmtOut := Bool × (∀ i : Empty, Unit)) + (InnerStmtIn := Statement (L := L) (SumcheckBaseContext L ℓ') (0 : Fin (ℓ' + 1)) × + (∀ i, OracleStatement 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ := ℓ') ϑ + (0 : Fin (ℓ' + 1)) i)) + (InnerStmtOut := Bool × (∀ i : Empty, Unit)) + (OuterWitIn := WitMLP (K := L) (ℓ := ℓ')) + (OuterWitOut := Unit) + (InnerWitIn := Witness (L := L) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (ℓ := ℓ') (0 : Fin (ℓ' + 1))) + (InnerWitOut := Unit) + (outerRelIn := (bbfAbstractOStmtIn 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ϑ := ϑ)).toRelInput) + (innerRelIn := roundRelation (mp := BBF_SumcheckMultiplierParam) 𝔽q β (ϑ := ϑ) + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) (0 : Fin (ℓ' + 1))) + (outerRelOut := acceptRejectOracleRel) + (innerRelOut := acceptRejectOracleRel) + (compatStmt := compatStmt) + (compatWit := fun _ _ => True) + (lens := largeFieldInvocationExtractorLens 𝔽q β) where + proj_knowledgeSound := by + intro outerStmtIn innerStmtOut outerWitOut _ hOuter + simpa [largeFieldInvocationExtractorLens, largeFieldInvocationStmtLens] using hOuter + lift_knowledgeSound := by + intro outerStmtIn outerWitOut innerWitIn _ hInner + rcases outerStmtIn with ⟨stmtIn, oStmtIn⟩ + have hInner' : + roundRelationProp (mp := BBF_SumcheckMultiplierParam) 𝔽q β + (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) + (0 : Fin (ℓ' + 1)) + ((reducedMLPEvalStatement_to_BBF_Statement (L := L) (ℓ' := ℓ') stmtIn, + oStmtIn), innerWitIn) := by + simpa [roundRelation, Set.mem_setOf_eq] using hInner + unfold roundRelationProp Binius.BinaryBasefold.masterKStateProp at hInner' + have h_no_bad : + ¬ incrementalBadEventExistsProp 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + (ϑ := ϑ) (stmtIdx := (0 : Fin (ℓ' + 1))) + (oracleIdx := OracleFrontierIndex.mkFromStmtIdx (0 : Fin (ℓ' + 1))) + (oStmt := oStmtIn) + (challenges := (reducedMLPEvalStatement_to_BBF_Statement (L := L) + (ℓ' := ℓ') stmtIn).challenges) := by + intro h_bad + rcases h_bad with ⟨j, hj⟩ + have hj0 : j = 0 := by + apply Fin.eq_of_val_eq + have hjlt : j.val < 1 := by + simpa [toOutCodewordsCountOf0] using j.isLt + exact Nat.lt_one_iff.mp hjlt + subst hj0 + dsimp [oraclePositionToDomainIndex] at hj + exact absurd hj (by + apply BinaryBasefold.incrementalFoldingBadEvent_of_k_eq_0_is_false (𝔽q := 𝔽q) (β := β) + (h_k := by + simp only [Nat.zero_mod, zero_mul, tsub_self, zero_le, inf_of_le_right]) + (h_midIdx := by simp only [Nat.zero_mod, zero_mul, tsub_self, zero_le, + inf_of_le_right, add_zero]) + ) + rcases hInner' with h_bad | h_good + · exact (h_no_bad h_bad).elim + · have h_local := h_good.1 + have h_struct := h_good.2.1 + have h_first := h_good.2.2.1 + refine ⟨?_, ?_⟩ + · exact MLPEvalRelation_of_round0_local_and_structural 𝔽q β + (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) + stmtIn innerWitIn h_local h_struct + · simpa [bbfAbstractOStmtIn] using h_first + +/-! ### MLIOPCS Instance -/ + +/-- Binary Basefold as an `MLIOPCS L ℓ'`. + +This wraps the full Binary Basefold protocol (core interaction + query phase) +as a multilinear polynomial commitment scheme over the large field `L`. -/ +def bbfMLIOPCS : MLIOPCS L ℓ' where + toAbstractOStmtIn := bbfAbstractOStmtIn 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ϑ := ϑ) + numRounds := _ -- inferred from fullPSpec + pSpec := fullPSpec 𝔽q β γ_repetitions (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + Oₘ := inferInstance + O_challenges := inferInstance + oracleReduction := largeFieldInvocationOracleReduction 𝔽q β γ_repetitions (𝓑 := 𝓑) + perfectCompleteness := by + intro σ init impl hInit + exact largeFieldInvocationOracleReduction_perfectCompleteness 𝔽q β γ_repetitions (𝓑 := 𝓑) + (init := init) (impl := impl) hInit + strictPerfectCompleteness := by + intro σ init impl hInit + exact largeFieldInvocationOracleReduction_perfectCompleteness 𝔽q β γ_repetitions (𝓑 := 𝓑) + (init := init) (impl := impl) hInit + rbrKnowledgeError := + fullRbrKnowledgeError 𝔽q β γ_repetitions (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) + rbrKnowledgeSoundness := by + intro σ init impl + have h_bbf := FullBinaryBasefold.fullOracleVerifier_rbrKnowledgeSoundness + 𝔽q β γ_repetitions (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) + (init := init) (impl := impl) + letI : + Inhabited (Witness (L := L) 𝔽q β (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (ℓ := ℓ') + (0 : Fin (ℓ' + 1))) := ⟨{ + t := 0 + H := 0 + f := fun _ => 0 + }⟩ + letI : ∀ i : Empty, Inhabited ((fun _ : Empty => Unit) i) := by + intro i + exact (i.elim) + have h_lifted := OracleVerifier.liftContext_rbr_knowledgeSoundness + (V := FullBinaryBasefold.fullOracleVerifier 𝔽q β γ_repetitions (ϑ := ϑ) + (𝓑 := 𝓑) (h_ℓ_add_R_rate := h_ℓ_add_R_rate)) + (stmtLens := largeFieldInvocationStmtLens 𝔽q β) + (witLens := (largeFieldInvocationExtractorLens 𝔽q β).wit) + (lensKS := largeFieldInvocationExtractorLens_rbr_knowledge_soundness + (𝔽q := 𝔽q) (β := β) + (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑) + (compatStmt := (FullBinaryBasefold.fullOracleVerifier 𝔽q β γ_repetitions (ϑ := ϑ) + (𝓑 := 𝓑) (h_ℓ_add_R_rate := h_ℓ_add_R_rate)).toVerifier.compatStatement + (largeFieldInvocationStmtLens 𝔽q β))) + (h := by simpa using h_bbf) + simpa [largeFieldInvocationOracleReduction] using h_lifted + +end BinaryBasefoldMLIOPCS + +/-! ## Part 2: End-to-End Composition + +Compose Ring-switching with the Binary Basefold MLIOPCS using the existing +infrastructure in `RingSwitching/General.lean`. +-/ + +section Composition + +variable {r : ℕ} [NeZero r] +variable {L : Type} [Field L] [Fintype L] [DecidableEq L] [CharP L 2] + [SampleableType L] +variable (𝔽q : Type) [Field 𝔽q] [Fintype 𝔽q] [DecidableEq 𝔽q] + [h_Fq_char_prime : Fact (Nat.Prime (ringChar 𝔽q))] [hF₂ : Fact (Fintype.card 𝔽q = 2)] +variable [Algebra 𝔽q L] +variable (β : Fin r → L) [hβ_lin_indep : Fact (LinearIndependent 𝔽q β)] + [h_β₀_eq_1 : Fact (β 0 = 1)] +variable {ℓ' 𝓡 ϑ : ℕ} (γ_repetitions : ℕ) [NeZero ℓ'] [NeZero 𝓡] [NeZero ϑ] +variable {h_ℓ_add_R_rate : ℓ' + 𝓡 < r} +variable {𝓑 : Fin 2 ↪ L} +variable [h_B01 : Fact (𝓑 0 = 0 ∧ 𝓑 1 = 1)] +variable [hdiv : Fact (ϑ ∣ ℓ')] + +-- Ring-switching variables +variable (κ : ℕ) [NeZero κ] +variable (K : Type) [Field K] [Fintype K] [DecidableEq K] +variable [Algebra K L] +variable (β_rs : Basis (Fin κ → Fin 2) K L) +variable (ℓ : ℕ) [NeZero ℓ] +variable (h_l : ℓ = ℓ' + κ) + +variable {σ : Type} (init : ProbComp σ) {impl : QueryImpl []ₒ (StateT σ ProbComp)} + +/-- Perfect completeness of the composed protocol: +Ring-switching + Binary Basefold as MLIOPCS. + +This is a direct instantiation of `fullOracleReduction_perfectCompleteness` from +`RingSwitching/General.lean` with the Binary Basefold MLIOPCS. -/ +theorem bbf_fullOracleReduction_perfectCompleteness (hInit : NeverFail init) : + OracleReduction.perfectCompleteness + (oracleReduction := FullRingSwitching.fullOracleReduction κ L K β_rs ℓ ℓ' h_l + (𝓑 := 𝓑) + (bbfMLIOPCS 𝔽q β γ_repetitions + (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑))) + (relIn := BatchingPhase.strictBatchingInputRelation + κ L K β_rs ℓ ℓ' h_l + (bbfMLIOPCS 𝔽q β γ_repetitions + (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑)).toAbstractOStmtIn) + (relOut := acceptRejectOracleRel) + (init := init) (impl := impl) := + FullRingSwitching.fullOracleReduction_perfectCompleteness κ L K β_rs ℓ ℓ' h_l + (𝓑 := 𝓑) + (bbfMLIOPCS 𝔽q β γ_repetitions + (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑)) + init hInit + +/-- RBR knowledge soundness of the composed protocol: +Ring-switching + Binary Basefold as MLIOPCS. + +This is a direct instantiation of `fullOracleVerifier_rbrKnowledgeSoundness` from +`RingSwitching/General.lean` with the Binary Basefold MLIOPCS. -/ +theorem bbf_fullOracleVerifier_rbrKnowledgeSoundness : + OracleVerifier.rbrKnowledgeSoundness + (verifier := FullRingSwitching.fullOracleVerifier κ L K β_rs ℓ ℓ' (𝓑 := 𝓑) h_l + (bbfMLIOPCS 𝔽q β γ_repetitions + (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑))) + (init := init) (impl := impl) + (relIn := BatchingPhase.batchingInputRelation + κ L K β_rs ℓ ℓ' h_l + (bbfMLIOPCS 𝔽q β γ_repetitions + (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑)).toAbstractOStmtIn) + (relOut := acceptRejectOracleRel) + (rbrKnowledgeError := fun i => FullRingSwitching.fullRbrKnowledgeError κ L K ℓ' + (bbfMLIOPCS 𝔽q β γ_repetitions + (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑)) i) := + FullRingSwitching.fullOracleVerifier_rbrKnowledgeSoundness κ L K β_rs ℓ ℓ' h_l + (𝓑 := 𝓑) + (bbfMLIOPCS 𝔽q β γ_repetitions + (ϑ := ϑ) (h_ℓ_add_R_rate := h_ℓ_add_R_rate) (𝓑 := 𝓑)) + init + +end Composition + +end +end Binius.RingSwitching.BBFSmallFieldIOPCS diff --git a/ArkLib/ProofSystem/Binius/RingSwitching/BatchingPhase.lean b/ArkLib/ProofSystem/Binius/RingSwitching/BatchingPhase.lean index 39787cd36..775237ad3 100644 --- a/ArkLib/ProofSystem/Binius/RingSwitching/BatchingPhase.lean +++ b/ArkLib/ProofSystem/Binius/RingSwitching/BatchingPhase.lean @@ -7,11 +7,13 @@ Authors: Chung Thai Nguyen, Quang Dao import ArkLib.ProofSystem.Binius.RingSwitching.Prelude import ArkLib.ProofSystem.Binius.RingSwitching.Spec import ArkLib.OracleReduction.Basic +import ArkLib.OracleReduction.Completeness import ArkLib.ProofSystem.Binius.BinaryBasefold.ReductionLogic import CompPoly.Fields.Binary.Tower.TensorAlgebra +import ArkLib.Data.Probability.Instances open OracleSpec OracleComp ProtocolSpec Finset AdditiveNTT Polynomial MvPolynomial - Module Binius.BinaryBasefold TensorProduct Nat Matrix + Module Binius.BinaryBasefold TensorProduct Nat Matrix ProbabilityTheory open scoped NNReal /-! @@ -85,12 +87,34 @@ def batchingInputRelationProp (stmt : BatchingStmtIn L ℓ) wit.t' = packMLE κ L K ℓ ℓ' h_l β wit.t ∧ stmt.original_claim = wit.t.val.aeval stmt.t_eval_point ∧ aOStmtIn.initialCompatibility ⟨wit.t', oStmt⟩ +def strictBatchingInputRelationProp (stmt : BatchingStmtIn L ℓ) + (oStmt : ∀ j, aOStmtIn.OStmtIn j) (wit : BatchingWitIn L K ℓ ℓ') : Prop := + wit.t' = packMLE κ L K ℓ ℓ' h_l β wit.t ∧ stmt.original_claim = wit.t.val.aeval stmt.t_eval_point + ∧ aOStmtIn.strictInitialCompatibility ⟨wit.t', oStmt⟩ + /-- Input relation: the witness `t` and `t'` are consistent, and `t` satisfies the original claim. -/ def batchingInputRelation : Set ((BatchingStmtIn L ℓ × (∀ j, aOStmtIn.OStmtIn j)) × BatchingWitIn L K ℓ ℓ') := {⟨⟨stmt, oStmt⟩, wit⟩ | batchingInputRelationProp κ L K β ℓ ℓ' h_l aOStmtIn stmt oStmt wit } +/-- Strict input relation for completeness proofs. -/ +def strictBatchingInputRelation : + Set ((BatchingStmtIn L ℓ × (∀ j, aOStmtIn.OStmtIn j)) × BatchingWitIn L K ℓ ℓ') := + {⟨⟨stmt, oStmt⟩, wit⟩ | + strictBatchingInputRelationProp κ L K β ℓ ℓ' h_l aOStmtIn stmt oStmt wit } + +omit [NeZero κ] [Fintype L] [DecidableEq L] [CharP L 2] [SampleableType L] [Fintype K] + [DecidableEq K] [NeZero ℓ] [NeZero ℓ'] in +lemma strictBatchingInputRelation_subset_batchingInputRelation : + strictBatchingInputRelation κ L K β ℓ ℓ' h_l aOStmtIn ⊆ + batchingInputRelation κ L K β ℓ ℓ' h_l aOStmtIn := by + intro input h_input + rcases input with ⟨⟨stmt, oStmt⟩, wit⟩ + rcases h_input with ⟨h_pack, h_claim, h_strict_compat⟩ + exact ⟨h_pack, h_claim, + aOStmtIn.strictInitialCompatibility_implies_initialCompatibility oStmt wit.t' h_strict_compat⟩ + /-! ## Pure Logic Functions (ReductionLogicStep Infrastructure) -/ /-- Pure verifier check: validates that the prover's ŝ satisfies Check 1. @@ -153,7 +177,7 @@ def batchingProverWitOut (stmtIn : BatchingStmtIn L ℓ) (witIn : BatchingWitIn This encapsulates the pure logic of the batching phase, separating it from the monadic oracle operations. -/ def batchingStepLogic : - Binius.BinaryBasefold.CoreInteraction.ReductionLogicStep + Binius.BinaryBasefold.ReductionLogicStep -- In/Out Types (BatchingStmtIn L ℓ) (BatchingWitIn L K ℓ ℓ') @@ -164,30 +188,26 @@ def batchingStepLogic : -- Protocol Spec (pSpecBatching (κ:=κ) (L:=L) (K:=K)) where - -- 1. Relations (using strict relations for completeness) completeness_relIn := fun ((s, o), w) => - ((s, o), w) ∈ batchingInputRelation κ L K β ℓ ℓ' h_l aOStmtIn + ((s, o), w) ∈ strictBatchingInputRelation κ L K β ℓ ℓ' h_l aOStmtIn completeness_relOut := fun ((s, o), w) => - ((s, o), w) ∈ sumcheckRoundRelation κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) aOStmtIn 0 - + ((s, o), w) ∈ strictSumcheckRoundRelation κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) aOStmtIn 0 -- 2. Verifier Logic (Using extracted kernels) verifierCheck := fun stmtIn transcript => - batchingVerifierCheck (κ:=κ) (L:=L) (K:=K) (β:=β) (ℓ:=ℓ) (ℓ':=ℓ') (h_l:=h_l) (stmtIn := stmtIn) (transcript.messages ⟨0, rfl⟩) - + batchingVerifierCheck (κ:=κ) (L:=L) (K:=K) (β:=β) (ℓ:=ℓ) (ℓ':=ℓ') (h_l:=h_l) (stmtIn := stmtIn) + (transcript.messages ⟨0, rfl⟩) verifierOut := fun stmtIn transcript => - batchingVerifierStmtOut (κ:=κ) (L:=L) (K:=K) (β:=β) (ℓ:=ℓ) (ℓ':=ℓ') (stmtIn := stmtIn) (msg0 := transcript.messages ⟨0, rfl⟩) (r_batching := transcript.challenges ⟨1, rfl⟩) - + batchingVerifierStmtOut (κ:=κ) (L:=L) (K:=K) (β:=β) (ℓ:=ℓ) (ℓ':=ℓ') (stmtIn := stmtIn) + (msg0 := transcript.messages ⟨0, rfl⟩) (r_batching := transcript.challenges ⟨1, rfl⟩) -- 2b. Oracle Embedding (must match oracleVerifier) embed := ⟨fun j => Sum.inl j, fun a b h => by cases h; rfl⟩ hEq := fun i => rfl - -- 3. Honest Prover Logic (Constructing the transcript) honestProverTranscript := fun stmtIn witIn _oStmtIn chal => let msg : TensorAlgebra K L := batchingProverComputeMsg (κ := κ) (L := L) (K := K) (ℓ := ℓ) (ℓ' := ℓ') (h_l := h_l) stmtIn witIn FullTranscript.mk2 msg (chal ⟨1, rfl⟩) - -- 4. Prover Output (State Update) proverOut := fun stmtIn witIn oStmtIn transcript => let msg0 : TensorAlgebra K L := transcript.messages ⟨0, rfl⟩ @@ -227,23 +247,19 @@ lemma batchingStep_is_logic_complete : let proverStmtOut := proverOutput.1.1 let proverOStmtOut := proverOutput.1.2 let proverWitOut := proverOutput.2 - - -- Extract properties from h_relIn (batchingInputRelation) - simp only [batchingStepLogic, batchingInputRelation, batchingInputRelationProp, + -- Extract properties from h_relIn (strictBatchingInputRelation) + simp only [batchingStepLogic, strictBatchingInputRelation, strictBatchingInputRelationProp, Set.mem_setOf_eq] at h_relIn obtain ⟨h_t'_eq_t_packed, h_original_evaluation_claim, h_compat⟩ := h_relIn - -- The message computed by the honest prover let msg0 := batchingProverComputeMsg κ L K ℓ ℓ' h_l stmtIn witIn let r_batching := challenges ⟨1, rfl⟩ - have h_s_hat_eq : transcript.messages ⟨0, rfl⟩ = embedded_MLP_eval κ L K ℓ ℓ' h_l (packMLE κ L K ℓ ℓ' h_l β witIn.t) stmtIn.t_eval_point := by dsimp only [transcript, step, batchingStepLogic] unfold FullTranscript.mk2 dsimp only [batchingProverComputeMsg] rw [h_t'_eq_t_packed] - -- Fact 1: Verifier check passes let hVCheck_passed : step.verifierCheck stmtIn transcript := by simp only [step, batchingStepLogic, batchingVerifierCheck] @@ -252,28 +268,25 @@ lemma batchingStep_is_logic_complete : rw [←h_s_hat_eq] at res rw [h_original_evaluation_claim] exact res - - -- Fact 2: Output relation holds (sumcheckRoundRelation 0) + -- Fact 2: Output relation holds (strictSumcheckRoundRelation 0) let hRelOut : step.completeness_relOut ((verifierStmtOut, verifierOStmtOut), proverWitOut) := by - simp only [step, batchingStepLogic, sumcheckRoundRelation, sumcheckRoundRelationProp, + simp only [step, batchingStepLogic, strictSumcheckRoundRelation, strictSumcheckRoundRelationProp, Set.mem_setOf_eq] -- batching_target_consistency - dsimp only [masterKStateProp, Fin.coe_ofNat_eq_mod]; rw [true_and] + dsimp only [masterStrictKStateProp, Fin.coe_ofNat_eq_mod]; constructor - · -- ⊢ witnessStructuralInvariant κ L K β ℓ ℓ' h_l verifierStmtOut proverWitOut - rfl - · constructor - · -- ⊢ sumcheckConsistencyProp verifierStmtOut.sumcheck_target proverWitOut.H + · -- ⊢ sumcheckConsistencyProp verifierStmtOut.sumcheck_target proverWitOut.H exact batching_target_consistency κ L K β ℓ ℓ' h_l (𝓑:=𝓑) witIn.t' (transcript.messages ⟨0, rfl⟩) r_batching verifierStmtOut.ctx + · constructor + · -- ⊢ witnessStructuralInvariant κ L K β ℓ ℓ' h_l verifierStmtOut proverWitOut + rfl · -- ⊢ aOStmtIn.initialCompatibility (proverWitOut.t', verifierOStmtOut) exact h_compat - -- Fact 3: Prover and verifier statements agree have hStmtOut_eq : proverStmtOut = verifierStmtOut := by simp only [step, batchingStepLogic, proverStmtOut, verifierStmtOut] rfl - -- Fact 4: Prover and verifier oracle statements agree have hOStmtOut_eq : proverOStmtOut = verifierOStmtOut := by simp only [step, batchingStepLogic, proverOStmtOut, verifierOStmtOut] @@ -281,7 +294,6 @@ lemma batchingStep_is_logic_complete : simp only [OracleVerifier.mkVerifierOStmtOut] -- Oracle statements are unchanged (all map via Sum.inl) rfl - -- Combine all facts refine ⟨?_, ?_, ?_, ?_⟩ · exact hVCheck_passed @@ -307,26 +319,22 @@ noncomputable def batchingOracleProver : (WitOut := SumcheckWitness L ℓ' 0) (pSpec := pSpecBatching (κ:=κ) (L:=L) (K:=K)) where PrvState := PrvState κ L K ℓ ℓ' aOStmtIn - input := fun ⟨⟨stmt, oStmt⟩, wit⟩ => (stmt, oStmt, wit) - sendMessage | ⟨0, _⟩ => fun (stmt, oStmt, wit) => do -- USE THE SHARED KERNEL (Guarantees match with batchingStepLogic) let s_hat := batchingProverComputeMsg (κ:=κ) (L:=L) (K:=K) (ℓ:=ℓ) (ℓ':=ℓ') (h_l:=h_l) stmt wit return ⟨s_hat, (stmt, oStmt, wit, s_hat)⟩ | ⟨1, h⟩ => fun _ => do nomatch h -- V to P round - receiveChallenge | ⟨0, h⟩ => nomatch h -- i.e. contradiction | ⟨1, _⟩ => fun ⟨stmt, oStmt, wit, s_hat⟩ => do return fun r_batching => (stmt, oStmt, wit, s_hat, r_batching) - output := fun ⟨stmt, oStmt, wit, (s_hat : TensorAlgebra K L), (r_batching : Fin κ → L)⟩ => do -- Construct the transcript that the honest prover produces -- This matches logic.honestProverTranscript exactly - let logic := (batchingStepLogic (κ := κ) (L := L) (K := K) (β := β) (𝓑 := 𝓑) (ℓ := ℓ) (ℓ' := ℓ') - (h_l := h_l) (aOStmtIn := aOStmtIn)) + let logic := (batchingStepLogic (κ := κ) (L := L) (K := K) (β := β) (𝓑 := 𝓑) (ℓ := ℓ) + (ℓ' := ℓ') (h_l := h_l) (aOStmtIn := aOStmtIn)) let challenges : (pSpecBatching (κ:=κ) (L:=L) (K:=K)).Challenges := fun ⟨j, hj⟩ => by match j with @@ -350,14 +358,13 @@ noncomputable def batchingOracleVerifier : (L:=L) (K:=K).Message]ₒ) ⟨⟨0, by rfl⟩, (by simpa using ())⟩ let r_batching : Fin κ → L := pSpec_batching_challenges ⟨1, by rfl⟩ -- Reconstruct the transcript (matches what honestProverTranscript produces) - let logic := (batchingStepLogic (κ := κ) (L := L) (K := K) (β := β) (𝓑 := 𝓑) (ℓ := ℓ) (ℓ' := ℓ') - (h_l := h_l) (aOStmtIn := aOStmtIn)) + let logic := (batchingStepLogic (κ := κ) (L := L) (K := K) (β := β) (𝓑 := 𝓑) (ℓ := ℓ) + (ℓ' := ℓ') (h_l := h_l) (aOStmtIn := aOStmtIn)) -- Note: We can't call honestProverTranscript directly because we don't have the witness -- But we know the transcript structure must match it let t := FullTranscript.mk2 s_hat r_batching guard (logic.verifierCheck stmtIn t) pure (logic.verifierOut stmtIn t) - -- Reuse embed and hEq from batchingStepLogic to ensure consistency embed := (batchingStepLogic (κ := κ) (L := L) (K := K) (β := β) (𝓑 := 𝓑) (ℓ := ℓ) (ℓ' := ℓ') (h_l := h_l) (aOStmtIn := aOStmtIn)).embed @@ -403,10 +410,7 @@ noncomputable def batchingRbrExtractor : /-- RBR knowledge soundness error for the batching phase. The only verifier randomness is `r''`. A collision has probability related to `κ/|L|`. For simplicity, we can set a placeholder value. -/ -def batchingRBRKnowledgeError (i : (pSpecBatching (κ := κ) (L := L) (K := K)).ChallengeIdx) : ℝ≥0 := - match i with - | ⟨1, _⟩ => (κ : ℝ≥0) / (Fintype.card L : ℝ≥0) -- Schwartz-Zippel error - | _ => 0 -- No other challenges +def batchingRBRKnowledgeError : ℝ≥0 := (κ : ℝ≥0) / (Fintype.card L : ℝ≥0) -- Schwartz-Zippel error def batchingKStateProp {m : Fin (2 + 1)} (tr : Transcript m (pSpecBatching (κ := κ) (L := L) (K := K))) @@ -416,29 +420,23 @@ def batchingKStateProp {m : Fin (2 + 1)} match m with | ⟨0, _⟩ => -- equiv s relIn batchingInputRelationProp κ L K β ℓ ℓ' h_l aOStmtIn stmt oStmt witMid - | ⟨1, _⟩ => by -- P sends hᵢ(X) - let ⟨msgsUpTo, _⟩ := Transcript.equivMessagesChallenges (k := 1) - (pSpec := pSpecBatching (κ:=κ) (L:=L) (K:=K)) tr - let i_msg1 : ((pSpecBatching (κ:=κ) (L:=L) (K:=K)).take 1 (by omega)).MessageIdx := - ⟨⟨0, Nat.lt_of_succ_le (by omega)⟩, by simp [pSpecBatching]; rfl⟩ - let s_hat: TensorAlgebra K L := msgsUpTo i_msg1 + | ⟨1, _⟩ => by -- P sends ŝ + let s_hat : TensorAlgebra K L := tr.messages ⟨0, rfl⟩ exact witMid.t' = packMLE κ L K ℓ ℓ' h_l β witMid.t -- implied by `extractMid` - -- The last two constraints are equivalent to `t(r) = s` + -- `P's computation: ŝ := φ₁(t')(φ₀(r_κ), ..., φ₀(r_{ℓ-1}))` ∧ embedded_MLP_eval κ L K ℓ ℓ' h_l witMid.t' stmt.t_eval_point = s_hat + -- The last two constraints are equivalent to `t(r) = s` + -- `V's check: s ?= Σ_{v ∈ {0,1}^κ} eqTilde(v, r_{0..κ-1}) ⋅ ŝ_v.` ∧ performCheckOriginalEvaluation κ L K β ℓ ℓ' h_l stmt.original_claim stmt.t_eval_point s_hat -- local V check + -- The passed-through oracle compatibility condition of `t'`, i.e. carried through the whole + -- ring-switching protocol + ∧ aOStmtIn.initialCompatibility ⟨witMid.t', oStmt⟩ | ⟨2, _⟩ => by -- implied by relOut simp only [batchingWitMid] at witMid - let ⟨msgsUpTo, chalsUpTo⟩ := Transcript.equivMessagesChallenges (k := 2) - (pSpec := pSpecBatching (κ:=κ) (L:=L) (K:=K)) tr - let i_msg1 : ((pSpecBatching (κ:=κ) (L:=L) (K:=K)).take 2 (by omega)).MessageIdx := - ⟨⟨0, Nat.lt_of_succ_le (by omega)⟩, by simp [pSpecBatching]; rfl⟩ - let s_hat: TensorAlgebra K L := msgsUpTo i_msg1 - let i_msg2 : ((pSpecBatching (κ:=κ) (L:=L) (K:=K)).take 2 (by omega)).ChallengeIdx := - ⟨⟨1, Nat.lt_of_succ_le (by omega)⟩, by simp [pSpecBatching]; rfl⟩ - let batching_challenges: Fin κ → L := chalsUpTo i_msg2 - + let s_hat : TensorAlgebra K L := tr.messages ⟨0, rfl⟩ + let batching_challenges : Fin κ → L := tr.challenges ⟨1, rfl⟩ let ctx : RingSwitchingBaseContext κ L K ℓ := { t_eval_point := stmt.t_eval_point, original_claim := stmt.original_claim, @@ -459,33 +457,228 @@ def batchingKStateProp {m : Fin (2 + 1)} exact sumcheckRoundRelationProp κ L K β ℓ ℓ' h_l (𝓑:=𝓑) aOStmtIn (i:=0) stmtOut oStmt witOut ∧ performCheckOriginalEvaluation κ L K β ℓ ℓ' h_l stmt.original_claim - stmt.t_eval_point s_hat -- local V check + stmt.t_eval_point s_hat -- local V check (kept in m=2 for doom proof; see + -- batching_doom_escape_probability_bound) ∧ aOStmtIn.initialCompatibility ⟨witMid.t', oStmt⟩ /-- Knowledge state function for the batching phase. -/ noncomputable def batchingKnowledgeStateFunction : - (batchingOracleVerifier κ L K β ℓ ℓ' h_l (𝓑:=𝓑) (aOStmtIn:=aOStmtIn)).KnowledgeStateFunction init impl + (batchingOracleVerifier κ L K β ℓ ℓ' h_l (𝓑:=𝓑) (aOStmtIn:=aOStmtIn)).KnowledgeStateFunction + init impl (relIn := batchingInputRelation κ L K β ℓ ℓ' h_l aOStmtIn) (relOut := sumcheckRoundRelation κ L K β ℓ ℓ' h_l (𝓑:=𝓑) aOStmtIn 0) (batchingRbrExtractor κ L K β ℓ ℓ' h_l (aOStmtIn:=aOStmtIn)) where toFun := fun m ⟨stmt, oStmt⟩ tr witMid => batchingKStateProp κ L K β ℓ ℓ' h_l (𝓑:=𝓑) aOStmtIn tr stmt witMid oStmt toFun_empty _ _ := by rfl - toFun_next := fun m hDir stmtIn tr msg witMid => - match m with - | ⟨0, _⟩ => by -- from accumulative KState - intro hSuccTrue - simp only [batchingKStateProp, Fin.zero_eta, Fin.isValue, Fin.succ_zero_eq_one, - Equiv.toFun_as_coe, Transcript.equivMessagesChallenges_apply, Fin.castSucc_zero, - batchingRbrExtractor, Fin.mk_one, Fin.succ_one_eq_two, - batchingInputRelationProp] at ⊢ hSuccTrue - rw [hSuccTrue.1] - simp only [true_and] - set s_hat := (Transcript.concat msg tr).toMessagesChallenges.1 ⟨(0 : Fin (0 + 1)), by rfl⟩ - -- ⊢ stmtIn.1.original_claim = (MvPolynomial.aeval stmtIn.1.t_eval_point) ↑witMid.t - sorry - | ⟨1, h⟩ => nomatch h - toFun_full := fun ⟨stmtLast, oStmtLast⟩ tr witOut => by sorry + toFun_next := fun m hDir stmtIn tr msg witMid => by + have h_m_eq_0 : m = 0 := by + cases m using Fin.cases with + | zero => rfl + | succ m' => simp only [ne_eq, reduceCtorEq, not_false_eq_true, Matrix.cons_val_succ, + Matrix.cons_val_fin_one, Direction.not_V_to_P_eq_P_to_V] at hDir + subst h_m_eq_0 + intro hSuccTrue + unfold batchingKStateProp at hSuccTrue ⊢ + simp only [Fin.zero_eta, Fin.isValue, Fin.succ_zero_eq_one, Fin.castSucc_zero, + batchingRbrExtractor, Fin.mk_one, Fin.succ_one_eq_two] at hSuccTrue ⊢ + obtain ⟨h_t'_eq, h_embed_eq, h_check_true, h_compat⟩ := hSuccTrue + simp only [batchingInputRelationProp] + constructor + · exact h_t'_eq + · constructor + · -- stmt.original_claim = witMid.t.val.aeval stmt.t_eval_point from check + s_hat = φ(t')(r) + have h_check_stmt : + performCheckOriginalEvaluation κ L K β ℓ ℓ' h_l + stmtIn.1.original_claim stmtIn.1.t_eval_point + (embedded_MLP_eval κ L K ℓ ℓ' h_l witMid.t' stmtIn.1.t_eval_point) = true := by + simpa [h_embed_eq] using h_check_true + have h_check_wit : + performCheckOriginalEvaluation κ L K β ℓ ℓ' h_l + (witMid.t.val.aeval stmtIn.1.t_eval_point) stmtIn.1.t_eval_point + (embedded_MLP_eval κ L K ℓ ℓ' h_l witMid.t' stmtIn.1.t_eval_point) = true := by + have h_honest := + batching_check_correctness (κ := κ) (L := L) (K := K) (β := β) + (ℓ := ℓ) (ℓ' := ℓ') (h_l := h_l) (t := witMid.t) + (eval_point := stmtIn.1.t_eval_point) + simpa [h_t'_eq] using h_honest + have hs₁ := (decide_eq_true_eq.mp h_check_stmt) + have hs₂ := (decide_eq_true_eq.mp h_check_wit) + exact hs₁.trans hs₂.symm + · -- aOStmtIn.initialCompatibility + exact h_compat + toFun_full := fun ⟨stmtIn, oStmtIn⟩ tr witOut h_relOut => by + -- h_relOut: ∃ stmtOut oStmtOut, verifier outputs (stmtOut, oStmtOut) with prob > 0 + -- and ((stmtOut, oStmtOut), witOut) ∈ foldStepRelOut + simp only [StateT.run'_eq, gt_iff_lt, probEvent_pos_iff, Prod.exists] at h_relOut + rcases h_relOut with ⟨stmtOut, oStmtOut, h_output_mem_V_run_support, h_relOut⟩ + have h_output_mem_V_run_support' : + some (stmtOut, oStmtOut) ∈ + (do + let s ← init + Prod.fst <$> + (simulateQ impl + (Verifier.run (stmtIn, oStmtIn) tr + (batchingOracleVerifier κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) + (aOStmtIn := aOStmtIn)).toVerifier)).run s).support := by + exact (OptionT.mem_support_iff + (mx := OptionT.mk (do + let s ← init + Prod.fst <$> + (simulateQ impl + (Verifier.run (stmtIn, oStmtIn) tr + (batchingOracleVerifier κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) + (aOStmtIn := aOStmtIn)).toVerifier)).run s)) + (x := (stmtOut, oStmtOut))).1 h_output_mem_V_run_support + simp only [support_bind, Set.mem_iUnion, exists_prop] at h_output_mem_V_run_support' + rcases h_output_mem_V_run_support' with ⟨s, hs_init, h_output_mem_V_run_support⟩ + conv at h_output_mem_V_run_support => + simp only [Verifier.run, OracleVerifier.toVerifier] + -- Now unfold the foldOracleVerifier's `verify()` method + simp only [batchingOracleVerifier] + -- dsimp only [StateT.run] + -- simp only [simulateQ_bind, simulateQ_query, simulateQ_pure] + -- oracle query unfolding + simp only [support_bind, Set.mem_iUnion] + dsimp only [StateT.run] + -- enter [1, i_1, 2, 1, x] + simp only [simulateQ_bind] + unfold OracleInterface.answer + --------------------------------------- + -- Now simplify the `guard` and `ite` of StateT.map generated from it + simp only [MessageIdx, Fin.isValue, Matrix.cons_val_zero, simulateQ_pure, Message, guard_eq, + pure_bind, Function.comp_apply, simulateQ_map, simulateQ_ite, + OptionT.simulateQ_failure', bind_map_left] + simp only [MessageIdx, Message, Fin.isValue, Matrix.cons_val_zero, Matrix.cons_val_one, + bind_pure_comp, simulateQ_map, simulateQ_ite, simulateQ_pure, OptionT.simulateQ_failure', + bind_map_left, Function.comp_apply] + simp only [support_ite] + simp only [Fin.isValue, Set.mem_ite_empty_right, Set.mem_singleton_iff, Prod.mk.injEq, + exists_and_left, exists_eq', exists_eq_right, exists_and_right] + erw [simulateQ_bind] + enter [1, x, 1, 2, 1, 2]; + erw [simulateQ_bind] + erw [OptionT.simulateQ_simOracle2_liftM_query_T2] + simp only [Fin.isValue, FullTranscript.mk1_eq_snoc, pure_bind, OptionT.simulateQ_map] + conv at h_output_mem_V_run_support => + simp only [Fin.isValue, cons_val_zero, id_eq, Function.comp_apply, support_map, Set.mem_image, + Prod.exists, exists_and_right, exists_eq_right] + simp only [show OptionT.pure (m := (OracleComp ([]ₒ))) = pure by rfl] + at h_output_mem_V_run_support + erw [support_bind] at h_output_mem_V_run_support + set V_check := (batchingStepLogic κ L K β ℓ ℓ' h_l aOStmtIn).verifierCheck stmtIn + (FullTranscript.mk2 (msg0 := + (OracleInterface.answer (Message := (TensorAlgebra K L)) + (FullTranscript.messages tr ⟨(0 : Fin 2), rfl⟩) ()) + ) (msg1 := FullTranscript.challenges tr ⟨(1 : Fin 2), rfl⟩)) with h_V_check_def + erw [←h_V_check_def] at h_output_mem_V_run_support + by_cases h_V_check : V_check + · simp only [Fin.isValue, h_V_check, ↓reduceIte, OptionT.run_pure, simulateQ_pure, + Set.mem_iUnion, exists_prop, Prod.exists] at h_output_mem_V_run_support + erw [simulateQ_bind] at h_output_mem_V_run_support + simp only [simulateQ_pure, Fin.isValue, Function.comp_apply, + pure_bind] at h_output_mem_V_run_support + erw [support_pure] at h_output_mem_V_run_support + simp only [Fin.isValue, Set.mem_singleton_iff, Prod.mk.injEq, ↓existsAndEq, and_true, + exists_eq_left] at h_output_mem_V_run_support + erw [support_pure] at h_output_mem_V_run_support + simp only [Fin.isValue, Set.mem_singleton_iff, Option.some.injEq, + Prod.mk.injEq] at h_output_mem_V_run_support + -- simp only [support_map, Set.mem_image, exists_prop] at h_output_mem_V_run_support + rcases h_output_mem_V_run_support with ⟨init_value, ⟨h_stmtOut_eq, h_oStmtOut_eq⟩, + h_initValue_eq⟩ + simp only [Fin.reduceLast, Fin.isValue] + dsimp only [sumcheckRoundRelation, sumcheckRoundRelationProp, Set.mem_setOf_eq] at h_relOut + unfold batchingKStateProp + simp only [Fin.isValue, batchingWitMid] + set s_hat := tr.messages ⟨(0 : Fin 2), rfl⟩ with _h_s_hat + set batching_challenges := tr.challenges ⟨(1 : Fin 2), rfl⟩ with _h_chal + set ctx : RingSwitchingBaseContext κ L K ℓ := + { t_eval_point := stmtIn.t_eval_point, original_claim := stmtIn.original_claim, + s_hat := s_hat, r_batching := batching_challenges } with h_ctx_def + set stmtOut_computed : Statement (RingSwitchingBaseContext κ L K ℓ) 0 := + batchingVerifierStmtOut (κ := κ) (L := L) (K := K) (β := β) (ℓ := ℓ) (ℓ' := ℓ') + stmtIn s_hat batching_challenges with _h_stmtOut + have h_stmtOut_eq_computed : stmtOut = stmtOut_computed := by + rw [h_stmtOut_eq]; rfl + rw [h_stmtOut_eq_computed] at h_relOut + constructor + · -- masterKStateProp with oStmtOut/witOut → sumcheckRoundRelationProp with + -- oStmtLast/extractOut; extractOut = id; need embed compatibility + have h_oStmtOut_eq_oStmtIn : oStmtOut = oStmtIn := by + rw [h_oStmtOut_eq] + funext j + simp [OracleVerifier.mkVerifierOStmtOut, batchingStepLogic] + rw [h_oStmtOut_eq_oStmtIn] at h_relOut + have h_stmt_goal_eq : + (({ + sumcheck_target := compute_s0 κ L K β s_hat batching_challenges, + challenges := Fin.elim0, ctx := ctx }) : + Statement (L := L) (ℓ := ℓ') (RingSwitchingBaseContext κ L K ℓ) 0) + = stmtOut_computed := by + subst ctx + simp only [batchingVerifierStmtOut, stmtOut_computed] + rw [h_stmt_goal_eq] + unfold sumcheckRoundRelationProp masterKStateProp + have h_cons : sumcheckConsistencyProp (𝓑 := 𝓑) stmtOut_computed.sumcheck_target witOut.H + := h_relOut.1 + have h_wit_struct : + witOut.H = projectToMidSumcheckPoly (L := L) (ℓ := ℓ') (t := witOut.t') + (m := (RingSwitching_SumcheckMultParam κ L K β ℓ ℓ' h_l).multpoly + stmtOut_computed.ctx) + (i := 0) (challenges := stmtOut_computed.challenges) := h_relOut.2.1 + have h_h_goal_eq : + projectToMidSumcheckPoly (L := L) (ℓ := ℓ') (t := witOut.t') + (m := (RingSwitching_SumcheckMultParam κ L K β ℓ ℓ' h_l).multpoly ctx) + (i := 0) (challenges := Fin.elim0) = + projectToMidSumcheckPoly (L := L) (ℓ := ℓ') (t := witOut.t') + (m := (RingSwitching_SumcheckMultParam κ L K β ℓ ℓ' h_l).multpoly + stmtOut_computed.ctx) + (i := 0) (challenges := stmtOut_computed.challenges) := by + subst ctx + simp only [Fin.coe_ofNat_eq_mod, stmtOut_computed] + have h_wit_H_eq_goal : + witOut.H = + projectToMidSumcheckPoly (L := L) (ℓ := ℓ') (t := witOut.t') + (m := (RingSwitching_SumcheckMultParam κ L K β ℓ ℓ' h_l).multpoly ctx) + (i := 0) (challenges := Fin.elim0) := by + calc + witOut.H + = projectToMidSumcheckPoly (L := L) (ℓ := ℓ') (t := witOut.t') + (m := (RingSwitching_SumcheckMultParam κ L K β ℓ ℓ' h_l).multpoly + stmtOut_computed.ctx) + (i := 0) (challenges := stmtOut_computed.challenges) := h_wit_struct + _ = projectToMidSumcheckPoly (L := L) (ℓ := ℓ') (t := witOut.t') + (m := (RingSwitching_SumcheckMultParam κ L K β ℓ ℓ' h_l).multpoly ctx) + (i := 0) (challenges := Fin.elim0) := by + simpa using h_h_goal_eq.symm + constructor + · simpa only [Fin.coe_ofNat_eq_mod, zero_mod, Nat.sub_zero, batchingRbrExtractor, + Fin.zero_eta, Fin.isValue, Fin.succ_zero_eq_one, Fin.mk_one, Fin.succ_one_eq_two, + h_wit_H_eq_goal] using h_cons + · constructor + · unfold witnessStructuralInvariant + simp only [Fin.coe_ofNat_eq_mod, zero_mod, Nat.sub_zero, batchingRbrExtractor, + Fin.zero_eta, Fin.isValue, Fin.succ_zero_eq_one, Fin.mk_one, Fin.succ_one_eq_two] + exact h_h_goal_eq + · simpa only [batchingRbrExtractor, Fin.zero_eta, Fin.isValue, Fin.succ_zero_eq_one, + Fin.mk_one, Fin.succ_one_eq_two] using h_relOut.2.2 + · constructor + · exact h_V_check -- verifierCheck is performCheckOriginalEvaluation ... = true + · rw [h_oStmtOut_eq] at h_relOut + simpa only [batchingRbrExtractor, Fin.zero_eta, Fin.isValue, Fin.succ_zero_eq_one, + Fin.mk_one, Fin.succ_one_eq_two] using h_relOut.2.2 + · simp only [Fin.isValue, h_V_check, ↓reduceIte, OptionT.run_failure, simulateQ_pure, + Set.mem_iUnion, exists_prop, Prod.exists] at h_output_mem_V_run_support + erw [simulateQ_bind] at h_output_mem_V_run_support + simp only [simulateQ_pure, Fin.isValue, Function.comp_apply, + pure_bind] at h_output_mem_V_run_support + erw [support_pure] at h_output_mem_V_run_support + simp only [Set.mem_singleton_iff, Prod.mk.injEq, ↓existsAndEq, and_true, exists_eq_left, + simulateQ_pure] at h_output_mem_V_run_support + erw [support_pure] at h_output_mem_V_run_support + simp only [Set.mem_singleton_iff, Prod.mk.injEq, reduceCtorEq, false_and, + exists_false] at h_output_mem_V_run_support -- False /-! ## Security Properties -/ @@ -509,8 +702,8 @@ always succeeds (with probability 1) and produces valid outputs. theorem batchingReduction_perfectCompleteness (hInit : NeverFail init) : OracleReduction.perfectCompleteness (oracleReduction := batchingOracleReduction κ L K β ℓ ℓ' h_l (𝓑:=𝓑) (aOStmtIn:=aOStmtIn)) - (relIn := batchingInputRelation κ L K β ℓ ℓ' h_l aOStmtIn) - (relOut := sumcheckRoundRelation κ L K β ℓ ℓ' h_l (𝓑:=𝓑) aOStmtIn 0) + (relIn := strictBatchingInputRelation κ L K β ℓ ℓ' h_l aOStmtIn) + (relOut := strictSumcheckRoundRelation κ L K β ℓ ℓ' h_l (𝓑:=𝓑) aOStmtIn 0) (init := init) (impl := impl) := by -- Step 1: Unroll the 2-message reduction to convert from probability to logic -- **NOTE**: this requires `ProtocolSpec.challengeOracleInterface` to avoid conflict @@ -523,12 +716,12 @@ theorem batchingReduction_perfectCompleteness (hInit : NeverFail init) : -- Step 2: Convert probability 1 to universal quantification over support rw [probEvent_eq_one_iff] -- Step 3: Unfold protocol definitions - dsimp only [batchingOracleReduction, batchingOracleProver, batchingOracleVerifier, OracleVerifier.toVerifier, - FullTranscript.mk2] + dsimp only [batchingOracleReduction, batchingOracleProver, batchingOracleVerifier, + OracleVerifier.toVerifier, FullTranscript.mk2] let step := (batchingStepLogic (κ := κ) (L := L) (K := K) (β := β) (𝓑 := 𝓑) (ℓ := ℓ) (ℓ' := ℓ') (h_l := h_l) (aOStmtIn := aOStmtIn)) - let strongly_complete : step.IsStronglyComplete := batchingStep_is_logic_complete (κ := κ) (L := L) (K := K) (β := β) (ℓ := ℓ) (ℓ' := ℓ') (h_l := h_l) (aOStmtIn := aOStmtIn) - + let strongly_complete : step.IsStronglyComplete := batchingStep_is_logic_complete (κ := κ) + (L := L) (K := K) (β := β) (ℓ := ℓ) (ℓ' := ℓ') (h_l := h_l) (aOStmtIn := aOStmtIn) -- Step 4: Split into safety and correctness goals refine ⟨?_, ?_⟩ -- GOAL 1: SAFETY - Prove the verifier never crashes ([⊥|...] = 0) @@ -546,7 +739,6 @@ theorem batchingReduction_perfectCompleteness (hInit : NeverFail init) : simp only [ChallengeIdx, Challenge, Fin.isValue, Matrix.cons_val_one, Matrix.cons_val_zero, liftComp_eq_liftM, OptionT.probFailure_lift, HasEvalPMF.probFailure_eq_zero] rw [true_and] - intro r_i' h_r_i'_mem_query_1_support conv => enter [1]; @@ -556,7 +748,6 @@ theorem batchingReduction_perfectCompleteness (hInit : NeverFail init) : Fin.succ_one_eq_two, Message, Fin.succ_zero_eq_one, Fin.castSucc_one, liftComp_eq_liftM, OptionT.probFailure_lift, HasEvalPMF.probFailure_eq_zero] rw [true_and] - intro h_receive_challenge_fn h_receive_challenge_fn_mem_support conv => enter [1]; @@ -614,7 +805,8 @@ theorem batchingReduction_perfectCompleteness (hInit : NeverFail init) : set V_check := step.verifierCheck stmtIn (FullTranscript.mk2 (msg0 := _) - (msg1 := (FullTranscript.mk2 (batchingProverComputeMsg κ L K ℓ ℓ' h_l stmtIn witIn) r_i').challenges ⟨1, rfl⟩)) + (msg1 := (FullTranscript.mk2 (batchingProverComputeMsg κ L K ℓ ℓ' h_l stmtIn witIn) + r_i').challenges ⟨1, rfl⟩)) with h_V_check_def obtain ⟨h_V_check, h_rel, h_agree⟩ := strongly_complete (stmtIn := stmtIn) (witIn := witIn) (h_relIn := h_relIn) (challenges := @@ -656,9 +848,7 @@ theorem batchingReduction_perfectCompleteness (hInit : NeverFail init) : Fin.reduceLast, MessageIdx, Message, exists_eq_left] at hx_mem_support -- Step 2b: Extract the challenge r1 and the trace equations obtain ⟨r1, ⟨_h_r1_mem_challenge_support, h_trace_support⟩⟩ := hx_mem_support - rcases h_trace_support with ⟨prvOut_eq, h_verOut_mem_support⟩ - -- Step 2c: Simplify the verifier computation conv at h_verOut_mem_support => erw [simulateQ_bind] @@ -677,7 +867,8 @@ theorem batchingReduction_perfectCompleteness (hInit : NeverFail init) : set V_check := step.verifierCheck stmtIn (FullTranscript.mk2 (msg0 := _) - (msg1 := (FullTranscript.mk2 (batchingProverComputeMsg κ L K ℓ ℓ' h_l stmtIn witIn) r1).challenges ⟨1, rfl⟩)) + (msg1 := (FullTranscript.mk2 (batchingProverComputeMsg κ L K ℓ ℓ' h_l stmtIn witIn) + r1).challenges ⟨1, rfl⟩)) with h_V_check_def obtain ⟨h_V_check, h_rel, h_agree⟩ := strongly_complete (stmtIn := stmtIn) (witIn := witIn) (h_relIn := h_relIn) (challenges := @@ -691,7 +882,6 @@ theorem batchingReduction_perfectCompleteness (hInit : NeverFail init) : exact hj_ne hj | 1 => exact r1 ) - have h_V_check_is_true : V_check := h_V_check simp only [h_V_check_is_true, ↓reduceIte, Fin.isValue, pure_bind] at h_verOut_mem_support erw [simulateQ_pure, liftM_pure] at h_verOut_mem_support @@ -700,9 +890,7 @@ theorem batchingReduction_perfectCompleteness (hInit : NeverFail init) : rcases h_verOut_mem_support with ⟨verStmtOut_eq, verOStmtOut_eq⟩ dsimp only [batchingStepLogic, batchingProverComputeMsg, step] at prvOut_eq rw [Prod.mk.injEq, Prod.mk.injEq] at prvOut_eq - obtain ⟨⟨prvStmtOut_eq, prvOStmtOut_eq⟩, prvWitOut_eq⟩ := prvOut_eq - constructor · rw [prvWitOut_eq, verStmtOut_eq, verOStmtOut_eq]; exact h_rel @@ -711,6 +899,409 @@ theorem batchingReduction_perfectCompleteness (hInit : NeverFail init) : · rw [verOStmtOut_eq, prvOStmtOut_eq]; exact h_agree.2 +#check ProtocolSpec.challengeOracleInterface + +/-- Repacking the unpacked polynomial is identity for multilinear `t'`. -/ +lemma batching_pack_unpack_id (t' : MultilinearPoly L ℓ') : + packMLE κ L K ℓ ℓ' h_l β (unpackMLE κ L K ℓ ℓ' h_l β t') = t' := by + apply Subtype.ext + simp [packMLE, unpackMLE] + simpa [MvPolynomial.toEvalsZeroOne] using + (MvPolynomial.is_multilinear_iff_eq_evals_zeroOne (p := t'.val)).mp t'.property + +/-- `compute_s0` is evaluation of the row-MLE at the batching challenge. -/ +lemma batching_compute_s0_eq_eval_MLE + (s_hat : TensorAlgebra K L) (y : Fin κ → L) : + compute_s0 κ L K β s_hat y = + MvPolynomial.eval y + (MvPolynomial.MLE (fun u : Fin κ → Fin 2 => + decompose_tensor_algebra_rows (L := L) (K := K) (β := β) s_hat u)) := by + classical + rw [compute_s0, MvPolynomial.MLE] + simp_rw [Binius.BinaryBasefold.eqTilde] + simp [MvPolynomial.eval_mul, MvPolynomial.eval_C] + apply Finset.sum_congr rfl + intro u hu + congr 1 + apply Finset.prod_congr rfl + intro x hx + by_cases hux : u x = 1 + · simp [hux] + · have hux0 : u x = 0 := by + have hix : ((u x : Fin 2) : ℕ) = 0 ∨ ((u x : Fin 2) : ℕ) = 1 := by omega + rcases hix with h0 | h1 + · exact Fin.ext h0 + · exfalso + exact hux (Fin.ext h1) + simp only [hux0, Fin.isValue, zero_ne_one, ↓reduceIte, sub_zero, one_mul, map_zero, add_zero, + Fin.coe_ofNat_eq_mod, zero_mod, cast_zero, zero_mul] + +/-- Mismatch polynomial from row-decomposition difference `msg0 - s_bar`. -/ +def batchingMismatchPoly (msg0 s_bar : TensorAlgebra K L) : MvPolynomial (Fin κ) L := + MvPolynomial.MLE (fun u : Fin κ → Fin 2 => + decompose_tensor_algebra_rows (L := L) (K := K) (β := β) msg0 u - + decompose_tensor_algebra_rows (L := L) (K := K) (β := β) s_bar u) + +/-- The mismatch polynomial evaluates to the `compute_s0` difference. -/ +lemma batching_compute_s0_sub_eq_eval_mismatch + (msg0 s_bar : TensorAlgebra K L) (y : Fin κ → L) : + compute_s0 κ L K β msg0 y - compute_s0 κ L K β s_bar y = + MvPolynomial.eval y + (batchingMismatchPoly (κ := κ) (L := L) (K := K) (β := β) msg0 s_bar) := by + rw [batching_compute_s0_eq_eval_MLE (κ := κ) (L := L) (K := K) (β := β) + (s_hat := msg0) (y := y)] + rw [batching_compute_s0_eq_eval_MLE (κ := κ) (L := L) (K := K) (β := β) + (s_hat := s_bar) (y := y)] + unfold batchingMismatchPoly + simp [MvPolynomial.MLE, MvPolynomial.eval_sum, MvPolynomial.eval_mul, MvPolynomial.eval_C, + sub_eq_add_neg] + rw [← Finset.sum_neg_distrib] + rw [← Finset.sum_add_distrib] + apply Finset.sum_congr rfl + intro x hx + ring + +/-- Degree bound for mismatch polynomial: multilinear in `κ` vars, so total degree ≤ `κ`. -/ +lemma batchingMismatchPoly_totalDegree_le + (msg0 s_bar : TensorAlgebra K L) : + (batchingMismatchPoly (κ := κ) (L := L) (K := K) (β := β) msg0 s_bar).totalDegree ≤ κ := by + let P := batchingMismatchPoly (κ := κ) (L := L) (K := K) (β := β) msg0 s_bar + have h_mem : P ∈ MvPolynomial.restrictDegree (Fin κ) L 1 := by + dsimp [P, batchingMismatchPoly] + exact (MvPolynomial.MLE_mem_restrictDegree (σ := Fin κ) (R := L) + (evals := fun u : Fin κ → Fin 2 => + decompose_tensor_algebra_rows (L := L) (K := K) (β := β) msg0 u - + decompose_tensor_algebra_rows (L := L) (K := K) (β := β) s_bar u)) + have h_degOf : ∀ i : Fin κ, MvPolynomial.degreeOf i P ≤ 1 := by + intro i + exact (MvPolynomial.mem_restrictDegree_iff_degreeOf_le (p := P) (n := 1)).1 h_mem i + rw [MvPolynomial.totalDegree_eq] + apply Finset.sup_le + intro m hm + rw [Finsupp.card_toMultiset] + have hm_le_one : ∀ i ∈ m.support, m i ≤ 1 := by + intro i hi + exact le_trans (MvPolynomial.monomial_le_degreeOf i hm) (h_degOf i) + calc + m.sum (fun _ e => e) ≤ m.sum (fun _ _ => (1 : ℕ)) := by + exact Finsupp.sum_le_sum hm_le_one + _ = m.support.card := by + rw [Finsupp.sum] + simp + _ ≤ κ := by + simpa using (Finset.card_le_univ (s := m.support)) + +/-- If embedded evaluation mismatches `msg0`, the mismatch polynomial is nonzero. -/ +lemma batchingMismatchPoly_nonzero_of_embed_ne + (stmt : BatchingStmtIn L ℓ) + (msg0 : TensorAlgebra K L) + (t' : MultilinearPoly L ℓ') + (h_embed_ne : embedded_MLP_eval κ L K ℓ ℓ' h_l t' stmt.t_eval_point ≠ msg0) : + batchingMismatchPoly (κ := κ) (L := L) (K := K) (β := β) msg0 + (embedded_MLP_eval κ L K ℓ ℓ' h_l t' stmt.t_eval_point) ≠ 0 := by + let s_bar := embedded_MLP_eval κ L K ℓ ℓ' h_l t' stmt.t_eval_point + have h_rows_ne : + (decompose_tensor_algebra_rows (L := L) (K := K) (β := β) msg0) ≠ + (decompose_tensor_algebra_rows (L := L) (K := K) (β := β) s_bar) := by + intro h_eq + have h_repr_eq : (β.baseChange L).repr msg0 = (β.baseChange L).repr s_bar := by + ext u + exact congrFun h_eq u + have hs : msg0 = s_bar := (β.baseChange L).repr.injective h_repr_eq + exact h_embed_ne (by simpa [s_bar] using hs.symm) + have h_diff_ne : + (fun u : Fin κ → Fin 2 => + decompose_tensor_algebra_rows (L := L) (K := K) (β := β) msg0 u - + decompose_tensor_algebra_rows (L := L) (K := K) (β := β) s_bar u) ≠ 0 := by + intro h_zero + apply h_rows_ne + funext u + exact sub_eq_zero.mp (congrFun h_zero u) + intro h_poly_zero + have h_poly_zero' : + batchingMismatchPoly (κ := κ) (L := L) (K := K) (β := β) msg0 s_bar = 0 := by + simpa [s_bar] using h_poly_zero + apply h_diff_ne + funext u + have hu_eval_zero : + MvPolynomial.eval (fun i => ((u i : Fin 2) : L)) + (batchingMismatchPoly (κ := κ) (L := L) (K := K) (β := β) msg0 s_bar) = 0 := by + rw [h_poly_zero'] + simp + have hu_eval_mle : + MvPolynomial.eval (fun i => ((u i : Fin 2) : L)) + (batchingMismatchPoly (κ := κ) (L := L) (K := K) (β := β) msg0 s_bar) = + decompose_tensor_algebra_rows (L := L) (K := K) (β := β) msg0 u - + decompose_tensor_algebra_rows (L := L) (K := K) (β := β) s_bar u := by + simp [batchingMismatchPoly, MvPolynomial.MLE_eval_zeroOne] + rw [hu_eval_mle] at hu_eval_zero + exact hu_eval_zero + +omit [NeZero κ] [Fintype L] [DecidableEq L] [CharP L 2] [SampleableType L] + [Fintype K] [DecidableEq K] in +/-- If `msg0 ≠ s_bar` in the tensor algebra, the mismatch polynomial is nonzero. + Generalization of `batchingMismatchPoly_nonzero_of_embed_ne`. -/ +lemma batchingMismatchPoly_nonzero_of_ne + (msg0 s_bar : TensorAlgebra K L) + (h_ne : msg0 ≠ s_bar) : + batchingMismatchPoly (κ := κ) (L := L) (K := K) (β := β) msg0 s_bar ≠ 0 := by + have h_rows_ne : + (decompose_tensor_algebra_rows (L := L) (K := K) (β := β) msg0) ≠ + (decompose_tensor_algebra_rows (L := L) (K := K) (β := β) s_bar) := by + intro h_eq + have h_repr_eq : (β.baseChange L).repr msg0 = (β.baseChange L).repr s_bar := by + ext u; exact congrFun h_eq u + exact h_ne ((β.baseChange L).repr.injective h_repr_eq) + have h_diff_ne : + (fun u : Fin κ → Fin 2 => + decompose_tensor_algebra_rows (L := L) (K := K) (β := β) msg0 u - + decompose_tensor_algebra_rows (L := L) (K := K) (β := β) s_bar u) ≠ 0 := by + intro h_zero + apply h_rows_ne + funext u + exact sub_eq_zero.mp (congrFun h_zero u) + intro h_poly_zero + apply h_diff_ne + funext u + have hu_eval_zero : + MvPolynomial.eval (fun i => ((u i : Fin 2) : L)) + (batchingMismatchPoly (κ := κ) (L := L) (K := K) (β := β) msg0 s_bar) = 0 := by + rw [h_poly_zero]; simp + have hu_eval_mle : + MvPolynomial.eval (fun i => ((u i : Fin 2) : L)) + (batchingMismatchPoly (κ := κ) (L := L) (K := K) (β := β) msg0 s_bar) = + decompose_tensor_algebra_rows (L := L) (K := K) (β := β) msg0 u - + decompose_tensor_algebra_rows (L := L) (K := K) (β := β) s_bar u := by + simp [batchingMismatchPoly, MvPolynomial.MLE_eval_zeroOne] + rw [hu_eval_mle] at hu_eval_zero + exact hu_eval_zero + +/-- From `KState 2` truth, derive equality of the two `compute_s0` forms. -/ +lemma batching_compute_eq_from_hafter + (stmtOStmtIn : (BatchingStmtIn L ℓ) × (∀ j, aOStmtIn.OStmtIn j)) + (msg0 : (pSpecBatching (κ := κ) (L := L) (K := K)).Message ⟨0, rfl⟩) + (y : Fin κ → L) + (witMid : batchingWitMid L K ℓ ℓ' 2) + (h_after : batchingKStateProp (κ := κ) (L := L) (K := K) (β := β) (ℓ := ℓ) (ℓ' := ℓ') + (h_l := h_l) (𝓑 := 𝓑) (aOStmtIn := aOStmtIn) (tr := FullTranscript.mk2 msg0 y) stmtOStmtIn.1 + witMid stmtOStmtIn.2) : + compute_s0 κ L K β msg0 y = + compute_s0 κ L K β + (embedded_MLP_eval κ L K ℓ ℓ' h_l witMid.t' stmtOStmtIn.1.t_eval_point) y := by + dsimp [batchingKStateProp] at h_after + have h_sumcheck_msg0 := h_after.1 + dsimp [sumcheckRoundRelationProp] at h_sumcheck_msg0 + dsimp [Binius.RingSwitching.masterKStateProp] at h_sumcheck_msg0 + have h_msg : + sumcheckConsistencyProp (𝓑 := 𝓑) + (compute_s0 κ L K β msg0 y) + (projectToMidSumcheckPoly (L := L) (ℓ := ℓ') (t := witMid.t') + (m := (RingSwitching_SumcheckMultParam κ L K β ℓ ℓ' h_l).multpoly + { t_eval_point := stmtOStmtIn.1.t_eval_point, + original_claim := stmtOStmtIn.1.original_claim, + s_hat := msg0, + r_batching := y }) + (i := 0) (challenges := Fin.elim0)) := h_sumcheck_msg0.1 + have h_bar : + sumcheckConsistencyProp (𝓑 := 𝓑) + (compute_s0 κ L K β + (embedded_MLP_eval κ L K ℓ ℓ' h_l witMid.t' stmtOStmtIn.1.t_eval_point) y) + (projectToMidSumcheckPoly (L := L) (ℓ := ℓ') (t := witMid.t') + (m := (RingSwitching_SumcheckMultParam κ L K β ℓ ℓ' h_l).multpoly + { t_eval_point := stmtOStmtIn.1.t_eval_point, + original_claim := stmtOStmtIn.1.original_claim, + s_hat := msg0, + r_batching := y }) + (i := 0) (challenges := Fin.elim0)) := by + simpa using + (batching_target_consistency (κ := κ) (L := L) (K := K) (β := β) (ℓ := ℓ) + (ℓ' := ℓ') (h_l := h_l) (𝓑 := 𝓑) (t' := witMid.t') + (msg0 := embedded_MLP_eval κ L K ℓ ℓ' h_l witMid.t' stmtOStmtIn.1.t_eval_point) + (r_batching := y) + (ctx := { t_eval_point := stmtOStmtIn.1.t_eval_point, + original_claim := stmtOStmtIn.1.original_claim, + s_hat := msg0, + r_batching := y })) + unfold sumcheckConsistencyProp at h_msg h_bar + exact h_msg.trans h_bar.symm + +/-- The "bad batching event": the prover's ŝ (`msg0`) disagrees with the honest ŝ (`s_bar`), + but their `compute_s0` values agree at the batching challenges `y`. + Corresponds to $S(r''_0, \ldots, r''_{\kappa-1}) = 0$ in Theorem 3.5 of the spec, where + $S(X) := \sum_{u \in \mathcal{B}_\kappa} (\hat{s}_u - \bar{s}_u) \cdot \widetilde{eq}(u, X)$. -/ +def badBatchingEventProp (y : Fin κ → L) (msg0 s_bar : TensorAlgebra K L) : Prop := + msg0 ≠ s_bar ∧ compute_s0 κ L K β msg0 y = compute_s0 κ L K β s_bar y + +/-- **Schwartz-Zippel bound for the bad batching event.** + When `msg0 = s_bar`, the event never holds (first conjunct is `False`). + When `msg0 ≠ s_bar`, the mismatch polynomial $S$ is nonzero with `totalDegree ≤ κ`, + so Schwartz-Zippel gives `Pr[S(y) = 0] ≤ κ / |L|`. -/ +lemma probability_bound_badBatchingEventProp + (msg0 s_bar : TensorAlgebra K L) : + Pr_{ let y ← $ᵖ (Fin κ → L) }[ + badBatchingEventProp (κ := κ) (L := L) (K := K) (β := β) y msg0 s_bar ] ≤ + batchingRBRKnowledgeError (κ := κ) (L := L) := by + classical + unfold badBatchingEventProp + by_cases h_ne : msg0 ≠ s_bar + · -- msg0 ≠ s_bar: reduce to S.eval(y) = 0, apply Schwartz-Zippel + simp only [ne_eq, h_ne, not_false_eq_true, true_and] + -- Rewrite compute_s0 equality as mismatch polynomial root + have h_mono := prob_mono (D := $ᵖ (Fin κ → L)) + (f := fun y => compute_s0 κ L K β msg0 y = compute_s0 κ L K β s_bar y) + (g := fun y => MvPolynomial.eval y + (batchingMismatchPoly (κ := κ) (L := L) (K := K) (β := β) msg0 s_bar) = 0) + (h_imp := by + intro y h_eq + rw [← batching_compute_s0_sub_eq_eval_mismatch (κ := κ) (L := L) (K := K) (β := β) + (msg0 := msg0) (s_bar := s_bar) (y := y)] + exact sub_eq_zero.mpr h_eq) + apply le_trans h_mono + have h_nonzero : batchingMismatchPoly (κ := κ) (L := L) (K := K) (β := β) msg0 s_bar ≠ 0 := + batchingMismatchPoly_nonzero_of_ne (κ := κ) (L := L) (K := K) (β := β) msg0 s_bar h_ne + have h_sz := prob_schwartz_zippel_mv_polynomial + (P := batchingMismatchPoly (κ := κ) (L := L) (K := K) (β := β) msg0 s_bar) h_nonzero + (batchingMismatchPoly_totalDegree_le (κ := κ) (L := L) (K := K) (β := β) + (msg0 := msg0) (s_bar := s_bar)) + conv_rhs => + dsimp only [batchingRBRKnowledgeError] + rw [ENNReal.coe_div (hr := by simp only [ne_eq, Nat.cast_eq_zero, Fintype.card_ne_zero, + not_false_eq_true])] + simp only [ENNReal.coe_ofNat, ENNReal.coe_natCast] + exact h_sz + · -- msg0 = s_bar: event is False ∧ _, which never holds + simp only [h_ne, false_and] + simp only [PMF.monad_pure_eq_pure, PMF.monad_bind_eq_bind, PMF.bind_const, PMF.pure_apply, + eq_iff_iff, iff_false, not_true_eq_false, ↓reduceIte, _root_.zero_le] + +/-- Extraction failure implies a witness-dependent bad batching event. + The extracted `witMid` also carries oracle compatibility at the same `oStmt`. -/ +lemma batching_rbrExtractionFailureEvent_imply_badBatchingEvent + (stmtOStmtIn : (BatchingStmtIn L ℓ) × (∀ j, aOStmtIn.OStmtIn j)) + (msg0 : (pSpecBatching (κ := κ) (L := L) (K := K)).Message ⟨0, rfl⟩) + (y : Fin κ → L) + (doomEscape : rbrExtractionFailureEvent + (kSF := batchingKnowledgeStateFunction (κ := κ) (L := L) (K := K) (β := β) (ℓ := ℓ) + (ℓ' := ℓ') (h_l := h_l) (𝓑 := 𝓑) (aOStmtIn := aOStmtIn) (init := init) (impl := impl)) + (extractor := batchingRbrExtractor (κ := κ) (L := L) (K := K) (β := β) (ℓ := ℓ) + (ℓ' := ℓ') (h_l := h_l) (aOStmtIn := aOStmtIn)) + ⟨1, rfl⟩ stmtOStmtIn (FullTranscript.mk1 msg0) y) : + ∃ witMid : batchingWitMid L K ℓ ℓ' 2, + aOStmtIn.initialCompatibility ⟨witMid.t', stmtOStmtIn.2⟩ ∧ + let s_bar := embedded_MLP_eval κ L K ℓ ℓ' h_l witMid.t' stmtOStmtIn.1.t_eval_point + badBatchingEventProp (κ := κ) (L := L) (K := K) (β := β) y msg0 s_bar := by + classical + unfold rbrExtractionFailureEvent at doomEscape + rcases doomEscape with ⟨witMid, h_kState_before_false, h_kState_after_true⟩ + have h_after : + batchingKStateProp (κ := κ) (L := L) (K := K) (β := β) (ℓ := ℓ) (ℓ' := ℓ') + (h_l := h_l) (𝓑 := 𝓑) (aOStmtIn := aOStmtIn) (tr := FullTranscript.mk2 msg0 y) + stmtOStmtIn.1 witMid stmtOStmtIn.2 := by + simpa [batchingKnowledgeStateFunction] using h_kState_after_true + have h_before_false := by + simpa [batchingKnowledgeStateFunction] using h_kState_before_false + have h_compute_eq := + batching_compute_eq_from_hafter (κ := κ) (L := L) (K := K) (β := β) (ℓ := ℓ) (ℓ' := ℓ') + (h_l := h_l) (𝓑 := 𝓑) (aOStmtIn := aOStmtIn) (stmtOStmtIn := stmtOStmtIn) (msg0 := msg0) + (y := y) (witMid := witMid) (h_after := h_after) + dsimp [batchingKStateProp] at h_after + have h_check_true : + performCheckOriginalEvaluation κ L K β ℓ ℓ' h_l stmtOStmtIn.1.original_claim + stmtOStmtIn.1.t_eval_point msg0 = true := h_after.2.1 + have h_compat_mid : aOStmtIn.initialCompatibility ⟨witMid.t', stmtOStmtIn.2⟩ := h_after.2.2 + have h_embed_ne : + embedded_MLP_eval κ L K ℓ ℓ' h_l witMid.t' stmtOStmtIn.1.t_eval_point ≠ msg0 := by + intro h_embed_eq + apply h_before_false + dsimp [batchingKStateProp] + refine ⟨?_, ?_, h_check_true, h_compat_mid⟩ + · simp [batchingRbrExtractor, batching_pack_unpack_id] + · simpa using h_embed_eq + have h_msg0_ne : + msg0 ≠ embedded_MLP_eval κ L K ℓ ℓ' h_l witMid.t' stmtOStmtIn.1.t_eval_point := by + intro h_eq + exact h_embed_ne h_eq.symm + have h_bad : + badBatchingEventProp (κ := κ) (L := L) (K := K) (β := β) y msg0 + (embedded_MLP_eval κ L K ℓ ℓ' h_l witMid.t' stmtOStmtIn.1.t_eval_point) := by + exact ⟨h_msg0_ne, h_compute_eq⟩ + refine ⟨witMid, h_compat_mid, ?_⟩ + exact h_bad + +/-- Per-transcript batching bound: for a fixed prover message `msg0`, the probability + (over batching challenges `y : Fin κ → L`) that extraction fails is bounded by + `batchingRBRKnowledgeError`. + **Proof strategy** (follows `foldStep_doom_escape_probability_bound`): + 1. **Implication**: Show that extraction failure implies the + `badBatchingEventProp` (Theorem 3.5, $S(r'') = 0$). + 2. **Monotonicity**: Conclude `Pr[doom] ≤ Pr[badBatchingEvent]` via `prob_mono`. + 3. **Schwartz–Zippel**: Bound `Pr[badBatchingEvent]` by `κ/|L|`. -/ +lemma batching_doom_escape_probability_bound + (stmtOStmtIn : (BatchingStmtIn L ℓ) × (∀ j, aOStmtIn.OStmtIn j)) + (msg0 : (pSpecBatching (κ := κ) (L := L) (K := K)).Message ⟨0, rfl⟩) : + Pr_{ let y ← $ᵖ (Fin κ → L) }[ + rbrExtractionFailureEvent + (kSF := batchingKnowledgeStateFunction (κ := κ) (L := L) (K := K) (β := β) (ℓ := ℓ) + (ℓ' := ℓ') (h_l := h_l) (𝓑 := 𝓑) (aOStmtIn := aOStmtIn) (init := init) (impl := impl)) + (extractor := batchingRbrExtractor (κ := κ) (L := L) (K := K) (β := β) (ℓ := ℓ) + (ℓ' := ℓ') (h_l := h_l) (aOStmtIn := aOStmtIn)) + ⟨1, rfl⟩ stmtOStmtIn (FullTranscript.mk1 msg0) y ] ≤ + batchingRBRKnowledgeError (κ := κ) (L := L) := by + classical + let compatPred : MultilinearPoly L ℓ' → Prop := fun t => + aOStmtIn.initialCompatibility ⟨t, stmtOStmtIn.2⟩ + by_cases hCompat : ∃ t : MultilinearPoly L ℓ', compatPred t + · rcases hCompat with ⟨t_fixed, h_t_fixed_compat⟩ + let s_bar_fixed := + embedded_MLP_eval κ L K ℓ ℓ' h_l t_fixed stmtOStmtIn.1.t_eval_point + have h_prob_mono := prob_mono (D := $ᵖ (Fin κ → L)) + (f := fun y => rbrExtractionFailureEvent + (kSF := batchingKnowledgeStateFunction (κ := κ) (L := L) (K := K) (β := β) (ℓ := ℓ) + (ℓ' := ℓ') (h_l := h_l) (𝓑 := 𝓑) (aOStmtIn := aOStmtIn) (init := init) + (impl := impl)) + (extractor := batchingRbrExtractor (κ := κ) (L := L) (K := K) (β := β) (ℓ := ℓ) + (ℓ' := ℓ') (h_l := h_l) (aOStmtIn := aOStmtIn)) + ⟨1, rfl⟩ stmtOStmtIn (FullTranscript.mk1 msg0) y) + (g := fun y => + badBatchingEventProp (κ := κ) (L := L) (K := K) (β := β) y msg0 s_bar_fixed) + (h_imp := by + -- Uniqueness proof of `witMid` and `s_bar_fixed` + intro y h_doomEscape + obtain ⟨witMid, h_mid_compat, h_bad_extracted⟩ := + batching_rbrExtractionFailureEvent_imply_badBatchingEvent + (κ := κ) (L := L) (K := K) (β := β) (ℓ := ℓ) (ℓ' := ℓ') (h_l := h_l) + (𝓑 := 𝓑) (aOStmtIn := aOStmtIn) (init := init) (impl := impl) + (stmtOStmtIn := stmtOStmtIn) (msg0 := msg0) (y := y) + (doomEscape := h_doomEscape) + have h_t_eq : witMid.t' = t_fixed := + aOStmtIn.initialCompatibility_unique stmtOStmtIn.2 witMid.t' t_fixed + h_mid_compat h_t_fixed_compat + simpa [s_bar_fixed, h_t_eq] using h_bad_extracted) + apply le_trans h_prob_mono + exact probability_bound_badBatchingEventProp (κ := κ) (L := L) (K := K) (β := β) + (msg0 := msg0) (s_bar := s_bar_fixed) + · have h_prob_mono_false := prob_mono (D := $ᵖ (Fin κ → L)) + (f := fun y => rbrExtractionFailureEvent + (kSF := batchingKnowledgeStateFunction (κ := κ) (L := L) (K := K) (β := β) (ℓ := ℓ) + (ℓ' := ℓ') (h_l := h_l) (𝓑 := 𝓑) (aOStmtIn := aOStmtIn) (init := init) + (impl := impl)) + (extractor := batchingRbrExtractor (κ := κ) (L := L) (K := K) (β := β) (ℓ := ℓ) + (ℓ' := ℓ') (h_l := h_l) (aOStmtIn := aOStmtIn)) + ⟨1, rfl⟩ stmtOStmtIn (FullTranscript.mk1 msg0) y) + (g := fun _ => False) + (h_imp := by + intro y h_doomEscape + obtain ⟨witMid, h_mid_compat, _h_bad_extracted⟩ := + batching_rbrExtractionFailureEvent_imply_badBatchingEvent + (κ := κ) (L := L) (K := K) (β := β) (ℓ := ℓ) (ℓ' := ℓ') (h_l := h_l) + (𝓑 := 𝓑) (aOStmtIn := aOStmtIn) (init := init) (impl := impl) + (stmtOStmtIn := stmtOStmtIn) (msg0 := msg0) (y := y) + (doomEscape := h_doomEscape) + exact (hCompat ⟨witMid.t', h_mid_compat⟩).elim) + refine le_trans h_prob_mono_false ?_ + simp only [PMF.monad_pure_eq_pure, PMF.monad_bind_eq_bind, PMF.bind_const, PMF.pure_apply, + eq_iff_iff, iff_false, not_true_eq_false, ↓reduceIte, _root_.zero_le] + /-- RBR knowledge soundness for the batching phase oracle verifier. -/ theorem batchingOracleVerifier_rbrKnowledgeSoundness : OracleVerifier.rbrKnowledgeSoundness @@ -718,18 +1309,94 @@ theorem batchingOracleVerifier_rbrKnowledgeSoundness : (init := init) (impl := impl) (relIn := batchingInputRelation κ L K β ℓ ℓ' h_l aOStmtIn) (relOut := sumcheckRoundRelation κ L K β ℓ ℓ' h_l (𝓑:=𝓑) aOStmtIn 0) - (rbrKnowledgeError := batchingRBRKnowledgeError (κ:=κ) (L:=L) (K:=K)) := by - -- Proof follows by constructing the extractor and knowledge state function. - use batchingWitMid L K ℓ ℓ' - use batchingRbrExtractor κ L K β ℓ ℓ' h_l (aOStmtIn:=aOStmtIn) - use batchingKnowledgeStateFunction κ L K β ℓ ℓ' h_l (aOStmtIn:=aOStmtIn) (init:=init) (impl:=impl) - intro stmtIn witIn prover iChal - -- `KState 1 = (t' = packMLE t) ∧ (ŝ := φ₁(t')(φ₀(r_κ), ..., φ₀(r_{ℓ-1})))` - -- `∧ (s ?= Σ_{v ∈ {0,1}^κ} eqTilde(v, r_{0..κ-1}) ⋅ ŝ_v.` - -- `KState 2 = (s ?= Σ_{v ∈ {0,1}^κ} eqTilde(v, r_{0..κ-1}) ⋅ ŝ_v) ∧` - -- `h = projectSumcheckPoly t' 0 r r' ∧ s_0 = Σ_{w ∈ {0,1}^{ℓ'}} h(w)` - -- ⊢ `Pr[KState(2, witMidSucc) ∧ ¬KState(1, extractMid(iChal, witMidSucc))] ≤ (κ/|L|)` - sorry + (rbrKnowledgeError := fun _ => batchingRBRKnowledgeError (κ:=κ) (L:=L)) := by + classical + -- -- Proof follows by constructing the extractor and knowledge state function. + -- -- `KState 1 = (t' = packMLE t) ∧ (ŝ := φ₁(t')(φ₀(r_κ), ..., φ₀(r_{ℓ-1})))` + -- -- `∧ (s ?= Σ_{v ∈ {0,1}^κ} eqTilde(v, r_{0..κ-1}) ⋅ ŝ_v.` + -- -- `KState 2 = (s ?= Σ_{v ∈ {0,1}^κ} eqTilde(v, r_{0..κ-1}) ⋅ ŝ_v) ∧` + -- -- `h = projectSumcheckPoly t' 0 r r' ∧ s_0 = Σ_{w ∈ {0,1}^{ℓ'}} h(w)` + -- -- ⊢ `Pr[KState(2, witMidSucc) ∧ ¬KState(1, extractMid(iChal, witMidSucc))] ≤ (κ/|L|)` + apply OracleReduction.unroll_rbrKnowledgeSoundness + (kSF := batchingKnowledgeStateFunction κ L K β ℓ ℓ' h_l (aOStmtIn:=aOStmtIn) + (init:=init) (impl:=impl)) + intro stmtOStmtIn witIn prover j initState + let P := rbrExtractionFailureEvent + (kSF := batchingKnowledgeStateFunction κ L K β (𝓑 := 𝓑) ℓ ℓ' h_l (aOStmtIn:=aOStmtIn) + (init:=init) (impl:=impl)) + (extractor := batchingRbrExtractor κ L K β ℓ ℓ' h_l (aOStmtIn:=aOStmtIn)) + (i := j) (stmtIn := stmtOStmtIn) + rw [OracleReduction.probEvent_soundness_goal_unroll_log' + (pSpec := pSpecBatching (κ:=κ) (L:=L) (K:=K)) + (P := P) (impl := impl) (prover := prover) (i := j) (stmt := stmtOStmtIn) + (wit := witIn) (s := initState)] + have h_j_eq_1 : j = ⟨1, rfl⟩ := by + match j with + | ⟨0, h0⟩ => nomatch h0 + | ⟨1, _⟩ => rfl + subst h_j_eq_1 + conv_lhs => simp only [Fin.isValue, Fin.castSucc_one]; + rw [OracleReduction.soundness_unroll_runToRound_1_P_to_V_pSpec_2 + (pSpec := pSpecBatching (κ:=κ) (L:=L) (K:=K)) (prover := prover) (hDir0 := by rfl)] + simp only [Fin.isValue, Challenge, Matrix.cons_val_one, Matrix.cons_val_zero, ChallengeIdx, + QueryImpl.addLift_def, QueryImpl.liftTarget_self, Message, Fin.succ_zero_eq_one, Nat.reduceAdd, + Fin.coe_ofNat_eq_mod, Nat.reduceMod, FullTranscript.mk1_eq_snoc, bind_pure_comp, + liftComp_eq_liftM, bind_map_left, simulateQ_bind, simulateQ_map, StateT.run'_eq, + StateT.run_bind, StateT.run_map, map_bind, Functor.map_map] + rw [probEvent_bind_eq_tsum] + apply OracleReduction.ENNReal.tsum_mul_le_of_le_of_sum_le_one + · -- Bound the conditional probability for each transcript + intro x + -- rw [OracleComp.probEvent_map] + simp only [Fin.isValue, probEvent_map] + let q : OracleQuery [(pSpecBatching (κ := κ) (L := L) (K := K)).Challenge]ₒ _ + := query ⟨⟨1, by rfl⟩, ()⟩ + erw [OracleReduction.probEvent_StateT_run_ignore_state + (comp := simulateQ (impl.addLift challengeQueryImpl) (liftM (query q.input))) + (s := x.2) + (P := fun a => P (FullTranscript.mk1 x.1.1) (q.cont a))] + rw [probEvent_eq_tsum_ite] + erw [simulateQ_query] + simp only [ChallengeIdx, Challenge, Fin.isValue, Nat.reduceAdd, Fin.castSucc_one, + Fin.coe_ofNat_eq_mod, Nat.reduceMod, monadLift_self, + QueryImpl.addLift_def, QueryImpl.liftTarget_self, StateT.run'_eq, StateT.run_map, + Functor.map_map, ge_iff_le] + have h_L_inhabited : Inhabited L := ⟨0⟩ + conv_lhs => + enter [1, x_1, 2, 1, 2] + rw [addLift_challengeQueryImpl_input_run_eq_liftM_run (impl := impl) (q := q) (s := x.2)] + erw [StateT.run_monadLift, monadLift_self, liftComp_id] + rw [bind_pure_comp] + conv => + enter [1, 1, x_1, 2] + rw [Functor.map_map] + rw [← probEvent_eq_eq_probOutput] + rw [probEvent_map] + rw [OracleQuery.cont_apply] + dsimp only [MonadLift.monadLift] + rw [OracleQuery.cont_apply] + dsimp only [q] + simp_rw [OracleQuery.input_query, OracleQuery.snd_query] + conv_lhs => change (∑' (x_1 : (Fin κ → L)), _) + simp only [Function.comp_id] + conv => + enter [1, 1, x_1, 2] + rw [probEvent_eq_eq_probOutput] + change Pr[=x_1 | $ᵗ (Fin κ → L)] + rw [OracleReduction.probOutput_uniformOfFintype_eq_Pr (L := _) (x := x_1)] + rw [OracleReduction.tsum_uniform_Pr_eq_Pr (L := (Fin κ → L)) + (P := fun x_1 => P (FullTranscript.mk1 x.1.1) (q.2 x_1))] + -- Now the goal is in do-notation form, which is exactly what Pr_ notation expands to + -- Make this explicit using change + -- Convert the sum domain from [pSpecFold.Challenge]ₒ.range to L using h_L_eq + conv_lhs => change (∑' (x_1 : (Fin κ → L)), _) + -- Apply the per-transcript bound + exact batching_doom_escape_probability_bound (κ := κ) (L := L) (K := K) + (β := β) (ℓ := ℓ) (ℓ' := ℓ') (h_l := h_l) (𝓑 := 𝓑) (aOStmtIn := aOStmtIn) + (stmtOStmtIn := stmtOStmtIn) (msg0 := x.1.1) + (impl := impl) (init := init) + · -- Prove: ∑' x, [=x|transcript computation] ≤ 1 + apply tsum_probOutput_le_one end BatchingPhase end Binius.RingSwitching diff --git a/ArkLib/ProofSystem/Binius/RingSwitching/General.lean b/ArkLib/ProofSystem/Binius/RingSwitching/General.lean index c41a1502d..4bccbac4b 100644 --- a/ArkLib/ProofSystem/Binius/RingSwitching/General.lean +++ b/ArkLib/ProofSystem/Binius/RingSwitching/General.lean @@ -94,22 +94,18 @@ def fullOracleProof : variable [∀ i, SampleableType (mlIOPCS.pSpec.Challenge i)] -/-- Input relation for the full ring-switching protocol -/ -abbrev fullInputRelation := BatchingPhase.batchingInputRelation κ L K β ℓ ℓ' - h_l mlIOPCS.toAbstractOStmtIn -abbrev fullOutputRelation := acceptRejectOracleRel - open scoped NNReal section SecurityProperties variable {σ : Type} (init : ProbComp σ) {impl : QueryImpl []ₒ (StateT σ ProbComp)} -omit [(i : mlIOPCS.pSpec.ChallengeIdx) → SampleableType (mlIOPCS.pSpec.Challenge i)] in +omit [∀ i, SampleableType (mlIOPCS.pSpec.Challenge i)] in lemma batchingCore_perfectCompleteness (hInit : NeverFail init) : (batchingCoreReduction κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) mlIOPCS).perfectCompleteness (pSpec := pSpecLargeFieldReduction κ L K ℓ') - (relIn := BatchingPhase.batchingInputRelation κ L K β ℓ ℓ' h_l mlIOPCS.toAbstractOStmtIn) - (relOut := mlIOPCS.toRelInput) + (relIn := BatchingPhase.strictBatchingInputRelation κ L K β ℓ ℓ' h_l + mlIOPCS.toAbstractOStmtIn) + (relOut := mlIOPCS.toStrictRelInput) (init:=init) (impl:=impl) := by apply OracleReduction.append_perfectCompleteness · exact BatchingPhase.batchingReduction_perfectCompleteness (hInit:=hInit) κ L K β ℓ ℓ' h_l @@ -117,22 +113,24 @@ lemma batchingCore_perfectCompleteness (hInit : NeverFail init) : · exact SumcheckPhase.coreInteraction_perfectCompleteness (hInit:=hInit) κ L K β ℓ ℓ' h_l mlIOPCS.toAbstractOStmtIn (impl:=impl) -omit [(i : mlIOPCS.pSpec.ChallengeIdx) → SampleableType (mlIOPCS.pSpec.Challenge i)] in +omit [∀ i, SampleableType (mlIOPCS.pSpec.Challenge i)] in theorem fullOracleReduction_perfectCompleteness (hInit : NeverFail init) : (fullOracleReduction κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) mlIOPCS).perfectCompleteness - (relIn := BatchingPhase.batchingInputRelation κ L K β ℓ ℓ' h_l mlIOPCS.toAbstractOStmtIn) + (relIn := BatchingPhase.strictBatchingInputRelation κ L K β ℓ ℓ' h_l + mlIOPCS.toAbstractOStmtIn) (relOut := acceptRejectOracleRel) (init := init) (impl := impl) := by apply OracleReduction.append_perfectCompleteness (Oₛ₃:=by exact fun _ ↦ OracleInterface.instDefault) - · exact batchingCore_perfectCompleteness (hInit:=hInit) κ L K β ℓ ℓ' h_l mlIOPCS init - · exact mlIOPCS.perfectCompleteness + · exact batchingCore_perfectCompleteness κ L K β ℓ ℓ' h_l + (𝓑 := 𝓑) mlIOPCS init (hInit := hInit) (impl := impl) + · exact mlIOPCS.strictPerfectCompleteness hInit def batchingCoreRbrKnowledgeError (i : (pSpecBatching κ L K ++ₚ pSpecCoreInteraction L ℓ').ChallengeIdx) : ℝ≥0 := - Sum.elim (f:=BatchingPhase.batchingRBRKnowledgeError κ L K) + Sum.elim (f:=fun _ => BatchingPhase.batchingRBRKnowledgeError (κ:=κ) (L:=L)) (g:=SumcheckPhase.coreInteractionRbrKnowledgeError L ℓ') (ChallengeIdx.sumEquiv.symm i) @@ -141,49 +139,49 @@ def fullRbrKnowledgeError (i : (fullPspec κ L K ℓ' mlIOPCS).ChallengeIdx) : (g:=mlIOPCS.rbrKnowledgeError) (ChallengeIdx.sumEquiv.symm i) -variable [SampleableType L] - +omit [∀ i, SampleableType (mlIOPCS.pSpec.Challenge i)] in /-- Round-by-round knowledge soundness for the full ring-switching oracle verifier -/ theorem fullOracleVerifier_rbrKnowledgeSoundness {𝓑 : Fin 2 ↪ L} : OracleVerifier.rbrKnowledgeSoundness (verifier := fullOracleVerifier κ L K β ℓ ℓ' (𝓑 := 𝓑) h_l mlIOPCS) (init := init) (impl := impl) - (relIn := fullInputRelation κ L K β ℓ ℓ' h_l mlIOPCS) - (relOut := fullOutputRelation) + (relIn := BatchingPhase.batchingInputRelation κ L K β ℓ ℓ' + h_l mlIOPCS.toAbstractOStmtIn) + (relOut := acceptRejectOracleRel) (rbrKnowledgeError := fun i => fullRbrKnowledgeError κ L K ℓ' mlIOPCS i) := by unfold fullOracleVerifier fullRbrKnowledgeError have batchInteractionRBRKS := OracleVerifier.append_rbrKnowledgeSoundness (init:=init) (impl:=impl) - (rel₁:=fullInputRelation κ L K β ℓ ℓ' h_l mlIOPCS) + (rel₁:=BatchingPhase.batchingInputRelation κ L K β ℓ ℓ' + h_l mlIOPCS.toAbstractOStmtIn) (rel₂:=sumcheckRoundRelation κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) mlIOPCS.toAbstractOStmtIn 0) (rel₃:=mlIOPCS.toRelInput) (V₁:=BatchingPhase.batchingOracleVerifier κ L K β ℓ ℓ' h_l mlIOPCS.toAbstractOStmtIn) (V₂:=SumcheckPhase.coreInteractionOracleVerifier κ L K β ℓ ℓ' h_l mlIOPCS.toAbstractOStmtIn) - (rbrKnowledgeError₁:=BatchingPhase.batchingRBRKnowledgeError κ L K) + (rbrKnowledgeError₁:=fun _ => BatchingPhase.batchingRBRKnowledgeError (κ:=κ) (L:=L)) (rbrKnowledgeError₂:=SumcheckPhase.coreInteractionRbrKnowledgeError L ℓ') (h₁:=BatchingPhase.batchingOracleVerifier_rbrKnowledgeSoundness κ L K β ℓ ℓ' h_l mlIOPCS.toAbstractOStmtIn) (h₂:=SumcheckPhase.coreInteraction_rbrKnowledgeSoundness κ L K β ℓ ℓ' h_l mlIOPCS.toAbstractOStmtIn) - have res := OracleVerifier.append_rbrKnowledgeSoundness (init:=init) (impl:=impl) - (rel₁:=fullInputRelation κ L K β ℓ ℓ' h_l mlIOPCS) + (rel₁:=BatchingPhase.batchingInputRelation κ L K β ℓ ℓ' + h_l mlIOPCS.toAbstractOStmtIn) (rel₂:=mlIOPCS.toRelInput) - (rel₃:=fullOutputRelation) - (V₁:=batchingCoreVerifier κ L K β ℓ ℓ' h_l mlIOPCS) + (rel₃:=acceptRejectOracleRel) + (V₁:=batchingCoreVerifier κ L K β (𝓑 := 𝓑) ℓ ℓ' h_l mlIOPCS) (V₂:=mlIOPCS.oracleReduction.verifier) (Oₛ₃:=by exact fun i ↦ OracleInterface.instDefault) (rbrKnowledgeError₁:=batchingCoreRbrKnowledgeError κ L K ℓ') (rbrKnowledgeError₂:=mlIOPCS.rbrKnowledgeError) - (h₁:=batchInteractionRBRKS) (h₂:=by - convert mlIOPCS.rbrKnowledgeSoundness (L:=L) (ℓ' := ℓ') (init:=init) (impl:=impl) - · sorry - ) - convert res - · simp only [ChallengeIdx, Challenge, instSampleableTypeChallengeFullPspec] - sorry + (h₁:=batchInteractionRBRKS) + (h₂:= mlIOPCS.rbrKnowledgeSoundness) + exact OracleVerifier.rbrKnowledgeSoundness_of_eq_error + (init := init) (impl := impl) + (h_ε := by intro i; rfl) + (h := res) end SecurityProperties end diff --git a/ArkLib/ProofSystem/Binius/RingSwitching/Prelude.lean b/ArkLib/ProofSystem/Binius/RingSwitching/Prelude.lean index 3bd6754d4..e0a3e539d 100644 --- a/ArkLib/ProofSystem/Binius/RingSwitching/Prelude.lean +++ b/ArkLib/ProofSystem/Binius/RingSwitching/Prelude.lean @@ -121,10 +121,8 @@ def packMLE (β : Basis (Fin κ → Fin 2) K L) (t : MultilinearPoly K ℓ) : w ⟨i.val - κ, by omega⟩ -- Evaluate the small-field polynomial `t` at this point. MvPolynomial.eval (fun i => ↑(concatenated_point i)) t.val - -- b. Use `equivFun.symm` = ∑ v, (coeffs_for_w v) • (β v). β.equivFun.symm coeffs_for_w - -- 2. The packed polynomial `t'` is the multilinear extension of this function. ⟨MvPolynomial.MLE packing_func, MLE_mem_restrictDegree packing_func⟩ @@ -144,17 +142,14 @@ def unpackMLE (β : Basis (Fin κ → Fin 2) K L) (t' : MultilinearPoly L ℓ') -- a. Deconstruct the evaluation point `p` into `v` (first κ bits) and `w` (last ℓ' bits). let v (i : Fin κ) : Fin 2 := p ⟨i.val, by omega⟩ let w (i : Fin ℓ') : Fin 2 := p ⟨i.val + κ, by { rw [h_l]; omega }⟩ - -- b. Evaluate the large-field polynomial `t'` at the point `w`. let t'_eval_at_w : L := MvPolynomial.eval (fun i => ↑(w i)) t'.val - -- c. Get the K-coefficients of this L-element with respect to the basis `β`. -- `β.repr/β.equivFun` maps an element of L to its coordinate function `(Fin κ → Fin 2) → K`. let coeffs : (Fin κ → Fin 2) → K := β.repr t'_eval_at_w -- d. The desired evaluation t(p) = t(v,w) -- is the coefficient corresponding to the basis vector `β_v`. coeffs v - -- 2. The unpacked polynomial `t` is the multilinear extension of this evaluation function. ⟨MvPolynomial.MLE unpacked_evals, MLE_mem_restrictDegree unpacked_evals⟩ @@ -195,10 +190,6 @@ We define the Statement and Witness types at the boundaries of each phase following the enhanced specification. -/ --- Initial Input (Input to Batching Phase) -abbrev MLPEvalStatement := - Binius.BinaryBasefold.InitialStatement (L := L) (ℓ := ℓ) - structure WitMLP where t : MultilinearPoly K ℓ @@ -230,7 +221,7 @@ structure MLIOPCSStmt where def MLPEvalRelation (ιₛᵢ : Type) (OStmtIn : ιₛᵢ → Type) (input : ((MLPEvalStatement L ℓ') × (∀ j, OStmtIn j)) × (WitMLP L ℓ')) : Prop := let ⟨⟨stmt, _⟩, wit⟩ := input - stmt.original_claim = wit.t.val.eval stmt.t_eval_point + wit.t.val.eval stmt.t_eval_point = stmt.original_claim structure AbstractOStmtIn where ιₛᵢ : Type @@ -239,13 +230,43 @@ structure AbstractOStmtIn where -- The abstract initial compatibility relation, which along with -- MLPEvalRelation, forms the initial input relation for the MLIOPCS. initialCompatibility : (MultilinearPoly L ℓ') × (∀ j, OStmtIn j) → Prop - + -- Strict compatibility relation used by perfect-completeness statements. + strictInitialCompatibility : (MultilinearPoly L ℓ') × (∀ j, OStmtIn j) → Prop + -- Strict compatibility is stronger and should imply the relaxed one. + strictInitialCompatibility_implies_initialCompatibility : + ∀ (oStmt : ∀ j, OStmtIn j) (t : MultilinearPoly L ℓ'), + strictInitialCompatibility ⟨t, oStmt⟩ → initialCompatibility ⟨t, oStmt⟩ + -- The ideal oracle **(Functionality 2.4, 2.5, 2.6)** stores the exact vector, so the + -- oracle commitment uniquely determines the polynomial t'. + -- **NOTE**: This captures `|Λ| = 1` (i.e. set of compatible witnesses + -- compatible with oracles) in the WARP paper's terminology. + initialCompatibility_unique : ∀ (oStmt : ∀ j, OStmtIn j) (t₁ t₂ : MultilinearPoly L ℓ'), + initialCompatibility ⟨t₁, oStmt⟩ → initialCompatibility ⟨t₂, oStmt⟩ → t₁ = t₂ + +/-- Relaxed relation used for RBR knowledge-soundness statements. -/ def AbstractOStmtIn.toRelInput (aOStmtIn : AbstractOStmtIn L ℓ') : Set (((MLPEvalStatement L ℓ') × (∀ j, aOStmtIn.OStmtIn j)) × (WitMLP L ℓ')) := {input | MLPEvalRelation L ℓ' aOStmtIn.ιₛᵢ aOStmtIn.OStmtIn input ∧ aOStmtIn.initialCompatibility ⟨input.2.t, input.1.2⟩} +/-- Strict relation used for perfect-completeness statements. -/ +def AbstractOStmtIn.toStrictRelInput (aOStmtIn : AbstractOStmtIn L ℓ') : + Set (((MLPEvalStatement L ℓ') × (∀ j, aOStmtIn.OStmtIn j)) × (WitMLP L ℓ')) := + {input | + MLPEvalRelation L ℓ' aOStmtIn.ιₛᵢ aOStmtIn.OStmtIn input + ∧ aOStmtIn.strictInitialCompatibility ⟨input.2.t, input.1.2⟩} + +omit [Fintype L] [DecidableEq L] [CharP L 2] [NeZero ℓ'] in +lemma AbstractOStmtIn.toStrictRelInput_subset_toRelInput (aOStmtIn : AbstractOStmtIn L ℓ') : + aOStmtIn.toStrictRelInput ⊆ aOStmtIn.toRelInput := by + intro input h_input + rcases input with ⟨⟨stmt, oStmt⟩, wit⟩ + rcases h_input with ⟨h_eval, h_compat_strict⟩ + exact ⟨h_eval, + aOStmtIn.strictInitialCompatibility_implies_initialCompatibility oStmt wit.t + h_compat_strict⟩ + structure MLIOPCS extends (AbstractOStmtIn L ℓ') where /-- Protocol specification -/ numRounds : ℕ @@ -260,11 +281,22 @@ structure MLIOPCS extends (AbstractOStmtIn L ℓ') where (pSpec := pSpec) -- Security properties perfectCompleteness : ∀ {σ : Type} {init : ProbComp σ} {impl : QueryImpl []ₒ (StateT σ ProbComp)}, + NeverFail init → OracleReduction.perfectCompleteness (oSpec:=[]ₒ) (StmtIn:=MLPEvalStatement L ℓ') (OStmtIn:=OStmtIn) (StmtOut:=Bool) (OStmtOut:=fun _: Empty => Unit) (WitIn:=WitMLP L ℓ') (WitOut:=Unit) (pSpec:=pSpec) (init:=init) (impl:=impl) - (relIn := toAbstractOStmtIn.toRelInput) + (relIn := toAbstractOStmtIn.toStrictRelInput) + (relOut := acceptRejectOracleRel) + (oracleReduction := oracleReduction) + strictPerfectCompleteness : ∀ {σ : Type} {init : ProbComp σ} + {impl : QueryImpl []ₒ (StateT σ ProbComp)}, + NeverFail init → + OracleReduction.perfectCompleteness (oSpec:=[]ₒ) + (StmtIn:=MLPEvalStatement L ℓ') (OStmtIn:=OStmtIn) + (StmtOut:=Bool) (OStmtOut:=fun _: Empty => Unit) + (WitIn:=WitMLP L ℓ') (WitOut:=Unit) (pSpec:=pSpec) (init:=init) (impl:=impl) + (relIn := toAbstractOStmtIn.toStrictRelInput) (relOut := acceptRejectOracleRel) (oracleReduction := oracleReduction) -- RBR knowledge error function for the MLIOPCS @@ -339,7 +371,8 @@ lemma batching_check_correctness (eval_point : Fin ℓ → L) : performCheckOriginalEvaluation κ L K β ℓ ℓ' h_l (t.val.aeval eval_point) - (r := eval_point) (s_hat := embedded_MLP_eval κ (L := L) (K := K) ℓ ℓ' h_l (packMLE κ (L := L) (K := K) ℓ ℓ' h_l β t) eval_point) = true := by + (r := eval_point) (s_hat := embedded_MLP_eval κ (L := L) (K := K) ℓ ℓ' h_l + (packMLE κ (L := L) (K := K) ℓ ℓ' h_l β t) eval_point) = true := by -- Unfold the check definition unfold performCheckOriginalEvaluation simp only [decide_eq_true_eq] @@ -427,7 +460,8 @@ lemma compute_A_MLE_eval_eq_final_eq_value (r_eval : Fin ℓ → L) (r'_challenges : Fin ℓ' → L) (r''_batching : Fin κ → L) : - (compute_A_MLE κ L K β ℓ' (getEvaluationPointSuffix κ L ℓ ℓ' h_l r_eval) r''_batching).val.eval r'_challenges = + (compute_A_MLE κ L K β ℓ' (getEvaluationPointSuffix κ L ℓ ℓ' h_l r_eval) + r''_batching).val.eval r'_challenges = compute_final_eq_value κ L K β ℓ ℓ' h_l r_eval r'_challenges r''_batching := by -- Unfold definitions simp only [compute_A_MLE, compute_final_eq_value, getEvaluationPointSuffix] @@ -452,17 +486,27 @@ def masterKStateProp (aOStmtIn : AbstractOStmtIn L ℓ') (stmtIdx : Fin (ℓ' + (stmt : Statement (L := L) (RingSwitchingBaseContext κ L K ℓ) stmtIdx) (oStmt : ∀ j, aOStmtIn.OStmtIn j) (wit : SumcheckWitness L ℓ' stmtIdx) - (localChecks : Prop := True) : Prop := + (localChecks : Prop) : Prop := localChecks + -- Should witnessStructuralInvariant be part of localChecks? ∧ witnessStructuralInvariant κ L K β ℓ ℓ' h_l stmt wit - ∧ sumcheckConsistencyProp (𝓑:=𝓑) stmt.sumcheck_target wit.H ∧ aOStmtIn.initialCompatibility ⟨wit.t', oStmt⟩ +def masterStrictKStateProp (aOStmtIn : AbstractOStmtIn L ℓ') (stmtIdx : Fin (ℓ' + 1)) + (stmt : Statement (L := L) (RingSwitchingBaseContext κ L K ℓ) stmtIdx) + (oStmt : ∀ j, aOStmtIn.OStmtIn j) + (wit : SumcheckWitness L ℓ' stmtIdx) + (localChecks : Prop) : Prop := + localChecks + ∧ witnessStructuralInvariant κ L K β ℓ ℓ' h_l stmt wit + ∧ aOStmtIn.strictInitialCompatibility ⟨wit.t', oStmt⟩ + def sumcheckRoundRelationProp (aOStmtIn : AbstractOStmtIn L ℓ') (i : Fin (ℓ' + 1)) (stmt : Statement (L := L) (RingSwitchingBaseContext κ L K ℓ) i) (oStmt : ∀ j, aOStmtIn.OStmtIn j) (wit : SumcheckWitness L ℓ' i) : Prop := - masterKStateProp κ L K β ℓ ℓ' h_l (𝓑:=𝓑) aOStmtIn i stmt oStmt wit + masterKStateProp κ L K β ℓ ℓ' h_l aOStmtIn i stmt oStmt wit + (localChecks := sumcheckConsistencyProp (𝓑:=𝓑) stmt.sumcheck_target wit.H) /-- Input relation for single round: proper sumcheck statement -/ def sumcheckRoundRelation (aOStmtIn : AbstractOStmtIn L ℓ') (i : Fin (ℓ' + 1)) : @@ -471,6 +515,33 @@ def sumcheckRoundRelation (aOStmtIn : AbstractOStmtIn L ℓ') (i : Fin (ℓ' + 1 { ((stmt, oStmt), wit) | sumcheckRoundRelationProp κ L K β ℓ ℓ' h_l (𝓑:=𝓑) aOStmtIn i stmt oStmt wit } +def strictSumcheckRoundRelationProp (aOStmtIn : AbstractOStmtIn L ℓ') (i : Fin (ℓ' + 1)) + (stmt : Statement (L := L) (RingSwitchingBaseContext κ L K ℓ) i) + (oStmt : ∀ j, aOStmtIn.OStmtIn j) + (wit : SumcheckWitness L ℓ' i) : Prop := + masterStrictKStateProp κ L K β ℓ ℓ' h_l aOStmtIn i stmt oStmt wit + (localChecks := sumcheckConsistencyProp (𝓑:=𝓑) stmt.sumcheck_target wit.H) + +/-- Strict round relation for completeness proofs. -/ +def strictSumcheckRoundRelation (aOStmtIn : AbstractOStmtIn L ℓ') (i : Fin (ℓ' + 1)) : + Set (((Statement (L := L) (RingSwitchingBaseContext κ L K ℓ) i) × + (∀ j, aOStmtIn.OStmtIn j)) × SumcheckWitness L ℓ' i) := + { ((stmt, oStmt), wit) | strictSumcheckRoundRelationProp κ L K β ℓ ℓ' h_l (𝓑:=𝓑) + aOStmtIn i stmt oStmt wit } + +omit [Fintype L] [DecidableEq L] [CharP L 2] [SampleableType L] [Fintype K] [DecidableEq K] + [NeZero ℓ] [NeZero ℓ'] in +lemma strictSumcheckRoundRelation_subset_sumcheckRoundRelation (aOStmtIn : AbstractOStmtIn L ℓ') + (i : Fin (ℓ' + 1)) : + strictSumcheckRoundRelation κ L K β ℓ ℓ' h_l (𝓑:=𝓑) aOStmtIn i ⊆ + sumcheckRoundRelation κ L K β ℓ ℓ' h_l (𝓑:=𝓑) aOStmtIn i := by + intro input h_input + rcases input with ⟨⟨stmt, oStmt⟩, wit⟩ + rcases h_input with ⟨h_local, h_struct, h_strict_compat⟩ + exact ⟨h_local, h_struct, + aOStmtIn.strictInitialCompatibility_implies_initialCompatibility oStmt wit.t' + h_strict_compat⟩ + /-- **Consistency of the Batching Target** This lemma proves that the batched target value `s₀` computed by the verifier @@ -482,7 +553,9 @@ lemma batching_target_consistency (r_batching : Fin κ → L) (ctx : RingSwitchingBaseContext κ L K ℓ) : let s₀ := compute_s0 κ L K β msg0 r_batching - let H := projectToMidSumcheckPoly (L := L) (ℓ := ℓ') (t := t') (m := (RingSwitching_SumcheckMultParam κ L K β ℓ ℓ' h_l).multpoly ctx) (i := 0) (challenges := Fin.elim0) + let H := projectToMidSumcheckPoly (L := L) (ℓ := ℓ') (t := t') + (m := (RingSwitching_SumcheckMultParam κ L K β ℓ ℓ' h_l).multpoly ctx) (i := 0) + (challenges := Fin.elim0) sumcheckConsistencyProp (𝓑:=𝓑) s₀ H := by -- This lemma proves that s₀ = Σ_{x ∈ {0,1}^ℓ'} H(x) -- It follows from the definition of compute_s0, H, and A_MLE. diff --git a/ArkLib/ProofSystem/Binius/RingSwitching/Spec.lean b/ArkLib/ProofSystem/Binius/RingSwitching/Spec.lean index 70d135610..8e6280189 100644 --- a/ArkLib/ProofSystem/Binius/RingSwitching/Spec.lean +++ b/ArkLib/ProofSystem/Binius/RingSwitching/Spec.lean @@ -4,9 +4,11 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: Chung Thai Nguyen, Quang Dao -/ import ArkLib.ProofSystem.Binius.RingSwitching.Prelude +import ArkLib.ProofSystem.Binius.BinaryBasefold.Spec import ArkLib.ToVCVio.Oracle namespace Binius.RingSwitching +open Binius.BinaryBasefold /-! ## Protocol Specs for Ring-Switching @@ -58,7 +60,8 @@ def pSpecSumcheckRound : ProtocolSpec 2 := ⟨![Direction.P_to_V, Direction.V_to def pSpecSumcheckLoop := ProtocolSpec.seqCompose (fun (_: Fin ℓ') => pSpecSumcheckRound L) -def pSpecFinalSumcheck : ProtocolSpec 1 := ⟨![Direction.P_to_V], ![L]⟩ +@[reducible] +def pSpecFinalSumcheck := pSpecFinalSumcheckStep (L := L) @[reducible] def pSpecCoreInteraction := (pSpecSumcheckLoop L ℓ') ++ₚ (pSpecFinalSumcheck L) @@ -79,11 +82,12 @@ instance : ∀ j, OracleInterface ((pSpecBatching κ L K).Message j) | ⟨1, _⟩ => OracleInterface.instDefault -- r'' ∈ L^κ instance : ∀ j, OracleInterface ((pSpecBatching κ L K).Challenge j) := - ProtocolSpec.challengeOracleInterface + fun _ => OracleInterface.instDefault + -- NOTE: this is same as ProtocolSpec.challengeOracleInterface (pSpec := pSpecBatching κ L K) -instance : ∀ j, OracleInterface ((pSpecSumcheckRound (L:=L)).Message j) - | ⟨0, _⟩ => OracleInterface.instDefault -- h_i(X) polynomial - | ⟨1, _⟩ => OracleInterface.instDefault -- challenge r'_i +instance instOracleInterfaceMessagePSpecSumcheckRound : + ∀ j, OracleInterface ((pSpecSumcheckRound (L:=L)).Message j) := + fun _ => OracleInterface.instDefault instance : ∀ j, OracleInterface ((pSpecSumcheckRound (L:=L)).Challenge j) := ProtocolSpec.challengeOracleInterface @@ -91,12 +95,6 @@ instance : ∀ j, OracleInterface ((pSpecSumcheckRound (L:=L)).Challenge j) := instance : ∀ j, OracleInterface ((pSpecSumcheckLoop (L:=L) ℓ').Message j) := instOracleInterfaceMessageSeqCompose -instance : ∀ i, OracleInterface ((pSpecFinalSumcheck (L:=L)).Message i) - | ⟨0, _⟩ => OracleInterface.instDefault -- final constant c - -instance : ∀ i, OracleInterface ((pSpecFinalSumcheck (L:=L)).Challenge i) - := ProtocolSpec.challengeOracleInterface - instance : ∀ i, OracleInterface ((pSpecCoreInteraction (L:=L) (ℓ':=ℓ')).Message i) := instOracleInterfaceMessageAppend @@ -125,9 +123,6 @@ instance : ∀ j, SampleableType ((pSpecSumcheckRound (L:=L)).Challenge j) instance : ∀ j, SampleableType ((pSpecSumcheckLoop (L:=L) ℓ').Challenge j) := instSampleableTypeChallengeSeqCompose -instance : ∀ i, SampleableType ((pSpecFinalSumcheck (L:=L)).Challenge i) - | ⟨0, h0⟩ => by nomatch h0 -- P->V message has no challenge - instance : ∀ i, SampleableType ((pSpecCoreInteraction (L:=L) (ℓ':=ℓ')).Challenge i) := instSampleableTypeChallengeAppend diff --git a/ArkLib/ProofSystem/Binius/RingSwitching/SumcheckPhase.lean b/ArkLib/ProofSystem/Binius/RingSwitching/SumcheckPhase.lean index e9ba3b9a7..27988a6d7 100644 --- a/ArkLib/ProofSystem/Binius/RingSwitching/SumcheckPhase.lean +++ b/ArkLib/ProofSystem/Binius/RingSwitching/SumcheckPhase.lean @@ -10,9 +10,10 @@ import ArkLib.OracleReduction.Composition.Sequential.General import ArkLib.OracleReduction.Composition.Sequential.Append import ArkLib.OracleReduction.Security.RoundByRound import ArkLib.ProofSystem.Binius.BinaryBasefold.ReductionLogic +import ArkLib.ProofSystem.Binius.BinaryBasefold.Soundness open OracleSpec OracleComp ProtocolSpec Finset AdditiveNTT Polynomial MvPolynomial - Module Binius.BinaryBasefold TensorProduct Nat Matrix + Module Binius.BinaryBasefold TensorProduct Nat Matrix ProbabilityTheory open scoped NNReal /-! @@ -41,8 +42,8 @@ source of RBR knowledge soundness error. 7. `P` computes `s' := t'(r'_0, ..., r'_{ℓ'-1})` and sends `V` `s'`. 8. `V` sets `e := eq̃(φ₀(r_κ), ..., φ₀(r_{ℓ-1}), φ₁(r'_0), ..., φ₁(r'_{ℓ'-1}))` and decomposes `e =: Σ_{u ∈ {0,1}^κ} β_u ⊗ e_u`. -9. `V` requires `s_{ℓ'} ?= (Σ_{u ∈ {0,1}^κ} eq̃(u_0, ..., u_{κ-1}, r''_0, ..., r''_{κ-1}) ⋅ e_u) ⋅ s'`. --/ +9. `V` requires `s_{ℓ'} ?=` + `(Σ_{u ∈ {0,1}^κ} eq̃(u_0, ..., u_{κ-1}, r''_0, ..., r''_{κ-1}) ⋅ e_u) ⋅ s'`. -/ namespace Binius.RingSwitching.SumcheckPhase noncomputable section @@ -65,14 +66,14 @@ variable (i : Fin ℓ') /-- Pure verifier check: validates that s = h(0) + h(1). -/ @[reducible] -def sumcheckVerifierCheck (stmtIn : Statement (L := L) (ℓ := ℓ') (RingSwitchingBaseContext κ L K ℓ) i.castSucc) - (h_i : L⦃≤ 2⦄[X]) : Prop := +def sumcheckVerifierCheck (stmtIn : Statement (L := L) (ℓ := ℓ') + (RingSwitchingBaseContext κ L K ℓ) i.castSucc) (h_i : L⦃≤ 2⦄[X]) : Prop := h_i.val.eval (𝓑 0) + h_i.val.eval (𝓑 1) = stmtIn.sumcheck_target /-- Pure verifier output: computes the output statement given the transcript. -/ @[reducible] -def sumcheckVerifierStmtOut (stmtIn : Statement (L := L) (ℓ := ℓ') (RingSwitchingBaseContext κ L K ℓ) i.castSucc) - (h_i : L⦃≤ 2⦄[X]) (r_i' : L) : +def sumcheckVerifierStmtOut (stmtIn : Statement (L := L) (ℓ := ℓ') + (RingSwitchingBaseContext κ L K ℓ) i.castSucc) (h_i : L⦃≤ 2⦄[X]) (r_i' : L) : Statement (L := L) (ℓ := ℓ') (RingSwitchingBaseContext κ L K ℓ) i.succ := { ctx := stmtIn.ctx, sumcheck_target := h_i.val.eval r_i', @@ -86,7 +87,8 @@ def sumcheckProverComputeMsg (witIn : SumcheckWitness L ℓ' i.castSucc) : L⦃ /-- Pure prover output: computes the output witness given the transcript. -/ @[reducible] -def sumcheckProverWitOut (_stmtIn : Statement (L := L) (ℓ := ℓ') (RingSwitchingBaseContext κ L K ℓ) i.castSucc) +def sumcheckProverWitOut (_stmtIn : Statement (L := L) (ℓ := ℓ') + (RingSwitchingBaseContext κ L K ℓ) i.castSucc) (witIn : SumcheckWitness L ℓ' i.castSucc) (r_i' : L) : SumcheckWitness L ℓ' i.succ := { t' := witIn.t', @@ -97,7 +99,7 @@ def sumcheckProverWitOut (_stmtIn : Statement (L := L) (ℓ := ℓ') (RingSwitch /-- The Logic Instance for the i-th round of Ring Switching Sumcheck. -/ def sumcheckStepLogic : - Binius.BinaryBasefold.CoreInteraction.ReductionLogicStep + Binius.BinaryBasefold.ReductionLogicStep (Statement (L := L) (ℓ := ℓ') (RingSwitchingBaseContext κ L K ℓ) i.castSucc) (SumcheckWitness L ℓ' i.castSucc) (aOStmtIn.OStmtIn) @@ -105,26 +107,24 @@ def sumcheckStepLogic : (Statement (L := L) (ℓ := ℓ') (RingSwitchingBaseContext κ L K ℓ) i.succ) (SumcheckWitness L ℓ' i.succ) (pSpecSumcheckRound L) where - completeness_relIn := fun ((stmt, oStmt), wit) => - ((stmt, oStmt), wit) ∈ sumcheckRoundRelation κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) aOStmtIn i.castSucc + ((stmt, oStmt), wit) ∈ strictSumcheckRoundRelation κ L K β ℓ ℓ' h_l + (𝓑 := 𝓑) aOStmtIn i.castSucc completeness_relOut := fun ((stmt, oStmt), wit) => - ((stmt, oStmt), wit) ∈ sumcheckRoundRelation κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) aOStmtIn i.succ - + ((stmt, oStmt), wit) ∈ strictSumcheckRoundRelation κ L K β ℓ ℓ' h_l + (𝓑 := 𝓑) aOStmtIn i.succ verifierCheck := fun stmtIn transcript => - sumcheckVerifierCheck (κ:=κ) (L:=L) (K:=K) (ℓ:=ℓ) (ℓ':=ℓ') (𝓑:=𝓑) i stmtIn (transcript.messages ⟨0, rfl⟩) - + sumcheckVerifierCheck (κ:=κ) (L:=L) (K:=K) (ℓ:=ℓ) (ℓ':=ℓ') (𝓑:=𝓑) + i stmtIn (transcript.messages ⟨0, rfl⟩) verifierOut := fun stmtIn transcript => - sumcheckVerifierStmtOut (κ:=κ) (L:=L) (K:=K) (ℓ:=ℓ) (ℓ':=ℓ') i stmtIn (transcript.messages ⟨0, rfl⟩) (transcript.challenges ⟨1, rfl⟩) - + sumcheckVerifierStmtOut (κ:=κ) (L:=L) (K:=K) (ℓ:=ℓ) (ℓ':=ℓ') i stmtIn + (transcript.messages ⟨0, rfl⟩) (transcript.challenges ⟨1, rfl⟩) embed := ⟨fun j => Sum.inl j, fun a b h => by cases h; rfl⟩ hEq := fun i => rfl - -- honestProverTranscript is the concatenation of sendMessage & receiveChallenge methods honestProverTranscript := fun _stmtIn witIn _oStmtIn chal => let msg := sumcheckProverComputeMsg (L:=L) (ℓ':=ℓ') (𝓑:=𝓑) i witIn FullTranscript.mk2 msg (chal ⟨1, rfl⟩) - proverOut := fun stmtIn witIn oStmtIn transcript => let h_i := transcript.messages ⟨0, rfl⟩ let r_i' := transcript.challenges ⟨1, rfl⟩ @@ -154,22 +154,17 @@ noncomputable def iteratedSumcheckOracleProver (i : Fin ℓ') : (OStmtOut := aOStmtIn.OStmtIn) (WitOut := SumcheckWitness L ℓ' i.succ) (pSpec := pSpecSumcheckRound L) where - PrvState := iteratedSumcheckPrvState κ L K ℓ ℓ' aOStmtIn i - input := fun ⟨⟨stmt, oStmt⟩, wit⟩ => (stmt, oStmt, wit) - sendMessage -- There are 2 messages in the pSpec | ⟨0, _⟩ => fun ⟨stmt, oStmt, wit⟩ => do let h_i := sumcheckProverComputeMsg (L:=L) (ℓ':=ℓ') (𝓑:=𝓑) i wit pure ⟨h_i, (stmt, oStmt, wit, h_i)⟩ | ⟨1, _⟩ => by contradiction - receiveChallenge | ⟨0, h⟩ => nomatch h -- i.e. contradiction | ⟨1, _⟩ => fun ⟨stmt, oStmt, wit, h_i⟩ => do pure (fun r_i' => (stmt, oStmt, wit, h_i, r_i')) - -- output : PrvState → StmtOut × (∀i, OracleStatement i) × WitOut output := fun finalPrvState => let (stmt, oStmt, wit, h_i, r_i') := finalPrvState @@ -189,7 +184,6 @@ noncomputable def iteratedSumcheckOracleVerifier (i : Fin ℓ') : (StmtOut := Statement (L := L) (ℓ := ℓ') (RingSwitchingBaseContext κ L K ℓ) i.succ) (OStmtOut := aOStmtIn.OStmtIn) (pSpec := pSpecSumcheckRound L) where - -- The core verification logic. Takes the input statement `stmtIn` and the transcript. verify := fun stmtIn pSpecChallenges => do -- Message 0 : Receive h_i(X) from prover @@ -202,7 +196,6 @@ noncomputable def iteratedSumcheckOracleVerifier (i : Fin ℓ') : (𝓑:=𝓑) (aOStmtIn:=aOStmtIn) i guard (logic.verifierCheck stmtIn t) pure (logic.verifierOut stmtIn t) - embed := ⟨fun j => Sum.inl j, fun a b h => by cases h; rfl⟩ hEq := fun _ => rfl @@ -237,22 +230,17 @@ lemma sumcheckStep_is_logic_complete (i : Fin ℓ') : let proverStmtOut := proverOutput.1.1 let proverOStmtOut := proverOutput.1.2 let proverWitOut := proverOutput.2 - - dsimp only [sumcheckStepLogic, sumcheckRoundRelation, - sumcheckRoundRelationProp, masterKStateProp] at h_relIn - + dsimp only [sumcheckStepLogic, strictSumcheckRoundRelation, + strictSumcheckRoundRelationProp, masterStrictKStateProp] at h_relIn -- We'll need sumcheck consistency for Fact 1, so extract it from either branch have h_sumcheck_cons : sumcheckConsistencyProp (𝓑 := 𝓑) stmtIn.sumcheck_target witIn.H - := h_relIn.2.2.1 - + := h_relIn.1 -- Fact 1: Verifier check passes let h_VCheck_passed : step.verifierCheck stmtIn transcript := by simp only [sumcheckStepLogic, step, sumcheckVerifierCheck] rw [h_sumcheck_cons] apply getSumcheckRoundPoly_sum_eq (𝓑 := 𝓑) (i := i) (h := witIn.H) - have hStmtOut_eq : proverStmtOut = verifierStmtOut := rfl - have hOStmtOut_eq : proverOStmtOut = verifierOStmtOut := by change (step.proverOut stmtIn witIn oStmtIn transcript).1.2 = OracleVerifier.mkVerifierOStmtOut step.embed step.hEq oStmtIn transcript @@ -267,17 +255,15 @@ lemma sumcheckStep_is_logic_complete (i : Fin ℓ') : rfl · rename_i heq simp only [MessageIdx, Function.Embedding.coeFn_mk, reduceCtorEq] at heq - -- Key fact: Oracle statements are unchanged in the fold step -- (all oracle indices map via Sum.inl in the embedding) have h_verifierOStmtOut_eq : verifierOStmtOut = oStmtIn := by rw [← hOStmtOut_eq] simp only [proverOStmtOut, proverOutput, step, sumcheckStepLogic] - let hRelOut : step.completeness_relOut ((verifierStmtOut, verifierOStmtOut), proverWitOut) := by - -- Fact 2: Output relation holds (sumcheckRoundRelation) - simp only [step, sumcheckStepLogic, sumcheckRoundRelation, - sumcheckRoundRelationProp, masterKStateProp] + -- Fact 2: Output relation holds (strictSumcheckRoundRelation) + simp only [step, sumcheckStepLogic, strictSumcheckRoundRelation, + strictSumcheckRoundRelationProp, masterStrictKStateProp] let r_i' := challenges ⟨1, rfl⟩ simp only [Fin.val_succ, true_and, Set.mem_setOf_eq] simp only [Fin.val_castSucc] at h_relIn @@ -285,29 +271,27 @@ lemma sumcheckStep_is_logic_complete (i : Fin ℓ') : rw [h_verifierOStmtOut_eq]; dsimp only [strictOracleWitnessConsistency] at h_oracleWitConsistency_In ⊢ -- Extract the three components from the input - obtain ⟨h_wit_struct_In, ⟨h_sumcheck_cons_In, h_oStmtIn_compat⟩⟩ := - h_oracleWitConsistency_In - + obtain ⟨h_wit_struct_In, h_oStmtIn_compat⟩ := + h_oracleWitConsistency_In constructor - · -- Component 1: witnessStructuralInvariant - unfold witnessStructuralInvariant at ⊢ h_wit_struct_In - let h_H_In := h_wit_struct_In - conv_lhs => - dsimp only [proverWitOut, proverOutput, step, - sumcheckStepLogic] - conv_lhs => - rw [h_H_In] - rw [←projectToMidSumcheckPoly_succ] - rfl + · -- sumcheckConsistencyProp + unfold sumcheckConsistencyProp + dsimp only [verifierStmtOut, proverWitOut, proverOutput] + simp only [step, sumcheckStepLogic, transcript] + apply projectToNextSumcheckPoly_sum_eq · constructor - · -- Part 2.1: sumcheckConsistencyProp - unfold sumcheckConsistencyProp - dsimp only [verifierStmtOut, proverWitOut, proverOutput] - simp only [step, sumcheckStepLogic, transcript] - apply projectToNextSumcheckPoly_sum_eq + · -- Component 1: witnessStructuralInvariant + unfold witnessStructuralInvariant at ⊢ h_wit_struct_In + let h_H_In := h_wit_struct_In + conv_lhs => + dsimp only [proverWitOut, proverOutput, step, + sumcheckStepLogic] + conv_lhs => + rw [h_H_In] + rw [←projectToMidSumcheckPoly_succ] + rfl · --Part 2.2: initialCompatibility exact h_oStmtIn_compat - -- Prove the four required facts refine ⟨?_, ?_, ?_, ?_⟩ · exact h_VCheck_passed @@ -323,11 +307,11 @@ variable {σ : Type} {init : ProbComp σ} {impl : QueryImpl []ₒ (StateT σ Pro theorem iteratedSumcheckOracleReduction_perfectCompleteness (i : Fin ℓ') (hInit : NeverFail init) : OracleReduction.perfectCompleteness (pSpec := pSpecSumcheckRound L) - (relIn := sumcheckRoundRelation κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) aOStmtIn i.castSucc) - (relOut := sumcheckRoundRelation κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) aOStmtIn i.succ) - (oracleReduction := iteratedSumcheckOracleReduction κ (L:=L) (K:=K) (ℓ:=ℓ) (ℓ':=ℓ') (𝓑:=𝓑) (β := β) (h_l := h_l) aOStmtIn i) - (init := init) - (impl := impl) := by + (relIn := strictSumcheckRoundRelation κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) aOStmtIn i.castSucc) + (relOut := strictSumcheckRoundRelation κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) aOStmtIn i.succ) + (oracleReduction := iteratedSumcheckOracleReduction κ (L:=L) + (K:=K) (ℓ:=ℓ) (ℓ':=ℓ') (𝓑:=𝓑) (β := β) (h_l := h_l) aOStmtIn i) + (init := init) (impl := impl) := by -- Step 1: Unroll the 2-message reduction to convert from probability to logic -- **NOTE**: this requires `ProtocolSpec.challengeOracleInterface` to avoid conflict rw [OracleReduction.unroll_2_message_reduction_perfectCompleteness (oSpec := []ₒ) @@ -339,12 +323,12 @@ theorem iteratedSumcheckOracleReduction_perfectCompleteness (i : Fin ℓ') (hIni -- Step 2: Convert probability 1 to universal quantification over support rw [probEvent_eq_one_iff] -- Step 3: Unfold protocol definitions - dsimp only [iteratedSumcheckOracleReduction, iteratedSumcheckOracleProver, iteratedSumcheckOracleVerifier, OracleVerifier.toVerifier, - FullTranscript.mk2] + dsimp only [iteratedSumcheckOracleReduction, iteratedSumcheckOracleProver, + iteratedSumcheckOracleVerifier, OracleVerifier.toVerifier, FullTranscript.mk2] let step := (sumcheckStepLogic (κ := κ) (L := L) (K := K) (β := β) (𝓑 := 𝓑) (ℓ := ℓ) (ℓ' := ℓ') (h_l := h_l) (aOStmtIn := aOStmtIn)) (i := i) - let strongly_complete : step.IsStronglyComplete := sumcheckStep_is_logic_complete (κ := κ) (L := L) (K := K) (β := β) (ℓ := ℓ) (ℓ' := ℓ') (h_l := h_l) (aOStmtIn := aOStmtIn) (i := i) - + let strongly_complete : step.IsStronglyComplete := sumcheckStep_is_logic_complete (κ := κ) + (L := L) (K := K) (β := β) (ℓ := ℓ) (ℓ' := ℓ') (h_l := h_l) (aOStmtIn := aOStmtIn) (i := i) -- Step 4: Split into safety and correctness goals refine ⟨?_, ?_⟩ -- GOAL 1: SAFETY - Prove the verifier never crashes ([⊥|...] = 0) @@ -362,7 +346,6 @@ theorem iteratedSumcheckOracleReduction_perfectCompleteness (i : Fin ℓ') (hIni simp only [ChallengeIdx, Challenge, Fin.isValue, Matrix.cons_val_one, Matrix.cons_val_zero, liftComp_eq_liftM, OptionT.probFailure_lift, HasEvalPMF.probFailure_eq_zero] rw [true_and] - intro r_i' h_r_i'_mem_query_1_support conv => enter [1]; @@ -372,7 +355,6 @@ theorem iteratedSumcheckOracleReduction_perfectCompleteness (i : Fin ℓ') (hIni Fin.succ_one_eq_two, Message, Fin.succ_zero_eq_one, Fin.castSucc_one, liftComp_eq_liftM, OptionT.probFailure_lift, HasEvalPMF.probFailure_eq_zero] rw [true_and] - intro h_receive_challenge_fn h_receive_challenge_fn_mem_support conv => enter [1]; @@ -430,7 +412,8 @@ theorem iteratedSumcheckOracleReduction_perfectCompleteness (i : Fin ℓ') (hIni set V_check := step.verifierCheck stmtIn (FullTranscript.mk2 (msg0 := _) - (msg1 := (FullTranscript.mk2 (sumcheckProverComputeMsg L ℓ' i witIn) r_i').challenges ⟨1, rfl⟩)) + (msg1 := (FullTranscript.mk2 (sumcheckProverComputeMsg L ℓ' i witIn) r_i').challenges + ⟨1, rfl⟩)) with h_V_check_def obtain ⟨h_V_check, h_rel, h_agree⟩ := strongly_complete (stmtIn := stmtIn) (witIn := witIn) (h_relIn := h_relIn) (challenges := @@ -472,9 +455,7 @@ theorem iteratedSumcheckOracleReduction_perfectCompleteness (i : Fin ℓ') (hIni Fin.reduceLast, MessageIdx, Message, exists_eq_left] at hx_mem_support -- Step 2b: Extract the challenge r1 and the trace equations obtain ⟨r1, ⟨_h_r1_mem_challenge_support, h_trace_support⟩⟩ := hx_mem_support - rcases h_trace_support with ⟨prvOut_eq, h_verOut_mem_support⟩ - -- Step 2c: Simplify the verifier computation conv at h_verOut_mem_support => erw [simulateQ_bind] @@ -493,8 +474,8 @@ theorem iteratedSumcheckOracleReduction_perfectCompleteness (i : Fin ℓ') (hIni set V_check := step.verifierCheck stmtIn (FullTranscript.mk2 (msg0 := _) - (msg1 := (FullTranscript.mk2 (sumcheckProverComputeMsg L ℓ' i witIn) r1).challenges ⟨1, rfl⟩)) - with h_V_check_def + (msg1 := (FullTranscript.mk2 (sumcheckProverComputeMsg L ℓ' i witIn) r1).challenges + ⟨1, rfl⟩)) with h_V_check_def obtain ⟨h_V_check, h_rel, h_agree⟩ := strongly_complete (stmtIn := stmtIn) (witIn := witIn) (h_relIn := h_relIn) (challenges := fun ⟨j, hj⟩ => by @@ -507,7 +488,6 @@ theorem iteratedSumcheckOracleReduction_perfectCompleteness (i : Fin ℓ') (hIni exact hj_ne hj | 1 => exact r1 ) - have h_V_check_is_true : V_check := h_V_check simp only [h_V_check_is_true, ↓reduceIte, Fin.isValue, pure_bind] at h_verOut_mem_support erw [simulateQ_pure, liftM_pure] at h_verOut_mem_support @@ -516,9 +496,7 @@ theorem iteratedSumcheckOracleReduction_perfectCompleteness (i : Fin ℓ') (hIni rcases h_verOut_mem_support with ⟨verStmtOut_eq, verOStmtOut_eq⟩ dsimp only [sumcheckStepLogic, sumcheckProverComputeMsg, step] at prvOut_eq rw [Prod.mk.injEq, Prod.mk.injEq] at prvOut_eq - obtain ⟨⟨prvStmtOut_eq, prvOStmtOut_eq⟩, prvWitOut_eq⟩ := prvOut_eq - constructor · rw [prvWitOut_eq, verStmtOut_eq, verOStmtOut_eq]; exact h_rel @@ -529,6 +507,16 @@ theorem iteratedSumcheckOracleReduction_perfectCompleteness (i : Fin ℓ') (hIni def iteratedSumcheckRoundKnowledgeError (_ : Fin ℓ') : ℝ≥0 := (2 : ℝ≥0) / (Fintype.card L) +/-- Witness type at each message index for the iterated sumcheck step + (counterpart of BBF `foldWitMid`). + At m=0,1 we have input-round witness; at m=2 we have output-round witness so extractOut can + be identity. -/ +def iteratedSumcheckWitMid (i : Fin ℓ') : Fin (2 + 1) → Type := + fun m => match m with + | ⟨0, _⟩ => SumcheckWitness L ℓ' i.castSucc + | ⟨1, _⟩ => SumcheckWitness L ℓ' i.castSucc + | ⟨2, _⟩ => SumcheckWitness L ℓ' i.succ + noncomputable def iteratedSumcheckRbrExtractor (i : Fin ℓ') : Extractor.RoundByRound []ₒ (StmtIn := (Statement (L := L) (ℓ := ℓ') @@ -536,124 +524,475 @@ noncomputable def iteratedSumcheckRbrExtractor (i : Fin ℓ') : (WitIn := SumcheckWitness L ℓ' i.castSucc) (WitOut := SumcheckWitness L ℓ' i.succ) (pSpec := pSpecSumcheckRound L) - (WitMid := fun _messageIdx => SumcheckWitness L ℓ' i.castSucc) where + (WitMid := iteratedSumcheckWitMid (L := L) (ℓ' := ℓ') (i := i)) where eqIn := rfl - extractMid := fun _ _ _ witMidSucc => witMidSucc - extractOut := fun ⟨stmtIn, oStmtIn⟩ fullTranscript witOut => by - exact { - t' := witOut.t', - H := projectToMidSumcheckPoly (L := L) (ℓ := ℓ') (t := witOut.t') - (m := (RingSwitching_SumcheckMultParam κ L K β ℓ ℓ' h_l).multpoly (ctx := stmtIn.ctx)) - (i := i.castSucc) (challenges := stmtIn.challenges) - } - -/-- This follows the KState of `foldKStateProp` -/ + extractMid := fun m ⟨stmtIn, _⟩ _tr witMidSucc => + match m with + | ⟨0, _⟩ => witMidSucc -- WitMid 1 → WitMid 0, both SumcheckWitness i.castSucc + | ⟨1, _⟩ => + -- WitMid 2 → WitMid 1: extract backward from output witness using input challenges + { + t' := witMidSucc.t', + H := projectToMidSumcheckPoly (L := L) (ℓ := ℓ') (t := witMidSucc.t') + (m := (RingSwitching_SumcheckMultParam κ L K β ℓ ℓ' h_l).multpoly (ctx := stmtIn.ctx)) + (i := i.castSucc) (challenges := stmtIn.challenges) + } + extractOut := fun _stmtIn _fullTranscript witOut => witOut + +/-- KState for the iterated sumcheck step, matching the structure of Binary Basefold's +`foldKStateProp`: +- m=0: same as relIn (masterKStateProp at i.castSucc with sumcheckConsistencyProp) +- m=1: after P sends hᵢ(X), before V sends r'ᵢ (explicitVCheck ∧ localizedRoundPolyCheck) +- m=2: after V sends r'ᵢ — OUTPUT state (masterKStateProp at i.succ with stmtOut, witMid, + sumcheckConsistencyProp) + At m=2, witMid has type SumcheckWitness i.succ (via iteratedSumcheckWitMid). -/ def iteratedSumcheckKStateProp (i : Fin ℓ') (m : Fin (2 + 1)) (tr : Transcript m (pSpecSumcheckRound L)) - (stmt : Statement (L := L) (ℓ := ℓ') (RingSwitchingBaseContext κ L K ℓ) i.castSucc) - (witMid : SumcheckWitness L ℓ' i.castSucc) - (oStmt : ∀ j, aOStmtIn.OStmtIn j) : + (stmtMid : Statement (L := L) (ℓ := ℓ') (RingSwitchingBaseContext κ L K ℓ) i.castSucc) + (witMid : iteratedSumcheckWitMid (L := L) (ℓ' := ℓ') (i := i) m) + (oStmtMid : ∀ j, aOStmtIn.OStmtIn j) : Prop := - -- Ground-truth polynomial from witness - let h_star : ↥L⦃≤ 2⦄[X] := getSumcheckRoundPoly ℓ' 𝓑 (i := i) (h := witMid.H) - -- Checks available after message 1 (P -> V : hᵢ(X)) - let get_Hᵢ := fun (m: Fin (2 + 1)) (tr: Transcript m (pSpecSumcheckRound L)) (hm: 1 ≤ m.val) => - let ⟨msgsUpTo, _⟩ := Transcript.equivMessagesChallenges (k := m) - (pSpec := pSpecSumcheckRound L) tr - let i_msg1 : ((pSpecSumcheckRound L).take m m.is_le).MessageIdx := - ⟨⟨0, Nat.lt_of_succ_le hm⟩, by simp [pSpecSumcheckRound]; rfl⟩ - let h_i : L⦃≤ 2⦄[X] := msgsUpTo i_msg1 - h_i - - let get_rᵢ' := fun (m: Fin (2 + 1)) (tr: Transcript m (pSpecSumcheckRound L)) (hm: 2 ≤ m.val) => - let ⟨msgsUpTo, chalsUpTo⟩ := Transcript.equivMessagesChallenges (k := m) - (pSpec := pSpecSumcheckRound L) tr - let i_msg1 : ((pSpecSumcheckRound L).take m m.is_le).MessageIdx := - ⟨⟨0, Nat.lt_of_succ_le (Nat.le_trans (by decide) hm)⟩, by simp; rfl⟩ - let h_i : L⦃≤ 2⦄[X] := msgsUpTo i_msg1 - let i_msg2 : ((pSpecSumcheckRound L).take m m.is_le).ChallengeIdx := - ⟨⟨1, Nat.lt_of_succ_le hm⟩, by simp only [Nat.reduceAdd]; rfl⟩ - let r_i' : L := chalsUpTo i_msg2 - r_i' - match m with - | ⟨0, _⟩ => -- equiv s relIn - RingSwitching.masterKStateProp κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) + | ⟨0, _⟩ => -- Same as relIn (sumcheckRoundRelation at i.castSucc) + RingSwitching.masterKStateProp κ L K β ℓ ℓ' h_l aOStmtIn (stmtIdx := i.castSucc) - (stmt := stmt) (oStmt := oStmt) (wit := witMid) - (localChecks := True) - | ⟨1, h1⟩ => -- P sends hᵢ(X) - RingSwitching.masterKStateProp κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) aOStmtIn + (stmt := stmtMid) (oStmt := oStmtMid) (wit := witMid) + (localChecks := sumcheckConsistencyProp (𝓑 := 𝓑) stmtMid.sumcheck_target witMid.H) + | ⟨1, _⟩ => -- After P sends hᵢ(X), before V sends r'ᵢ + let h_star : ↥L⦃≤ 2⦄[X] := getSumcheckRoundPoly ℓ' 𝓑 (i := i) (h := witMid.H) + let h_i : ↥L⦃≤ 2⦄[X] := tr.messages ⟨0, rfl⟩ + RingSwitching.masterKStateProp κ L K β ℓ ℓ' h_l aOStmtIn (stmtIdx := i.castSucc) - (stmt := stmt) (oStmt := oStmt) (wit := witMid) + (stmt := stmtMid) (oStmt := oStmtMid) (wit := witMid) (localChecks := - let h_i := get_Hᵢ (m := ⟨1, h1⟩) (tr := tr) (hm := by simp only [le_refl]) - let explicitVCheck := h_i.val.eval (𝓑 0) + h_i.val.eval (𝓑 1) = stmt.sumcheck_target + let explicitVCheck := h_i.val.eval (𝓑 0) + h_i.val.eval (𝓑 1) = stmtMid.sumcheck_target let localizedRoundPolyCheck := h_i = h_star explicitVCheck ∧ localizedRoundPolyCheck ) - | ⟨2, h2⟩ => -- implied by (relOut + V's check) - -- The bad-folding-event of `fᵢ` is also introduced internaly by `masterKStateProp` - RingSwitching.masterKStateProp κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) aOStmtIn - (stmtIdx := i.castSucc) - (stmt := stmt) (oStmt := oStmt) (wit := witMid) + | ⟨2, _⟩ => -- After V sends r'ᵢ: use OUTPUT state (witMid is already SumcheckWitness i.succ) + let h_i : ↥L⦃≤ 2⦄[X] := tr.messages ⟨0, rfl⟩ + let r_i' : L := tr.challenges ⟨1, rfl⟩ + let stmtOut : Statement (L := L) (ℓ := ℓ') (RingSwitchingBaseContext κ L K ℓ) i.succ := + sumcheckVerifierStmtOut (κ := κ) (L := L) (K := K) (ℓ := ℓ) (ℓ' := ℓ') i stmtMid h_i r_i' + let oStmtOut := oStmtMid + let witOut := witMid + RingSwitching.masterKStateProp κ L K β ℓ ℓ' h_l aOStmtIn + (stmtIdx := i.succ) + (stmt := stmtOut) (oStmt := oStmtOut) (wit := witOut) (localChecks := - let h_i := get_Hᵢ (m := ⟨2, h2⟩) (tr := tr) (hm := by simp only [Nat.one_le_ofNat]) - let r_i' := get_rᵢ' (m := ⟨2, h2⟩) (tr := tr) (hm := by simp only [le_refl]) - let localizedRoundPolyCheck := h_i = h_star - let nextSumcheckTargetCheck := -- this presents sumcheck of next round (sᵢ = s^*ᵢ) - h_i.val.eval r_i' = h_star.val.eval r_i' - localizedRoundPolyCheck ∧ nextSumcheckTargetCheck - ) -- this holds the constraint for witOut in relOut + let explicitVCheck := h_i.val.eval (𝓑 0) + h_i.val.eval (𝓑 1) = stmtMid.sumcheck_target + explicitVCheck ∧ + sumcheckConsistencyProp (𝓑 := 𝓑) stmtOut.sumcheck_target witOut.H + ) /-- Knowledge state function (KState) for single round -/ def iteratedSumcheckKnowledgeStateFunction (i : Fin ℓ') : - (iteratedSumcheckOracleVerifier κ (L := L) (K := K) (ℓ := ℓ) (ℓ' := ℓ') (𝓑 := 𝓑) (β := β) (h_l := h_l) aOStmtIn i).KnowledgeStateFunction init impl + (iteratedSumcheckOracleVerifier κ (L := L) (K := K) (ℓ := ℓ) (ℓ' := ℓ') (𝓑 := 𝓑) (β := β) + (h_l := h_l) aOStmtIn i).KnowledgeStateFunction init impl (relIn := sumcheckRoundRelation κ L K β ℓ ℓ' h_l (𝓑:=𝓑) aOStmtIn i.castSucc) (relOut := sumcheckRoundRelation κ L K β ℓ ℓ' h_l (𝓑:=𝓑) aOStmtIn i.succ) (extractor := iteratedSumcheckRbrExtractor κ L K β ℓ ℓ' h_l aOStmtIn i) where toFun := fun m ⟨stmt, oStmt⟩ tr witMid => iteratedSumcheckKStateProp κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) - (i := i) (m := m) (tr := tr) (stmt := stmt) (witMid := witMid) (oStmt := oStmt) - toFun_empty := fun _ _ => by - simp only [sumcheckRoundRelation, sumcheckRoundRelationProp, Fin.val_castSucc, cast_eq, - Set.mem_setOf_eq, iteratedSumcheckKStateProp, masterKStateProp, true_and] - toFun_next := fun m hDir stmtIn tr msg witMid => by - obtain ⟨stmt, oStmt⟩ := stmtIn - fin_cases m - · -- m = 0: succ = 1, castSucc = 0 - unfold iteratedSumcheckKStateProp - simp only [masterKStateProp, iteratedSumcheckRbrExtractor, true_and] - simp only [Fin.succ_mk, Fin.castSucc_mk, Fin.castAdd_mk] - tauto - · -- m = 1: dir 1 = V_to_P, contradicts hDir - simp [pSpecSumcheckRound] at hDir - toFun_full := fun ⟨stmtLast, oStmtLast⟩ tr witOut => by - intro h_relOut - simp at h_relOut - rcases h_relOut with ⟨stmtOut, ⟨oStmtOut, h_conj⟩⟩ - have h_simulateQ := h_conj.1 - have h_SumcheckStepRelOut := h_conj.2 - set witLast := (iteratedSumcheckRbrExtractor κ L K β ℓ ℓ' h_l aOStmtIn i).extractOut - ⟨stmtLast, oStmtLast⟩ tr witOut - simp only [Fin.reduceLast, Fin.isValue] - -- ⊢ iteratedSumcheckKStateProp 𝔽q β 2 tr stmtLast witLast oStmtLast - -- TODO : prove this via the relations between stmtLast & stmtOut, - -- witLast & witOut, oStmtLast & oStmtOut - sorry + (i := i) (m := m) (tr := tr) (stmtMid := stmt) (witMid := witMid) (oStmtMid := oStmt) + toFun_empty := fun ⟨stmtIn, oStmtIn⟩ witMid => by + simp only [iteratedSumcheckKStateProp, sumcheckRoundRelation, sumcheckRoundRelationProp, + Set.mem_setOf_eq, Fin.val_castSucc, cast_eq] + toFun_next := fun m hDir ⟨stmtMid, oStmtMid⟩ tr msg witMid => by + -- For pSpecFold, the only P_to_V message is at index 0 + -- So m = 0, m.succ = 1, m.castSucc = 0 + have h_m_eq_0 : m = 0 := by + cases m using Fin.cases with + | zero => rfl + | succ m' => simp only [ne_eq, reduceCtorEq, not_false_eq_true, Matrix.cons_val_succ, + Matrix.cons_val_fin_one, Direction.not_V_to_P_eq_P_to_V] at hDir + subst h_m_eq_0 + intro h_kState_round1 + unfold iteratedSumcheckKStateProp at h_kState_round1 ⊢ + simp only [Fin.isValue, Fin.succ_zero_eq_one, Nat.reduceAdd, Fin.mk_one, + Fin.coe_ofNat_eq_mod, Nat.reduceMod] at h_kState_round1 + simp only [Fin.castSucc_zero] + -- At round 1: masterKStateProp with (explicitVCheck ∧ localizedRoundPolyCheck) + -- At round 0: masterKStateProp with sumcheckConsistencyProp + -- Extract the checks from round 1 + obtain ⟨⟨h_explicit, h_localized⟩, h_core⟩ := h_kState_round1 + -- Key: h_localized says h_i = h_star, and h_explicit says h_i(0) + h_i(1) = s + -- Therefore h_star(0) + h_star(1) = s, which is what Lemma 1.1 gives us + constructor + · -- Prove sumcheckConsistencyProp at round 0 + simp_rw [h_localized] at h_explicit + rw [h_explicit.symm] + apply getSumcheckRoundPoly_sum_eq + · -- The core (badEventExists ∨ oracleWitnessConsistency) is preserved + exact h_core + toFun_full := fun ⟨stmtIn, oStmtIn⟩ tr witOut probEvent_relOut_gt_0 => by + -- h_relOut: ∃ stmtOut oStmtOut, verifier outputs (stmtOut, oStmtOut) with prob > 0 + -- and ((stmtOut, oStmtOut), witOut) ∈ foldStepRelOut + simp only [StateT.run'_eq, gt_iff_lt, probEvent_pos_iff, Prod.exists] at probEvent_relOut_gt_0 + rcases probEvent_relOut_gt_0 with ⟨stmtOut, oStmtOut, h_output_mem_V_run_support, h_relOut⟩ + have h_output_mem_V_run_support' : + some (stmtOut, oStmtOut) ∈ + (do + let s ← init + Prod.fst <$> + (simulateQ impl + (Verifier.run (stmtIn, oStmtIn) tr + (iteratedSumcheckOracleVerifier κ (L := L) (K := K) (ℓ := ℓ) (ℓ' := ℓ') (𝓑 := 𝓑) + (β := β) (h_l := h_l) aOStmtIn i).toVerifier).run s)).support:= by + exact (OptionT.mem_support_iff + (mx := OptionT.mk (do + let s ← init + Prod.fst <$> + (simulateQ impl + (Verifier.run (stmtIn, oStmtIn) tr + (iteratedSumcheckOracleVerifier κ (L := L) (K := K) (ℓ := ℓ) (ℓ' := ℓ') (𝓑 := 𝓑) + (β := β) (h_l := h_l) aOStmtIn i).toVerifier).run s))) + (x := (stmtOut, oStmtOut))).1 h_output_mem_V_run_support + simp only [support_bind, Set.mem_iUnion, exists_prop] at h_output_mem_V_run_support' + rcases h_output_mem_V_run_support' with ⟨s, hs_init, h_output_mem_V_run_support⟩ + conv at h_output_mem_V_run_support => + simp only [Verifier.run, OracleVerifier.toVerifier] + -- Now unfold the foldOracleVerifier's `verify()` method + simp only [iteratedSumcheckOracleVerifier] + -- dsimp only [StateT.run] + -- simp only [simulateQ_bind, simulateQ_query, simulateQ_pure] + -- oracle query unfolding + simp only [support_bind, Set.mem_iUnion] + dsimp only [StateT.run] + -- enter [1, i_1, 2, 1, x] + simp only [simulateQ_bind] + unfold OracleInterface.answer + --------------------------------------- + -- Now simplify the `guard` and `ite` of StateT.map generated from it + simp only [MessageIdx, Fin.isValue, Matrix.cons_val_zero, simulateQ_pure, Message, guard_eq, + pure_bind, Function.comp_apply, simulateQ_map, simulateQ_ite, + OptionT.simulateQ_failure', bind_map_left] + simp only [MessageIdx, Message, Fin.isValue, Matrix.cons_val_zero, Matrix.cons_val_one, + bind_pure_comp, simulateQ_map, simulateQ_ite, simulateQ_pure, OptionT.simulateQ_failure', + bind_map_left, Function.comp_apply] + simp only [support_ite] + simp only [Fin.isValue, Set.mem_ite_empty_right, Set.mem_singleton_iff, Prod.mk.injEq, + exists_and_left, exists_eq', exists_eq_right, exists_and_right] + erw [simulateQ_bind] + enter [1, x, 1, 2, 1, 2]; + erw [simulateQ_bind] + erw [OptionT.simulateQ_simOracle2_liftM_query_T2] + simp only [Fin.isValue, FullTranscript.mk1_eq_snoc, pure_bind, OptionT.simulateQ_map] + conv at h_output_mem_V_run_support => + simp only [Fin.isValue, FullTranscript.mk1_eq_snoc, Function.comp_apply] + erw [support_bind] at h_output_mem_V_run_support + let step := (sumcheckStepLogic (κ := κ) (L := L) (K := K) (β := β) (𝓑 := 𝓑) (ℓ := ℓ) (ℓ' := ℓ') + (h_l := h_l) (aOStmtIn := aOStmtIn)) (i := i) + set V_check := step.verifierCheck stmtIn + (FullTranscript.mk2 (msg0 := _) (msg1 := _)) with h_V_check_def + by_cases h_V_check : V_check + · + simp only [Fin.isValue, Matrix.cons_val_zero, id_eq, h_V_check, ↓reduceIte, OptionT.run_pure, + simulateQ_pure, Function.comp_apply, Set.mem_iUnion, exists_prop, Prod.exists, + exists_and_right] at h_output_mem_V_run_support + erw [simulateQ_bind] at h_output_mem_V_run_support + simp only [simulateQ_pure, Fin.isValue, Function.comp_apply, + pure_bind] at h_output_mem_V_run_support + erw [support_pure] at h_output_mem_V_run_support + simp only [Fin.isValue, Set.mem_singleton_iff, Prod.mk.injEq, exists_eq_right, + exists_eq_left] at h_output_mem_V_run_support + erw [support_pure] at h_output_mem_V_run_support + simp only [Fin.isValue, Set.mem_singleton_iff, Option.some.injEq, + Prod.mk.injEq] at h_output_mem_V_run_support + rcases h_output_mem_V_run_support with ⟨h_stmtOut_eq, h_oStmtOut_eq⟩ + simp only [Fin.reduceLast, Fin.isValue] -- simp the `match` + dsimp only [sumcheckRoundRelation, sumcheckRoundRelationProp, masterKStateProp] at h_relOut + simp only [Fin.val_succ, Set.mem_setOf_eq] at h_relOut + dsimp only [iteratedSumcheckKStateProp] + set h_i : ↥L⦃≤ 2⦄[X] := tr.messages ⟨(0 : Fin 2), rfl⟩ with h_i_def + set r_i' : L := tr.challenges ⟨(1 : Fin 2), rfl⟩ with h_i_def + set extractedWitLast : SumcheckWitness L ℓ' i.succ := + (iteratedSumcheckRbrExtractor κ (L:=L) (K:=K) (ℓ:=ℓ) (ℓ':=ℓ') (β := β) + (h_l := h_l) aOStmtIn i).extractOut (stmtIn, oStmtIn) tr witOut + have h_oStmtOut_eq_oStmtIn : oStmtOut = oStmtIn := by + rw [h_oStmtOut_eq] + funext j + simp only [MessageIdx, Function.Embedding.coeFn_mk, Sum.inl.injEq, + OracleVerifier.mkVerifierOStmtOut_inl, cast_eq] + rw [h_oStmtOut_eq_oStmtIn] at h_relOut + dsimp only [sumcheckVerifierStmtOut] + have h_stmtOut_sumcheck_target_eq : stmtOut.sumcheck_target = (Polynomial.eval r_i' ↑h_i) + := by rw [h_stmtOut_eq]; rfl + dsimp only [masterKStateProp] + constructor + · constructor + · simpa [h_i_def] using h_V_check + · rw [h_stmtOut_sumcheck_target_eq] at h_relOut + exact h_relOut.1 + · obtain ⟨h_wit_struct_In, h_oStmtIn_compat⟩ := h_relOut.2 + constructor + · -- witnessStructuralInvariant + unfold witnessStructuralInvariant at h_wit_struct_In ⊢ + dsimp only [Fin.val_succ] + rw [h_stmtOut_eq] at h_wit_struct_In + exact h_wit_struct_In + · -- initialCompatibility + exact h_oStmtIn_compat + · simp only [Fin.isValue, h_V_check, ↓reduceIte, OptionT.run_failure, simulateQ_pure, + Set.mem_iUnion, exists_prop, Prod.exists] at h_output_mem_V_run_support + erw [simulateQ_bind] at h_output_mem_V_run_support + simp only [simulateQ_pure, Fin.isValue, Function.comp_apply, + pure_bind] at h_output_mem_V_run_support + erw [support_pure] at h_output_mem_V_run_support + simp only [Set.mem_singleton_iff, Prod.mk.injEq, ↓existsAndEq, and_true, exists_eq_left, + simulateQ_pure] at h_output_mem_V_run_support + erw [support_pure] at h_output_mem_V_run_support + simp only [Set.mem_singleton_iff, Prod.mk.injEq, reduceCtorEq, false_and, + exists_false] at h_output_mem_V_run_support -- False + +/-- Extraction failure implies a witness-dependent bad sumcheck event (no folding here). + The extracted `witMid` also carries oracle compatibility at the same `oStmt`. -/ +lemma iteratedSumcheck_rbrExtractionFailureEvent_imply_badSumcheck (i : Fin ℓ') + (stmtOStmtIn : (Statement (L := L) (ℓ := ℓ') (RingSwitchingBaseContext κ L K ℓ) i.castSucc) + × (∀ j, aOStmtIn.OStmtIn j)) + (h_i : (pSpecSumcheckRound L).Message ⟨0, rfl⟩) (r_i' : L) + (doomEscape : rbrExtractionFailureEvent + (kSF := iteratedSumcheckKnowledgeStateFunction (κ := κ) (L := L) (K := K) (ℓ := ℓ) (ℓ' := ℓ') + (𝓑 := 𝓑) (β := β) (h_l := h_l) aOStmtIn (init := init) (impl := impl) i) + (extractor := iteratedSumcheckRbrExtractor κ L K β ℓ ℓ' h_l aOStmtIn i) + (i := ⟨1, rfl⟩) (stmtIn := stmtOStmtIn) (transcript := FullTranscript.mk1 h_i) + (challenge := r_i')) : + ∃ witMid : SumcheckWitness L ℓ' i.succ, + aOStmtIn.initialCompatibility (witMid.t', stmtOStmtIn.2) ∧ + let witBefore : SumcheckWitness L ℓ' i.castSucc := + (iteratedSumcheckRbrExtractor.{0, 0} κ L K β ℓ ℓ' h_l aOStmtIn i).extractMid + (m := 1) stmtOStmtIn (FullTranscript.mk2 h_i r_i') witMid + let h_star : L⦃≤ 2⦄[X] := getSumcheckRoundPoly ℓ' 𝓑 (i := i) (h := witBefore.H) + badSumcheckEventProp r_i' h_i h_star := by + classical + unfold rbrExtractionFailureEvent at doomEscape + rcases doomEscape with ⟨witMid, h_kState_before_false, h_kState_after_true⟩ + simp only [iteratedSumcheckKnowledgeStateFunction] at h_kState_before_false h_kState_after_true + unfold iteratedSumcheckKStateProp at h_kState_before_false h_kState_after_true + simp only [Fin.isValue, Fin.castSucc_one, Fin.succ_one_eq_two, Nat.reduceAdd] + at h_kState_before_false h_kState_after_true + simp only [Transcript.concat, sumcheckVerifierStmtOut] + at h_kState_before_false h_kState_after_true + unfold masterKStateProp witnessStructuralInvariant at h_kState_before_false h_kState_after_true + simp only [iteratedSumcheckRbrExtractor, Fin.isValue] + at h_kState_before_false h_kState_after_true + have h_explicit_after : + h_i.val.eval (𝓑 0) + h_i.val.eval (𝓑 1) = stmtOStmtIn.1.sumcheck_target := by + simpa using h_kState_after_true.1.1 + have h_sumcheck_after : + sumcheckConsistencyProp (𝓑 := 𝓑) (Polynomial.eval r_i' h_i.val) witMid.H := by + simpa using h_kState_after_true.1.2 + have h_wit_struct_after : + witMid.H = projectToMidSumcheckPoly (L := L) (ℓ := ℓ') (t := witMid.t') + (m := (RingSwitching_SumcheckMultParam κ L K β ℓ ℓ' h_l).multpoly stmtOStmtIn.1.ctx) + (i := i.succ) (challenges := Fin.snoc stmtOStmtIn.1.challenges r_i') := by + simpa using h_kState_after_true.2.1 + have h_init_compat : aOStmtIn.initialCompatibility (witMid.t', stmtOStmtIn.2) + := h_kState_after_true.2.2 + let H_before : L⦃≤ 2⦄[X Fin (ℓ' - i.castSucc)] := + projectToMidSumcheckPoly (L := L) (ℓ := ℓ') (t := witMid.t') + (m := (RingSwitching_SumcheckMultParam κ L K β ℓ ℓ' h_l).multpoly stmtOStmtIn.1.ctx) + (i := i.castSucc) (challenges := stmtOStmtIn.1.challenges) + let h_star_extracted : L⦃≤ 2⦄[X] := getSumcheckRoundPoly ℓ' 𝓑 (i := i) (h := H_before) + have h_eval_eq_extracted : Polynomial.eval r_i' h_i.val + = Polynomial.eval r_i' h_star_extracted.val := by + unfold sumcheckConsistencyProp at h_sumcheck_after + rw [h_wit_struct_after] at h_sumcheck_after + rw [projectToMidSumcheckPoly_succ (L := L) (ℓ := ℓ') (t := witMid.t') + (m := (RingSwitching_SumcheckMultParam κ L K β ℓ ℓ' h_l).multpoly stmtOStmtIn.1.ctx) + (i := i) (challenges := stmtOStmtIn.1.challenges) (r_i' := r_i')] at h_sumcheck_after + have h_sum_eq := + projectToNextSumcheckPoly_sum_eq (L := L) (𝓑 := 𝓑) (ℓ := ℓ') + (i := i) (Hᵢ := H_before) (rᵢ := r_i') + have h_sum_eq' : + Polynomial.eval r_i' h_star_extracted.val = + ∑ x ∈ (univ.map 𝓑) ^ᶠ (ℓ' - i.succ), + (projectToNextSumcheckPoly (L := L) (ℓ := ℓ') (i := i) (Hᵢ := H_before) + (rᵢ := r_i')).val.eval x := by + simpa [h_star_extracted] using h_sum_eq + calc + Polynomial.eval r_i' h_i.val + = ∑ x ∈ (univ.map 𝓑) ^ᶠ (ℓ' - i.succ), + (projectToNextSumcheckPoly (L := L) (ℓ := ℓ') (i := i) (Hᵢ := H_before) + (rᵢ := r_i')).val.eval x := h_sumcheck_after + _ = Polynomial.eval r_i' h_star_extracted.val := by + symm + exact h_sum_eq' + have h_hi_ne_extracted : h_i ≠ h_star_extracted := by + intro h_eq + apply h_kState_before_false + constructor + · constructor + · exact h_explicit_after + · simpa [h_star_extracted] using h_eq + · constructor + · -- The middle conjunct at m=1 simplifies to `True`. + trivial + · -- initialCompatibility is preserved by extractMid(m=1) since t' is unchanged. + simpa [iteratedSumcheckRbrExtractor, Fin.isValue] using h_init_compat + have h_bad_extracted : badSumcheckEventProp r_i' h_i h_star_extracted := by + refine ⟨h_hi_ne_extracted, h_eval_eq_extracted⟩ + refine ⟨witMid, h_init_compat, ?_⟩ + simpa [h_star_extracted, H_before, iteratedSumcheckRbrExtractor, Fin.isValue] + using h_bad_extracted + +/-- Per-transcript bound: for prover message h_i, the probability (over verifier challenge y) + that extraction fails is at most iteratedSumcheckRoundKnowledgeError (2/|L|). + Counterpart of BBF `foldStep_doom_escape_probability_bound`; no folding bad event here. -/ +lemma iteratedSumcheck_doom_escape_probability_bound (i : Fin ℓ') + (stmtOStmtIn : (Statement (L := L) (ℓ := ℓ') (RingSwitchingBaseContext κ L K ℓ) i.castSucc) + × (∀ j, aOStmtIn.OStmtIn j)) + (h_i : (pSpecSumcheckRound L).Message ⟨0, rfl⟩) : + Pr_{ let y ← $ᵖ L }[ + rbrExtractionFailureEvent + (kSF := iteratedSumcheckKnowledgeStateFunction (κ := κ) (L := L) (K := K) (ℓ := ℓ) + (ℓ' := ℓ') (𝓑 := 𝓑) (β := β) (h_l := h_l) aOStmtIn (init := init) (impl := impl) i) + (extractor := iteratedSumcheckRbrExtractor κ L K β ℓ ℓ' h_l aOStmtIn i) + ⟨1, rfl⟩ stmtOStmtIn (FullTranscript.mk1 h_i) y ] ≤ + iteratedSumcheckRoundKnowledgeError L ℓ' i := by + classical + let compatPred : MultilinearPoly L ℓ' → Prop := fun t => + aOStmtIn.initialCompatibility (t, stmtOStmtIn.2) + by_cases hCompat : ∃ t : MultilinearPoly L ℓ', compatPred t + · rcases hCompat with ⟨t_fixed, h_t_fixed_compat⟩ + let H_fixed : L⦃≤ 2⦄[X Fin (ℓ' - i.castSucc)] := + projectToMidSumcheckPoly (L := L) (ℓ := ℓ') (t := t_fixed) + (m := (RingSwitching_SumcheckMultParam κ L K β ℓ ℓ' h_l).multpoly stmtOStmtIn.1.ctx) + (i := i.castSucc) (challenges := stmtOStmtIn.1.challenges) + let h_star_fixed : L⦃≤ 2⦄[X] := getSumcheckRoundPoly ℓ' 𝓑 (i := i) (h := H_fixed) + have h_prob_mono := prob_mono (D := $ᵖ L) + (f := fun y => rbrExtractionFailureEvent + (kSF := iteratedSumcheckKnowledgeStateFunction (κ := κ) (L := L) (K := K) (ℓ := ℓ) + (ℓ' := ℓ') (𝓑 := 𝓑) (β := β) (h_l := h_l) aOStmtIn (init := init) (impl := impl) i) + (extractor := iteratedSumcheckRbrExtractor κ L K β ℓ ℓ' h_l aOStmtIn i) + ⟨1, rfl⟩ stmtOStmtIn (FullTranscript.mk1 h_i) y) + (g := fun y => badSumcheckEventProp y h_i h_star_fixed) + (h_imp := by + intro y h_doom + obtain ⟨witMid, h_mid_compat, h_bad_extracted⟩ := + iteratedSumcheck_rbrExtractionFailureEvent_imply_badSumcheck + (κ := κ) (L := L) (K := K) (β := β) (ℓ := ℓ) (ℓ' := ℓ') (h_l := h_l) + (𝓑 := 𝓑) (aOStmtIn := aOStmtIn) (impl := impl) (init := init) + (i := i) (stmtOStmtIn := stmtOStmtIn) (h_i := h_i) (r_i' := y) + (doomEscape := h_doom) + have h_t_eq : witMid.t' = t_fixed := + aOStmtIn.initialCompatibility_unique stmtOStmtIn.2 witMid.t' t_fixed + h_mid_compat h_t_fixed_compat + simpa [h_star_fixed, H_fixed, iteratedSumcheckRbrExtractor, Fin.isValue, h_t_eq] + using h_bad_extracted) + apply le_trans h_prob_mono + have h_sz := probability_bound_badSumcheckEventProp (h_i := h_i) (h_star := h_star_fixed) + conv_rhs => + dsimp only [iteratedSumcheckRoundKnowledgeError] + rw [ENNReal.coe_div (hr := by simp only [ne_eq, Nat.cast_eq_zero, Fintype.card_ne_zero, + not_false_eq_true])] + simp only [ENNReal.coe_ofNat, ENNReal.coe_natCast] + exact h_sz + · have h_prob_mono_false := prob_mono (D := $ᵖ L) + (f := fun y => rbrExtractionFailureEvent + (kSF := iteratedSumcheckKnowledgeStateFunction (κ := κ) (L := L) (K := K) (ℓ := ℓ) + (ℓ' := ℓ') (𝓑 := 𝓑) (β := β) (h_l := h_l) aOStmtIn (init := init) (impl := impl) i) + (extractor := iteratedSumcheckRbrExtractor κ L K β ℓ ℓ' h_l aOStmtIn i) + ⟨1, rfl⟩ stmtOStmtIn (FullTranscript.mk1 h_i) y) + (g := fun _ => False) + (h_imp := by + intro y h_doom + obtain ⟨witMid, h_mid_compat, _h_bad_extracted⟩ := + iteratedSumcheck_rbrExtractionFailureEvent_imply_badSumcheck + (κ := κ) (L := L) (K := K) (β := β) (ℓ := ℓ) (ℓ' := ℓ') (h_l := h_l) + (𝓑 := 𝓑) (aOStmtIn := aOStmtIn) (impl := impl) (init := init) + (i := i) (stmtOStmtIn := stmtOStmtIn) (h_i := h_i) (r_i' := y) + (doomEscape := h_doom) + exact (hCompat ⟨witMid.t', h_mid_compat⟩).elim) + refine le_trans h_prob_mono_false ?_ + simp only [PMF.monad_pure_eq_pure, PMF.monad_bind_eq_bind, PMF.bind_const, PMF.pure_apply, + eq_iff_iff, iff_false, not_true_eq_false, ↓reduceIte, _root_.zero_le] /-- RBR knowledge soundness for a single round oracle verifier -/ theorem iteratedSumcheckOracleVerifier_rbrKnowledgeSoundness (i : Fin ℓ') : - (iteratedSumcheckOracleVerifier κ (L:=L) (K:=K) (ℓ:=ℓ) (ℓ':=ℓ') (𝓑:=𝓑) (β := β) (h_l := h_l) aOStmtIn i).rbrKnowledgeSoundness init impl + (iteratedSumcheckOracleVerifier κ (L:=L) (K:=K) (ℓ:=ℓ) (ℓ':=ℓ') + (𝓑:=𝓑) (β := β) (h_l := h_l) aOStmtIn i).rbrKnowledgeSoundness init impl (relIn := sumcheckRoundRelation κ L K β ℓ ℓ' h_l (𝓑:=𝓑) aOStmtIn i.castSucc) (relOut := sumcheckRoundRelation κ L K β ℓ ℓ' h_l (𝓑:=𝓑) aOStmtIn i.succ) - (fun j => iteratedSumcheckRoundKnowledgeError L ℓ' i) := by - use fun _ => SumcheckWitness L ℓ' i.castSucc - use iteratedSumcheckRbrExtractor κ (L:=L) (K:=K) (ℓ:=ℓ) (ℓ':=ℓ') (β := β) (h_l := h_l) aOStmtIn i - use iteratedSumcheckKnowledgeStateFunction (κ := κ) (L := L) (K := K) (ℓ := ℓ) (ℓ' := ℓ') (𝓑 := 𝓑) (β := β) (h_l := h_l) aOStmtIn i - intro stmtIn witIn prover j - sorry + (rbrKnowledgeError := fun _ => iteratedSumcheckRoundKnowledgeError L ℓ' i) := by + classical + apply OracleReduction.unroll_rbrKnowledgeSoundness + (kSF := iteratedSumcheckKnowledgeStateFunction (κ := κ) (L := L) (K := K) + (ℓ := ℓ) (ℓ' := ℓ') (𝓑 := 𝓑) (β := β) (h_l := h_l) aOStmtIn i) + intro stmtOStmtIn witIn prover j initState + let P := rbrExtractionFailureEvent + (kSF := iteratedSumcheckKnowledgeStateFunction (κ := κ) (L := L) (K := K) + (ℓ := ℓ) (ℓ' := ℓ') (β := β) (𝓑 := 𝓑) (h_l := h_l) aOStmtIn (impl := impl) (init := init) i) + (iteratedSumcheckRbrExtractor κ (L:=L) (K:=K) (ℓ:=ℓ) (ℓ':=ℓ') (β := β) (h_l := h_l) aOStmtIn i) + j + stmtOStmtIn + rw [OracleReduction.probEvent_soundness_goal_unroll_log' (pSpec := pSpecFold + (L := L)) (P := P) (impl := impl) (prover := prover) (i := j) (stmt := stmtOStmtIn) + (wit := witIn) (s := initState)] + have h_j_eq_1 : j = ⟨1, rfl⟩ := by + match j with + | ⟨0, h0⟩ => nomatch h0 + | ⟨1, _⟩ => rfl + subst h_j_eq_1 + conv_lhs => simp only [Fin.isValue, Fin.castSucc_one]; + rw [OracleReduction.soundness_unroll_runToRound_1_P_to_V_pSpec_2 + (pSpec := pSpecFold (L := L)) (prover := prover) (hDir0 := rfl)] + simp only [Fin.isValue, Challenge, Matrix.cons_val_one, Matrix.cons_val_zero, ChallengeIdx, + QueryImpl.addLift_def, QueryImpl.liftTarget_self, Message, Fin.succ_zero_eq_one, Nat.reduceAdd, + Fin.coe_ofNat_eq_mod, Nat.reduceMod, FullTranscript.mk1_eq_snoc, bind_pure_comp, + liftComp_eq_liftM, bind_map_left, simulateQ_bind, simulateQ_map, StateT.run'_eq, + StateT.run_bind, StateT.run_map, map_bind, Functor.map_map] + rw [probEvent_bind_eq_tsum] + apply OracleReduction.ENNReal.tsum_mul_le_of_le_of_sum_le_one + · -- Bound the conditional probability for each transcript + intro x + -- rw [OracleComp.probEvent_map] + simp only [Fin.isValue, probEvent_map] + let q : OracleQuery [(pSpecFold (L := L)).Challenge]ₒ _ := query ⟨⟨1, by rfl⟩, ()⟩ + erw [OracleReduction.probEvent_StateT_run_ignore_state + (comp := simulateQ (impl.addLift challengeQueryImpl) (liftM (query q.input))) + (s := x.2) + (P := fun a => P (FullTranscript.mk1 x.1.1) (q.cont a))] + rw [probEvent_eq_tsum_ite] + erw [simulateQ_query] + simp only [ChallengeIdx, Challenge, Fin.isValue, Nat.reduceAdd, Fin.castSucc_one, + Fin.coe_ofNat_eq_mod, Nat.reduceMod, monadLift_self, + QueryImpl.addLift_def, QueryImpl.liftTarget_self, StateT.run'_eq, StateT.run_map, + Functor.map_map, ge_iff_le] + have h_L_inhabited : Inhabited L := ⟨0⟩ + conv_lhs => + enter [1, x_1, 2, 1, 2] + rw [addLift_challengeQueryImpl_input_run_eq_liftM_run (impl := impl) (q := q) (s := x.2)] + erw [StateT.run_monadLift, monadLift_self, liftComp_id] + rw [bind_pure_comp] + conv => + enter [1, 1, x_1, 2] + rw [Functor.map_map] + rw [← probEvent_eq_eq_probOutput] + rw [probEvent_map] + rw [OracleQuery.cont_apply] + dsimp only [MonadLift.monadLift] + rw [OracleQuery.cont_apply] + dsimp only [q] + simp_rw [OracleQuery.input_query, OracleQuery.snd_query] + conv_lhs => change (∑' (x_1 : L), _) + simp only [Function.comp_id] + conv => + enter [1, 1, x_1, 2] + rw [probEvent_eq_eq_probOutput] + change Pr[=x_1 | $ᵗ L] + rw [OracleReduction.probOutput_uniformOfFintype_eq_Pr (L := _) (x := x_1)] + rw [OracleReduction.tsum_uniform_Pr_eq_Pr (L := L) (P := + fun x_1 => P (FullTranscript.mk1 x.1.1) (q.2 x_1))] + -- Make this explicit using change + -- Apply the per-transcript bound (Ring-switching counterpart of + -- foldStep_doom_escape_probability_bound) + exact iteratedSumcheck_doom_escape_probability_bound (κ := κ) (L := L) (K := K) (ℓ := ℓ) + (ℓ' := ℓ') (𝓑 := 𝓑) (β := β) (h_l := h_l) (aOStmtIn := aOStmtIn) (i := i) + (stmtOStmtIn := stmtOStmtIn) (h_i := x.1.1) + · -- Prove: ∑' x, [=x|transcript computation] ≤ 1 + apply tsum_probOutput_le_one end IteratedSumcheckStep @@ -705,34 +1044,28 @@ def finalSumcheckProverWitOut (witIn : SumcheckWitness L ℓ' (Fin.last ℓ')) : /-- The Logic Instance for the final sumcheck step. This is a 1-message protocol where the prover sends the final constant s'. -/ def finalSumcheckStepLogic : - Binius.BinaryBasefold.CoreInteraction.ReductionLogicStep + Binius.BinaryBasefold.ReductionLogicStep (Statement (L := L) (ℓ := ℓ') (RingSwitchingBaseContext κ L K ℓ) (Fin.last ℓ')) (SumcheckWitness L ℓ' (Fin.last ℓ')) (aOStmtIn.OStmtIn) (aOStmtIn.OStmtIn) (MLPEvalStatement L ℓ') (WitMLP L ℓ') - (pSpecFinalSumcheck L) where - + (pSpecFinalSumcheckStep (L := L)) where completeness_relIn := fun ((stmt, oStmt), wit) => - ((stmt, oStmt), wit) ∈ sumcheckRoundRelation κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) aOStmtIn (Fin.last ℓ') - + ((stmt, oStmt), wit) ∈ strictSumcheckRoundRelation κ L K β ℓ ℓ' h_l + (𝓑 := 𝓑) aOStmtIn (Fin.last ℓ') completeness_relOut := fun ((stmtOut, oStmtOut), witOut) => - ((stmtOut, oStmtOut), witOut) ∈ aOStmtIn.toRelInput - + ((stmtOut, oStmtOut), witOut) ∈ aOStmtIn.toStrictRelInput verifierCheck := fun stmtIn transcript => finalSumcheckVerifierCheck κ L K β ℓ ℓ' h_l stmtIn (transcript.messages ⟨0, rfl⟩) - verifierOut := fun stmtIn transcript => finalSumcheckVerifierStmtOut κ L K ℓ ℓ' stmtIn (transcript.messages ⟨0, rfl⟩) - embed := ⟨fun j => Sum.inl j, fun a b h => by cases h; rfl⟩ hEq := fun _ => rfl - honestProverTranscript := fun stmtIn witIn _oStmtIn _chal => let s' : L := finalSumcheckProverComputeMsg κ L K ℓ ℓ' witIn stmtIn FullTranscript.mk1 s' - proverOut := fun stmtIn witIn oStmtIn transcript => let s' : L := transcript.messages ⟨0, rfl⟩ let stmtOut := finalSumcheckVerifierStmtOut κ L K ℓ ℓ' stmtIn s' @@ -741,15 +1074,17 @@ def finalSumcheckStepLogic : /-! ## Helper Lemmas for Strong Completeness -/ -/-- At `Fin.last ℓ'`, the sumcheck consistency sum is over 0 variables, simplifying to a single evaluation. -This is analogous to Binary Basefold's simplification of `𝓑^ᶠ(0) = {∅}`. -/ +omit [Fintype L] [DecidableEq L] [CharP L 2] [SampleableType L] [NeZero ℓ'] in +/-- At `Fin.last ℓ'`, the sumcheck consistency sum is over 0 variables, +simplifying to a single evaluation. This is analogous to Binary Basefold's +simplification of `𝓑^ᶠ(0) = {∅}`. -/ lemma sumcheckConsistency_at_last_simplifies (target : L) (H : L⦃≤ 2⦄[X Fin (ℓ' - Fin.last ℓ')]) (h_cons : sumcheckConsistencyProp (𝓑 := 𝓑) target H) : target = H.val.eval (fun _ => (0 : L)) := by -- Since ℓ' - Fin.last ℓ' = 0, the sum is over Fin 0 which has only one element - simp only [Fin.val_last, tsub_self] at H h_cons ⊢ - simp only [sumcheckConsistencyProp, Fin.val_last, tsub_self] at h_cons + simp only [Fin.val_last] at H h_cons ⊢ + simp only [sumcheckConsistencyProp] at h_cons -- The piFinset over Fin 0 has only one element: fun _ => 0 haveI : IsEmpty (Fin 0) := Fin.isEmpty rw [Finset.sum_eq_single (a := fun _ => 0) @@ -764,12 +1099,14 @@ lemma sumcheckConsistency_at_last_simplifies intro i; simp only [tsub_self] at i; exact i.elim0)] at h_cons exact h_cons +omit [NeZero κ] [Fintype L] [DecidableEq L] [CharP L 2] [SampleableType L] + [Fintype K] [DecidableEq K] [NeZero ℓ] [NeZero ℓ'] in /-- The honest prover's message in the final sumcheck step equals `t'(challenges)`. -/ lemma finalSumcheck_honest_message_eq_t'_eval (stmtIn : Statement (L := L) (ℓ := ℓ') (RingSwitchingBaseContext κ L K ℓ) (Fin.last ℓ')) (witIn : SumcheckWitness L ℓ' (Fin.last ℓ')) (oStmtIn : ∀ j, aOStmtIn.OStmtIn j) - (challenges : (pSpecFinalSumcheck L).Challenges) : + (challenges : (pSpecFinalSumcheckStep (L := L)).Challenges) : let step := finalSumcheckStepLogic κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) aOStmtIn let transcript := step.honestProverTranscript stmtIn witIn oStmtIn challenges transcript.messages ⟨0, rfl⟩ = witIn.t'.val.eval stmtIn.challenges := by @@ -796,7 +1133,7 @@ lemma finalSumcheckStep_verifierCheck_passed (stmtIn : Statement (L := L) (ℓ := ℓ') (RingSwitchingBaseContext κ L K ℓ) (Fin.last ℓ')) (witIn : SumcheckWitness L ℓ' (Fin.last ℓ')) (oStmtIn : ∀ j, aOStmtIn.OStmtIn j) - (challenges : (pSpecFinalSumcheck L).Challenges) + (challenges : (pSpecFinalSumcheckStep (L := L)).Challenges) (h_sumcheck_cons : sumcheckConsistencyProp (𝓑 := 𝓑) stmtIn.sumcheck_target witIn.H) (h_wit_struct : witnessStructuralInvariant κ L K β ℓ ℓ' h_l stmtIn witIn) : let step := finalSumcheckStepLogic κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) aOStmtIn @@ -805,32 +1142,31 @@ lemma finalSumcheckStep_verifierCheck_passed intro step transcript -- Step 1: Simplify sumcheck consistency to single evaluation have h_target_eq_H_eval : stmtIn.sumcheck_target = witIn.H.val.eval (fun _ => 0) := - sumcheckConsistency_at_last_simplifies (L := L) (ℓ' := ℓ') (𝓑 := 𝓑) stmtIn.sumcheck_target witIn.H h_sumcheck_cons - + sumcheckConsistency_at_last_simplifies (L := L) (ℓ' := ℓ') (𝓑 := 𝓑) + stmtIn.sumcheck_target witIn.H h_sumcheck_cons -- Step 2: Use witnessStructuralInvariant to connect H to projected poly have h_H_eq : witIn.H = projectToMidSumcheckPoly (L := L) (ℓ := ℓ') (t := witIn.t') (m := (RingSwitching_SumcheckMultParam κ L K β ℓ ℓ' h_l).multpoly stmtIn.ctx) (i := Fin.last ℓ') (challenges := stmtIn.challenges) := h_wit_struct - -- Step 3: Apply projectToMidSumcheckPoly_at_last_eval have h_proj_eval : (projectToMidSumcheckPoly (L := L) (ℓ := ℓ') (t := witIn.t') (m := (RingSwitching_SumcheckMultParam κ L K β ℓ ℓ' h_l).multpoly stmtIn.ctx) (i := Fin.last ℓ') (challenges := stmtIn.challenges)).val.eval (fun _ => 0) = - ((RingSwitching_SumcheckMultParam κ L K β ℓ ℓ' h_l).multpoly stmtIn.ctx).val.eval stmtIn.challenges * - witIn.t'.val.eval stmtIn.challenges := by + ((RingSwitching_SumcheckMultParam κ L K β ℓ ℓ' h_l).multpoly + stmtIn.ctx).val.eval stmtIn.challenges * witIn.t'.val.eval stmtIn.challenges := by apply projectToMidSumcheckPoly_at_last_eval - -- Step 4: Connect multiplier poly to compute_final_eq_value -- This requires showing that A_MLE.eval(challenges) = compute_final_eq_value - have h_mult_eq_eq_value : ((RingSwitching_SumcheckMultParam κ L K β ℓ ℓ' h_l).multpoly stmtIn.ctx).val.eval stmtIn.challenges = - compute_final_eq_value κ L K β ℓ ℓ' h_l stmtIn.ctx.t_eval_point stmtIn.challenges stmtIn.ctx.r_batching := + have h_mult_eq_eq_value : ((RingSwitching_SumcheckMultParam κ L K β ℓ ℓ' h_l).multpoly + stmtIn.ctx).val.eval stmtIn.challenges = + compute_final_eq_value κ L K β ℓ ℓ' h_l stmtIn.ctx.t_eval_point stmtIn.challenges + stmtIn.ctx.r_batching := compute_A_MLE_eval_eq_final_eq_value κ L K β ℓ ℓ' h_l stmtIn.ctx.t_eval_point stmtIn.challenges stmtIn.ctx.r_batching - -- Step 5: Get the honest message have h_msg_eq : transcript.messages ⟨0, rfl⟩ = witIn.t'.val.eval stmtIn.challenges := - finalSumcheck_honest_message_eq_t'_eval κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) aOStmtIn stmtIn witIn oStmtIn challenges - + finalSumcheck_honest_message_eq_t'_eval κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) aOStmtIn stmtIn witIn + oStmtIn challenges -- Step 6: Combine everything simp only [step, finalSumcheckStepLogic, finalSumcheckVerifierCheck] rw [h_target_eq_H_eval, Subtype.val_inj.mpr h_H_eq, h_proj_eval, h_mult_eq_eq_value, h_msg_eq] @@ -839,16 +1175,19 @@ lemma finalSumcheckStep_verifierCheck_passed /-- Final sumcheck step logic is strongly complete. **Key Proof Obligations:** -1. **Verifier Check**: Show that `stmtIn.sumcheck_target = eq_tilde_eval * s'` where `s' = witIn.t'.val.eval stmtIn.challenges` +1. **Verifier Check**: Show that `stmtIn.sumcheck_target = eq_tilde_eval * s'` where + `s' = witIn.t'.val.eval stmtIn.challenges` - This should follow from `h_relIn` (sumcheckRoundRelation) which includes `masterKStateProp` - The `masterKStateProp` includes: * `witnessStructuralInvariant`: `wit.H = projectToMidSumcheckPoly ...` - * `sumcheckConsistencyProp`: `stmt.sumcheck_target = ∑ x ∈ (univ.map 𝓑) ^ᶠ (ℓ' - Fin.last ℓ'), wit.H.val.eval x` - For `i = Fin.last ℓ'`, we have `ℓ' - Fin.last ℓ' = 0`, so this is a sum over 0 variables (a constant) + * `sumcheckConsistencyProp`: `stmt.sumcheck_target =` + `∑ x ∈ (univ.map 𝓑) ^ᶠ (ℓ' - Fin.last ℓ'), wit.H.val.eval x` + For `i = Fin.last ℓ'`, we have `ℓ' - Fin.last ℓ' = 0`, so this is a sum over 0 variables + (a constant) - Need to connect these properties to show the verifier check passes -2. **Relation Out**: Show that the output satisfies `aOStmtIn.toRelInput` - - This involves showing `MLPEvalRelation` and `initialCompatibility` hold for the output +2. **Relation Out**: Show that the output satisfies `aOStmtIn.toStrictRelInput` + - This involves showing `MLPEvalRelation` and `strictInitialCompatibility` hold for the output -/ lemma finalSumcheckStep_is_logic_complete : (finalSumcheckStepLogic κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) aOStmtIn).IsStronglyComplete := by @@ -863,26 +1202,21 @@ lemma finalSumcheckStep_is_logic_complete : let proverOStmtOut := proverOutput.1.2 let proverWitOut := proverOutput.2 let s' := transcript.messages ⟨0, rfl⟩ - -- Extract properties from h_relIn BEFORE any simp changes its structure - simp only [finalSumcheckStepLogic, sumcheckRoundRelation, sumcheckRoundRelationProp, - Set.mem_setOf_eq, masterKStateProp] at h_relIn - obtain ⟨_, h_wit_struct, h_sumcheck_cons, h_oStmtIn_compat⟩ := h_relIn - + simp only [finalSumcheckStepLogic, strictSumcheckRoundRelation, + strictSumcheckRoundRelationProp, Set.mem_setOf_eq, masterStrictKStateProp] at h_relIn + obtain ⟨h_sumcheck_cons, h_wit_struct, h_oStmtIn_compat⟩ := h_relIn -- Fact 1: Verifier check passes (using the helper lemma) let h_VCheck_passed : step.verifierCheck stmtIn transcript := finalSumcheckStep_verifierCheck_passed κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) aOStmtIn stmtIn witIn oStmtIn challenges h_sumcheck_cons h_wit_struct - -- Fact 2: Prover and verifier statements agree have hStmtOut_eq : proverStmtOut = verifierStmtOut := by change (step.proverOut stmtIn witIn oStmtIn transcript).1.1 = step.verifierOut stmtIn transcript simp only [step, finalSumcheckStepLogic, finalSumcheckVerifierStmtOut, finalSumcheckProverWitOut] - -- Fact 3: Prover and verifier oracle statements agree (no new oracles added) have hOStmtOut_eq : proverOStmtOut = verifierOStmtOut := by rfl - -- Fact 4: Output relation holds have hRelOut : step.completeness_relOut ((verifierStmtOut, verifierOStmtOut), proverWitOut) := by simp only [step, finalSumcheckStepLogic] @@ -891,7 +1225,6 @@ lemma finalSumcheckStep_is_logic_complete : rfl · -- initial Compatibility exact h_oStmtIn_compat - -- Prove the four required facts refine ⟨?_, ?_, ?_, ?_⟩ · exact h_VCheck_passed @@ -911,25 +1244,22 @@ noncomputable def finalSumcheckProver : (StmtOut := MLPEvalStatement L ℓ') (OStmtOut := aOStmtIn.OStmtIn) (WitOut := WitMLP L ℓ') - (pSpec := pSpecFinalSumcheck L) where + (pSpec := pSpecFinalSumcheckStep (L := L)) where PrvState := fun | 0 => Statement (L := L) (ℓ := ℓ') (RingSwitchingBaseContext κ L K ℓ) (Fin.last ℓ') × (∀ j, aOStmtIn.OStmtIn j) × SumcheckWitness L ℓ' (Fin.last ℓ') | _ => Statement (L := L) (ℓ := ℓ') (RingSwitchingBaseContext κ L K ℓ) (Fin.last ℓ') × (∀ j, aOStmtIn.OStmtIn j) × SumcheckWitness L ℓ' (Fin.last ℓ') × L input := fun ⟨⟨stmt, oStmt⟩, wit⟩ => (stmt, oStmt, wit) - sendMessage | ⟨0, _⟩ => fun ⟨stmtIn, oStmtIn, witIn⟩ => do let s' := finalSumcheckProverComputeMsg κ L K ℓ ℓ' witIn stmtIn pure ⟨s', (stmtIn, oStmtIn, witIn, s')⟩ - receiveChallenge | ⟨0, h⟩ => nomatch h -- No challenges in this step - output := fun ⟨stmtIn, oStmtIn, witIn, s'⟩ => do let logic := finalSumcheckStepLogic κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) aOStmtIn - let t := FullTranscript.mk1 (pSpec := pSpecFinalSumcheck L) s' + let t := FullTranscript.mk1 (pSpec := pSpecFinalSumcheckStep (L := L)) s' pure (logic.proverOut stmtIn witIn oStmtIn t) /-- The verifier for the final sumcheck step -/ @@ -940,20 +1270,19 @@ noncomputable def finalSumcheckVerifier : (OStmtIn := aOStmtIn.OStmtIn) (StmtOut := MLPEvalStatement L ℓ') (OStmtOut := aOStmtIn.OStmtIn) - (pSpec := pSpecFinalSumcheck L) where + (pSpec := pSpecFinalSumcheckStep (L := L)) where verify := fun stmtIn _ => do -- Get the final constant `s'` from the prover's message - let s' : L ← query (spec := [(pSpecFinalSumcheck L).Message]ₒ) + let s' : L ← query (spec := [(pSpecFinalSumcheckStep (L := L)).Message]ₒ) ⟨⟨0, by rfl⟩, (by simpa using ())⟩ -- Construct the transcript - let t := FullTranscript.mk1 (pSpec := pSpecFinalSumcheck L) s' + let t := FullTranscript.mk1 (pSpec := pSpecFinalSumcheckStep (L := L)) s' -- Get the logic instance let logic := finalSumcheckStepLogic κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) aOStmtIn -- Use guard for verifier check (fails if check doesn't pass) have : Decidable (logic.verifierCheck stmtIn t) := Classical.propDecidable _ guard (logic.verifierCheck stmtIn t) pure (logic.verifierOut stmtIn t) - embed := (finalSumcheckStepLogic κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) aOStmtIn).embed hEq := (finalSumcheckStepLogic κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) aOStmtIn).hEq @@ -967,7 +1296,7 @@ noncomputable def finalSumcheckOracleReduction : (StmtOut := MLPEvalStatement L ℓ') (OStmtOut := aOStmtIn.OStmtIn) (WitOut := WitMLP L ℓ') - (pSpec := pSpecFinalSumcheck L) where + (pSpec := pSpecFinalSumcheckStep (L := L)) where prover := finalSumcheckProver κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) aOStmtIn verifier := finalSumcheckVerifier κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) aOStmtIn @@ -976,26 +1305,24 @@ theorem finalSumcheckOracleReduction_perfectCompleteness {σ : Type} (init : ProbComp σ) (hInit : NeverFail init) (impl : QueryImpl []ₒ (StateT σ ProbComp)) : OracleReduction.perfectCompleteness - (pSpec := pSpecFinalSumcheck L) - (relIn := sumcheckRoundRelation κ L K β ℓ ℓ' h_l (𝓑:=𝓑) aOStmtIn (Fin.last ℓ')) - (relOut := aOStmtIn.toRelInput) + (pSpec := pSpecFinalSumcheckStep (L := L)) + (relIn := strictSumcheckRoundRelation κ L K β ℓ ℓ' h_l (𝓑:=𝓑) aOStmtIn (Fin.last ℓ')) + (relOut := aOStmtIn.toStrictRelInput) (oracleReduction := finalSumcheckOracleReduction κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) aOStmtIn) (init := init) (impl := impl) := by -- Step 1: Unroll the 2-message reduction to convert from probability to logic rw [OracleReduction.unroll_1_message_reduction_perfectCompleteness_P_to_V (hInit := hInit) (hDir0 := by rfl) (hImplSupp := by simp only [Set.fmap_eq_image, IsEmpty.forall_iff, implies_true])] - intro stmtIn oStmtIn witIn h_relIn -- Step 2: Convert probability 1 to universal quantification over support rw [probEvent_eq_one_iff] -- Step 3: Unfold protocol definitions dsimp only [finalSumcheckOracleReduction, finalSumcheckProver, finalSumcheckVerifier, OracleVerifier.toVerifier, FullTranscript.mk1] - let step := (finalSumcheckStepLogic κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) aOStmtIn) - let strongly_complete : step.IsStronglyComplete := finalSumcheckStep_is_logic_complete (κ := κ) (L := L) (K := K) (β := β) (ℓ := ℓ) (ℓ' := ℓ') (h_l := h_l) (aOStmtIn := aOStmtIn) - + let strongly_complete : step.IsStronglyComplete := finalSumcheckStep_is_logic_complete + (κ := κ) (L := L) (K := K) (β := β) (ℓ := ℓ) (ℓ' := ℓ') (h_l := h_l) (aOStmtIn := aOStmtIn) -- Step 4: Split into safety and correctness goals refine ⟨?_, ?_⟩ -- GOAL 1: SAFETY - Prove the verifier never crashes ([⊥|...] = 0) @@ -1072,7 +1399,6 @@ theorem finalSumcheckOracleReduction_perfectCompleteness {σ : Type} exfalso exact hj_ne hj ) - have h_V_check_is_true : V_check := h_V_check simp only [h_V_check_is_true, ↓reduceIte, support_pure, Set.mem_singleton_iff, Fin.isValue, Fin.val_last, exists_eq_left, OptionT.support_OptionT_pure_run] at h_vStmtOut_mem_support @@ -1118,11 +1444,9 @@ theorem finalSumcheckOracleReduction_perfectCompleteness {σ : Type} Set.mem_iUnion, exists_prop] rw [simulateQ_ite]; erw [simulateQ_pure] simp only [OptionT.simulateQ_failure'] - set V_check := step.verifierCheck stmtIn (FullTranscript.mk1 (msg0 := _))with h_V_check_def - -- Step 2e: Apply the logic completeness lemma obtain ⟨h_V_check, h_rel, h_agree⟩ := strongly_complete (stmtIn := stmtIn) (witIn := witIn) (h_relIn := h_relIn) (challenges := @@ -1135,7 +1459,6 @@ theorem finalSumcheckOracleReduction_perfectCompleteness {σ : Type} exfalso exact hj_ne hj ) - have h_V_check_is_true : V_check := h_V_check simp only [h_V_check_is_true, ↓reduceIte, Fin.isValue] at h_verOut_mem_support erw [support_bind, support_pure] at h_verOut_mem_support @@ -1143,9 +1466,7 @@ theorem finalSumcheckOracleReduction_perfectCompleteness {σ : Type} OptionT.support_OptionT_pure_run, exists_eq_left, Option.some.injEq, Prod.mk.injEq] at h_verOut_mem_support rcases h_verOut_mem_support with ⟨verStmtOut_eq, verOStmtOut_eq⟩ - obtain ⟨⟨prvStmtOut_eq, prvOStmtOut_eq⟩, prvWitOut_eq⟩ := h_prvOut_mem_support - constructor · rw [verStmtOut_eq, verOStmtOut_eq, prvWitOut_eq]; exact h_rel @@ -1156,20 +1477,24 @@ theorem finalSumcheckOracleReduction_perfectCompleteness {σ : Type} /-- RBR knowledge error for the final sumcheck step -/ -def finalSumcheckRbrKnowledgeError : ℝ≥0 := (1 : ℝ≥0) / (Fintype.card L) +def finalSumcheckKnowledgeError (m : pSpecFinalSumcheckStep (L := L).ChallengeIdx) : + ℝ≥0 := + match m with + | ⟨0, h0⟩ => nomatch h0 -/-- The round-by-round extractor for the final sumcheck step -/ +/-- The round-by-round extractor for the final sumcheck step. + We do not collapse the witness away (unlike BBF): WitMid stays as full SumcheckWitness, + and we pass the polynomial t' (WitMLP) plus MLPEvalStatement to a final PCS invocation. -/ noncomputable def finalSumcheckRbrExtractor : Extractor.RoundByRound []ₒ (StmtIn := Statement (L := L) (ℓ := ℓ') (RingSwitchingBaseContext κ L K ℓ) (Fin.last ℓ') × (∀ j, aOStmtIn.OStmtIn j)) (WitIn := SumcheckWitness L ℓ' (Fin.last ℓ')) (WitOut := WitMLP L ℓ') - (pSpec := pSpecFinalSumcheck L) + (pSpec := pSpecFinalSumcheckStep (L := L)) (WitMid := fun _m => SumcheckWitness L ℓ' (Fin.last ℓ')) where eqIn := rfl extractMid := fun _m ⟨_, _⟩ _trSucc witMidSucc => witMidSucc - extractOut := fun ⟨stmtIn, _⟩ _tr witOut => { t' := witOut.t, H := projectToMidSumcheckPoly (L := L) (ℓ := ℓ') (t := witOut.t) @@ -1177,36 +1502,35 @@ noncomputable def finalSumcheckRbrExtractor : (i := Fin.last ℓ') (challenges := stmtIn.challenges) } -/- This follows the KState of `finalSumcheckKStateProp` in `BinaryBasefold`. -though the multiplier poly is different. -/ -def finalSumcheckKStateProp {m : Fin (1 + 1)} (tr : Transcript m (pSpecFinalSumcheck L)) - (stmt : Statement (L := L) (ℓ := ℓ') (RingSwitchingBaseContext κ L K ℓ) (Fin.last ℓ')) +/-- KState for the final sumcheck step, in the same style as BBF `finalSumcheckKStateProp`: + m=0: same as relIn (masterKStateProp with sumcheckConsistencyProp). + m=1: name prover message as `c`, build output statement `stmtOut`, then + sumcheckFinalCheck ∧ finalEvalCheck ∧ oracleCompatProp + (no folding; RS has only sumcheck + oracle compat). -/ +def finalSumcheckKStateProp {m : Fin (1 + 1)} (tr : Transcript m (pSpecFinalSumcheckStep (L := L))) + (stmtIn : Statement (L := L) (ℓ := ℓ') (RingSwitchingBaseContext κ L K ℓ) (Fin.last ℓ')) (witMid : SumcheckWitness L ℓ' (Fin.last ℓ')) - (oStmt : ∀ j, aOStmtIn.OStmtIn j) : Prop := + (oStmtIn : ∀ j, aOStmtIn.OStmtIn j) : Prop := match m with | ⟨0, _⟩ => -- same as relIn - RingSwitching.masterKStateProp κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) aOStmtIn + RingSwitching.masterKStateProp κ L K β ℓ ℓ' h_l aOStmtIn (stmtIdx := Fin.last ℓ') - (stmt := stmt) (oStmt := oStmt) (wit := witMid) - (localChecks := True) + (stmt := stmtIn) (oStmt := oStmtIn) (wit := witMid) + (localChecks := sumcheckConsistencyProp (𝓑 := 𝓑) stmtIn.sumcheck_target witMid.H) | ⟨1, _⟩ => -- implied by relOut + local checks via extractOut proofs - let tr_so_far := (pSpecFinalSumcheck L).take 1 (by omega) - let i_msg0 : tr_so_far.MessageIdx := ⟨⟨0, by omega⟩, rfl⟩ - let c : L := (ProtocolSpec.Transcript.equivMessagesChallenges (k := 1) - (pSpec := pSpecFinalSumcheck L) tr).1 i_msg0 - + let c : L := tr.messages ⟨0, rfl⟩ let stmtOut : MLPEvalStatement L ℓ' := { - t_eval_point := stmt.challenges, + t_eval_point := stmtIn.challenges, original_claim := c } - let sumcheckFinalLocalCheck : Prop := + let sumcheckFinalVCheck : Prop := let eq_tilde_eval : L := compute_final_eq_value κ L K β ℓ ℓ' h_l - stmt.ctx.t_eval_point stmt.challenges stmt.ctx.r_batching - stmt.sumcheck_target = eq_tilde_eval * c - - let final_eval : Prop := witMid.t'.val.eval stmt.challenges = c - sumcheckFinalLocalCheck ∧ final_eval - ∧ aOStmtIn.initialCompatibility ⟨witMid.t', oStmt⟩ + stmtIn.ctx.t_eval_point stmtIn.challenges stmtIn.ctx.r_batching + stmtIn.sumcheck_target = eq_tilde_eval * c + let finalEvalCheck : Prop := witMid.t'.val.eval stmtOut.t_eval_point = stmtOut.original_claim + let oracleCompatProp : Prop := aOStmtIn.initialCompatibility ⟨witMid.t', oStmtIn⟩ + let witnessStructProp : Prop := witnessStructuralInvariant κ L K β ℓ ℓ' h_l stmtIn witMid + sumcheckFinalVCheck ∧ finalEvalCheck ∧ oracleCompatProp ∧ witnessStructProp /-- The knowledge state function for the final sumcheck step -/ noncomputable def finalSumcheckKnowledgeStateFunction {σ : Type} (init : ProbComp σ) @@ -1217,28 +1541,189 @@ noncomputable def finalSumcheckKnowledgeStateFunction {σ : Type} (init : ProbCo (extractor := finalSumcheckRbrExtractor κ L K β ℓ ℓ' h_l aOStmtIn) where toFun := fun m ⟨stmt, oStmt⟩ tr witMid => - finalSumcheckKStateProp κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) - (m := m) (tr := tr) (stmt := stmt) (witMid := witMid) (oStmt := oStmt) - toFun_empty := fun stmt witMid => by - simp only [sumcheckRoundRelation, sumcheckRoundRelationProp, Fin.val_last, cast_eq, - Set.mem_setOf_eq, finalSumcheckKStateProp, masterKStateProp, true_and] - toFun_next := fun m hDir stmt tr msg witMid h => by - sorry - toFun_full := fun stmt tr witOut h => by - sorry + finalSumcheckKStateProp κ L K β ℓ ℓ' h_l + (m := m) (tr := tr) (stmtIn := stmt) (witMid := witMid) (oStmtIn := oStmt) + toFun_empty := fun stmt witMid => by rfl + toFun_next := fun m hDir (stmtIn, oStmtIn) tr msg witMid => by + -- Only round is m=0 → m=1; extractMid is identity (RS keeps full SumcheckWitness). + have h_m_eq_0 : m = 0 := by + cases m using Fin.cases with + | zero => rfl + | succ m' => omega + subst h_m_eq_0 + simp only [Fin.isValue, Fin.succ_zero_eq_one, Fin.castSucc_zero] + intro h_kState_round1 + unfold finalSumcheckKStateProp at h_kState_round1 ⊢ + simp only [Fin.isValue, Nat.reduceAdd, Fin.mk_one, Fin.coe_ofNat_eq_mod, Nat.reduceMod] + at h_kState_round1 + -- m=1 gives: sumcheckFinalCheck ∧ finalEvalCheck ∧ oracleCompatProp ∧ witnessStructProp + obtain ⟨h_sumcheckFinalCheck, h_finalEvalCheck, h_oracleCompat, h_witStruct⟩ := h_kState_round1 + -- Goal: masterKStateProp at m=0 = sumcheckConsistencyProp ∧ witnessStructuralInvariant + -- ∧ initialCompatibility + unfold RingSwitching.masterKStateProp + constructor + · -- sumcheckConsistencyProp: at Fin.last ℓ' the sum is a single term + -- witMid.H.val.eval (fun _ => 0) + unfold sumcheckConsistencyProp + simp only [Fin.val_last] + -- haveI : IsEmpty (Fin 0) := Fin.isEmpty + rw [Finset.sum_eq_single (a := fun _ => (0 : L)) + (h₀ := fun b _ hb_ne => by + exfalso; apply hb_ne + funext i; simp only [tsub_self] at i; exact i.elim0) + (h₁ := fun h_not_mem => by + exfalso; apply h_not_mem + simp only [Fintype.mem_piFinset] + intro i; simp only [tsub_self] at i; exact i.elim0)] + simp only [finalSumcheckRbrExtractor] + have h_H_eval : witMid.H.val.eval (fun _ => 0) = + ((RingSwitching_SumcheckMultParam κ L K β ℓ ℓ' h_l).multpoly stmtIn.ctx).val.eval + stmtIn.challenges * witMid.t'.val.eval stmtIn.challenges := by + rw [h_witStruct] + apply projectToMidSumcheckPoly_at_last_eval + have h_mult_eq : + ((RingSwitching_SumcheckMultParam κ L K β ℓ ℓ' h_l).multpoly stmtIn.ctx).val.eval + stmtIn.challenges = compute_final_eq_value κ L K β ℓ ℓ' h_l stmtIn.ctx.t_eval_point + stmtIn.challenges stmtIn.ctx.r_batching := + compute_A_MLE_eval_eq_final_eq_value κ L K β ℓ ℓ' h_l + stmtIn.ctx.t_eval_point stmtIn.challenges stmtIn.ctx.r_batching + refine Eq.trans h_sumcheckFinalCheck ?_ + let msgIdx : (pSpecFinalSumcheckStep (L := L)).MessageIdx := ⟨⟨0, Nat.zero_lt_succ 0⟩, hDir⟩ + let c : L := FullTranscript.messages (Transcript.concat msg tr) msgIdx + have h_eq : (MvPolynomial.eval stmtIn.challenges) + (((RingSwitching_SumcheckMultParam κ L K β ℓ ℓ' h_l).multpoly stmtIn.ctx).val) * + (MvPolynomial.eval stmtIn.challenges) witMid.t'.val = + (compute_final_eq_value κ L K β ℓ ℓ' h_l stmtIn.ctx.t_eval_point stmtIn.challenges + stmtIn.ctx.r_batching) * c := by + rw [h_mult_eq] + simp_rw [Fin.val_last] + rw [h_finalEvalCheck]; rfl + exact (h_H_eval.trans h_eq).symm + · constructor + · exact h_witStruct + · exact h_oracleCompat + toFun_full := fun ⟨stmtIn, oStmtIn⟩ tr witOut probEvent_relOut_gt_0 => by + -- Same pattern as relay: verifier output (stmtOut, oStmtOut) + h_relOut ⇒ commitKStateProp 1 + simp only [StateT.run'_eq, gt_iff_lt, probEvent_pos_iff, Prod.exists] at probEvent_relOut_gt_0 + rcases probEvent_relOut_gt_0 with ⟨stmtOut, oStmtOut, h_output_mem_V_run_support, h_relOut⟩ + have h_output_mem_V_run_support' : + some (stmtOut, oStmtOut) ∈ + (do + let s ← init + Prod.fst <$> + (simulateQ impl + (Verifier.run (stmtIn, oStmtIn) tr + (finalSumcheckVerifier κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) aOStmtIn).toVerifier)).run + s).support := by + exact (OptionT.mem_support_iff + (mx := OptionT.mk (do + let s ← init + Prod.fst <$> + (simulateQ impl + (Verifier.run (stmtIn, oStmtIn) tr + (finalSumcheckVerifier κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) aOStmtIn).toVerifier)).run s)) + (x := (stmtOut, oStmtOut))).1 h_output_mem_V_run_support + simp only [support_bind, Set.mem_iUnion, exists_prop] at h_output_mem_V_run_support' + rcases h_output_mem_V_run_support' with ⟨s, hs_init, h_output_mem_V_run_support⟩ + conv at h_output_mem_V_run_support => -- same as fold step + simp only [Verifier.run, OracleVerifier.toVerifier] + -- Now unfold the foldOracleVerifier's `verify()` method + simp only [finalSumcheckVerifier] + -- dsimp only [StateT.run] + -- simp only [simulateQ_bind, simulateQ_query, simulateQ_pure] + -- oracle query unfolding + simp only [support_bind, Set.mem_iUnion] + dsimp only [StateT.run] + -- enter [1, i_1, 2, 1, x] + simp only [simulateQ_bind] + --------------------------------------- + -- Now simplify the `guard` and `ite` of StateT.map generated from it + simp only [MessageIdx, Fin.isValue, Matrix.cons_val_zero, simulateQ_pure, Message, guard_eq, + pure_bind, Function.comp_apply, simulateQ_map, simulateQ_ite, + OptionT.simulateQ_failure', bind_map_left] + simp only [MessageIdx, Message, Fin.isValue, Matrix.cons_val_zero, Matrix.cons_val_one, + bind_pure_comp, simulateQ_map, simulateQ_ite, simulateQ_pure, OptionT.simulateQ_failure', + bind_map_left, Function.comp_apply] + simp only [support_ite] + simp only [Fin.isValue, Set.mem_ite_empty_right, Set.mem_singleton_iff, Prod.mk.injEq, + exists_and_left, exists_eq', exists_eq_right, exists_and_right] + simp only [Fin.isValue, id_eq, FullTranscript.mk1_eq_snoc, support_map, Set.mem_image, + Prod.exists, exists_and_right, exists_eq_right] + erw [simulateQ_bind] + enter [1, x, 1, 1, 1, 2]; + erw [simulateQ_bind] + erw [OptionT.simulateQ_simOracle2_liftM_query_T2] + simp only [Fin.isValue, FullTranscript.mk1_eq_snoc, pure_bind, OptionT.simulateQ_map] + conv at h_output_mem_V_run_support => + simp only [Fin.isValue, FullTranscript.mk1_eq_snoc, Function.comp_apply] + erw [support_bind] at h_output_mem_V_run_support + let step := (finalSumcheckStepLogic κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) aOStmtIn) + set V_check := step.verifierCheck stmtIn + (FullTranscript.mk1 (msg0 := _)) with h_V_check_def + by_cases h_V_check : V_check + · + simp only [Fin.isValue, h_V_check, ↓reduceIte, OptionT.run_pure, simulateQ_pure, + Set.mem_iUnion, exists_prop, Prod.exists] at h_output_mem_V_run_support + erw [simulateQ_bind] at h_output_mem_V_run_support + simp only [simulateQ_pure, Fin.isValue, Function.comp_apply, + pure_bind] at h_output_mem_V_run_support + erw [support_pure] at h_output_mem_V_run_support + simp only [Set.mem_singleton_iff, Prod.mk.injEq, ↓existsAndEq, and_true, exists_eq_left, + simulateQ_pure] at h_output_mem_V_run_support + erw [support_pure] at h_output_mem_V_run_support + simp only [Fin.isValue, Set.mem_singleton_iff, Prod.mk.injEq, Option.some.injEq, + exists_eq_right] at h_output_mem_V_run_support + rcases h_output_mem_V_run_support with ⟨h_stmtOut_eq, h_oStmtOut_eq⟩ + simp only [Fin.reduceLast, Fin.isValue] + -- h_relOut : ((stmtOut, oStmtOut), witOut) ∈ roundRelation 𝔽q β i.succ + simp only [AbstractOStmtIn.toRelInput, MLPEvalRelation, Set.mem_setOf_eq] at h_relOut + -- Goal: commitKStateProp 1 stmtIn oStmtIn tr witOut + unfold finalSumcheckKStateProp + -- Unfold the sendMessage, receiveChallenge, output logic of prover + dsimp only + -- stmtOut = stmtIn; need oStmtOut = snoc_oracle oStmtIn witOut.f so goal matches h_relOut + simp only [h_stmtOut_eq] at h_relOut ⊢ + have h_oStmtOut_eq_oStmtIn : oStmtOut = oStmtIn := by rw [h_oStmtOut_eq]; rfl + -- c equals tr.messages ⟨0, rfl⟩ + constructor + · -- First conjunct: V checks sumcheck_target = eqTilde r challenges * c + exact h_V_check + · -- Second conjunct: finalSumcheckStepFoldingStateProp + -- ({ toStatement := stmtIn, final_constant := c }, oStmtIn) + rw [h_oStmtOut_eq_oStmtIn] at h_relOut + constructor + · exact h_relOut.1 + · constructor + · exact h_relOut.2 + · dsimp only [witnessStructuralInvariant, finalSumcheckRbrExtractor] + · simp only [Fin.isValue, h_V_check, ↓reduceIte, OptionT.run_failure, simulateQ_pure, + Set.mem_iUnion, exists_prop, Prod.exists] at h_output_mem_V_run_support + erw [simulateQ_bind] at h_output_mem_V_run_support + simp only [simulateQ_pure, Fin.isValue, Function.comp_apply, + pure_bind] at h_output_mem_V_run_support + erw [support_pure] at h_output_mem_V_run_support + simp only [Set.mem_singleton_iff, Prod.mk.injEq, ↓existsAndEq, and_true, exists_eq_left, + simulateQ_pure] at h_output_mem_V_run_support + erw [support_pure] at h_output_mem_V_run_support + simp only [Set.mem_singleton_iff, Prod.mk.injEq, reduceCtorEq, false_and, + exists_false] at h_output_mem_V_run_support -- False /-- Round-by-round knowledge soundness for the final sumcheck step -/ -theorem finalSumcheckOracleVerifier_rbrKnowledgeSoundness [Fintype L] {σ : Type} +theorem finalSumcheckOracleVerifier_rbrKnowledgeSoundness {σ : Type} (init : ProbComp σ) (impl : QueryImpl []ₒ (StateT σ ProbComp)) : (finalSumcheckVerifier κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) aOStmtIn).rbrKnowledgeSoundness init impl (relIn := sumcheckRoundRelation κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) aOStmtIn (Fin.last ℓ')) (relOut := aOStmtIn.toRelInput) - (rbrKnowledgeError := fun _ => finalSumcheckRbrKnowledgeError (L := L)) := by + (rbrKnowledgeError := finalSumcheckKnowledgeError (L := L)) := by use (fun _ => SumcheckWitness L ℓ' (Fin.last ℓ')) use finalSumcheckRbrExtractor κ L K β ℓ ℓ' h_l aOStmtIn use finalSumcheckKnowledgeStateFunction κ L K β ℓ ℓ' h_l aOStmtIn init impl - intro stmtIn witIn prover j - sorry + intro stmtIn witIn prover ⟨j, hj⟩ + -- pSpecFinalSumcheckStep has 1 message (ChallengeIdx = Fin 1); same pattern as commit + cases j using Fin.cases with + | zero => simp only [pSpecFinalSumcheckStep, ne_eq, reduceCtorEq, not_false_eq_true, Fin.isValue, + Matrix.cons_val_fin_one, Direction.not_P_to_V_eq_V_to_P] at hj + | succ j' => exact Fin.elim0 j' end FinalSumcheckStep @@ -1265,10 +1750,14 @@ def sumcheckLoopOracleReduction : (WitIn := SumcheckWitness L ℓ' 0) (WitOut := SumcheckWitness L ℓ' (Fin.last ℓ')) := OracleReduction.seqCompose (m:=ℓ') (oSpec:=[]ₒ) - (OStmt := fun _ => aOStmtIn.OStmtIn) + (OStmt := fun _ => (aOStmtIn.OStmtIn (L := L) (ℓ' := ℓ'))) (Stmt := Statement (L := L) (ℓ := ℓ') (RingSwitchingBaseContext κ L K ℓ)) (Wit := fun i => SumcheckWitness L ℓ' i) - (R := fun (i: Fin ℓ') => iteratedSumcheckOracleReduction κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) aOStmtIn i) + (pSpec := fun _ => pSpecSumcheckRound L) + (Oₘ := fun _ j => instOracleInterfaceMessagePSpecSumcheckRound L j) + (R := fun (i : Fin ℓ') => + iteratedSumcheckOracleReduction (κ := κ) (L := L) (K := K) (β := β) (ℓ := ℓ) + (ℓ' := ℓ') (h_l := h_l) (𝓑 := 𝓑) (aOStmtIn := aOStmtIn) (i := i)) /-- Large-field reduction verifier: Sumcheck seqCompose, then append FinalSum -/ @[reducible] @@ -1277,7 +1766,7 @@ def coreInteractionOracleVerifier := (V₁:=sumcheckLoopOracleVerifier κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) aOStmtIn) (pSpec₁:=pSpecSumcheckLoop L ℓ') (V₂:=finalSumcheckVerifier κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) aOStmtIn) - (pSpec₂:=pSpecFinalSumcheck L) + (pSpec₂:=pSpecFinalSumcheckStep (L := L)) /-- Large-field reduction: Sumcheck seqCompose, then append FinalSum -/ @[reducible] @@ -1286,7 +1775,7 @@ def coreInteractionOracleReduction := (R₁ := sumcheckLoopOracleReduction κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) aOStmtIn) (pSpec₁:=pSpecSumcheckLoop L ℓ') (R₂ := finalSumcheckOracleReduction κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) aOStmtIn) - (pSpec₂:=pSpecFinalSumcheck L) + (pSpec₂:=pSpecFinalSumcheckStep (L := L)) /-! ## RBR Knowledge Soundness Components for Single Round @@ -1304,15 +1793,15 @@ theorem coreInteraction_perfectCompleteness (hInit : NeverFail init) : (OStmtOut := aOStmtIn.OStmtIn) (WitIn := SumcheckWitness L ℓ' 0) (WitOut := WitMLP L ℓ') - (relIn := sumcheckRoundRelation κ L K β ℓ ℓ' h_l (𝓑:=𝓑) aOStmtIn 0) - (relOut := aOStmtIn.toRelInput) + (relIn := strictSumcheckRoundRelation κ L K β ℓ ℓ' h_l (𝓑:=𝓑) aOStmtIn 0) + (relOut := aOStmtIn.toStrictRelInput) (init := init) (impl := impl) := by -- Follows from append_perfectCompleteness of interactionPhase and finalSumcheck apply OracleReduction.append_perfectCompleteness - (rel₂ := (sumcheckRoundRelation κ L K β ℓ ℓ' h_l aOStmtIn (Fin.last ℓ'))) + (rel₂ := (strictSumcheckRoundRelation κ L K β ℓ ℓ' h_l aOStmtIn (Fin.last ℓ'))) · exact OracleReduction.seqCompose_perfectCompleteness - (rel := fun i => sumcheckRoundRelation κ L K β ℓ ℓ' h_l (𝓑:=𝓑) aOStmtIn i) + (rel := fun i => strictSumcheckRoundRelation κ L K β ℓ ℓ' h_l (𝓑:=𝓑) aOStmtIn i) (R := fun i => iteratedSumcheckOracleReduction κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) aOStmtIn i) (h := fun i => iteratedSumcheckOracleReduction_perfectCompleteness (hInit:=hInit) (κ:=κ) (L:=L) (K:=K) @@ -1324,12 +1813,28 @@ theorem coreInteraction_perfectCompleteness (hInit : NeverFail init) : /-- standard sumcheck error -/ def coreInteractionRbrKnowledgeError (_ : (pSpecCoreInteraction L ℓ').ChallengeIdx) : ℝ≥0 := - (2 : ℝ≥0) / (Fintype.card L) - --- TODO: iteratedSumcheckLoop_rbrKnowledgeSoundness + (2 : ℝ≥0) / (Fintype.card L) -- this terms comes from the sumcheck + -- steps, i.e. iteratedSumcheckRoundKnowledgeError + +/-- RBR knowledge soundness for the sumcheck loop (seqCompose over ℓ'). -/ +theorem sumcheckLoopOracleVerifier_rbrKnowledgeSoundness : + (sumcheckLoopOracleVerifier κ (L := L) (K := K) (β := β) (ℓ := ℓ) (ℓ' := ℓ') (h_l := h_l) + (𝓑 := 𝓑) aOStmtIn).rbrKnowledgeSoundness (init := init) (impl := impl) + (relIn := sumcheckRoundRelation κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) aOStmtIn 0) + (relOut := sumcheckRoundRelation κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) aOStmtIn (Fin.last ℓ')) + (rbrKnowledgeError := fun _ => (2 : ℝ≥0) / Fintype.card L) := + OracleVerifier.seqCompose_rbrKnowledgeSoundness + (init := init) (impl := impl) + (rel := fun i => sumcheckRoundRelation κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) aOStmtIn i) + (V := fun i => iteratedSumcheckOracleVerifier κ (𝓑 := 𝓑) L K β ℓ ℓ' h_l aOStmtIn i) + (rbrKnowledgeError := fun roundIdx _challengeIdx => + -- Each round has exactly one challenge, so _challengeIdx is not used in the error + iteratedSumcheckRoundKnowledgeError L ℓ' roundIdx) + (h := fun i => + iteratedSumcheckOracleVerifier_rbrKnowledgeSoundness κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) aOStmtIn i) /-- RBR knowledge soundness for large-field reduction (Sumcheck ++ FinalSum) -/ -theorem coreInteraction_rbrKnowledgeSoundness: +theorem coreInteraction_rbrKnowledgeSoundness : OracleVerifier.rbrKnowledgeSoundness (verifier := coreInteractionOracleVerifier κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) aOStmtIn) (StmtIn := Statement (L := L) (ℓ := ℓ') (RingSwitchingBaseContext κ L K ℓ) 0) @@ -1343,7 +1848,43 @@ theorem coreInteraction_rbrKnowledgeSoundness: (relIn := sumcheckRoundRelation κ L K β ℓ ℓ' h_l (𝓑:=𝓑) aOStmtIn 0) (relOut := aOStmtIn.toRelInput) (rbrKnowledgeError := coreInteractionRbrKnowledgeError (L:=L) (ℓ':=ℓ')) := by - sorry + let hAppend := OracleVerifier.append_rbrKnowledgeSoundness + (init := init) (impl := impl) + (rel₁ := sumcheckRoundRelation κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) aOStmtIn 0) + (rel₂ := sumcheckRoundRelation κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) aOStmtIn (Fin.last ℓ')) + (rel₃ := aOStmtIn.toRelInput) + (V₁ := sumcheckLoopOracleVerifier κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) aOStmtIn) + (V₂ := finalSumcheckVerifier κ L K β ℓ ℓ' h_l (𝓑 := 𝓑) aOStmtIn) + (rbrKnowledgeError₁ := fun _ => (2 : ℝ≥0) / Fintype.card L) + (rbrKnowledgeError₂ := finalSumcheckKnowledgeError (L := L)) + (h₁ := by + simpa using (sumcheckLoopOracleVerifier_rbrKnowledgeSoundness + (κ := κ) (L := L) (K := K) (β := β) (ℓ := ℓ) (ℓ' := ℓ') (h_l := h_l) + (𝓑 := 𝓑) (aOStmtIn := aOStmtIn) (init := init) (impl := impl)) + ) + (h₂ := by + simpa using (finalSumcheckOracleVerifier_rbrKnowledgeSoundness + (κ := κ) (L := L) (K := K) (β := β) (ℓ := ℓ) (ℓ' := ℓ') (h_l := h_l) + (𝓑 := 𝓑) (aOStmtIn := aOStmtIn) (init := init) (impl := impl)) + ) + exact OracleVerifier.rbrKnowledgeSoundness_of_eq_error + (init := init) (impl := impl) + (h_ε := by + intro i + change (2 : ℝ≥0) / Fintype.card L = + ((fun _ => (2 : ℝ≥0) / Fintype.card L) ⊕ᵥ finalSumcheckKnowledgeError (L := L)) + (ChallengeIdx.sumEquiv.symm i) + rcases h_idx : ChallengeIdx.sumEquiv.symm i with i₁ | i₂ + · simp + · rcases i₂ with ⟨j, hj⟩ + cases j using Fin.cases with + | zero => + exfalso + simp only [pSpecFinalSumcheckStep, ne_eq, reduceCtorEq, not_false_eq_true, + Fin.isValue, Matrix.cons_val_fin_one, Direction.not_P_to_V_eq_V_to_P] at hj + | succ j' => exact Fin.elim0 j' + ) + (h := hAppend) end LargeFieldReduction end diff --git a/ArkLib/ToMathlib/MvPolynomial/Equiv.lean b/ArkLib/ToMathlib/MvPolynomial/Equiv.lean new file mode 100644 index 000000000..734993e82 --- /dev/null +++ b/ArkLib/ToMathlib/MvPolynomial/Equiv.lean @@ -0,0 +1,34 @@ +import CompPoly.ToMathlib.MvPolynomial.Equiv +import CompPoly.ToMathlib.Finsupp.Fin + +section ToMvPolynomial + +variable {σ : Type*} {R : Type*} [CommRing R] + +/-- `Polynomial.toMvPolynomial` preserves non-zero property. -/ +lemma Polynomial.toMvPolynomial_ne_zero_iff (p : Polynomial R) (i : σ) : + (Polynomial.toMvPolynomial i) p ≠ 0 ↔ p ≠ 0 := by + constructor + · intro h hp + rw [hp, map_zero] at h + exact h rfl + · intro hp h + apply hp + rw [← map_zero (Polynomial.toMvPolynomial i)] at h + exact (Polynomial.toMvPolynomial_injective i) h + +/-- The total degree of `toMvPolynomial p` is at most the natural degree of `p`. -/ +lemma Polynomial.toMvPolynomial_totalDegree_le [Nontrivial R] (p : Polynomial R) (i : σ) : + ((Polynomial.toMvPolynomial i) p).totalDegree ≤ p.natDegree := by + rw [p.as_sum_support] + rw [map_sum] + refine MvPolynomial.totalDegree_finsetSum_le ?_ + intro n hn + simp only [Polynomial.toMvPolynomial, Polynomial.aeval_monomial, MvPolynomial.algebraMap_eq] + rw [← Polynomial.as_sum_support p] + apply (MvPolynomial.totalDegree_mul _ _).trans + simp only [MvPolynomial.totalDegree_C, zero_add] + rw [MvPolynomial.totalDegree_X_pow] -- this requires [Nontrivial R] + exact Polynomial.le_natDegree_of_mem_supp n hn + +end ToMvPolynomial diff --git a/ArkLib/ToVCVio/Lemmas.lean b/ArkLib/ToVCVio/Lemmas.lean index e3651aaa3..07ba42306 100644 --- a/ArkLib/ToVCVio/Lemmas.lean +++ b/ArkLib/ToVCVio/Lemmas.lean @@ -801,4 +801,78 @@ by simp [StateT.run'] alias run'_pure_bind := run'_bind_pure +-- @[simp] +-- lemma mem_support_pure_stateT_run {ι σ α : Type} {spec : OracleSpec ι} (x : α) (s s' : σ) (y : α) : +-- (y, s') ∈ OracleComp.support ((pure x : StateT σ (OracleComp spec) α).run s) ↔ y = x ∧ s' = s := +-- by simp only [StateT.run_pure, OracleComp.support_pure, Set.mem_singleton_iff, Prod.mk.injEq] + +-- @[simp] +-- lemma mem_support_pure_state {ι σ α : Type} {spec : OracleSpec ι} (x : α) (s s' : σ) (y : α) : +-- (y, s') ∈ OracleComp.support ((pure x : StateT σ (OracleComp spec) α) s) ↔ y = x ∧ s' = s := by +-- apply mem_support_pure_stateT_run + +-- section MapLemmas + +-- variable {ι : Type} {spec : OracleSpec ι} + +-- /-- StateT.map distributes over if-then-else. -/ +-- @[simp] +-- theorem map_ite [Monad m] (f : α → β) (p : Prop) [Decidable p] +-- (ma ma' : StateT σ m α) (s : σ) : +-- (if p then ma else ma').map f s = +-- if p then ma.map f s else ma'.map f s := by +-- split_ifs <;> rfl + +-- /-- StateT.map over pure for OracleComp. -/ +-- @[simp] +-- theorem map_pure (f : α → β) (a : α) (s : σ) : +-- (pure a : StateT σ (OracleComp spec) α).map f s = +-- pure (f a, s) := by +-- rfl + +-- /-- StateT.map over failure for OracleComp. -/ +-- @[simp] +-- theorem map_failure (f : α → β) (s : σ) : +-- (failure : StateT σ (OracleComp spec) α).map f s = +-- failure := by +-- rfl + +-- /-- Support of StateT.map over pure. -/ +-- @[simp] +-- theorem support_map_pure (f : α → β) (a : α) (s : σ) : +-- ((pure a : StateT σ (OracleComp spec) α).map f s).support = +-- {(f a, s)} := by +-- simp only [map_pure, support_pure] + +-- /-- Support of StateT.map over failure. -/ +-- @[simp] +-- theorem support_map_failure (f : α → β) (s : σ) : +-- ((failure : StateT σ (OracleComp spec) α).map f s).support = +-- ∅ := by +-- simp only [map_failure, support_failure] + +-- /-- StateT.map with constant function over pure. Useful for `map (fun _ => result) (pure x)`. -/ +-- @[simp] +-- theorem map_const_pure {γ : Type u} (result : β) (x : γ) (s : σ) : +-- (pure x : StateT σ (OracleComp spec) γ).map (fun _ => result) s = +-- pure (result, s) := by +-- rfl + +-- /-- Support of StateT.map with constant function over pure. -/ +-- @[simp] +-- theorem support_map_const_pure {γ : Type u} (result : β) (x : γ) (s : σ) : +-- ((pure x : StateT σ (OracleComp spec) γ).map (fun _ => result) s).support = +-- {(result, s)} := by +-- simp only [map_const_pure, support_pure] + +-- /-- Support distributes over if-then-else for StateT with OracleComp. -/ +-- @[simp] +-- theorem support_ite (p : Prop) [Decidable p] +-- (ma ma' : StateT σ (OracleComp spec) α) (s : σ) : +-- ((if p then ma else ma') s).support = +-- if p then (ma s).support else (ma' s).support := by +-- split_ifs <;> rfl + +-- end MapLemmas + end StateT diff --git a/ArkLib/ToVCVio/Simulation.lean b/ArkLib/ToVCVio/Simulation.lean index 18db11652..69f19e8a4 100644 --- a/ArkLib/ToVCVio/Simulation.lean +++ b/ArkLib/ToVCVio/Simulation.lean @@ -13,6 +13,7 @@ import Mathlib.Data.ENNReal.Basic import VCVio.OracleComp.EvalDist import ArkLib.OracleReduction.OracleInterface import VCVio.EvalDist.Instances.OptionT +import ArkLib.Data.Probability.Instances /-! ## Monad-to-Logic Bridge Lemmas @@ -547,7 +548,6 @@ by congr 1 funext proverOutput rw [map_eq_bind_pure_comp] - ext simp only [Option.getM, map_eq_bind_pure_comp, OptionT.run_bind, Function.comp_apply, OptionT.run_pure] @@ -725,6 +725,21 @@ lemma support_simulateQ_run'_eq simp only [QueryImpl.mapQuery_query] at h_pair exact ⟨(x', s'), h_pair, h_y_sim'⟩ +/-- OptionT run-level wrapper of `support_simulateQ_run'_eq` (stateful implementation). -/ +@[simp] +lemma OptionT.support_run_simulateQ_run'_eq + {oSpec : OracleSpec ι} [oSpec.Fintype] [oSpec.Inhabited] {σ α : Type} + (impl : QueryImpl oSpec (StateT σ ProbComp)) + (oa : OptionT (OracleComp oSpec) α) (s : σ) + (hImplSupp : ∀ {β} (q : OracleQuery oSpec β) s, + Prod.fst <$> support ((QueryImpl.mapQuery impl q).run s) + = support (liftM q : OracleComp oSpec β)) : + support (m := ProbComp) (α := Option α) ((simulateQ impl oa).run' s) = + support (m := OracleComp oSpec) (α := Option α) oa := by + simpa using + (support_simulateQ_run'_eq (impl := impl) (oa := oa) (s := s) + (hImplSupp := hImplSupp)) + /-- OptionT-wrapper version of `neverFails_of_simulateQ` for option-valued computations. -/ lemma neverFails_of_simulateQ_mk {spec : OracleSpec ι} [spec.Fintype] [spec.Inhabited] @@ -1262,7 +1277,6 @@ lemma probFailure_forIn_of_relations Pr[⊥ | forIn l init f] = 0 := by -- Instead of using `probFailure_forIn_of_invariant` which has a weaker inductive hypothesis -- (it quantifies ∀ x ∈ l, losing the index information), we use a direct recursive helper. - -- Helper: Proves safety for a suffix `xs` starting at index `k`. -- k: The current index in the original list `l`. -- xs: The suffix of `l` remaining to process. @@ -1478,21 +1492,6 @@ lemma liftComp_forIn {ι ι' : Type} {spec : OracleSpec ι} {superSpec : OracleS congr; funext s cases s <;> simp only [liftComp_pure, forIn'_eq_forIn, ih] -/-- `OptionT` variant of `liftComp_forIn`. -/ -@[simp] -lemma OptionT.liftComp_forIn {ι ι' : Type _} {spec : OracleSpec ι} {superSpec : OracleSpec ι'} - [spec.Fintype] [superSpec.Fintype] - [MonadLift (OracleQuery spec) (OracleQuery superSpec)] - {α β : Type _} (l : List α) (init : β) - (f : α → β → OptionT (OracleComp spec) (ForInStep β)) : - (forIn l init f).liftComp superSpec = - forIn l init (fun a b ↦ OptionT.mk ((f a b).liftComp superSpec)) := by - induction l generalizing init with - | nil => rfl - | cons x xs ih => - simp only [forIn, List.forIn'_cons] - sorry - /-- Distributes `simulateQ` over a `forIn` loop. This allows us to verify the body of the loop under simulation. -/ @@ -1533,8 +1532,480 @@ lemma OptionT.simulateQ_forIn rfl | cons x xs ih => simp only [forIn, List.forIn'_cons] - sorry + change simulateQ so (OptionT.bind (f x init) (fun step => + match step with + | ForInStep.done b => pure b + | ForInStep.yield b => forIn' xs b (fun a' _ b => f a' b))) = + OptionT.bind (simulateQ so (f x init)) (fun step => + match step with + | ForInStep.done b => pure b + | ForInStep.yield b => forIn' xs b (fun a' _ b => simulateQ so (f a' b))) + rw [OptionT.simulateQ_bind] + apply bind_congr + intro step + cases step with + | none => simp + | some step => + cases step with + | done res => rfl + | yield res => simpa [forIn'_eq_forIn] using ih res + +/-- Stateful version of simulateQ_forIn. + Distributes simulation over a loop where the oracle implementation itself has state. -/ +@[simp] +lemma simulateQ_forIn_stateful_run_eq {ι : Type} {spec : OracleSpec ι} + {σ α β : Type} (impl : QueryImpl spec (StateT σ ProbComp)) + (l : List α) (init : β) (f : α → β → OracleComp spec (ForInStep β)) (s : σ) : + (simulateQ impl (forIn l init f)).run s = + (forIn l init (fun a b => simulateQ impl (f a b))).run s := by + induction l generalizing init s with + | nil => + -- Base case: both sides reduce to pure init + simp only [forIn, List.forIn'_nil, simulateQ_pure, StateT.run_pure] + | cons x xs ih => + -- Inductive case: x :: xs + simp only [forIn, List.forIn'_cons, simulateQ_bind, StateT.run_bind] + congr + funext pair + rcases pair with ⟨step, s'⟩ + cases step with + | done res => simp only [forIn'_eq_forIn, Function.comp_apply, simulateQ_pure, + StateT.run_pure] + | yield res => exact ih res s' + +/-- Distributes stateful simulation over forIn at the StateT level. -/ +@[simp] +lemma simulateQ_forIn_stateful_comp {ι : Type} {spec : OracleSpec ι} + {σ α β : Type} (impl : QueryImpl spec (StateT σ ProbComp)) + (l : List α) (init : β) (f : α → β → OracleComp spec (ForInStep β)) : + simulateQ impl (forIn l init f) = + forIn l init (fun a b => simulateQ impl (f a b)) := by + -- Proof is by induction on l, matching the structure of forIn unrolling + induction l generalizing init with + | nil => simp [forIn, simulateQ_pure] + | cons x xs ih => + simp only [forIn, List.forIn'_cons, simulateQ_bind] + congr; funext step + cases step with + | done res => simp [simulateQ_pure] + | yield res => exact ih res + +/-- Stateful `StateT` specialization of `OptionT.simulateQ_bind`. -/ +@[simp] +lemma OptionT.simulateQ_bind_stateful {ι : Type} {spec : OracleSpec ι} + {σ α β : Type} (impl : QueryImpl spec (StateT σ ProbComp)) + (mx : OptionT (OracleComp spec) α) (my : α → OptionT (OracleComp spec) β) : + simulateQ impl (OptionT.bind mx my) = + OptionT.bind (m := StateT σ ProbComp) (simulateQ impl mx) + (fun x => simulateQ impl (my x)) := by + change + simulateQ impl (mx >>= fun z => match z with | some a => my a | none => pure none) = + OptionT.bind (m := StateT σ ProbComp) (simulateQ impl mx) (fun x => simulateQ impl (my x)) + rw [_root_.simulateQ_bind] + simp only [OptionT.bind, OptionT.mk] + apply bind_congr + intro z + cases z <;> rfl + +/-- `OptionT` version of `simulateQ_forIn_stateful_comp`. -/ +@[simp] +lemma OptionT.simulateQ_forIn_stateful_comp {ι : Type} {spec : OracleSpec ι} + {σ α β : Type} (impl : QueryImpl spec (StateT σ ProbComp)) + (l : List α) (init : β) (f : α → β → OptionT (OracleComp spec) (ForInStep β)) : + simulateQ impl (forIn l init f : OptionT (OracleComp spec) β) = + (forIn l init (fun a b => simulateQ impl (f a b)) : + OptionT (StateT σ ProbComp) β) := by + induction l generalizing init with + | nil => + rfl + | cons x xs ih => + simp only [forIn, List.forIn'_cons] + change simulateQ impl (OptionT.bind (f x init) (fun step => + match step with + | ForInStep.done b => pure b + | ForInStep.yield b => forIn' xs b (fun a' _ b => f a' b))) = + OptionT.bind (m := StateT σ ProbComp) (simulateQ impl (f x init)) (fun step => + match step with + | ForInStep.done b => pure b + | ForInStep.yield b => forIn' xs b (fun a' _ b => simulateQ impl (f a' b))) + rw [OptionT.simulateQ_bind_stateful] + apply bind_congr + intro step + cases step with + | none => simp + | some step => + cases step with + | done res => rfl + | yield res => simpa [forIn'_eq_forIn] using ih res + +/-- **Loop Path Extraction**: + If a stateful forIn loop over PUnit reaches a final state, then for every element + in the list, there must exist a local start state and end state such that the + body of that iteration succeeded. + **Important:** this requires the loop body to be yield-only on support + (i.e. no early `.done`). -/ +lemma exists_path_of_mem_support_forIn_unit {σ α : Type} [spec.Fintype] + (l : List α) (f : α → PUnit → StateT σ ProbComp (ForInStep PUnit)) + (s_init s_final : σ) (u : PUnit) + (h_yield : ∀ (x : α) (s_pre : σ) (res_step : ForInStep PUnit × σ), + res_step ∈ ((f x PUnit.unit).run s_pre).support → + res_step.1 = ForInStep.yield PUnit.unit) + (h_mem : (u, s_final) ∈ ((forIn l PUnit.unit f).run s_init).support) : + ∀ x ∈ l, ∃ s_pre s_post, + (ForInStep.yield PUnit.unit, s_post) ∈ ((f x PUnit.unit).run s_pre).support := by + induction l generalizing s_init s_final u with + | nil => simp + | cons a t ih => + simp only [forIn, List.forIn'_cons] at h_mem + intro x hx + simp only [List.mem_cons] at hx + simp only [StateT.run_bind, support_bind, + Set.mem_iUnion, exists_prop, Prod.exists] at h_mem + obtain ⟨step, s_mid, h_step_mem, h_rest⟩ := h_mem + have h_y := h_yield a s_init (step, s_mid) h_step_mem + simp only at h_y; subst h_y + simp only [forIn'_eq_forIn] at h_rest + rcases hx with rfl | hx + · exact ⟨s_init, s_mid, h_step_mem⟩ + · exact ih s_mid s_final u h_rest x hx + +lemma OptionT.exists_path_of_mem_support_forIn_unit {σ α : Type} [spec.Fintype] + (l : List α) (f : α → PUnit → OptionT (StateT σ ProbComp) (ForInStep PUnit)) + (s_init s_final : σ) (u : PUnit) + (h_yield : ∀ (x : α) (s_pre : σ) (res_step : ForInStep PUnit × σ), + (some res_step.1, res_step.2) ∈ ((f x PUnit.unit).run s_pre).support → + res_step.1 = ForInStep.yield PUnit.unit) + (h_mem : (some u, s_final) ∈ ((forIn l PUnit.unit f).run s_init).support) : + ∀ x ∈ l, ∃ s_pre s_post, + (some (ForInStep.yield PUnit.unit), s_post) ∈ + ((f x PUnit.unit).run s_pre).support := by + induction l generalizing s_init s_final u with + | nil => simp + | cons a t ih => + simp only [forIn, List.forIn'_cons] at h_mem + rw [OptionT.run_bind] at h_mem + simp only [Option.elimM, OptionT.run] at h_mem + rw [show ∀ (m : StateT σ ProbComp _) + (g : _ → StateT σ ProbComp _) (s : σ), + (m >>= g) s = m.run s >>= fun p => (g p.1).run p.2 + from fun _ _ _ => rfl] at h_mem + rw [_root_.mem_support_bind_iff] at h_mem + obtain ⟨⟨opt_step, s_mid⟩, h_step_mem, h_rest⟩ := h_mem + cases h_opt : opt_step with + | none => + simp [h_opt] at h_rest + | some step => + have h_step_some_mem : (some step, s_mid) ∈ ((f a PUnit.unit).run s_init).support := by + simpa [h_opt] using h_step_mem + have h_step_yield : step = ForInStep.yield PUnit.unit := + h_yield a s_init (step, s_mid) h_step_some_mem + cases h_step_yield + simp only [h_opt, Option.elim, forIn'_eq_forIn] at h_rest + intro x hx + simp only [List.mem_cons] at hx + rcases hx with rfl | hx + · exact ⟨s_init, s_mid, h_step_some_mem⟩ + · exact ih s_mid s_final u h_rest x hx + +/-- **Stateful forIn: path + relation from support** (combines path extraction and +relation induction). +Given a stateful forIn loop, a relation `rel : Fin (l.length + 1) → β → σ → Prop`, base case and +step preservation, and a result `res` in the loop support, this lemma provides: +1. The final relation holds: `rel ⟨l.length, _⟩ res.1 res.2` +2. A constructive path: sequences `bs` and `ss` such that `(bs 0, ss 0) = (init, s)`, + `(bs ⟨l.length, _⟩, ss ⟨l.length, _⟩) = (res.1, res.2)`, each step + `(.yield (bs k.succ), ss k.succ)` is in the support of the body run from + `(bs k.castSucc, ss k.castSucc)`, and `rel k (bs k) (ss k)` for all `k`. +So you get both "exists_path_of_mem_support_forIn_unit"-style per-step membership and +"support_forIn_stateful_of_relations"-style relation at every index (including the final one). +**Note:** This also assumes the loop body is yield-only on support, so the loop does not stop +early via `.done`. +The loop's `.run` support is `Set (β × σ)` (the accumulated value and state); each body step's +support is `Set (ForInStep β × σ)`, hence `h_step` uses `ForInStep.state res_step.1`. -/ +@[simp] +lemma exists_rel_path_of_mem_support_forIn_stateful {ι : Type} {spec : OracleSpec ι} [spec.Fintype] + {α σ β : Type} (l : List α) (init : β) (f : α → β → StateT σ ProbComp (ForInStep β)) + (s : σ) + (rel : Fin (l.length + 1) → β → σ → Prop) + (h_start : rel 0 init s) + (h_step : ∀ (k : Fin l.length) (b : β) (s_curr : σ), + rel k.castSucc b s_curr → + ∀ res_step ∈ ((f (l.get k) b).run s_curr).support, + rel k.succ (ForInStep.state res_step.1) res_step.2) + (h_yield : ∀ (x : α) (b : β) (s_curr : σ) (res_step : ForInStep β × σ), + res_step ∈ ((f x b).run s_curr).support → + ∃ b', res_step.1 = ForInStep.yield b') + (res : β × σ) + (h_mem : res ∈ ((forIn l init f).run s).support) : + rel ⟨l.length, by omega⟩ res.1 res.2 ∧ + ∃ (bs : Fin (l.length + 1) → β) (ss : Fin (l.length + 1) → σ), + bs 0 = init ∧ ss 0 = s ∧ + bs ⟨l.length, by omega⟩ = res.1 ∧ ss ⟨l.length, by omega⟩ = res.2 ∧ + (∀ k : Fin l.length, + (ForInStep.yield (bs k.succ), ss k.succ) ∈ + ((f (l.get k) (bs k.castSucc)).run (ss k.castSucc)).support) ∧ + (∀ k : Fin (l.length + 1), rel k (bs k) (ss k)) := by + -- Helper: suffix induction parameterized by k (number of elements already processed). + -- Simultaneously constructs the relation proof and the path witnesses. + let rec aux (k : ℕ) (xs : List α) (b₀ : β) (s₀ : σ) + (h_suffix : l.drop k = xs) + (h_len : k + xs.length = l.length) + (h_rel : rel ⟨k, by omega⟩ b₀ s₀) + -- Accumulated path from step 0 to step k + (bs_acc : Fin (k + 1) → β) (ss_acc : Fin (k + 1) → σ) + (h_bs0 : bs_acc 0 = init) (h_ss0 : ss_acc 0 = s) + (h_bsk : bs_acc ⟨k, by omega⟩ = b₀) (h_ssk : ss_acc ⟨k, by omega⟩ = s₀) + (h_acc_steps : ∀ j : Fin k, (j.val < l.length) → + (ForInStep.yield (bs_acc ⟨j.val + 1, by omega⟩), ss_acc ⟨j.val + 1, by omega⟩) ∈ + ((f (l.get ⟨j.val, by omega⟩) (bs_acc ⟨j.val, by omega⟩)).run + (ss_acc ⟨j.val, by omega⟩)).support) + (h_acc_rels : ∀ j : Fin (k + 1), rel ⟨j.val, by omega⟩ (bs_acc j) (ss_acc j)) + (res' : β × σ) + (h_mem' : res' ∈ ((forIn xs b₀ f : StateT σ ProbComp β).run s₀).support) : + rel ⟨l.length, by omega⟩ res'.1 res'.2 ∧ + ∃ (bs : Fin (l.length + 1) → β) (ss : Fin (l.length + 1) → σ), + bs 0 = init ∧ ss 0 = s ∧ + bs ⟨l.length, by omega⟩ = res'.1 ∧ ss ⟨l.length, by omega⟩ = res'.2 ∧ + (∀ j : Fin l.length, + (ForInStep.yield (bs j.succ), ss j.succ) ∈ + ((f (l.get j) (bs j.castSucc)).run (ss j.castSucc)).support) ∧ + (∀ j : Fin (l.length + 1), rel j (bs j) (ss j)) := by + induction xs generalizing k b₀ s₀ bs_acc ss_acc with + | nil => + -- xs = [], so k = l.length + simp only [List.length_nil, add_zero] at h_len + have h_k_eq : k = l.length := h_len + have h_run : ((forIn ([] : List α) b₀ f : StateT σ ProbComp β).run s₀) = + pure (b₀, s₀) := rfl + rw [h_run, support_pure, Set.mem_singleton_iff] at h_mem' + subst h_mem' + subst h_k_eq + refine ⟨h_rel, ?_⟩ + -- Extend bs_acc and ss_acc to Fin (l.length + 1) — already the right size + exact ⟨bs_acc, ss_acc, h_bs0, h_ss0, h_bsk, h_ssk, + fun j => h_acc_steps ⟨j.val, by exact j.isLt⟩ j.isLt, + fun j => by exact h_acc_rels j⟩ + | cons y ys ih => + -- Unfold forIn for (y :: ys) + simp only [forIn, List.forIn'_cons, support_bind, Set.mem_iUnion, exists_prop, + StateT.run_bind] at h_mem' + obtain ⟨⟨step, s'⟩, h_step_sup, h_rest_sup⟩ := h_mem' + -- step must be yield + obtain ⟨b', h_yield_eq⟩ := h_yield y b₀ s₀ _ h_step_sup + subst h_yield_eq + -- Simplify match in h_rest_sup + simp only [ForInStep.casesOn] at h_rest_sup + -- y = l.get ⟨k, ...⟩ + have h_k_lt : k < l.length := by simp only [List.length_cons] at h_len; omega + have h_y_eq : y = l.get ⟨k, h_k_lt⟩ := by + have h_len_drop : 0 < (l.drop k).length := by rw [h_suffix]; exact Nat.zero_lt_succ _ + have : l[k]'h_k_lt = (l.drop k)[0]'h_len_drop := by + simp only [List.getElem_drop, Nat.add_zero] + simp only [List.get_eq_getElem, this, h_suffix, List.getElem_cons_zero] + -- From h_step: relation advances + have h_rel_next : rel ⟨k + 1, by omega⟩ b' s' := by + have h_app := h_step ⟨k, h_k_lt⟩ b₀ s₀ + h_rel + ⟨.yield b', s'⟩ + (by + have hn := h_step_sup + rw [← h_y_eq] + exact hn) + simp only [Fin.succ, ForInStep.state] at h_app + exact h_app + -- Extend accumulated path by one step + let bs_next : Fin (k + 1 + 1) → β := fun j => + if h : j.val ≤ k then bs_acc ⟨j.val, by omega⟩ else b' + let ss_next : Fin (k + 1 + 1) → σ := fun j => + if h : j.val ≤ k then ss_acc ⟨j.val, by omega⟩ else s' + have h_suffix_ys : l.drop (k + 1) = ys := by + rw [← List.drop_drop, h_suffix]; rfl + -- Apply IH + exact ih (k + 1) b' s' + h_suffix_ys + (by simp only [List.length_cons] at h_len; omega) + h_rel_next + bs_next ss_next + (by simp [bs_next, h_bs0]) + (by simp [ss_next, h_ss0]) + (by simp [bs_next]) + (by simp [ss_next]) + (by + intro j h_j_lt + by_cases hj : j.val < k + · have h_from_acc := h_acc_steps ⟨j.val, by omega⟩ (by omega) + simp only [bs_next, ss_next, + show j.val + 1 ≤ k from by omega, show j.val ≤ k from by omega, ↓reduceDIte] + exact h_from_acc + · have h_j_eq : j.val = k := by omega + simp only [h_j_eq, List.get_eq_getElem, le_refl, ↓reduceDIte, add_le_iff_nonpos_right, + nonpos_iff_eq_zero, one_ne_zero, bs_next, ss_next] + have hn := h_step_sup + rw [h_y_eq, h_bsk.symm, h_ssk.symm] at hn + exact hn) + (by + intro j + by_cases hj : j.val ≤ k + · simp only [bs_next, ss_next, hj, ↓reduceDIte] + exact h_acc_rels ⟨j.val, by omega⟩ + · have h_j_eq : j.val = k + 1 := by omega + have h_neg : ¬ (k + 1 ≤ k) := by omega + simp only [bs_next, ss_next, h_j_eq, h_neg, ↓reduceDIte] + exact h_rel_next) + h_rest_sup + -- Apply the helper starting at index 0 + exact aux 0 l init s rfl (by omega) h_start + (fun _ => init) (fun _ => s) rfl rfl rfl rfl + (fun j => by exact Fin.elim0 j) + (fun j => by have : j = 0 := Fin.eq_zero j; subst this; simpa using h_start) + res h_mem +/-- `OptionT` variant of `exists_rel_path_of_mem_support_forIn_stateful`. + +This keeps the same path/relation conclusion over `β`, while all support facts are +expressed through the `some` branch of `OptionT.run`. -/ +@[simp] +lemma OptionT.exists_rel_path_of_mem_support_forIn_stateful {ι : Type} {spec : OracleSpec ι} + [spec.Fintype] + {α σ β : Type} (l : List α) (init : β) + (f : α → β → OptionT (StateT σ ProbComp) (ForInStep β)) + (s : σ) + (rel : Fin (l.length + 1) → Option β → σ → Prop) + (h_start : rel 0 (some init) s) + (h_step : ∀ (k : Fin l.length) (b : β) (s_curr : σ), + rel k.castSucc (some b) s_curr → + ∀ (res_step : ForInStep β × σ), + (some res_step.1, res_step.2) ∈ ((f (l.get k) b).run s_curr).support → + rel k.succ (some (ForInStep.state res_step.1)) res_step.2) + (h_yield : ∀ (x : α) (b : β) (s_curr : σ) (res_step : ForInStep β × σ), + (some res_step.1, res_step.2) ∈ ((f x b).run s_curr).support → + ∃ b', res_step.1 = ForInStep.yield b') + (res : β × σ) + (h_mem : (some res.1, res.2) ∈ ((forIn l init f).run s).support) : + rel ⟨l.length, by omega⟩ (some res.1) res.2 ∧ + ∃ (bs : Fin (l.length + 1) → β) (ss : Fin (l.length + 1) → σ), + bs 0 = init ∧ ss 0 = s ∧ + bs ⟨l.length, by omega⟩ = res.1 ∧ ss ⟨l.length, by omega⟩ = res.2 ∧ + (∀ k : Fin l.length, + (some (ForInStep.yield (bs k.succ)), ss k.succ) ∈ + ((f (l.get k) (bs k.castSucc)).run (ss k.castSucc)).support) ∧ + (∀ k : Fin (l.length + 1), rel k (some (bs k)) (ss k)) := by + let rec aux (k : ℕ) (xs : List α) (b₀ : β) (s₀ : σ) + (h_suffix : l.drop k = xs) + (h_len : k + xs.length = l.length) + (h_rel : rel ⟨k, by omega⟩ (some b₀) s₀) + (bs_acc : Fin (k + 1) → β) (ss_acc : Fin (k + 1) → σ) + (h_bs0 : bs_acc 0 = init) (h_ss0 : ss_acc 0 = s) + (h_bsk : bs_acc ⟨k, by omega⟩ = b₀) (h_ssk : ss_acc ⟨k, by omega⟩ = s₀) + (h_acc_steps : ∀ j : Fin k, (j.val < l.length) → + (some (ForInStep.yield (bs_acc ⟨j.val + 1, by omega⟩)), ss_acc ⟨j.val + 1, by omega⟩) ∈ + ((f (l.get ⟨j.val, by omega⟩) (bs_acc ⟨j.val, by omega⟩)).run + (ss_acc ⟨j.val, by omega⟩)).support) + (h_acc_rels : ∀ j : Fin (k + 1), rel ⟨j.val, by omega⟩ (some (bs_acc j)) (ss_acc j)) + (res' : β × σ) + (h_mem' : (some res'.1, res'.2) ∈ + ((forIn xs b₀ f : OptionT (StateT σ ProbComp) β).run s₀).support) : + rel ⟨l.length, by omega⟩ (some res'.1) res'.2 ∧ + ∃ (bs : Fin (l.length + 1) → β) (ss : Fin (l.length + 1) → σ), + bs 0 = init ∧ ss 0 = s ∧ + bs ⟨l.length, by omega⟩ = res'.1 ∧ ss ⟨l.length, by omega⟩ = res'.2 ∧ + (∀ j : Fin l.length, + (some (ForInStep.yield (bs j.succ)), ss j.succ) ∈ + ((f (l.get j) (bs j.castSucc)).run (ss j.castSucc)).support) ∧ + (∀ j : Fin (l.length + 1), rel j (some (bs j)) (ss j)) := by + induction xs generalizing k b₀ s₀ bs_acc ss_acc with + | nil => + simp only [List.length_nil, add_zero] at h_len + have h_k_eq : k = l.length := h_len + have h_run : ((forIn ([] : List α) b₀ f : OptionT (StateT σ ProbComp) β).run s₀) = + pure (some b₀, s₀) := rfl + rw [h_run, support_pure, Set.mem_singleton_iff] at h_mem' + have h_res_eq : res'.1 = b₀ ∧ res'.2 = s₀ := by + simpa [Prod.mk.injEq, Option.some.injEq] using h_mem' + rcases h_res_eq with ⟨h_res1, h_res2⟩ + subst h_res1; subst h_res2 + subst h_k_eq + refine ⟨h_rel, ?_⟩ + exact ⟨bs_acc, ss_acc, h_bs0, h_ss0, h_bsk, h_ssk, + fun j => h_acc_steps ⟨j.val, by exact j.isLt⟩ j.isLt, + fun j => by exact h_acc_rels j⟩ + | cons y ys ih => + simp only [forIn, List.forIn'_cons] at h_mem' + rw [OptionT.run_bind] at h_mem' + simp only [Option.elimM, OptionT.run] at h_mem' + rw [show ∀ (m : StateT σ ProbComp _) + (g : _ → StateT σ ProbComp _) (s0 : σ), + (m >>= g) s0 = m.run s0 >>= fun p => (g p.1).run p.2 + from fun _ _ _ => rfl] at h_mem' + rw [_root_.mem_support_bind_iff] at h_mem' + obtain ⟨⟨opt_step, s'⟩, h_step_sup, h_rest_sup⟩ := h_mem' + cases h_opt : opt_step with + | none => + simp [h_opt] at h_rest_sup + | some step => + have h_step_some_mem : (some step, s') ∈ ((f y b₀).run s₀).support := by + simpa [h_opt] using h_step_sup + obtain ⟨b', h_yield_eq⟩ := h_yield y b₀ s₀ (step, s') h_step_some_mem + subst h_yield_eq + simp [h_opt] at h_rest_sup + have h_k_lt : k < l.length := by simp only [List.length_cons] at h_len; omega + have h_y_eq : y = l.get ⟨k, h_k_lt⟩ := by + have h_len_drop : 0 < (l.drop k).length := by rw [h_suffix]; exact Nat.zero_lt_succ _ + have : l[k]'h_k_lt = (l.drop k)[0]'h_len_drop := by + simp only [List.getElem_drop, Nat.add_zero] + simp only [List.get_eq_getElem, this, h_suffix, List.getElem_cons_zero] + have h_rel_next : rel ⟨k + 1, by omega⟩ (some b') s' := by + have h_app := h_step ⟨k, h_k_lt⟩ b₀ s₀ h_rel + (ForInStep.yield b', s') + (by + have hn := h_step_some_mem + rw [h_y_eq] at hn + exact hn) + simp [ForInStep.state, Fin.succ] at h_app + exact h_app + let bs_next : Fin (k + 1 + 1) → β := fun j => + if h : j.val ≤ k then bs_acc ⟨j.val, by omega⟩ else b' + let ss_next : Fin (k + 1 + 1) → σ := fun j => + if h : j.val ≤ k then ss_acc ⟨j.val, by omega⟩ else s' + have h_suffix_ys : l.drop (k + 1) = ys := by + rw [← List.drop_drop, h_suffix]; rfl + exact ih (k + 1) b' s' + h_suffix_ys + (by simp only [List.length_cons] at h_len; omega) + h_rel_next + bs_next ss_next + (by simp [bs_next, h_bs0]) + (by simp [ss_next, h_ss0]) + (by simp [bs_next]) + (by simp [ss_next]) + (by + intro j h_j_lt + by_cases hj : j.val < k + · have h_from_acc := h_acc_steps ⟨j.val, by omega⟩ (by omega) + simp only [bs_next, ss_next, show j.val + 1 ≤ k from by omega, + show j.val ≤ k from by omega, ↓reduceDIte] + exact h_from_acc + · have h_j_eq : j.val = k := by omega + simp only [h_j_eq, List.get_eq_getElem, le_refl, ↓reduceDIte, add_le_iff_nonpos_right, + nonpos_iff_eq_zero, one_ne_zero, bs_next, ss_next] + have hn := h_step_some_mem + rw [h_y_eq, h_bsk.symm, h_ssk.symm] at hn + exact hn) + (by + intro j + by_cases hj : j.val ≤ k + · simp only [bs_next, ss_next, hj, ↓reduceDIte] + exact h_acc_rels ⟨j.val, by omega⟩ + · have h_j_eq : j.val = k + 1 := by omega + have h_neg : ¬ (k + 1 ≤ k) := by omega + simp only [bs_next, ss_next, h_j_eq, h_neg, ↓reduceDIte] + exact h_rel_next) + h_rest_sup + exact aux 0 l init s rfl (by omega) h_start + (fun _ => init) (fun _ => s) rfl rfl rfl rfl + (fun j => by exact Fin.elim0 j) + (fun j => by have : j = 0 := Fin.eq_zero j; subst this; simpa using h_start) + res h_mem /-- Distributes `simulateQ` over `Vector.mapM`. TODO: This proof is non-trivial because `Vector.mapM` is implemented via an auxiliary @@ -1658,3 +2129,138 @@ lemma mem_support_vector_mapM_pure {α β : Type} {n : ℕ} simp only [Fin.getElem_fin, support_pure, Vector.getElem_map, Set.mem_singleton_iff] end ForInLemmas + + +/-! +## Probability Notation Bridge Lemmas + +This section contains lemmas to bridge between VCVio's `probEvent` notation `[p | oa]` +and ArkLib's `Pr_{...}[...]` PMF-based notation, enabling the use of probability +tools from `Instances.lean` (like Schwartz-Zippel) in security proofs. + +### Key Strategy + +Use `OracleComp.probEvent_bind_eq_tsum` to factor complex probability statements: +```lean +[q | oa >>= ob] = ∑' x : α, [= x | oa] * [q | ob x] +``` +-/ + +section NestedSimulateQSupport +open OracleComp OracleSpec OracleQuery SimOracle + +variable {ι : Type} {oSpec oSpec' : OracleSpec ι} + [oSpec.Fintype] [oSpec'.Fintype] + +omit [oSpec.Fintype] in +/-- **Support of simulateQ through bind with StateT** + +For stateful oracle implementations, the support of `(simulateQ impl oa >>= f).run s` can be +related to the spec support by unfolding through the monadic structure. + +This handles the case where we have a bind after simulateQ, which is common in verifier +executions that continue with additional stateful computations. -/ +lemma support_simulateQ_bind_run_eq + {σ α β : Type} + (impl : QueryImpl oSpec (StateT σ ProbComp)) + (oa : OracleComp oSpec α) (f : α → StateT σ ProbComp β) (s : σ) : + ((simulateQ impl oa >>= f).run s).support = + (do let ⟨x, s'⟩ ← (simulateQ impl oa).run s; (f x).run s').support := by + simp only [StateT.run]; rfl + +-- OptionT (StateT σ ProbComp) PUnit.{1} +/-- **Support of StateT bind (run form)** +Membership in `support ((m >>= g).run s)` is equivalent to: there exists `out_forIn ∈ support (m.run s)` +such that `x` is in the support of continuing with `g` from that result (i.e. `(g out_forIn.1).run out_forIn.2`). +Useful to "peel" the outer bind and get an existential over the forIn (or first part) outcome. -/ +lemma mem_support_StateT_bind_run {σ α β : Type} + (ma : StateT σ ProbComp α) (f : α → StateT σ ProbComp β) (s : σ) (x : β × σ) : + x ∈ ((ma >>= f).run s).support ↔ + ∃ (y : α) (s' : σ), (y, s') ∈ (ma.run s).support ∧ x ∈ ((f y).run s').support := by + simp only [StateT.run_bind, support_bind, Set.mem_iUnion, exists_prop, Prod.exists] + +-- StateT σ ProbComp (Option (ForInStep PUnit.{1})) +lemma OptionT.mem_support_StateT_bind_run {σ α β : Type} + (ma : StateT σ ProbComp α) (f : α → StateT σ ProbComp (Option β)) (s : σ) (x : Option (β) × σ) : + x ∈ ((ma >>= f).run s).support ↔ + ∃ (y : α) (s' : σ), (y, s') ∈ (ma.run s).support ∧ x ∈ ((f y).run s').support := by + simp only [StateT.run_bind, _root_.support_bind, Set.mem_iUnion, exists_prop, Prod.exists] + +lemma support_StateT_ite_apply {σ α : Type} + (ma ma' : StateT σ ProbComp α) (p : Prop) [Decidable p] (s : σ) : + support ((ite p ma ma') s) = ite p (support (ma s)) (support (ma' s)) := by + by_cases hp : p <;> simp [hp] + +end NestedSimulateQSupport + + +section QueryImplSimplification + +open ENNReal NNReal + +open OracleSpec OracleComp ProtocolSpec ProbComp QueryImpl +open scoped ProbabilityTheory + +variable {ι : Type} {oSpec : OracleSpec ι} [oSpec.Fintype] + {StmtIn WitIn StmtOut WitOut : Type} + {n : ℕ} {pSpec : ProtocolSpec n} + [∀ i, SampleableType (pSpec.Challenge i)] + [∀ i, Fintype (pSpec.Challenge i)] [∀ i, Inhabited (pSpec.Challenge i)] + {σ : Type} + +/-- **Simplification: QueryImpl append for Sum.inr queries (challenge queries)** + +When appending a `QueryImpl` with `challengeQueryImpl`, queries to `Sum.inr` (challenge queries) +are routed to `challengeQueryImpl`, which samples uniformly. + +This lemma simplifies `(impl ++ₛₒ challengeQueryImpl).impl (query (Sum.inr i) ())` to +show it samples uniformly from the challenge space. + +**Note**: The `++ₛₒ` operator implicitly lifts `challengeQueryImpl` from `ProbComp` to `StateT σ ProbComp`. +-/ +theorem QueryImpl_append_impl_inr_stateful + (impl : QueryImpl oSpec (StateT σ ProbComp)) + (i : pSpec.ChallengeIdx) (s : σ) : + ((QueryImpl.addLift impl challengeQueryImpl : + QueryImpl (oSpec + [pSpec.Challenge]ₒ) (StateT σ ProbComp)) (.inr ⟨i, ()⟩)) s = + (liftM (challengeQueryImpl ⟨i, ()⟩) : StateT σ ProbComp _).run s := by + rfl + +/-- **Simplification: QueryImpl append for Sum.inr queries (challenge queries) - run' version** + +Same as `QueryImpl_append_impl_inr_stateful` but using `run'` which discards the state. +-/ +theorem QueryImpl_append_impl_inr_stateful_run' + (impl : QueryImpl oSpec (StateT σ ProbComp)) + (i : pSpec.ChallengeIdx) (s : σ) : + (((QueryImpl.addLift impl challengeQueryImpl : + QueryImpl (oSpec + [pSpec.Challenge]ₒ) (StateT σ ProbComp)) (.inr ⟨i, ()⟩)).run') s = + (liftM (challengeQueryImpl ⟨i, ()⟩) : StateT σ ProbComp _).run' s := by + rfl + +/-- For challenge queries, `monadLift` on `OracleQuery` lands in the `.inr` branch. -/ +lemma addLift_challengeQueryImpl_input_run_eq_liftM_run + {σ α : Type} {pSpec : ProtocolSpec n} + [∀ i, SampleableType (pSpec.Challenge i)] + (impl : QueryImpl []ₒ (StateT σ ProbComp)) + (q : OracleQuery [pSpec.Challenge]ₒ α) + (s : σ) : + ((impl + QueryImpl.liftTarget (StateT σ ProbComp) (challengeQueryImpl (pSpec := pSpec))) + ((MonadLift.monadLift (query q.input) : + OracleQuery ([]ₒ + [pSpec.Challenge]ₒ) + ([pSpec.Challenge]ₒ.Range q.input)).input)).run s = + ((liftM (challengeQueryImpl (pSpec := pSpec) q.input)) : + StateT σ ProbComp ([pSpec.Challenge]ₒ.Range q.input)).run s := by rfl + +end QueryImplSimplification + +section MapLemmas + +variable {ι : Type} {spec : OracleSpec ι} {α β : Type} + +/-- Map over pure reduces to pure of the mapped value. -/ +@[simp] +lemma map_pure (f : α → β) (a : α) : + (f <$> pure a : OracleComp spec β) = pure (f a) := rfl + +end MapLemmas