Skip to content

Commit 65e70e7

Browse files
committed
MB-65972: Add support for auto-renewal JWTs for MemcachedConnection
Allow MemcachedConnection to keep a JWT builder around and update the JWT in use when the server returns AuthStale. Change-Id: Ia386459210e555dca51ce90e0a2993bc41729b22 Reviewed-on: https://review.couchbase.org/c/kv_engine/+/225359 Reviewed-by: Vesko Karaganev <vesko.karaganev@couchbase.com> Tested-by: Build Bot <build@couchbase.com>
1 parent 93f8d40 commit 65e70e7

File tree

5 files changed

+87
-15
lines changed

5 files changed

+87
-15
lines changed

json_web_token/builder.cc

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,12 @@
1818
namespace cb::jwt {
1919
class BuilderImpl : public Builder {
2020
public:
21-
BuilderImpl(nlohmann::json initial_header, nlohmann::json initial_payload)
21+
BuilderImpl(nlohmann::json initial_header,
22+
nlohmann::json initial_payload,
23+
std::optional<std::chrono::seconds> lifetime)
2224
: header(std::move(initial_header)),
23-
payload(std::move(initial_payload)) {
25+
payload(std::move(initial_payload)),
26+
lifetime(lifetime) {
2427
if (!header.contains("typ")) {
2528
header["typ"] = "JWT";
2629
}
@@ -61,30 +64,40 @@ class BuilderImpl : public Builder {
6164
payload[name] = value;
6265
}
6366
std::string build() override {
67+
if (lifetime.has_value()) {
68+
const auto now = std::chrono::system_clock::now();
69+
const auto expiration = now + lifetime.value();
70+
setIssuedAt(now);
71+
setExpiration(expiration);
72+
}
6473
return fmt::format("{}.{}",
6574
cb::base64url::encode(header.dump()),
6675
cb::base64url::encode(payload.dump()));
6776
}
6877

6978
nlohmann::json header;
7079
nlohmann::json payload;
80+
std::optional<std::chrono::seconds> lifetime;
7181
};
7282

7383
class PlainBuilderImpl : public BuilderImpl {
7484
public:
75-
explicit PlainBuilderImpl(nlohmann::json initial)
76-
: BuilderImpl({{"alg", "none"}}, std::move(initial)) {
85+
explicit PlainBuilderImpl(nlohmann::json initial,
86+
std::optional<std::chrono::seconds> lifetime)
87+
: BuilderImpl({{"alg", "none"}}, std::move(initial), lifetime) {
7788
}
7889

7990
[[nodiscard]] std::unique_ptr<Builder> clone() const override {
80-
return std::make_unique<PlainBuilderImpl>(payload);
91+
return std::make_unique<PlainBuilderImpl>(payload, lifetime);
8192
}
8293
};
8394

8495
class HS256BuilderImpl : public BuilderImpl {
8596
public:
86-
HS256BuilderImpl(std::string passphrase, nlohmann::json initial)
87-
: BuilderImpl({{"alg", "HS256"}}, std::move(initial)),
97+
HS256BuilderImpl(std::string passphrase,
98+
nlohmann::json initial,
99+
std::optional<std::chrono::seconds> lifetime)
100+
: BuilderImpl({{"alg", "HS256"}}, std::move(initial), lifetime),
88101
passphrase(std::move(passphrase)) {
89102
}
90103

@@ -96,22 +109,25 @@ class HS256BuilderImpl : public BuilderImpl {
96109
}
97110

98111
[[nodiscard]] std::unique_ptr<Builder> clone() const override {
99-
return std::make_unique<HS256BuilderImpl>(passphrase, payload);
112+
return std::make_unique<HS256BuilderImpl>(
113+
passphrase, payload, lifetime);
100114
}
101115

102116
protected:
103117
const std::string passphrase;
104118
};
105119

106-
std::unique_ptr<Builder> Builder::create(std::string_view alg,
107-
std::string_view passphrase,
108-
nlohmann::json payload) {
120+
std::unique_ptr<Builder> Builder::create(
121+
std::string_view alg,
122+
std::string_view passphrase,
123+
nlohmann::json payload,
124+
std::optional<std::chrono::seconds> lifetime) {
109125
if (alg == "HS256") {
110-
return std::make_unique<HS256BuilderImpl>(std::string(passphrase),
111-
std::move(payload));
126+
return std::make_unique<HS256BuilderImpl>(
127+
std::string(passphrase), std::move(payload), lifetime);
112128
}
113129
if (alg.empty() || alg == "none") {
114-
return std::make_unique<PlainBuilderImpl>(std::move(payload));
130+
return std::make_unique<PlainBuilderImpl>(std::move(payload), lifetime);
115131
}
116132
throw std::invalid_argument("Invalid Algorithm");
117133
}

json_web_token/builder.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <nlohmann/json.hpp>
1313
#include <chrono>
1414
#include <memory>
15+
#include <optional>
1516
#include <string_view>
1617

1718
namespace cb::jwt {
@@ -33,7 +34,8 @@ class Builder {
3334
[[nodiscard]] static std::unique_ptr<Builder> create(
3435
std::string_view alg = "none",
3536
std::string_view passphrase = {},
36-
nlohmann::json payload = nlohmann::json::object());
37+
nlohmann::json payload = nlohmann::json::object(),
38+
std::optional<std::chrono::seconds> lifetime = {});
3739
/// Set the expiration for the token to use
3840
virtual void setExpiration(std::chrono::system_clock::time_point exp) = 0;
3941
/// Specify when we can start using the token

protocol/connection/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ target_link_libraries(mc_client_connection PRIVATE
2424
platform
2525
${COUCHBASE_NETWORK_LIBS}
2626
PRIVATE
27+
json_web_token
2728
folly_io_callbacks
2829
cbcompress
2930
json_validator

protocol/connection/client_connection.cc

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <folly/io/IOBuf.h>
1818
#include <folly/io/async/AsyncSSLSocket.h>
1919
#include <json/syntax_validator.h>
20+
#include <json_web_token/builder.h>
2021
#include <mcbp/codec/dcp_snapshot_marker.h>
2122
#include <mcbp/codec/frameinfo.h>
2223
#include <mcbp/mcbp.h>
@@ -38,6 +39,8 @@
3839
#include <netdb.h>
3940
#include <netinet/tcp.h> // For TCP_NODELAY etc
4041
#endif
42+
#include <json_web_token/token.h>
43+
4144
#include <stdexcept>
4245
#include <string>
4346
#include <system_error>
@@ -954,6 +957,9 @@ std::unique_ptr<MemcachedConnection> MemcachedConnection::clone(
954957
if (!agent_name.empty()) {
955958
ret->setAgentName(std::move(agent_name));
956959
}
960+
if (tokenBuilder) {
961+
ret->tokenBuilder = tokenBuilder->clone();
962+
}
957963

958964
if (connect) {
959965
ret->connect();
@@ -1239,6 +1245,29 @@ void MemcachedConnection::recvResponse(BinprotResponse& response,
12391245
traceData = response.getTracingData();
12401246
}
12411247

1248+
void MemcachedConnection::setTokenBuilder(
1249+
std::unique_ptr<cb::jwt::Builder> builder) {
1250+
if (!ssl) {
1251+
throw std::logic_error(
1252+
"MemcachedConnection::setTokenBuilder: "
1253+
"SSL must be enabled to use JWT");
1254+
}
1255+
tokenBuilder = std::move(builder);
1256+
}
1257+
1258+
void MemcachedConnection::authenticateWithToken() {
1259+
auto token = tokenBuilder->build();
1260+
1261+
auto parsed = cb::jwt::Token::parse(tokenBuilder->build());
1262+
if (!parsed->payload.contains("sub")) {
1263+
throw std::logic_error(
1264+
"MemcachedConnection::authenticateWithToken: "
1265+
"The token must contain a sub field");
1266+
}
1267+
1268+
doSaslAuthenticate(parsed->payload["sub"], token, "OAUTHBEARER");
1269+
}
1270+
12421271
void MemcachedConnection::authenticate(
12431272
const std::string& user,
12441273
const std::optional<std::string>& password,
@@ -2257,6 +2286,16 @@ BinprotResponse MemcachedConnection::execute(
22572286
sendCommand(command);
22582287
recvResponse(response, command.getOp(), readTimeout);
22592288

2289+
if (response.getStatus() == cb::mcbp::Status::AuthStale &&
2290+
tokenBuilder) {
2291+
if (command.getOp() == cb::mcbp::ClientOpcode::SaslAuth) {
2292+
// we don't want to recursively do sasl auth
2293+
return true;
2294+
}
2295+
authenticateWithToken();
2296+
return false;
2297+
}
2298+
22602299
bool retry_by_tmpfail =
22612300
auto_retry_tmpfail &&
22622301
response.getStatus() == cb::mcbp::Status::Etmpfail;

protocol/connection/client_connection.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@
3232
#include <utility>
3333
#include <vector>
3434

35+
namespace cb::jwt {
36+
class Builder;
37+
}
3538
namespace cb::mcbp::request {
3639
class FrameInfo;
3740
}
@@ -436,6 +439,16 @@ class MemcachedConnection {
436439
const std::optional<std::string>& password = {},
437440
const std::string& mech = "PLAIN");
438441

442+
/**
443+
* Set the connection to use JWT (and auto renewal of tokens). Note
444+
* that JWT may only be used over TLS
445+
*
446+
* @param builder The builder used to build JWT tokens
447+
*/
448+
void setTokenBuilder(std::unique_ptr<cb::jwt::Builder> builder);
449+
450+
void authenticateWithToken();
451+
439452
/**
440453
* Create a bucket
441454
*
@@ -1225,6 +1238,7 @@ class MemcachedConnection {
12251238
std::string name;
12261239
std::string serverInterfaceUuid;
12271240
std::optional<std::chrono::microseconds> traceData;
1241+
std::unique_ptr<cb::jwt::Builder> tokenBuilder;
12281242

12291243
using Featureset = std::unordered_set<uint16_t>;
12301244

0 commit comments

Comments
 (0)