diff --git a/packages/grid/x-data-grid-pro/src/hooks/features/columnResize/useGridColumnResize.tsx b/packages/grid/x-data-grid-pro/src/hooks/features/columnResize/useGridColumnResize.tsx index adb2fd329b9f2..ed0bcfbc2549a 100644 --- a/packages/grid/x-data-grid-pro/src/hooks/features/columnResize/useGridColumnResize.tsx +++ b/packages/grid/x-data-grid-pro/src/hooks/features/columnResize/useGridColumnResize.tsx @@ -132,6 +132,7 @@ export const useGridColumnResize = ( const colDefRef = React.useRef(); const colElementRef = React.useRef(); + const colGroupingElementRef = React.useRef(); const colCellElementsRef = React.useRef(); const theme = useTheme(); @@ -158,7 +159,7 @@ export const useGridColumnResize = ( colElementRef.current!.style.minWidth = `${newWidth}px`; colElementRef.current!.style.maxWidth = `${newWidth}px`; - colCellElementsRef.current!.forEach((element) => { + [...colCellElementsRef.current!, ...colGroupingElementRef.current!].forEach((element) => { const div = element as HTMLDivElement; let finalWidth: `${number}px`; @@ -252,6 +253,12 @@ export const useGridColumnResize = ( `[data-field="${colDef.field}"]`, )!; + colGroupingElementRef.current = Array.from( + apiRef.current.columnHeadersContainerElementRef?.current!.querySelectorAll( + `[data-fields~="${colDef.field}"]`, + ) ?? [], + ); + colCellElementsRef.current = findGridCellElementsFromCol( colElementRef.current, apiRef.current, diff --git a/packages/grid/x-data-grid/src/hooks/features/columnHeaders/useGridColumnHeaders.tsx b/packages/grid/x-data-grid/src/hooks/features/columnHeaders/useGridColumnHeaders.tsx index 0052030ad4bdf..4b4482c958106 100644 --- a/packages/grid/x-data-grid/src/hooks/features/columnHeaders/useGridColumnHeaders.tsx +++ b/packages/grid/x-data-grid/src/hooks/features/columnHeaders/useGridColumnHeaders.tsx @@ -32,6 +32,9 @@ import { getFirstColumnIndexToRender } from '../columns/gridColumnsUtils'; import { useGridVisibleRows } from '../../utils/useGridVisibleRows'; import { getRenderableIndexes } from '../virtualization/useGridVirtualScroller'; +// TODO: add the possibility to switch this value if needed for customization +const MERGE_EMPTY_CELLS = true; + const GridColumnHeaderRow = styled('div', { name: 'MuiDataGrid', slot: 'ColumnHeaderRow', @@ -56,15 +59,22 @@ const HeaderDepth = (props: HeaderDepthProp) => { const { depthInfo, headerGroupingRowHeight, ...other } = props; return ( - {depthInfo.map(({ name, width }, index) => ( + {depthInfo.map(({ name, width, fields }, index) => ( {name} @@ -76,6 +86,7 @@ const HeaderDepth = (props: HeaderDepthProp) => { interface HeaderInfo { name: string | null; width: number; + fields: string[]; } interface UseGridColumnHeadersProps { @@ -378,26 +389,44 @@ export const useGridColumnHeaders = (props: UseGridColumnHeadersProps) => { const headerToRender: HeaderInfo[][] = []; for (let depth = 0; depth < headerGroupingMaxDepth; depth += 1) { // TODO: deal with header starting/ending outside of the virtualization - const initialHeader: HeaderInfo[] = [{ name: null, width: 0 }]; + const initialHeader: HeaderInfo[] = []; const depthInfo = renderedColumns.reduce((aggregated, column) => { const lastItem = aggregated[aggregated.length - 1]; + if (column.headers && column.headers.length > depth) { - if (lastItem.name === column.headers[depth]) { + if (lastItem && lastItem.name === column.headers[depth]) { + // Merge with the previous columns return [ ...aggregated.slice(0, aggregated.length - 1), - { ...lastItem, width: lastItem.width + (column.width ?? 0) }, + { + ...lastItem, + width: lastItem.width + (column.width ?? 0), + fields: [...lastItem.fields, column.field], + }, ]; } - return [...aggregated, { name: column.headers[depth], width: column.width ?? 0 }]; + // Create a new grouping + return [ + ...aggregated, + { name: column.headers[depth], width: column.width ?? 0, fields: [column.field] }, + ]; } - if (lastItem.name === null) { + + // It is the first level for which their is no group + if (MERGE_EMPTY_CELLS && lastItem && lastItem.name === null) { + // We merge with previous column return [ ...aggregated.slice(0, aggregated.length - 1), - { ...lastItem, width: lastItem.width + (column.width ?? 0) }, + { + ...lastItem, + width: lastItem.width + (column.width ?? 0), + fields: [...lastItem.fields, column.field], + }, ]; } - return [...aggregated, { name: null, width: column.width ?? 0 }]; + // We create new cell with correct rowSpan + return [...aggregated, { name: null, width: column.width ?? 0, fields: [column.field] }]; }, initialHeader); headerToRender.push([...depthInfo]);