packages/fxa-shared/db/redis.ts (279 lines of code) (raw):

/* This Source Code Form is subject to the terms of the Mozilla Public * License, v. 2.0. If a copy of the MPL was not distributed with this * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ import { readdirSync, readFileSync } from 'fs'; import Redis from 'ioredis'; import { basename, extname, resolve } from 'path'; import { StatsD } from 'hot-shots'; import { ILogger } from '../log'; import { AccessToken as AccessToken } from './models/auth/access-token'; import { RefreshTokenMetadata } from './models/auth/refresh-token-meta-data'; import opentelemetry from '@opentelemetry/api'; const tracer = opentelemetry.trace.getTracer('redis-tracer'); const hex = require('buf').to.hex; export type Config = { enabled?: boolean; enableMetrics?: boolean; prefix?: string; recordLimit?: number; maxttl?: number | string; timeoutMs?: number; } & Redis.RedisOptions; interface ICustomRedisCache { getAccessToken(uid: string): Promise<any>; getAccessTokens(uid: string): Promise<any>; getSessionTokens(uid: string): Promise<any>; pruneSessionTokens(uid: string, tokenIds: string): Promise<any>; touchSessionToken(uid: string, token: any): Promise<any>; } // Type guard for ICustomRedisCache function isCustomRedisCache( redis: any ): redis is Redis.Redis & ICustomRedisCache { if ( 'getAccessToken' in redis && 'getAccessTokens' in redis && 'getSessionTokens' in redis && 'pruneSessionTokens' in redis && 'touchSessionToken' in redis ) { return true; } return false; } export class RedisShared { public readonly redis: Redis.Redis & ICustomRedisCache; protected get keyPrefix() { return this.config.keyPrefix; } protected get recordLimit() { return this.config.recordLimit; } protected get maxttl() { return this.config.maxttl; } protected get timeoutMs() { return this.config.timeoutMs || 1000; } constructor( protected readonly config: Config, protected readonly log?: ILogger, protected readonly metrics?: StatsD ) { if (!config.keyPrefix && config.prefix) { config.keyPrefix = config.prefix; } const redis = new Redis(config); // Listen to all client events redis.on('connect', () => { this.metrics?.increment('redis.connect'); }); redis.on('ready', () => { this.metrics?.increment('redis.ready'); }); redis.on('error', (err) => { this.metrics?.increment('redis.error'); log?.error('RedisShared', { msg: `RedisShared: Redis error encountered ${err}`, host: config.host, port: config.port, error: err, }); }); redis.on('close', () => { this.metrics?.increment('redis.close'); }); redis.on('reconnecting', () => { this.metrics?.increment('redis.reconnecting'); }); redis.on('end', () => { this.metrics?.increment('redis.end'); }); const scriptsDirectory = resolve(__dirname, 'luaScripts'); // Applies custom scripts which are turned into methods on // the redis object. this.defineCommands(redis, scriptsDirectory); // Invoke type guard to make sure custom scripts were loaded // properly. Fail hard otherwise. if (isCustomRedisCache(redis)) { this.redis = redis; } else { this.log?.warn('RedisShared', { msg: 'RedisShared: Missing scripts to fully define a customized redis instance.', }); throw new Error( 'Missing scripts to fully define a customized redis instance.' ); } } protected defineCommands(redis: Redis.Redis, directory: string) { this.getScriptNames(directory).forEach((name: string) => this.defineCommand(redis, name, directory) ); } protected resolveInMs( cancel: Promise<any>, ms: number, value?: any ): Promise<any> { return new Promise((resolve) => this.cacellableAction(() => resolve(value), ms, cancel) ); } protected async rejectInMs( cancel: Promise<any>, ms: number, err = new Error('redis timeout') ) { return new Promise((_, reject) => this.cacellableAction(() => reject(err), ms, cancel) ); } protected cacellableAction(cb: () => void, ms: number, cancel: Promise<any>) { var id = setTimeout(cb, ms); cancel.then(() => clearTimeout(id)); } private defineCommand( redis: Redis.Redis, scriptName: string, directory: string ) { const [name, numberOfKeys] = scriptName.split('_'); redis.defineCommand(name, { lua: this.readScript(directory, scriptName), numberOfKeys: +numberOfKeys, }); } private readScript(directory: string, name: string) { return readFileSync(resolve(directory, `${name}.lua`), { encoding: 'utf8', }); } private getScriptNames(directory: string) { const dir = resolve(directory); const scriptNames = readdirSync(dir, { withFileTypes: true }) .filter( (dirent: any) => dirent.isFile() && extname(dirent.name) === '.lua' ) .map((dirent: any) => basename(dirent.name, '.lua')); return scriptNames; } async close() { await this.redis.quit(); } async del(key: string): Promise<number | undefined> { const result = await this.redis.del(key); return result; } async getRefreshTokens(uid: Buffer | string) { this.metrics?.increment('redis.getRefreshTokens'); const span = tracer.startSpan('redis.getRefreshTokens'); try { const p1 = this.redis.hgetall(hex(uid)); const p2 = this.resolveInMs(p1, this.timeoutMs, {}); const tokens = await Promise.race([p1, p2]); span.setAttribute( 'redis.getRefreshTokens.tokens.length', Object.keys(tokens).length ); this.metrics?.histogram( 'redis.getRefreshTokens.tokens.length', Object.keys(tokens).length ); for (const id of Object.keys(tokens)) { tokens[id] = RefreshTokenMetadata.parse(tokens[id]); } return tokens; } catch (e) { this.metrics?.increment('redis.getRefreshTokens.error'); this.log?.error('RedisShared', { error: e }); return {}; } finally { span.end(); } } async pruneSessionTokens(uid: string, tokenIds: string[] = []) { this.metrics?.increment('redis.pruneSessionTokens'); const span = tracer.startSpan('redis.pruneSessionTokens'); const p1 = this.redis.pruneSessionTokens(uid, JSON.stringify(tokenIds)); const p2 = this.rejectInMs(p1, this.timeoutMs); const result = await Promise.race([p1, p2]); span.end(); return result; } async pruneRefreshTokens( uid: Buffer | String, tokenIdsToPrune: Buffer[] | string[] ) { this.metrics?.increment('redis.pruneRefreshTokens'); const span = tracer.startSpan('redis.pruneRefreshTokens'); const p1 = this.redis.hdel(hex(uid), ...tokenIdsToPrune.map((v) => hex(v))); const p2 = this.resolveInMs(p1, this.timeoutMs); const result = await Promise.race([p1, p2]); span.end(); return result; } async getSessionTokens(uid: string) { this.metrics?.increment('redis.getSessionTokens'); const span = tracer.startSpan('redis.getSessionTokens'); try { const p1 = this.redis.getSessionTokens(uid); const p2 = this.rejectInMs(p1, this.timeoutMs); const value = await Promise.race([p1, p2]); if (value?.length > 0) { span.setAttribute('redis.getSessionTokens.tokens.length', value.length); this.metrics?.histogram( 'redis.getSessionTokens.tokens.length', value.length ); } return JSON.parse(value as string); } catch (e) { this.log?.error('RedisShared', { error: e, }); return {}; } finally { span.end(); } } async getAccessTokens(uid: Buffer | String) { this.metrics?.increment('redis.getAccessTokens'); const span = tracer.startSpan('redis.getAccessTokens'); try { const values = await this.redis.getAccessTokens(hex(uid)); if (values?.length) { span.setAttribute('redis.getAccessTokens.tokens.length', values.length); this.metrics?.histogram( 'redis.getAccessTokens.tokens.length', values.length ); } return values.map((v: string) => AccessToken.parse(v)); } catch (e) { this.log?.error('RedisShared', { error: e, }); return []; } finally { span.end(); } } async getAccessToken(uid: Buffer | String) { this.metrics?.increment('redis.getAccessToken'); const span = tracer.startSpan('redis.getAccessToken'); try { const value = await this.redis.getAccessToken(hex(uid)); if (value) return AccessToken.parse(value); } catch (e) { this.log?.error('RedisShared', { error: e, }); } finally { span.end(); } return null; } async touchSessionToken(uid: string, token: any) { this.metrics?.increment('redis.touchSessionToken'); const span = tracer.startSpan('redis.touchSessionToken'); // remove keys with null values const json = JSON.stringify(token, (k, v) => (v == null ? undefined : v)); span.setAttribute('redis.touchSessionToken.json.size', json.length); const p1 = this.redis.touchSessionToken(uid, json); const p2 = this.resolveInMs(p1, this.timeoutMs); const value = await Promise.race([p1, p2]); span.end(); return value; } }