diff --git a/packages/zundo/__tests__/options.test.ts b/packages/zundo/__tests__/options.test.ts index 7c11175..7721534 100644 --- a/packages/zundo/__tests__/options.test.ts +++ b/packages/zundo/__tests__/options.test.ts @@ -435,7 +435,7 @@ describe('Middleware options', () => { store.temporal.getState() as TemporalStateWithInternals; expect(__internal).toBeDefined(); expect(__internal.handleUserSet).toBeInstanceOf(Function); - expect(__internal.onSave).toBeInstanceOf(Function); + expect(__internal.onSave).toBe(undefined); }); describe('onSave', () => { it('should call onSave cb without adding a new state when onSave is set by user', () => { @@ -452,6 +452,7 @@ describe('Middleware options', () => { act(() => { onSave(store.getState(), store.getState()); }); + expect(__internal.onSave).toBeInstanceOf(Function); expect(store.temporal.getState().pastStates.length).toBe(0); expect(console.error).toHaveBeenCalledTimes(1); }); diff --git a/packages/zundo/src/index.ts b/packages/zundo/src/index.ts index 5b914c8..cc1b7fe 100644 --- a/packages/zundo/src/index.ts +++ b/packages/zundo/src/index.ts @@ -27,54 +27,58 @@ declare module 'zustand/vanilla' { } } -type ZundoImpl = ( - config: StateCreator, - options: ZundoOptions, -) => StateCreator; +const zundoImpl = + ( + config: StateCreator, + { + partialize = (state: TState) => state, + equality, + onSave, + limit, + handleSet: userlandSetFactory = (handleSetCb) => handleSetCb, + } = {} as ZundoOptions, + ): StateCreator => + (set, get, _store) => { + type TState = ReturnType; + type StoreAddition = StoreApi>; -const zundoImpl: ZundoImpl = (config, baseOptions) => (set, get, _store) => { - type TState = ReturnType; - type StoreAddition = StoreApi>; + const temporalStore = createVanillaTemporal(set, get, { + partialize, + equality, + onSave, + limit, + }); - const options = { - partialize: (state: TState) => state, - handleSet: (handleSetCb: typeof set) => handleSetCb, - ...baseOptions, - }; - const { partialize, handleSet: userlandSetFactory } = options; - - const temporalStore = createVanillaTemporal(set, get, options); + const store = _store as Mutate< + StoreApi, + [['temporal', StoreAddition]] + >; + const { setState } = store; - const store = _store as Mutate< - StoreApi, - [['temporal', StoreAddition]] - >; - const { setState } = store; + // TODO: should temporal be only temporalStore.getState()? + // We can hide the rest of the store in the secret internals. + store.temporal = temporalStore; - // TODO: should temporal be only temporalStore.getState()? - // We can hide the rest of the store in the secret internals. - store.temporal = temporalStore; + const curriedUserLandSet = userlandSetFactory( + temporalStore.getState().__internal.handleUserSet, + ); - const curriedUserLandSet = userlandSetFactory( - temporalStore.getState().__internal.handleUserSet, - ); + const modifiedSetState: typeof setState = (state, replace) => { + const pastState = partialize(get()); + setState(state, replace); + curriedUserLandSet(pastState); + }; + store.setState = modifiedSetState; - const modifiedSetState: typeof setState = (state, replace) => { - const pastState = partialize(get()); - setState(state, replace); - curriedUserLandSet(pastState); - }; - store.setState = modifiedSetState; + const modifiedSetter: typeof set = (state, replace) => { + // Get most up to date state. Should this be the same as the state in the callback? + const pastState = partialize(get()); + set(state, replace); + curriedUserLandSet(pastState); + }; - const modifiedSetter: typeof set = (state, replace) => { - // Get most up to date state. Should this be the same as the state in the callback? - const pastState = partialize(get()); - set(state, replace); - curriedUserLandSet(pastState); + return config(modifiedSetter, get, _store); }; - return config(modifiedSetter, get, _store); -}; - export const temporal = zundoImpl as unknown as Zundo; -export type { ZundoOptions, Zundo, TemporalState }; \ No newline at end of file +export type { ZundoOptions, Zundo, TemporalState }; diff --git a/packages/zundo/src/temporal.ts b/packages/zundo/src/temporal.ts index f154294..2b1c305 100644 --- a/packages/zundo/src/temporal.ts +++ b/packages/zundo/src/temporal.ts @@ -1,18 +1,16 @@ import { createStore, type StoreApi } from 'zustand'; -import type { TemporalStateWithInternals, ZundoOptions } from './types'; +import type { TemporalStateWithInternals, WithRequired, ZundoOptions } from './types'; export const createVanillaTemporal = ( userSet: StoreApi['setState'], userGet: StoreApi['getState'], - baseOptions?: ZundoOptions, + { + partialize, + equality, + onSave, + limit, + } = {} as Omit, | 'partialize'>, 'handleSet'>, ) => { - const options = { - partialize: (state: TState) => state, - equality: (a: TState, b: TState) => false, - onSave: () => {}, - ...baseOptions, - }; - const { partialize, onSave, limit, equality } = options; return createStore>()((set, get) => { return { @@ -73,7 +71,7 @@ export const createVanillaTemporal = ( const currentState = partialize(userGet()); if ( trackingStatus === 'tracking' && - !equality(currentState, pastState) + !equality?.(currentState, pastState) ) { if (limit && ps.length >= limit) { ps.shift(); diff --git a/packages/zundo/src/types.ts b/packages/zundo/src/types.ts index 1ac0e0e..668b511 100644 --- a/packages/zundo/src/types.ts +++ b/packages/zundo/src/types.ts @@ -1,6 +1,6 @@ import type { StoreApi } from 'zustand'; -type onSave = (pastState: TState, currentState: TState) => void; +type onSave = ((pastState: TState, currentState: TState) => void) | undefined; export interface TemporalStateWithInternals { pastStates: TState[]; @@ -37,3 +37,6 @@ export type TemporalState = Omit< TemporalStateWithInternals, '__internal' >; + +// https://stackoverflow.com/a/69328045/9931154 +export type WithRequired = T & { [P in K]-?: T[P] }