diff --git a/src/datasource/__tests__/loaders.test.ts b/src/datasource/__tests__/loaders.test.ts index 55288465..cecf5a10 100644 --- a/src/datasource/__tests__/loaders.test.ts +++ b/src/datasource/__tests__/loaders.test.ts @@ -162,6 +162,22 @@ describe(LoaderFactory, () => { }) expect(await loader.load('zzz')).toMatchObject(dummyRow) }) + + it('accepts a query options function', async () => { + const dummyRow = { id: 999, name: 'zzz', code: 'zzz' } + const getData = jest.fn((): [DummyRowType] => [dummyRow]) + const loader = factory.create('name', { getData }, () => ({ limit: 1 })) + await loader.load('zzz') + expect(getData).toHaveBeenCalledWith( + expect.anything(), + expect.anything(), + expect.anything(), + expect.anything(), + expect.objectContaining({ + limit: 1, + }) + ) + }) }) describe('createMulti', () => { @@ -235,6 +251,24 @@ describe(LoaderFactory, () => { ] `) }) + + it('accepts a query options function', async () => { + const dummyRow = { id: 999, name: 'zzz', code: 'zzz' } + const getData = jest.fn((): [DummyRowType] => [dummyRow]) + const loader = factory.createMulti(['name', 'code'], { getData }, () => ({ + limit: 1, + })) + await loader.load({ name: 'zzz', code: 'zzz' }) + expect(getData).toHaveBeenCalledWith( + expect.anything(), + expect.anything(), + expect.anything(), + expect.anything(), + expect.objectContaining({ + limit: 1, + }) + ) + }) }) describe('auto-priming loaders', () => { diff --git a/src/datasource/loaders/LoaderFactory.ts b/src/datasource/loaders/LoaderFactory.ts index 76c007b5..76fe3546 100644 --- a/src/datasource/loaders/LoaderFactory.ts +++ b/src/datasource/loaders/LoaderFactory.ts @@ -1,4 +1,5 @@ import DataLoader from 'dataloader' +import { QueryOptions } from '../queries/QueryBuilder' import { ExtendedDataLoader, @@ -69,7 +70,8 @@ export default class LoaderFactory { key: TColumnName, options: LoaderOptions & { multi: true - } + }, + queryOptions?: () => QueryOptions ): ExtendedDataLoader public create< TColumnName extends SearchableKeys & keyof TRowType & string @@ -78,7 +80,8 @@ export default class LoaderFactory { columnType: string, options: LoaderOptions & { multi: true - } + }, + queryOptions?: () => QueryOptions ): ExtendedDataLoader public create< TColumnName extends SearchableKeys & keyof TRowType & string @@ -86,7 +89,8 @@ export default class LoaderFactory { key: TColumnName, options?: LoaderOptions & { multi?: false - } + }, + queryOptions?: () => QueryOptions ): ExtendedDataLoader public create< TColumnName extends SearchableKeys & keyof TRowType & string @@ -95,7 +99,8 @@ export default class LoaderFactory { columnType?: string, options?: LoaderOptions & { multi?: false - } + }, + queryOptions?: () => QueryOptions ): ExtendedDataLoader public create< TColumnName extends SearchableKeys & keyof TRowType & string, @@ -103,16 +108,21 @@ export default class LoaderFactory { >( key: TColumnName, columnType?: string | LoaderOptions, - options?: LoaderOptions + options?: LoaderOptions | (() => QueryOptions), + queryOptions?: () => QueryOptions ): DataLoader { + if (typeof options === 'function') { + queryOptions = options + options = undefined + } if (typeof columnType === 'object') { options = columnType columnType = undefined - } else if (typeof options === 'undefined') { - options = {} } - const getData = options.getData || this.getData + const actualOptions = options || {} + + const getData = actualOptions.getData || this.getData const type: string = columnType || this.options.columnTypes[key] @@ -122,13 +132,16 @@ export default class LoaderFactory { callbackFn, autoPrime, primeLoaders, - } = options + } = actualOptions const loader = new DataLoader< TColType, TRowType[] | (TRowType | undefined) >(async (args: readonly TColType[]) => { - const data = await getData(args, key, type, loader, options) + const data = await getData(args, key, type, loader, { + ...actualOptions, + ...queryOptions?.(), + }) data.forEach((row, idx, arr) => { callbackFn && callbackFn(row, idx, arr) @@ -172,7 +185,8 @@ export default class LoaderFactory { key: TColumnNames, options: LoaderOptions & { multi: true - } + }, + queryOptions?: () => QueryOptions ): ExtendedDataLoader public createMulti< TColumnNames extends Array< @@ -186,7 +200,8 @@ export default class LoaderFactory { columnTypes: string[], options: LoaderOptions & { multi: true - } + }, + queryOptions?: () => QueryOptions ): ExtendedDataLoader public createMulti< TColumnNames extends Array< @@ -199,7 +214,8 @@ export default class LoaderFactory { key: TColumnNames, options?: LoaderOptions & { multi?: false - } + }, + queryOptions?: () => QueryOptions ): ExtendedDataLoader public createMulti< TColumnNames extends Array< @@ -213,7 +229,8 @@ export default class LoaderFactory { columnTypes?: string[], options?: LoaderOptions & { multi?: false - } + }, + queryOptions?: () => QueryOptions ): ExtendedDataLoader public createMulti< TColumnNames extends Array< @@ -225,16 +242,21 @@ export default class LoaderFactory { >( keys: TColumnNames, columnTypes?: string[] | LoaderOptions, - options?: LoaderOptions + options?: LoaderOptions | (() => QueryOptions), + queryOptions?: () => QueryOptions ): DataLoader { + if (typeof options === 'function') { + queryOptions = options + options = undefined + } if (typeof columnTypes === 'object' && !Array.isArray(columnTypes)) { options = columnTypes columnTypes = undefined - } else if (typeof options === 'undefined') { - options = {} } - const getData = options.getData || this.getDataMulti + const actualOptions = options || {} + + const getData = actualOptions.getData || this.getDataMulti const types: string[] = columnTypes || keys.map((key) => this.options.columnTypes[key]) @@ -252,7 +274,7 @@ export default class LoaderFactory { callbackFn, autoPrime, primeLoaders, - } = options + } = actualOptions const loader = new DataLoader< TBatchKey, @@ -265,7 +287,10 @@ export default class LoaderFactory { keys, types, loader, - options + { + ...actualOptions, + ...queryOptions?.(), + } ) data.forEach((row, idx, arr) => {