Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 41 additions & 35 deletions src/protocol/sasl/scram-sha.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { createHash, createHmac, pbkdf2Sync, randomBytes } from 'node:crypto'
import { createHash, createHmac, pbkdf2, randomBytes } from 'node:crypto'
import { promisify } from 'node:util'
import { createPromisifiedCallback, kCallbackPromise, type CallbackWithPromise } from '../../apis/callbacks.ts'
import { type SASLAuthenticationAPI, type SaslAuthenticateResponse } from '../../apis/security/sasl-authenticate-v2.ts'
import { AuthenticationError } from '../../errors.ts'
Expand All @@ -19,7 +20,7 @@ export interface ScramAlgorithmDefinition {

export interface ScramCryptoModule {
h: (definition: ScramAlgorithmDefinition, data: string | Buffer) => Buffer
hi: (definition: ScramAlgorithmDefinition, password: string, salt: Buffer, iterations: number) => Buffer
hi: (definition: ScramAlgorithmDefinition, password: string, salt: Buffer, iterations: number) => Promise<Buffer>
hmac: (definition: ScramAlgorithmDefinition, key: Buffer, data: string | Buffer) => Buffer
xor: (a: Buffer, b: Buffer) => Buffer
}
Expand Down Expand Up @@ -62,8 +63,10 @@ export function h (definition: ScramAlgorithmDefinition, data: string | Buffer):
return createHash(definition.algorithm).update(data).digest()
}

export function hi (definition: ScramAlgorithmDefinition, password: string, salt: Buffer, iterations: number): Buffer {
return pbkdf2Sync(password, salt, iterations, definition.keyLength, definition.algorithm)
const pbkdf2Async = promisify(pbkdf2)

export function hi (definition: ScramAlgorithmDefinition, password: string, salt: Buffer, iterations: number): Promise<Buffer> {
return pbkdf2Async(password, salt, iterations, definition.keyLength, definition.algorithm)
}

export function hmac (definition: ScramAlgorithmDefinition, key: Buffer, data: string | Buffer): Buffer {
Expand Down Expand Up @@ -143,37 +146,40 @@ function performAuthentication (
// ClientProof := ClientKey XOR ClientSignature
// ServerKey := HMAC(SaltedPassword, "Server Key")
// ServerSignature := HMAC(ServerKey, AuthMessage)
const saltedPassword = hi(definition, password, salt, iterations)
const clientKey = hmac(definition, saltedPassword, HMAC_CLIENT_KEY)
const storedKey = h(definition, clientKey)
const clientFinalMessageWithoutProof = `c=${GS2_HEADER_BASE64},r=${serverNonce}`
const authMessage = `${clientFirstMessageBare},${serverFirstMessage},${clientFinalMessageWithoutProof}`
const clientSignature = hmac(definition, storedKey, authMessage)
const clientProof = xor(clientKey, clientSignature)
const serverKey = hmac(definition, saltedPassword, HMAC_SERVER_KEY)
const serverSignature = hmac(definition, serverKey, authMessage)

authenticateAPI(connection, Buffer.from(`${clientFinalMessageWithoutProof},p=${clientProof.toString('base64')}`), (
error,
lastResponse
) => {
if (error) {
callback(new AuthenticationError('SASL authentication failed.', { cause: error }))
return
}

// Send the last message to the server
const lastData = parseParameters(lastResponse!.authBytes)

if (lastData.e) {
callback(new AuthenticationError(lastData.e))
return
} else if (lastData.v !== serverSignature.toString('base64')) {
callback(new AuthenticationError('Invalid server signature.'))
return
}

callback(null, lastResponse)
hi(definition, password, salt, iterations).then(saltedPassword => {
const clientKey = hmac(definition, saltedPassword, HMAC_CLIENT_KEY)
const storedKey = h(definition, clientKey)
const clientFinalMessageWithoutProof = `c=${GS2_HEADER_BASE64},r=${serverNonce}`
const authMessage = `${clientFirstMessageBare},${serverFirstMessage},${clientFinalMessageWithoutProof}`
const clientSignature = hmac(definition, storedKey, authMessage)
const clientProof = xor(clientKey, clientSignature)
const serverKey = hmac(definition, saltedPassword, HMAC_SERVER_KEY)
const serverSignature = hmac(definition, serverKey, authMessage)

authenticateAPI(connection, Buffer.from(`${clientFinalMessageWithoutProof},p=${clientProof.toString('base64')}`), (
error,
lastResponse
) => {
if (error) {
callback(new AuthenticationError('SASL authentication failed.', { cause: error }))
return
}

// Send the last message to the server
const lastData = parseParameters(lastResponse!.authBytes)

if (lastData.e) {
callback(new AuthenticationError(lastData.e))
return
} else if (lastData.v !== serverSignature.toString('base64')) {
callback(new AuthenticationError('Invalid server signature.'))
return
}

callback(null, lastResponse)
})
}).catch(error => {
callback(new AuthenticationError('SASL authentication failed.', { cause: error }))
})
})
}
Expand Down
8 changes: 4 additions & 4 deletions test/protocol/sasl/scram-sha.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -85,22 +85,22 @@ test('h should hash data using the algorithm from the definition', () => {
})

// Test hi (PBKDF2) function
test('hi should derive key using PBKDF2', () => {
test('hi should derive key using PBKDF2', async () => {
const sha256Def = ScramAlgorithms['SHA-256']
const password = 'password'
const salt = Buffer.from('salt')
const iterations = 1

// Should return a buffer of the expected length
const key = hi(sha256Def, password, salt, iterations)
const key = await hi(sha256Def, password, salt, iterations)
strictEqual(key.length, sha256Def.keyLength, 'Key length should match algorithm definition')

// Different iterations should produce different results
const key2 = hi(sha256Def, password, salt, 2)
const key2 = await hi(sha256Def, password, salt, 2)
ok(!key.equals(key2), 'Different iterations should produce different keys')

// Different definitions should produce different results
const key512 = hi(ScramAlgorithms['SHA-512'], password, salt, iterations)
const key512 = await hi(ScramAlgorithms['SHA-512'], password, salt, iterations)
strictEqual(key512.length, ScramAlgorithms['SHA-512'].keyLength, 'Key length should match SHA-512 definition')
ok(!key.equals(key512), 'Different algorithms should produce different keys')
})
Expand Down