diff --git a/include/mls/crypto.h b/include/mls/crypto.h index 6d54db51..a4273b57 100644 --- a/include/mls/crypto.h +++ b/include/mls/crypto.h @@ -31,6 +31,8 @@ struct KeyAndNonce { bytes key; bytes nonce; + + TLS_SERIALIZABLE(key, nonce) }; // opaque HashReference; @@ -272,7 +274,7 @@ struct SignaturePrivateKey void set_public_key(CipherSuite suite); std::string to_jwk(CipherSuite suite) const; - TLS_SERIALIZABLE(data) + TLS_SERIALIZABLE(data, public_key) private: SignaturePrivateKey(bytes priv_data, bytes pub_data); diff --git a/include/mls/key_schedule.h b/include/mls/key_schedule.h index 85cd7aba..a4e295ee 100644 --- a/include/mls/key_schedule.h +++ b/include/mls/key_schedule.h @@ -20,6 +20,8 @@ struct HashRatchet size_t nonce_size; size_t secret_size; + TLS_SERIALIZABLE(suite, next_secret, next_generation, cache, key_size, nonce_size, secret_size); + // These defaults are necessary for use with containers HashRatchet() = default; HashRatchet(const HashRatchet& other) = default; @@ -51,13 +53,23 @@ struct SecretTree NodeIndex root; std::map secrets; size_t secret_size; + + TLS_SERIALIZABLE(suite, group_size, root, secrets, secret_size); + + friend tls::ostream& operator<<(tls::ostream& str, const SecretTree& obj); + friend tls::istream& operator>>(tls::istream& str, SecretTree& obj); }; +tls::ostream& +operator<<(tls::ostream& str, const SecretTree& obj); +tls::istream& +operator>>(tls::istream& str, SecretTree& obj); + using ReuseGuard = std::array; struct GroupKeySource { - enum struct RatchetType + enum struct RatchetType : uint8_t { handshake, application, @@ -89,14 +101,22 @@ struct GroupKeySource HashRatchet& chain(ContentType type, LeafIndex sender); static const std::array all_ratchet_types; + + TLS_SERIALIZABLE(suite, secret_tree, chains); + + friend tls::ostream& operator<<(tls::ostream& str, const GroupKeySource& obj); + friend tls::istream& operator>>(tls::istream& str, GroupKeySource& obj); }; +tls::ostream& +operator<<(tls::ostream& str, const GroupKeySource& obj); +tls::istream& +operator>>(tls::istream& str, GroupKeySource& obj); + struct KeyScheduleEpoch { -private: CipherSuite suite; -public: bytes joiner_secret; bytes epoch_secret; @@ -113,6 +133,8 @@ struct KeyScheduleEpoch HPKEPrivateKey external_priv; + TLS_SERIALIZABLE(suite, joiner_secret, epoch_secret, sender_data_secret, encryption_secret, exporter_secret, epoch_authenticator, external_secret, confirmation_key, confirmation_tag, membership_key, resumption_psk, init_secret, external_priv); + KeyScheduleEpoch() = default; // Full initializer, used by invited joiner @@ -191,6 +213,10 @@ struct TranscriptHash bytes confirmed; bytes interim; + TLS_SERIALIZABLE(suite, confirmed, interim); + + TranscriptHash() = default; + // For a new group TranscriptHash(CipherSuite suite_in); diff --git a/include/mls/state.h b/include/mls/state.h index ae5dc1ea..33def2f6 100644 --- a/include/mls/state.h +++ b/include/mls/state.h @@ -46,6 +46,9 @@ class State /// Constructors /// + // For deserialization purposes only + State() = default; + // Initialize an empty group State(bytes group_id, CipherSuite suite, @@ -267,6 +270,8 @@ class State using EpochRef = std::tuple; std::map _resumption_psks; + TLS_SERIALIZABLE(_suite, _group_id, _epoch, _tree, _tree_priv, _transcript_hash, _extensions, _key_schedule, _keys, _index, _identity_priv, _external_psks, _resumption_psks); + // Cache of Proposals and update secrets struct CachedProposal { @@ -467,6 +472,15 @@ class State bool has_path, const std::vector& psks, const std::optional& force_init_secret) const; + + // Friend operators for serialization + friend tls::ostream& operator<<(tls::ostream& str, const State& obj); + friend tls::istream& operator>>(tls::istream& str, State& obj); }; +tls::ostream& +operator<<(tls::ostream& str, const State& obj); +tls::istream& +operator>>(tls::istream& str, State& obj); + } // namespace MLS_NAMESPACE diff --git a/include/mls/treekem.h b/include/mls/treekem.h index 7825f43f..7d2ebb12 100644 --- a/include/mls/treekem.h +++ b/include/mls/treekem.h @@ -81,6 +81,8 @@ struct TreeKEMPrivateKey std::map path_secrets; std::map private_key_cache; + TLS_SERIALIZABLE(suite, index, update_secret, path_secrets, private_key_cache); + static TreeKEMPrivateKey solo(CipherSuite suite, LeafIndex index, HPKEPrivateKey leaf_priv); @@ -250,7 +252,7 @@ struct TreeKEMPublicKey OptionalNode& node_at(LeafIndex n); const OptionalNode& node_at(LeafIndex n) const; - TLS_SERIALIZABLE(nodes) + TLS_SERIALIZABLE(suite, size, nodes, hashes) #if ENABLE_TREE_DUMP void dump() const; @@ -291,6 +293,8 @@ struct TreeKEMPublicKey OptionalNode blank_node; friend struct TreeKEMPrivateKey; + friend tls::ostream& operator<<(tls::ostream& str, const TreeKEMPublicKey& obj); + friend tls::istream& operator>>(tls::istream& str, TreeKEMPublicKey& obj); }; tls::ostream& diff --git a/lib/tls_syntax/include/tls/tls_syntax.h b/lib/tls_syntax/include/tls/tls_syntax.h index a5c87202..6db99239 100644 --- a/lib/tls_syntax/include/tls/tls_syntax.h +++ b/lib/tls_syntax/include/tls/tls_syntax.h @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -67,6 +68,9 @@ class ostream template friend ostream& operator<<(ostream& out, const std::vector& data); + template + friend ostream& operator<<(ostream& out, const std::map& data); + friend struct varint; }; @@ -115,6 +119,9 @@ class istream template friend istream& operator>>(istream& in, std::vector& data); + template + friend istream& operator>>(istream& in, std::map& data); + friend struct varint; }; @@ -312,6 +319,98 @@ operator>>(istream& str, std::vector& vec) return str; } +// Map writer +template +ostream& +operator<<(ostream& str, const std::map& map) +{ + ostream temp; + for (const auto& [key, value] : map) { + temp << key << value; + } + + varint::encode(str, temp._buffer.size()); + str.write_raw(temp.bytes()); + + return str; +} + +// Map reader +template +istream& +operator>>(istream& str, std::map& map) +{ + auto size = uint64_t(0); + varint::decode(str, size); + if (size > str._buffer.size()) { + throw ReadError("Map is longer than remaining data"); + } + + istream r; + const auto size_diff = static_cast(size); + r._buffer = std::vector{ str._buffer.end() - size_diff, str._buffer.end() }; + + map.clear(); + while (r._buffer.size() > 0) { + K key{}; + V value{}; + r >> key >> value; + map[key] = value; + } + + str._buffer.erase(str._buffer.end() - size_diff, str._buffer.end()); + + return str; +} + +// Tuple writer helper +template +typename std::enable_if::type +write_tuple_elements(ostream&, const std::tuple&) +{ +} + +template +typename std::enable_if::type +write_tuple_elements(ostream& str, const std::tuple& t) +{ + str << std::get(t); + write_tuple_elements(str, t); +} + +// Tuple writer +template +ostream& +operator<<(ostream& str, const std::tuple& tuple) +{ + write_tuple_elements(str, tuple); + return str; +} + +// Tuple reader helper +template +typename std::enable_if::type +read_tuple_elements(istream&, std::tuple&) +{ +} + +template +typename std::enable_if::type +read_tuple_elements(istream& str, std::tuple& t) +{ + str >> std::get(t); + read_tuple_elements(str, t); +} + +// Tuple reader +template +istream& +operator>>(istream& str, std::tuple& tuple) +{ + read_tuple_elements(str, tuple); + return str; +} + // Abbreviations template std::vector diff --git a/src/crypto.cpp b/src/crypto.cpp index 5b4fe851..9ec64229 100644 --- a/src/crypto.cpp +++ b/src/crypto.cpp @@ -184,7 +184,7 @@ CipherSuite::get() const return ciphers_X448_CHACHA20POLY1305_SHA512_Ed448; #endif -#if !defined(P256_SHA256) +#if !defined(P256_SHA256) && defined(WITH_PQ) case ID::MLKEM768X25519_AES256GCM_SHA384_Ed25519: return ciphers_MLKEM768X25519_AES256GCM_SHA384_Ed25519; diff --git a/src/key_schedule.cpp b/src/key_schedule.cpp index 15169b2b..dc51acfb 100644 --- a/src/key_schedule.cpp +++ b/src/key_schedule.cpp @@ -576,4 +576,44 @@ operator==(const TranscriptHash& lhs, const TranscriptHash& rhs) return confirmed && interim; } +tls::ostream& +operator<<(tls::ostream& str, const SecretTree& obj) +{ + str << obj.suite + << obj.group_size + << obj.root + << obj.secrets + << obj.secret_size; + return str; +} + +tls::istream& +operator>>(tls::istream& str, SecretTree& obj) +{ + str >> obj.suite + >> obj.group_size + >> obj.root + >> obj.secrets + >> obj.secret_size; + return str; +} + +tls::ostream& +operator<<(tls::ostream& str, const GroupKeySource& obj) +{ + str << obj.suite + << obj.secret_tree + << obj.chains; + return str; +} + +tls::istream& +operator>>(tls::istream& str, GroupKeySource& obj) +{ + str >> obj.suite + >> obj.secret_tree + >> obj.chains; + return str; +} + } // namespace MLS_NAMESPACE diff --git a/src/state.cpp b/src/state.cpp index ab5e911f..1cd1fe74 100644 --- a/src/state.cpp +++ b/src/state.cpp @@ -2469,4 +2469,42 @@ State::successor(LeafIndex index, return next; } +tls::ostream& +operator<<(tls::ostream& str, const State& obj) +{ + str << obj._suite + << obj._group_id + << obj._epoch + << obj._tree + << obj._tree_priv + << obj._transcript_hash + << obj._extensions + << obj._key_schedule + << obj._keys + << obj._index + << obj._identity_priv + << obj._external_psks + << obj._resumption_psks; + return str; +} + +tls::istream& +operator>>(tls::istream& str, State& obj) +{ + str >> obj._suite + >> obj._group_id + >> obj._epoch + >> obj._tree + >> obj._tree_priv + >> obj._transcript_hash + >> obj._extensions + >> obj._key_schedule + >> obj._keys + >> obj._index + >> obj._identity_priv + >> obj._external_psks + >> obj._resumption_psks; + return str; +} + } // namespace MLS_NAMESPACE diff --git a/src/treekem.cpp b/src/treekem.cpp index 35899b7e..a38fb343 100644 --- a/src/treekem.cpp +++ b/src/treekem.cpp @@ -1327,9 +1327,12 @@ TreeKEMPublicKey::exists_in_tree(const SignaturePublicKey& key, tls::ostream& operator<<(tls::ostream& str, const TreeKEMPublicKey& obj) { - // Empty tree + str << obj.suite; + if (obj.size.val == 0) { - return str << std::vector{}; + str << std::vector{}; + str << obj.hashes; + return str; } LeafIndex cut = LeafIndex{ obj.size.val - 1 }; @@ -1345,15 +1348,24 @@ operator<<(tls::ostream& str, const TreeKEMPublicKey& obj) view.at(i.val) = obj.nodes.at(i); } - return str << view; + str << view; + str << obj.hashes; + return str; } tls::istream& operator>>(tls::istream& str, TreeKEMPublicKey& obj) { + // Deserialize the cipher suite first + str >> obj.suite; + // Read the node list std::vector nodes; str >> nodes; + + // Read the hashes map + str >> obj.hashes; + if (nodes.empty()) { return str; }