diff --git a/src/protocol/sasl/scram-sha.ts b/src/protocol/sasl/scram-sha.ts index 78df39b8..60f7ecb7 100644 --- a/src/protocol/sasl/scram-sha.ts +++ b/src/protocol/sasl/scram-sha.ts @@ -1,4 +1,5 @@ -import { createHash, createHmac, pbkdf2Sync, randomBytes } from 'node:crypto' +import { createHash, createHmac, pbkdf2, randomBytes } from 'node:crypto' +import { promisify } from 'node:util' import { createPromisifiedCallback, kCallbackPromise, type CallbackWithPromise } from '../../apis/callbacks.ts' import { type SASLAuthenticationAPI, type SaslAuthenticateResponse } from '../../apis/security/sasl-authenticate-v2.ts' import { AuthenticationError } from '../../errors.ts' @@ -19,7 +20,7 @@ export interface ScramAlgorithmDefinition { export interface ScramCryptoModule { h: (definition: ScramAlgorithmDefinition, data: string | Buffer) => Buffer - hi: (definition: ScramAlgorithmDefinition, password: string, salt: Buffer, iterations: number) => Buffer + hi: (definition: ScramAlgorithmDefinition, password: string, salt: Buffer, iterations: number) => Promise hmac: (definition: ScramAlgorithmDefinition, key: Buffer, data: string | Buffer) => Buffer xor: (a: Buffer, b: Buffer) => Buffer } @@ -62,8 +63,10 @@ export function h (definition: ScramAlgorithmDefinition, data: string | Buffer): return createHash(definition.algorithm).update(data).digest() } -export function hi (definition: ScramAlgorithmDefinition, password: string, salt: Buffer, iterations: number): Buffer { - return pbkdf2Sync(password, salt, iterations, definition.keyLength, definition.algorithm) +const pbkdf2Async = promisify(pbkdf2) + +export function hi (definition: ScramAlgorithmDefinition, password: string, salt: Buffer, iterations: number): Promise { + return pbkdf2Async(password, salt, iterations, definition.keyLength, definition.algorithm) } export function hmac (definition: ScramAlgorithmDefinition, key: Buffer, data: string | Buffer): Buffer { @@ -143,37 +146,40 @@ function performAuthentication ( // ClientProof := ClientKey XOR ClientSignature // ServerKey := HMAC(SaltedPassword, "Server Key") // ServerSignature := HMAC(ServerKey, AuthMessage) - const saltedPassword = hi(definition, password, salt, iterations) - const clientKey = hmac(definition, saltedPassword, HMAC_CLIENT_KEY) - const storedKey = h(definition, clientKey) - const clientFinalMessageWithoutProof = `c=${GS2_HEADER_BASE64},r=${serverNonce}` - const authMessage = `${clientFirstMessageBare},${serverFirstMessage},${clientFinalMessageWithoutProof}` - const clientSignature = hmac(definition, storedKey, authMessage) - const clientProof = xor(clientKey, clientSignature) - const serverKey = hmac(definition, saltedPassword, HMAC_SERVER_KEY) - const serverSignature = hmac(definition, serverKey, authMessage) - - authenticateAPI(connection, Buffer.from(`${clientFinalMessageWithoutProof},p=${clientProof.toString('base64')}`), ( - error, - lastResponse - ) => { - if (error) { - callback(new AuthenticationError('SASL authentication failed.', { cause: error })) - return - } - - // Send the last message to the server - const lastData = parseParameters(lastResponse!.authBytes) - - if (lastData.e) { - callback(new AuthenticationError(lastData.e)) - return - } else if (lastData.v !== serverSignature.toString('base64')) { - callback(new AuthenticationError('Invalid server signature.')) - return - } - - callback(null, lastResponse) + hi(definition, password, salt, iterations).then(saltedPassword => { + const clientKey = hmac(definition, saltedPassword, HMAC_CLIENT_KEY) + const storedKey = h(definition, clientKey) + const clientFinalMessageWithoutProof = `c=${GS2_HEADER_BASE64},r=${serverNonce}` + const authMessage = `${clientFirstMessageBare},${serverFirstMessage},${clientFinalMessageWithoutProof}` + const clientSignature = hmac(definition, storedKey, authMessage) + const clientProof = xor(clientKey, clientSignature) + const serverKey = hmac(definition, saltedPassword, HMAC_SERVER_KEY) + const serverSignature = hmac(definition, serverKey, authMessage) + + authenticateAPI(connection, Buffer.from(`${clientFinalMessageWithoutProof},p=${clientProof.toString('base64')}`), ( + error, + lastResponse + ) => { + if (error) { + callback(new AuthenticationError('SASL authentication failed.', { cause: error })) + return + } + + // Send the last message to the server + const lastData = parseParameters(lastResponse!.authBytes) + + if (lastData.e) { + callback(new AuthenticationError(lastData.e)) + return + } else if (lastData.v !== serverSignature.toString('base64')) { + callback(new AuthenticationError('Invalid server signature.')) + return + } + + callback(null, lastResponse) + }) + }).catch(error => { + callback(new AuthenticationError('SASL authentication failed.', { cause: error })) }) }) } diff --git a/test/protocol/sasl/scram-sha.test.ts b/test/protocol/sasl/scram-sha.test.ts index 7135f476..e2253bc9 100644 --- a/test/protocol/sasl/scram-sha.test.ts +++ b/test/protocol/sasl/scram-sha.test.ts @@ -85,22 +85,22 @@ test('h should hash data using the algorithm from the definition', () => { }) // Test hi (PBKDF2) function -test('hi should derive key using PBKDF2', () => { +test('hi should derive key using PBKDF2', async () => { const sha256Def = ScramAlgorithms['SHA-256'] const password = 'password' const salt = Buffer.from('salt') const iterations = 1 // Should return a buffer of the expected length - const key = hi(sha256Def, password, salt, iterations) + const key = await hi(sha256Def, password, salt, iterations) strictEqual(key.length, sha256Def.keyLength, 'Key length should match algorithm definition') // Different iterations should produce different results - const key2 = hi(sha256Def, password, salt, 2) + const key2 = await hi(sha256Def, password, salt, 2) ok(!key.equals(key2), 'Different iterations should produce different keys') // Different definitions should produce different results - const key512 = hi(ScramAlgorithms['SHA-512'], password, salt, iterations) + const key512 = await hi(ScramAlgorithms['SHA-512'], password, salt, iterations) strictEqual(key512.length, ScramAlgorithms['SHA-512'].keyLength, 'Key length should match SHA-512 definition') ok(!key.equals(key512), 'Different algorithms should produce different keys') })