diff --git a/packages/material-ui/src/Select/Select.js b/packages/material-ui/src/Select/Select.js index 7a66fbb4001fbf..a16d4fc67bb268 100644 --- a/packages/material-ui/src/Select/Select.js +++ b/packages/material-ui/src/Select/Select.js @@ -4,6 +4,7 @@ import React from 'react'; import PropTypes from 'prop-types'; import SelectInput from './SelectInput'; import withStyles from '../styles/withStyles'; +import mergeClasses from '../styles/mergeClasses'; import ArrowDropDownIcon from '../internal/svg-icons/ArrowDropDown'; import Input from '../Input'; import { styles as nativeSelectStyles } from '../NativeSelect/NativeSelect'; @@ -32,19 +33,15 @@ function Select(props) { } = props; const inputComponent = native ? NativeSelectInput : SelectInput; - const inputNativeProps = { - children, - classes, - IconComponent, - type: undefined, // We render a select. We can ignore the type provided by the `Input`. - }; return React.cloneElement(input, { // Most of the logic is implemented in `SelectInput`. // The `Select` component is a simple API wrapper to expose something better to play with. inputComponent, inputProps: { - ...inputNativeProps, + children, + IconComponent, + type: undefined, // We render a select. We can ignore the type provided by the `Input`. ...(native ? {} : { @@ -59,6 +56,13 @@ function Select(props) { SelectDisplayProps, }), ...inputProps, + classes: inputProps + ? mergeClasses({ + baseClasses: classes, + newClasses: inputProps.classes, + Component: Select, + }) + : classes, ...(input ? input.props.inputProps : {}), }, ...other, diff --git a/packages/material-ui/src/Select/Select.test.js b/packages/material-ui/src/Select/Select.test.js index e18af1c8eb3cc1..7870def8b3ed82 100644 --- a/packages/material-ui/src/Select/Select.test.js +++ b/packages/material-ui/src/Select/Select.test.js @@ -9,14 +9,14 @@ describe(', children: [1, 2], }; before(() => { shallow = createShallow({ dive: true }); - classes = getClasses(); mount = createMount(); }); @@ -25,18 +25,35 @@ describe('); + const wrapper = shallow(); + const wrapper = shallow(, + ); + assert.deepEqual(wrapper.props().inputProps.classes, { + ...classes, + root: `${classes.root} root`, + }); + }); + }); + it('should be able to mount the component', () => { const wrapper = mount( - None diff --git a/packages/material-ui/src/styles/mergeClasses.js b/packages/material-ui/src/styles/mergeClasses.js new file mode 100644 index 00000000000000..e0eed2392c3300 --- /dev/null +++ b/packages/material-ui/src/styles/mergeClasses.js @@ -0,0 +1,41 @@ +import warning from 'warning'; +import getDisplayName from 'recompose/getDisplayName'; + +function mergeClasses(options = {}) { + const { baseClasses, newClasses, Component, noBase = false } = options; + + if (!newClasses) { + return baseClasses; + } + + return { + ...baseClasses, + ...Object.keys(newClasses).reduce((accumulator, key) => { + warning( + baseClasses[key] || noBase, + [ + `Material-UI: the key \`${key}\` ` + + `provided to the classes property is not implemented in ${getDisplayName(Component)}.`, + `You can only override one of the following: ${Object.keys(baseClasses).join(',')}`, + ].join('\n'), + ); + + warning( + !newClasses[key] || typeof newClasses[key] === 'string', + [ + `Material-UI: the key \`${key}\` ` + + `provided to the classes property is not valid for ${getDisplayName(Component)}.`, + `You need to provide a non empty string instead of: ${newClasses[key]}.`, + ].join('\n'), + ); + + if (newClasses[key]) { + accumulator[key] = `${baseClasses[key]} ${newClasses[key]}`; + } + + return accumulator; + }, {}), + }; +} + +export default mergeClasses; diff --git a/packages/material-ui/src/styles/withStyles.js b/packages/material-ui/src/styles/withStyles.js index 2e57c83e41fe56..2560d6d3da9f14 100644 --- a/packages/material-ui/src/styles/withStyles.js +++ b/packages/material-ui/src/styles/withStyles.js @@ -8,6 +8,7 @@ import contextTypes from 'react-jss/lib/contextTypes'; import { create } from 'jss'; import * as ns from 'react-jss/lib/ns'; import jssPreset from './jssPreset'; +import mergeClasses from './mergeClasses'; import createMuiTheme from './createMuiTheme'; import themeListener from './themeListener'; import createGenerateClassName from './createGenerateClassName'; @@ -165,44 +166,12 @@ const withStyles = (stylesOrCreator, options = {}) => Component => { } if (generate) { - if (this.props.classes) { - this.cacheClasses.value = { - ...this.cacheClasses.lastJSS, - ...Object.keys(this.props.classes).reduce((accumulator, key) => { - warning( - this.cacheClasses.lastJSS[key] || this.disableStylesGeneration, - [ - `Material-UI: the key \`${key}\` ` + - `provided to the classes property is not implemented in ${getDisplayName( - Component, - )}.`, - `You can only override one of the following: ${Object.keys( - this.cacheClasses.lastJSS, - ).join(',')}`, - ].join('\n'), - ); - - warning( - !this.props.classes[key] || typeof this.props.classes[key] === 'string', - [ - `Material-UI: the key \`${key}\` ` + - `provided to the classes property is not valid for ${getDisplayName( - Component, - )}.`, - `You need to provide a non empty string instead of: ${this.props.classes[key]}.`, - ].join('\n'), - ); - - if (this.props.classes[key]) { - accumulator[key] = `${this.cacheClasses.lastJSS[key]} ${this.props.classes[key]}`; - } - - return accumulator; - }, {}), - }; - } else { - this.cacheClasses.value = this.cacheClasses.lastJSS; - } + this.cacheClasses.value = mergeClasses({ + baseClasses: this.cacheClasses.lastJSS, + newClasses: this.props.classes, + Component, + noBase: this.disableStylesGeneration, + }); } return this.cacheClasses.value;