Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion include/mls/crypto.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ struct KeyAndNonce
{
bytes key;
bytes nonce;

TLS_SERIALIZABLE(key, nonce)
};

// opaque HashReference<V>;
Expand Down Expand Up @@ -272,7 +274,7 @@ struct SignaturePrivateKey
void set_public_key(CipherSuite suite);
std::string to_jwk(CipherSuite suite) const;

TLS_SERIALIZABLE(data)
TLS_SERIALIZABLE(data, public_key)

private:
SignaturePrivateKey(bytes priv_data, bytes pub_data);
Expand Down
32 changes: 29 additions & 3 deletions include/mls/key_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ struct HashRatchet
size_t nonce_size;
size_t secret_size;

TLS_SERIALIZABLE(suite, next_secret, next_generation, cache, key_size, nonce_size, secret_size);

// These defaults are necessary for use with containers
HashRatchet() = default;
HashRatchet(const HashRatchet& other) = default;
Expand Down Expand Up @@ -51,13 +53,23 @@ struct SecretTree
NodeIndex root;
std::map<NodeIndex, bytes> secrets;
size_t secret_size;

TLS_SERIALIZABLE(suite, group_size, root, secrets, secret_size);

friend tls::ostream& operator<<(tls::ostream& str, const SecretTree& obj);
friend tls::istream& operator>>(tls::istream& str, SecretTree& obj);
};

tls::ostream&
operator<<(tls::ostream& str, const SecretTree& obj);
tls::istream&
operator>>(tls::istream& str, SecretTree& obj);

using ReuseGuard = std::array<uint8_t, 4>;

struct GroupKeySource
{
enum struct RatchetType
enum struct RatchetType : uint8_t
{
handshake,
application,
Expand Down Expand Up @@ -89,14 +101,22 @@ struct GroupKeySource
HashRatchet& chain(ContentType type, LeafIndex sender);

static const std::array<RatchetType, 2> all_ratchet_types;

TLS_SERIALIZABLE(suite, secret_tree, chains);

friend tls::ostream& operator<<(tls::ostream& str, const GroupKeySource& obj);
friend tls::istream& operator>>(tls::istream& str, GroupKeySource& obj);
};

tls::ostream&
operator<<(tls::ostream& str, const GroupKeySource& obj);
tls::istream&
operator>>(tls::istream& str, GroupKeySource& obj);

struct KeyScheduleEpoch
{
private:
CipherSuite suite;

public:
bytes joiner_secret;
bytes epoch_secret;

Expand All @@ -113,6 +133,8 @@ struct KeyScheduleEpoch

HPKEPrivateKey external_priv;

TLS_SERIALIZABLE(suite, joiner_secret, epoch_secret, sender_data_secret, encryption_secret, exporter_secret, epoch_authenticator, external_secret, confirmation_key, confirmation_tag, membership_key, resumption_psk, init_secret, external_priv);

KeyScheduleEpoch() = default;

// Full initializer, used by invited joiner
Expand Down Expand Up @@ -191,6 +213,10 @@ struct TranscriptHash
bytes confirmed;
bytes interim;

TLS_SERIALIZABLE(suite, confirmed, interim);

TranscriptHash() = default;

// For a new group
TranscriptHash(CipherSuite suite_in);

Expand Down
14 changes: 14 additions & 0 deletions include/mls/state.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ class State
/// Constructors
///

// For deserialization purposes only
State() = default;

// Initialize an empty group
State(bytes group_id,
CipherSuite suite,
Expand Down Expand Up @@ -267,6 +270,8 @@ class State
using EpochRef = std::tuple<bytes, epoch_t>;
std::map<EpochRef, bytes> _resumption_psks;

TLS_SERIALIZABLE(_suite, _group_id, _epoch, _tree, _tree_priv, _transcript_hash, _extensions, _key_schedule, _keys, _index, _identity_priv, _external_psks, _resumption_psks);

// Cache of Proposals and update secrets
struct CachedProposal
{
Expand Down Expand Up @@ -467,6 +472,15 @@ class State
bool has_path,
const std::vector<PSKWithSecret>& psks,
const std::optional<bytes>& force_init_secret) const;

// Friend operators for serialization
friend tls::ostream& operator<<(tls::ostream& str, const State& obj);
friend tls::istream& operator>>(tls::istream& str, State& obj);
};

tls::ostream&
operator<<(tls::ostream& str, const State& obj);
tls::istream&
operator>>(tls::istream& str, State& obj);

} // namespace MLS_NAMESPACE
6 changes: 5 additions & 1 deletion include/mls/treekem.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ struct TreeKEMPrivateKey
std::map<NodeIndex, bytes> path_secrets;
std::map<NodeIndex, HPKEPrivateKey> private_key_cache;

TLS_SERIALIZABLE(suite, index, update_secret, path_secrets, private_key_cache);

static TreeKEMPrivateKey solo(CipherSuite suite,
LeafIndex index,
HPKEPrivateKey leaf_priv);
Expand Down Expand Up @@ -250,7 +252,7 @@ struct TreeKEMPublicKey
OptionalNode& node_at(LeafIndex n);
const OptionalNode& node_at(LeafIndex n) const;

TLS_SERIALIZABLE(nodes)
TLS_SERIALIZABLE(suite, size, nodes, hashes)

#if ENABLE_TREE_DUMP
void dump() const;
Expand Down Expand Up @@ -291,6 +293,8 @@ struct TreeKEMPublicKey
OptionalNode blank_node;

friend struct TreeKEMPrivateKey;
friend tls::ostream& operator<<(tls::ostream& str, const TreeKEMPublicKey& obj);
friend tls::istream& operator>>(tls::istream& str, TreeKEMPublicKey& obj);
};

tls::ostream&
Expand Down
99 changes: 99 additions & 0 deletions lib/tls_syntax/include/tls/tls_syntax.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <namespace.h>
#include <optional>
#include <stdexcept>
#include <tuple>
#include <vector>

#include <tls/compat.h>
Expand Down Expand Up @@ -67,6 +68,9 @@ class ostream
template<typename T>
friend ostream& operator<<(ostream& out, const std::vector<T>& data);

template<typename K, typename V>
friend ostream& operator<<(ostream& out, const std::map<K, V>& data);

friend struct varint;
};

Expand Down Expand Up @@ -115,6 +119,9 @@ class istream
template<typename T>
friend istream& operator>>(istream& in, std::vector<T>& data);

template<typename K, typename V>
friend istream& operator>>(istream& in, std::map<K, V>& data);

friend struct varint;
};

Expand Down Expand Up @@ -312,6 +319,98 @@ operator>>(istream& str, std::vector<T>& vec)
return str;
}

// Map writer
template<typename K, typename V>
ostream&
operator<<(ostream& str, const std::map<K, V>& map)
{
ostream temp;
for (const auto& [key, value] : map) {
temp << key << value;
}

varint::encode(str, temp._buffer.size());
str.write_raw(temp.bytes());

return str;
}

// Map reader
template<typename K, typename V>
istream&
operator>>(istream& str, std::map<K, V>& map)
{
auto size = uint64_t(0);
varint::decode(str, size);
if (size > str._buffer.size()) {
throw ReadError("Map is longer than remaining data");
}

istream r;
const auto size_diff = static_cast<ptrdiff_t>(size);
r._buffer = std::vector<uint8_t>{ str._buffer.end() - size_diff, str._buffer.end() };

map.clear();
while (r._buffer.size() > 0) {
K key{};
V value{};
r >> key >> value;
map[key] = value;
}

str._buffer.erase(str._buffer.end() - size_diff, str._buffer.end());

return str;
}

// Tuple writer helper
template<size_t I = 0, typename... Tp>
typename std::enable_if<I == sizeof...(Tp), void>::type
write_tuple_elements(ostream&, const std::tuple<Tp...>&)
{
}

template<size_t I = 0, typename... Tp>
typename std::enable_if<I < sizeof...(Tp), void>::type
write_tuple_elements(ostream& str, const std::tuple<Tp...>& t)
{
str << std::get<I>(t);
write_tuple_elements<I + 1, Tp...>(str, t);
}

// Tuple writer
template<typename... Tp>
ostream&
operator<<(ostream& str, const std::tuple<Tp...>& tuple)
{
write_tuple_elements(str, tuple);
return str;
}

// Tuple reader helper
template<size_t I = 0, typename... Tp>
typename std::enable_if<I == sizeof...(Tp), void>::type
read_tuple_elements(istream&, std::tuple<Tp...>&)
{
}

template<size_t I = 0, typename... Tp>
typename std::enable_if<I < sizeof...(Tp), void>::type
read_tuple_elements(istream& str, std::tuple<Tp...>& t)
{
str >> std::get<I>(t);
read_tuple_elements<I + 1, Tp...>(str, t);
}

// Tuple reader
template<typename... Tp>
istream&
operator>>(istream& str, std::tuple<Tp...>& tuple)
{
read_tuple_elements(str, tuple);
return str;
}

// Abbreviations
template<typename T>
std::vector<uint8_t>
Expand Down
2 changes: 1 addition & 1 deletion src/crypto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ CipherSuite::get() const
return ciphers_X448_CHACHA20POLY1305_SHA512_Ed448;
#endif

#if !defined(P256_SHA256)
#if !defined(P256_SHA256) && defined(WITH_PQ)
case ID::MLKEM768X25519_AES256GCM_SHA384_Ed25519:
return ciphers_MLKEM768X25519_AES256GCM_SHA384_Ed25519;

Expand Down
40 changes: 40 additions & 0 deletions src/key_schedule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -576,4 +576,44 @@ operator==(const TranscriptHash& lhs, const TranscriptHash& rhs)
return confirmed && interim;
}

tls::ostream&
operator<<(tls::ostream& str, const SecretTree& obj)
{
str << obj.suite
<< obj.group_size
<< obj.root
<< obj.secrets
<< obj.secret_size;
return str;
}

tls::istream&
operator>>(tls::istream& str, SecretTree& obj)
{
str >> obj.suite
>> obj.group_size
>> obj.root
>> obj.secrets
>> obj.secret_size;
return str;
}

tls::ostream&
operator<<(tls::ostream& str, const GroupKeySource& obj)
{
str << obj.suite
<< obj.secret_tree
<< obj.chains;
return str;
}

tls::istream&
operator>>(tls::istream& str, GroupKeySource& obj)
{
str >> obj.suite
>> obj.secret_tree
>> obj.chains;
return str;
}

} // namespace MLS_NAMESPACE
38 changes: 38 additions & 0 deletions src/state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2469,4 +2469,42 @@ State::successor(LeafIndex index,
return next;
}

tls::ostream&
operator<<(tls::ostream& str, const State& obj)
{
str << obj._suite
<< obj._group_id
<< obj._epoch
<< obj._tree
<< obj._tree_priv
<< obj._transcript_hash
<< obj._extensions
<< obj._key_schedule
<< obj._keys
<< obj._index
<< obj._identity_priv
<< obj._external_psks
<< obj._resumption_psks;
return str;
}

tls::istream&
operator>>(tls::istream& str, State& obj)
{
str >> obj._suite
>> obj._group_id
>> obj._epoch
>> obj._tree
>> obj._tree_priv
>> obj._transcript_hash
>> obj._extensions
>> obj._key_schedule
>> obj._keys
>> obj._index
>> obj._identity_priv
>> obj._external_psks
>> obj._resumption_psks;
return str;
}

} // namespace MLS_NAMESPACE
Loading