diff --git a/packages/core/src/aggregate.ts b/packages/core/src/aggregate.ts index f31ed1543..f2613cb8b 100644 --- a/packages/core/src/aggregate.ts +++ b/packages/core/src/aggregate.ts @@ -4,6 +4,7 @@ import { Aggregate, AggregateBy, Aggregates, + ColumnTypes, Items, Row, Rows, @@ -48,12 +49,23 @@ export function getItems(row: Row): string[] { * @param aggs - The aggregates object containing the order and values of the aggregates. * @param items - The items object containing the data items. * @param attributeColumns - The array of attribute columns to calculate the attributes from. + * @param columnTypes - The column type definitions. */ -function updateAggValues(aggs: Aggregates, items: Items, attributeColumns: string[]) { +function updateAggValues( + aggs: Aggregates, + items: Items, + attributeColumns: string[], + columnTypes: ColumnTypes, +) { aggs.order.forEach((aggId) => { aggs.values[aggId].atts = { dataset: { - ...getSixNumberSummary(items, getItems(aggs.values[aggId]), attributeColumns), + ...getSixNumberSummary( + items, + getItems(aggs.values[aggId]), + attributeColumns, + columnTypes, + ), }, derived: { deviation: aggs.values[aggId].atts.derived.deviation }, }; @@ -87,6 +99,7 @@ export const getAggSize = (row: Row) => { * @param level - The level of the aggregation. * @param items - The items to be aggregated. * @param attributeColumns - The attribute columns to be considered in the aggregation. + * @param columnTypes - The column type definitions. * @param parentPrefix - The parent prefix for the aggregated subsets. * @returns The aggregated subsets. */ @@ -95,6 +108,7 @@ function aggregateByDegree( level: number, items: Items, attributeColumns: string[], + columnTypes: ColumnTypes, parentPrefix: string, ) { if (subsets.order.length === 0) return subsets; @@ -152,7 +166,7 @@ function aggregateByDegree( relevantAggregate.atts.derived.deviation += subset.atts.derived.deviation; }); - updateAggValues(aggs, items, attributeColumns); + updateAggValues(aggs, items, attributeColumns, columnTypes); return aggs; } @@ -165,6 +179,7 @@ function aggregateByDegree( * @param level - The level of aggregation. * @param items - The items to be aggregated. * @param attributeColumns - The attribute columns used for aggregation. + * @param columnTypes - The column type definitions. * @param parentPrefix - The parent prefix used for aggregation. * @returns The aggregated subsets. */ @@ -174,6 +189,7 @@ function aggregateBySets( level: number, items: Items, attributeColumns: string[], + columnTypes: ColumnTypes, parentPrefix: string, ) { if (subsets.order.length === 0) return subsets; @@ -259,7 +275,7 @@ function aggregateBySets( }); }); - updateAggValues(aggs, items, attributeColumns); + updateAggValues(aggs, items, attributeColumns, columnTypes); return aggs; } @@ -271,6 +287,7 @@ function aggregateBySets( * @param level - The level of the aggregation. * @param items - The items to be aggregated. * @param attributeColumns - The attribute columns to be considered. + * @param columnTypes - The column type definitions. * @param parentPrefix - The parent prefix for the aggregation. * @returns The aggregated subsets. */ @@ -279,6 +296,7 @@ function aggregateByDeviation( level: number, items: Items, attributeColumns: string[], + columnTypes: ColumnTypes, parentPrefix: string, ) { if (subsets.order.length === 0) return subsets; @@ -347,7 +365,7 @@ function aggregateByDeviation( relevantAggregate.atts.derived.deviation += subset.atts.derived.deviation; }); - updateAggValues(aggs, items, attributeColumns); + updateAggValues(aggs, items, attributeColumns, columnTypes); return aggs; } @@ -361,6 +379,7 @@ function aggregateByDeviation( * @param level - The level of the aggregation. * @param items - The items associated with the subsets. * @param attributeColumns - The attribute columns to be considered for aggregation. + * @param columnTypes - The column type definitions. * @param parentPrefix - The prefix for the parent identifier. * @returns The aggregated subsets. */ @@ -371,6 +390,7 @@ function aggregateByOverlaps( level: number, items: Items, attributeColumns: string[], + columnTypes: ColumnTypes, parentPrefix: string, ) { if (subsets.order.length === 0) return subsets; @@ -451,7 +471,7 @@ function aggregateByOverlaps( }); }); - updateAggValues(aggs, items, attributeColumns); + updateAggValues(aggs, items, attributeColumns, columnTypes); return aggs; } @@ -465,6 +485,7 @@ function aggregateByOverlaps( * @param sets - The sets associated with the subsets. * @param items - The items associated with the subsets. * @param attributeColumns - The attribute columns to consider when aggregating by deviations. + * @param columnTypes - The column type definitions. * @param level - The level of aggregation (default: 1). * @param parentPrefix - The parent prefix for the aggregated subsets (default: ''). * @returns The aggregated subsets. @@ -476,15 +497,38 @@ function aggregateSubsets( sets: Sets, items: Items, attributeColumns: string[], + columnTypes: ColumnTypes, level: number = 1, parentPrefix: string = '', ) { if (aggregateBy === 'Degree') - return aggregateByDegree(subsets, level, items, attributeColumns, parentPrefix); + return aggregateByDegree( + subsets, + level, + items, + attributeColumns, + columnTypes, + parentPrefix, + ); if (aggregateBy === 'Sets') - return aggregateBySets(subsets, sets, level, items, attributeColumns, parentPrefix); + return aggregateBySets( + subsets, + sets, + level, + items, + attributeColumns, + columnTypes, + parentPrefix, + ); if (aggregateBy === 'Deviations') - return aggregateByDeviation(subsets, level, items, attributeColumns, parentPrefix); + return aggregateByDeviation( + subsets, + level, + items, + attributeColumns, + columnTypes, + parentPrefix, + ); if (aggregateBy === 'Overlaps') { return aggregateByOverlaps( subsets, @@ -493,6 +537,7 @@ function aggregateSubsets( level, items, attributeColumns, + columnTypes, parentPrefix, ); } @@ -508,6 +553,7 @@ function aggregateSubsets( * @param sets - The sets associated with the subsets. * @param items - The items associated with the subsets. * @param attributeColumns - The attribute columns to be considered during aggregation. + * @param columnTypes - The column type definitions. * @returns The aggregated rows. */ export function firstAggregation( @@ -517,6 +563,7 @@ export function firstAggregation( sets: Sets, items: Items, attributeColumns: string[], + columnTypes: ColumnTypes, ): Rows { return aggregateSubsets( subsets, @@ -525,6 +572,7 @@ export function firstAggregation( sets, items, attributeColumns, + columnTypes, ); } @@ -537,6 +585,7 @@ export function firstAggregation( * @param sets - The sets data. * @param items - The items data. * @param attributeColumns - The attribute columns to consider. + * @param columnTypes - The column type definitions. * @returns The aggregated result. */ export function secondAggregation( @@ -546,6 +595,7 @@ export function secondAggregation( sets: Sets, items: Items, attributeColumns: string[], + columnTypes: ColumnTypes, ) { const aggs: Aggregates = { values: {}, @@ -562,6 +612,7 @@ export function secondAggregation( sets, items, attributeColumns, + columnTypes, 2, `${agg.id}-`, ); diff --git a/packages/core/src/process.ts b/packages/core/src/process.ts index 085d6296b..c22d910ec 100644 --- a/packages/core/src/process.ts +++ b/packages/core/src/process.ts @@ -162,21 +162,29 @@ function processRawData(data: TableRow[], columns: ColumnTypes) { } /** - * Calculates the six-number summary for each attribute in the given items. + * Calculates the six-number summary for each numerical attribute in the given items. + * Categorical attributes are automatically filtered out. * * @param items - The items to calculate the summary for. * @param memberItems - The member items to consider. * @param attributeColumns - The attribute columns to calculate the summary for. - * @returns An object containing the six-number summary for each attribute. + * @param columnTypes - The column type definitions to filter numerical attributes. + * @returns An object containing the six-number summary for each numerical attribute. */ export function getSixNumberSummary( items: Items, memberItems: string[], attributeColumns: string[], + columnTypes: ColumnTypes, ): AttributeList { const attributes: AttributeList = {}; - attributeColumns.forEach((attribute) => { + // Only process numerical attributes; categorical attributes don't have meaningful numerical statistics + const numericalAttributes = attributeColumns.filter( + (attr) => columnTypes[attr] === 'number', + ); + + numericalAttributes.forEach((attribute) => { const values = memberItems .map((d) => items[d].atts[attribute] as number) .filter((val) => !Number.isNaN(val)); @@ -201,6 +209,7 @@ export function getSixNumberSummary( * @param setColumns - The array of column names representing the set columns. * @param items - The items object containing the data items. * @param attributeColumns - The array of column names representing the attribute columns. + * @param columnTypes - The column type definitions. * @returns The sets object containing the retrieved sets. */ function getSets( @@ -208,6 +217,7 @@ function getSets( setColumns: ColumnName[], items: Items, attributeColumns: ColumnName[], + columnTypes: ColumnTypes, ) { const setMembershipStatus: { [col: string]: SetMembershipStatus } = {}; @@ -225,7 +235,12 @@ function getSets( size: setMembership[col].length, setMembership: { ...setMembershipStatus, [col]: 'Yes' }, atts: { - dataset: getSixNumberSummary(items, setMembership[col], attributeColumns), + dataset: getSixNumberSummary( + items, + setMembership[col], + attributeColumns, + columnTypes, + ), derived: { deviation: 0 }, }, }; @@ -246,7 +261,7 @@ function getSets( export function process(data: TableRow[], columns: ColumnTypes): CoreUpsetData { const { items, setMembership, labelColumn, setColumns, attributeColumns } = processRawData(data, columns); - const sets = getSets(setMembership, setColumns, items, attributeColumns); + const sets = getSets(setMembership, setColumns, items, attributeColumns, columns); return { label: labelColumn, @@ -265,6 +280,7 @@ export function process(data: TableRow[], columns: ColumnTypes): CoreUpsetData { * @param sets - The sets used to calculate subsets. * @param vSets - The vSets used to calculate subsets. * @param attributeColumns - The attribute columns used to calculate subsets. + * @param columnTypes - The column type definitions. * @returns The calculated subsets. */ export function getSubsets( @@ -272,6 +288,7 @@ export function getSubsets( sets: Sets, vSets: string[], attributeColumns: string[], + columnTypes: ColumnTypes, ): Subsets { if (vSets.length === 0) { return { @@ -349,7 +366,7 @@ export function getSubsets( setMembership: setMembershipStatus, atts: { derived: { deviation: subsetDeviation }, - dataset: getSixNumberSummary(dataItems, itm, attributeColumns), + dataset: getSixNumberSummary(dataItems, itm, attributeColumns, columnTypes), }, }; diff --git a/packages/core/src/render.ts b/packages/core/src/render.ts index 21fef892d..11221be05 100644 --- a/packages/core/src/render.ts +++ b/packages/core/src/render.ts @@ -67,6 +67,7 @@ const firstAggRR = (data: CoreUpsetData, state: UpsetConfig) => { data.sets, state.visibleSets, data.attributeColumns, + data.columnTypes, ); return firstAggregation( subsets, @@ -75,6 +76,7 @@ const firstAggRR = (data: CoreUpsetData, state: UpsetConfig) => { data.sets, data.items, data.attributeColumns, + data.columnTypes, ); }; @@ -98,6 +100,7 @@ const secondAggRR = (data: CoreUpsetData, state: UpsetConfig) => { data.sets, data.items, data.attributeColumns, + data.columnTypes, ); return secondAgg; @@ -165,6 +168,7 @@ const sortByRR = (data: CoreUpsetData, state: UpsetConfig, ignoreQuery = false) data.sets, state.visibleSets, data.attributeColumns, + data.columnTypes, ); renderRows = getQueryResult(subsets, state.setQuery.query); } else { diff --git a/packages/core/src/sort.ts b/packages/core/src/sort.ts index 05f8299a8..239341663 100644 --- a/packages/core/src/sort.ts +++ b/packages/core/src/sort.ts @@ -197,11 +197,13 @@ function sortByAttribute(rows: Intersections, sortBy: string, sortByOrder?: Sort meanA = values[a].atts.derived.deviation; meanB = values[b].atts.derived.deviation; } else { - meanA = values[a].atts.dataset[sortBy].mean; - meanB = values[b].atts.dataset[sortBy].mean; + // For categorical attributes, dataset[sortBy] will be undefined since + // getSixNumberSummary filters them out; they will be sorted to the bottom + meanA = values[a].atts.dataset[sortBy]?.mean; + meanB = values[b].atts.dataset[sortBy]?.mean; } - // If one of the values is undefined (empty subset), sort it to the bottom + // If one of the values is undefined (empty subset or categorical attribute), sort it to the bottom if (!meanA) { return 1; } diff --git a/packages/upset/src/components/Header/AttributeButton.tsx b/packages/upset/src/components/Header/AttributeButton.tsx index e434a478c..373f04f7c 100644 --- a/packages/upset/src/components/Header/AttributeButton.tsx +++ b/packages/upset/src/components/Header/AttributeButton.tsx @@ -13,6 +13,7 @@ import { ContextMenuItem } from '../../types'; import { allowAttributeRemovalAtom } from '../../atoms/config/allowAttributeRemovalAtom'; import { attributePlotsSelector } from '../../atoms/config/plotAtoms'; import { UpsetActions } from '../../provenance'; +import { attTypesSelector } from '../../atoms/attributeAtom'; type Props = { /** @@ -44,9 +45,13 @@ export const AttributeButton: FC = ({ label, tooltip }) => { const setContextMenu = useSetRecoilState(contextMenuAtom); const attributePlots = useRecoilValue(attributePlotsSelector); + const attTypes = useRecoilValue(attTypesSelector); const allowAttributeRemoval = useRecoilValue(allowAttributeRemovalAtom); + // Check if this attribute is categorical (categorical attributes cannot be sorted by mean) + const isCategorical = attTypes[label] === 'category'; + /** * Sorts the attribute in the specified order. * @@ -60,8 +65,15 @@ export const AttributeButton: FC = ({ label, tooltip }) => { * Handles the click event of the button. * If the attribute is not currently sorted, it sorts it in ascending order. * If the attribute is already sorted, it toggles between ascending and descending order. + * Categorical attributes cannot be sorted, so clicking them does nothing. */ const handleOnClick = (e: React.MouseEvent) => { + // Don't allow sorting by categorical attributes + if (isCategorical) { + e.stopPropagation(); + return; + } + if (sortBy !== label) { sortByHeader('Ascending'); } else { @@ -84,26 +96,34 @@ export const AttributeButton: FC = ({ label, tooltip }) => { * @returns An array of menu items. */ function getMenuItems(): ContextMenuItem[] { - const items = [ - { - label: `Sort by ${label} - Ascending`, - onClick: () => { - sortByHeader('Ascending'); - handleContextMenuClose(); + const items: ContextMenuItem[] = []; + + // Only add sort options for non-categorical attributes + // Categorical attributes don't have meaningful numerical statistics (mean) to sort by + if (!isCategorical) { + items.push( + { + label: `Sort by ${label} - Ascending`, + onClick: () => { + sortByHeader('Ascending'); + handleContextMenuClose(); + }, + disabled: sortBy === label && sortByOrder === 'Ascending', }, - disabled: sortBy === label && sortByOrder === 'Ascending', - }, - { - label: `Sort by ${label} - Descending`, - onClick: () => { - sortByHeader('Descending'); - handleContextMenuClose(); + { + label: `Sort by ${label} - Descending`, + onClick: () => { + sortByHeader('Descending'); + handleContextMenuClose(); + }, + disabled: sortBy === label && sortByOrder === 'Descending', }, - disabled: sortBy === label && sortByOrder === 'Descending', - }, - ]; + ); + } - if (!UPSET_ATTS.includes(label)) { + // Only add plot type options for non-categorical, non-UPSET attributes + // Categorical attributes only have one visualization type (stacked bar) and cannot be changed + if (!UPSET_ATTS.includes(label) && !isCategorical) { // for every possible value of the type AttributePlotType (from core), add a menu item Object.values(AttributePlotType).forEach((plot) => { items.push({