diff --git a/src/SnackbarProvider.tsx b/src/SnackbarProvider.tsx index fc9afaf8..19980e9b 100644 --- a/src/SnackbarProvider.tsx +++ b/src/SnackbarProvider.tsx @@ -2,7 +2,7 @@ import React, { Component } from 'react'; import { createPortal } from 'react-dom'; import clsx from 'clsx'; import SnackbarContext from './SnackbarContext'; -import { MESSAGES, REASONS, originKeyExtractor, omitContainerKeys } from './utils/constants'; +import { MESSAGES, REASONS, originKeyExtractor, omitContainerKeys, merge, DEFAULTS } from './utils/constants'; import SnackbarItem from './SnackbarItem'; import SnackbarContainer from './SnackbarContainer'; import warning from './utils/warning'; @@ -51,6 +51,8 @@ class SnackbarProvider extends Component { enqueueSnackbar = (message: SnackbarMessage, { key, preventDuplicate, ...options }: OptionsObject = {}): SnackbarKey => { const hasSpecifiedKey = key || key === 0; const id = hasSpecifiedKey ? (key as SnackbarKey) : new Date().getTime() + Math.random(); + + const merger = merge(options, this.props, DEFAULTS); const snack: Snack = { key: id, ...options, @@ -58,9 +60,9 @@ class SnackbarProvider extends Component { open: true, entered: false, requestClose: false, - variant: options.variant || this.props.variant || 'default', - autoHideDuration: options.autoHideDuration || this.props.autoHideDuration || 5000, - anchorOrigin: options.anchorOrigin || this.props.anchorOrigin || { vertical: 'bottom', horizontal: 'left' }, + variant: merger('variant'), + anchorOrigin: merger('anchorOrigin'), + autoHideDuration: merger('autoHideDuration'), }; if (options.persist) { diff --git a/src/utils/constants.ts b/src/utils/constants.ts index d45a812c..0556e9d2 100644 --- a/src/utils/constants.ts +++ b/src/utils/constants.ts @@ -50,6 +50,30 @@ export const omitContainerKeys = (classes: SnackbarProviderProps['classes']): Sn Object.keys(classes).filter(key => !allClasses.container[key]).reduce((obj, key) => ({ ...obj, [key]: classes[key] }), {}) ); +export const DEFAULTS = { + variant: 'default', + autoHideDuration: 5000, + anchorOrigin: { + vertical: 'bottom', + horizontal: 'left', + }, +}; + +const numberOrNull = (numberish?: number | null) => ( + typeof numberish === 'number' || numberish === null +); + +// @ts-ignore +export const merge = (options, props, defaults) => (name: keyof Snack): any => { + if (name === 'autoHideDuration') { + if (numberOrNull(options.autoHideDuration)) return options.autoHideDuration; + if (numberOrNull(props.autoHideDuration)) return props.autoHideDuration; + return DEFAULTS.autoHideDuration; + } + + return options[name] || props[name] || defaults[name]; +}; + export const REASONS: { [key: string]: CloseReason } = { CLICKAWAY: 'clickaway', MAXSNACK: 'maxsnack',