Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 108 additions & 66 deletions packages/react-wallet-kit/src/providers/client/Provider.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -348,10 +348,10 @@ export const ClientProvider: React.FC<ClientProviderProps> = ({
error instanceof TurnkeyError
? error
: new TurnkeyError(
"Facebook authentication failed",
TurnkeyErrorCodes.OAUTH_SIGNUP_ERROR,
error,
),
"Facebook authentication failed",
TurnkeyErrorCodes.OAUTH_SIGNUP_ERROR,
error,
),
);
}
});
Expand Down Expand Up @@ -414,10 +414,10 @@ export const ClientProvider: React.FC<ClientProviderProps> = ({
err instanceof TurnkeyError
? err
: new TurnkeyError(
"Discord authentication failed",
TurnkeyErrorCodes.OAUTH_SIGNUP_ERROR,
err,
),
"Discord authentication failed",
TurnkeyErrorCodes.OAUTH_SIGNUP_ERROR,
err,
),
);
}
}
Expand Down Expand Up @@ -494,10 +494,10 @@ export const ClientProvider: React.FC<ClientProviderProps> = ({
err instanceof TurnkeyError
? err
: new TurnkeyError(
"Twitter authentication failed",
TurnkeyErrorCodes.OAUTH_SIGNUP_ERROR,
err,
),
"Twitter authentication failed",
TurnkeyErrorCodes.OAUTH_SIGNUP_ERROR,
err,
),
);
}
}
Expand Down Expand Up @@ -920,7 +920,7 @@ export const ClientProvider: React.FC<ClientProviderProps> = ({
walletProviders: WalletProvider[],
onUpdateState: () => Promise<void>,
): Promise<() => void> {
if (walletProviders.length === 0) return () => {};
if (walletProviders.length === 0) return () => { };

const cleanups: Array<() => void> = [];

Expand All @@ -930,20 +930,20 @@ export const ClientProvider: React.FC<ClientProviderProps> = ({

const ethProviders = masterConfig?.walletConfig?.chains.ethereum?.native
? walletProviders.filter(
(provider) =>
provider.chainInfo.namespace === Chain.Ethereum &&
nativeOnly(provider) &&
provider.connectedAddresses.length > 0,
)
(provider) =>
provider.chainInfo.namespace === Chain.Ethereum &&
nativeOnly(provider) &&
provider.connectedAddresses.length > 0,
)
: [];

const solProviders = masterConfig?.walletConfig?.chains.solana?.native
? walletProviders.filter(
(provider) =>
provider.chainInfo.namespace === Chain.Solana &&
nativeOnly(provider) &&
provider.connectedAddresses.length > 0,
)
(provider) =>
provider.chainInfo.namespace === Chain.Solana &&
nativeOnly(provider) &&
provider.connectedAddresses.length > 0,
)
: [];

// WalletConnect is excluded from native event wiring. Instead,
Expand Down Expand Up @@ -3177,33 +3177,35 @@ export const ClientProvider: React.FC<ClientProviderProps> = ({
const { verifier, codeChallenge } = await generateChallengePair();
sessionStorage.setItem("discord_verifier", verifier);

// Construct Discord Auth URL
const discordAuthUrl = new URL(DISCORD_AUTH_URL);
discordAuthUrl.searchParams.set("client_id", clientId);
discordAuthUrl.searchParams.set("redirect_uri", redirectURI);
discordAuthUrl.searchParams.set("response_type", "code");
discordAuthUrl.searchParams.set("code_challenge", codeChallenge);
discordAuthUrl.searchParams.set("code_challenge_method", "S256");
discordAuthUrl.searchParams.set("scope", "identify email");
discordAuthUrl.searchParams.set(
"state",
`provider=discord&flow=${flow}&publicKey=${encodeURIComponent(publicKey)}&nonce=${nonce}`,
);
// Generate random state for CSRF protection
const randomState = crypto.randomUUID();
sessionStorage.setItem("discord_state", randomState);

// Build state string with all parameters
let state = `provider=discord&flow=${flow}&publicKey=${encodeURIComponent(publicKey)}&nonce=${nonce}&randomState=${randomState}`;

// Append additional state parameters
if (additionalParameters) {
const extra = Object.entries(additionalParameters)
.map(
([k, v]) => `${encodeURIComponent(k)}=${encodeURIComponent(v)}`,
)
.join("&");
if (extra) {
discordAuthUrl.searchParams.set(
"state",
discordAuthUrl.searchParams.get("state")! + `&${extra}`,
);
state += `&${extra}`;
}
}

// Construct Discord Auth URL
const discordAuthUrl = new URL(DISCORD_AUTH_URL);
discordAuthUrl.searchParams.set("client_id", clientId);
discordAuthUrl.searchParams.set("redirect_uri", redirectURI);
discordAuthUrl.searchParams.set("response_type", "code");
discordAuthUrl.searchParams.set("code_challenge", codeChallenge);
discordAuthUrl.searchParams.set("code_challenge_method", "S256");
discordAuthUrl.searchParams.set("scope", "identify email");
discordAuthUrl.searchParams.set("state", state);

if (openInPage) {
window.location.href = discordAuthUrl.toString();
return new Promise((_, reject) => {
Expand Down Expand Up @@ -3237,6 +3239,8 @@ export const ClientProvider: React.FC<ClientProviderProps> = ({
try {
if (authWindow.closed) {
clearInterval(interval);
sessionStorage.removeItem("discord_verifier");
sessionStorage.removeItem("discord_state");
reject(new Error("Authentication window was closed."));
return;
}
Expand All @@ -3246,24 +3250,40 @@ export const ClientProvider: React.FC<ClientProviderProps> = ({
const urlParams = new URLSearchParams(new URL(url).search);
const authCode = urlParams.get("code");
const stateParam = urlParams.get("state");
const sessionKey = stateParam
?.split("&")
.find((param) => param.startsWith("sessionKey="))
?.split("=")[1];

if (
authCode &&
stateParam &&
stateParam.includes("provider=discord")
) {
// Validate state to prevent CSRF attacks
const returnedRandomState = new URLSearchParams(stateParam).get("randomState");
const expectedRandomState = sessionStorage.getItem("discord_state");

if (!returnedRandomState || returnedRandomState !== expectedRandomState) {
authWindow.close();
clearInterval(interval);
sessionStorage.removeItem("discord_verifier");
sessionStorage.removeItem("discord_state");
reject(new TurnkeyError(
"OAuth state mismatch - possible CSRF attack",
TurnkeyErrorCodes.OAUTH_LOGIN_ERROR,
));
return;
}

authWindow.close();
clearInterval(interval);

const verifier = sessionStorage.getItem("discord_verifier");
if (!verifier) {
sessionStorage.removeItem("discord_state");
reject(new Error("Missing PKCE verifier"));
return;
}

const sessionKey = new URLSearchParams(stateParam).get("sessionKey");

client?.httpClient
.proxyOAuth2Authenticate({
provider: "OAUTH2_PROVIDER_DISCORD",
Expand All @@ -3275,6 +3295,7 @@ export const ClientProvider: React.FC<ClientProviderProps> = ({
})
.then((resp) => {
sessionStorage.removeItem("discord_verifier");
sessionStorage.removeItem("discord_state");

const oidcToken = resp.oidcToken;
if (params?.onOauthSuccess) {
Expand Down Expand Up @@ -3364,33 +3385,35 @@ export const ClientProvider: React.FC<ClientProviderProps> = ({
const { verifier, codeChallenge } = await generateChallengePair();
sessionStorage.setItem("twitter_verifier", verifier);

// Construct Twitter Auth URL
const twitterAuthUrl = new URL(X_AUTH_URL);
twitterAuthUrl.searchParams.set("client_id", clientId);
twitterAuthUrl.searchParams.set("redirect_uri", redirectURI);
twitterAuthUrl.searchParams.set("response_type", "code");
twitterAuthUrl.searchParams.set("code_challenge", codeChallenge);
twitterAuthUrl.searchParams.set("code_challenge_method", "S256");
twitterAuthUrl.searchParams.set("scope", "tweet.read users.read");
twitterAuthUrl.searchParams.set(
"state",
`provider=twitter&flow=${flow}&publicKey=${encodeURIComponent(publicKey)}&nonce=${nonce}`,
);
// Generate random state for CSRF protection
const randomState = crypto.randomUUID();
sessionStorage.setItem("twitter_state", randomState);

// Build state string with all parameters
let state = `provider=twitter&flow=${flow}&publicKey=${encodeURIComponent(publicKey)}&nonce=${nonce}&randomState=${randomState}`;

// Append additional state parameters
if (additionalParameters) {
const extra = Object.entries(additionalParameters)
.map(
([k, v]) => `${encodeURIComponent(k)}=${encodeURIComponent(v)}`,
)
.join("&");
if (extra) {
twitterAuthUrl.searchParams.set(
"state",
twitterAuthUrl.searchParams.get("state")! + `&${extra}`,
);
state += `&${extra}`;
}
}

// Construct Twitter Auth URL
const twitterAuthUrl = new URL(X_AUTH_URL);
twitterAuthUrl.searchParams.set("client_id", clientId);
twitterAuthUrl.searchParams.set("redirect_uri", redirectURI);
twitterAuthUrl.searchParams.set("response_type", "code");
twitterAuthUrl.searchParams.set("code_challenge", codeChallenge);
twitterAuthUrl.searchParams.set("code_challenge_method", "S256");
twitterAuthUrl.searchParams.set("scope", "tweet.read users.read");
twitterAuthUrl.searchParams.set("state", state);

if (openInPage) {
window.location.href = twitterAuthUrl.toString();
return new Promise((_, reject) => {
Expand Down Expand Up @@ -3424,6 +3447,8 @@ export const ClientProvider: React.FC<ClientProviderProps> = ({
try {
if (authWindow.closed) {
clearInterval(interval);
sessionStorage.removeItem("twitter_verifier");
sessionStorage.removeItem("twitter_state");
reject(new Error("Authentication window was closed."));
return;
}
Expand All @@ -3433,24 +3458,40 @@ export const ClientProvider: React.FC<ClientProviderProps> = ({
const urlParams = new URLSearchParams(new URL(url).search);
const authCode = urlParams.get("code");
const stateParam = urlParams.get("state");
const sessionKey = stateParam
?.split("&")
.find((param) => param.startsWith("sessionKey="))
?.split("=")[1];

if (
authCode &&
stateParam &&
stateParam.includes("provider=twitter")
) {
// Validate state to prevent CSRF attacks
const returnedRandomState = new URLSearchParams(stateParam).get("randomState");
const expectedRandomState = sessionStorage.getItem("twitter_state");

if (!returnedRandomState || returnedRandomState !== expectedRandomState) {
authWindow.close();
clearInterval(interval);
sessionStorage.removeItem("twitter_verifier");
sessionStorage.removeItem("twitter_state");
reject(new TurnkeyError(
"OAuth state mismatch - possible CSRF attack",
TurnkeyErrorCodes.OAUTH_LOGIN_ERROR,
));
return;
}

authWindow.close();
clearInterval(interval);

const verifier = sessionStorage.getItem("twitter_verifier");
if (!verifier) {
sessionStorage.removeItem("twitter_state");
reject(new Error("Missing PKCE verifier"));
return;
}

const sessionKey = new URLSearchParams(stateParam).get("sessionKey");

client?.httpClient
.proxyOAuth2Authenticate({
provider: "OAUTH2_PROVIDER_X",
Expand All @@ -3462,6 +3503,7 @@ export const ClientProvider: React.FC<ClientProviderProps> = ({
})
.then((resp) => {
sessionStorage.removeItem("twitter_verifier");
sessionStorage.removeItem("twitter_state");

const oidcToken = resp.oidcToken;
if (params?.onOauthSuccess) {
Expand Down Expand Up @@ -5624,7 +5666,7 @@ export const ClientProvider: React.FC<ClientProviderProps> = ({
let currentUrl = "";
try {
currentUrl = onRampWindow?.location.href || "";
} catch {}
} catch { }

if (
currentUrl &&
Expand Down Expand Up @@ -5652,7 +5694,7 @@ export const ClientProvider: React.FC<ClientProviderProps> = ({
cleanup();
try {
onRampWindow?.close();
} catch {}
} catch { }
setCompleted(true);
resolveAction();
}
Expand Down Expand Up @@ -5839,7 +5881,7 @@ export const ClientProvider: React.FC<ClientProviderProps> = ({
}
};

let cleanup = () => {};
let cleanup = () => { };
initializeWalletProviderListeners(walletProviders, handleUpdateState)
.then((fn) => {
cleanup = fn;
Expand Down