1- import { Flex , IconButton , Input , InputGroup , InputRightElement } from '@invoke-ai/ui-library' ;
1+ import { Checkbox , Flex , IconButton , Input , InputGroup , InputRightElement , Text } from '@invoke-ai/ui-library' ;
22import { useAppDispatch , useAppSelector } from 'app/store/storeHooks' ;
3- import { selectSearchTerm , setSearchTerm } from 'features/modelManagerV2/store/modelManagerV2Slice' ;
3+ import {
4+ type FilterableModelType ,
5+ modelSelectionChanged ,
6+ selectFilteredModelType ,
7+ selectSearchTerm ,
8+ selectSelectedModelKeys ,
9+ setSearchTerm ,
10+ } from 'features/modelManagerV2/store/modelManagerV2Slice' ;
411import { t } from 'i18next' ;
512import type { ChangeEventHandler } from 'react' ;
6- import { memo , useCallback } from 'react' ;
13+ import { memo , useCallback , useMemo } from 'react' ;
714import { PiXBold } from 'react-icons/pi' ;
15+ import { modelConfigsAdapterSelectors , useGetModelConfigsQuery } from 'services/api/endpoints/models' ;
16+ import type { AnyModelConfig } from 'services/api/types' ;
817
918import { ModelTypeFilter } from './ModelTypeFilter' ;
1019
1120export const ModelListNavigation = memo ( ( ) => {
1221 const dispatch = useAppDispatch ( ) ;
1322 const searchTerm = useAppSelector ( selectSearchTerm ) ;
23+ const filteredModelType = useAppSelector ( selectFilteredModelType ) ;
24+ const selectedModelKeys = useAppSelector ( selectSelectedModelKeys ) ;
25+ const { data } = useGetModelConfigsQuery ( ) ;
26+
27+ // Calculate displayed (filtered) model keys
28+ const displayedModelKeys = useMemo ( ( ) => {
29+ const modelConfigs = modelConfigsAdapterSelectors . selectAll ( data ?? { ids : [ ] , entities : { } } ) ;
30+ const filteredModels = modelsFilter ( modelConfigs , searchTerm , filteredModelType ) ;
31+ return filteredModels . map ( ( m ) => m . key ) ;
32+ } , [ data , searchTerm , filteredModelType ] ) ;
33+
34+ // Calculate checkbox state
35+ const { allSelected, someSelected } = useMemo ( ( ) => {
36+ if ( displayedModelKeys . length === 0 ) {
37+ return { allSelected : false , someSelected : false } ;
38+ }
39+ const selectedSet = new Set ( selectedModelKeys ) ;
40+ const displayedSelectedCount = displayedModelKeys . filter ( ( key ) => selectedSet . has ( key ) ) . length ;
41+ return {
42+ allSelected : displayedSelectedCount === displayedModelKeys . length ,
43+ someSelected : displayedSelectedCount > 0 && displayedSelectedCount < displayedModelKeys . length ,
44+ } ;
45+ } , [ displayedModelKeys , selectedModelKeys ] ) ;
1446
1547 const handleSearch : ChangeEventHandler < HTMLInputElement > = useCallback (
1648 ( event ) => {
@@ -23,28 +55,56 @@ export const ModelListNavigation = memo(() => {
2355 dispatch ( setSearchTerm ( '' ) ) ;
2456 } , [ dispatch ] ) ;
2557
58+ const handleToggleAll = useCallback ( ( ) => {
59+ if ( allSelected ) {
60+ // Deselect all displayed models
61+ const displayedSet = new Set ( displayedModelKeys ) ;
62+ const newSelection = selectedModelKeys . filter ( ( key ) => ! displayedSet . has ( key ) ) ;
63+ dispatch ( modelSelectionChanged ( newSelection ) ) ;
64+ } else {
65+ // Select all displayed models (merge with existing selection)
66+ const selectedSet = new Set ( selectedModelKeys ) ;
67+ displayedModelKeys . forEach ( ( key ) => selectedSet . add ( key ) ) ;
68+ dispatch ( modelSelectionChanged ( Array . from ( selectedSet ) ) ) ;
69+ }
70+ } , [ allSelected , displayedModelKeys , selectedModelKeys , dispatch ] ) ;
71+
2672 return (
2773 < Flex gap = { 2 } alignItems = "center" justifyContent = "space-between" >
28- < InputGroup >
29- < Input
30- placeholder = { t ( 'modelManager.search' ) }
31- value = { searchTerm || '' }
32- data-testid = "board-search-input"
33- onChange = { handleSearch }
34- />
35-
36- { ! ! searchTerm ?. length && (
37- < InputRightElement h = "full" pe = { 2 } >
38- < IconButton
39- size = "sm"
40- variant = "link"
41- aria-label = { t ( 'boards.clearSearch' ) }
42- icon = { < PiXBold /> }
43- onClick = { clearSearch }
44- />
45- </ InputRightElement >
46- ) }
47- </ InputGroup >
74+ < Flex gap = { 2 } alignItems = "center" >
75+ < Flex gap = { 2 } alignItems = "center" flexShrink = { 0 } >
76+ < Checkbox
77+ isChecked = { allSelected }
78+ isIndeterminate = { someSelected }
79+ onChange = { handleToggleAll }
80+ isDisabled = { displayedModelKeys . length === 0 }
81+ aria-label = { t ( 'modelManager.selectAll' ) }
82+ />
83+ < Text fontSize = "sm" fontWeight = "medium" whiteSpace = "nowrap" >
84+ { t ( 'modelManager.selectAll' ) }
85+ </ Text >
86+ </ Flex >
87+ < InputGroup >
88+ < Input
89+ placeholder = { t ( 'modelManager.search' ) }
90+ value = { searchTerm || '' }
91+ data-testid = "board-search-input"
92+ onChange = { handleSearch }
93+ />
94+
95+ { ! ! searchTerm ?. length && (
96+ < InputRightElement h = "full" pe = { 2 } >
97+ < IconButton
98+ size = "sm"
99+ variant = "link"
100+ aria-label = { t ( 'boards.clearSearch' ) }
101+ icon = { < PiXBold /> }
102+ onClick = { clearSearch }
103+ />
104+ </ InputRightElement >
105+ ) }
106+ </ InputGroup >
107+ </ Flex >
48108 < Flex shrink = { 0 } >
49109 < ModelTypeFilter />
50110 </ Flex >
@@ -53,3 +113,34 @@ export const ModelListNavigation = memo(() => {
53113} ) ;
54114
55115ModelListNavigation . displayName = 'ModelListNavigation' ;
116+
117+ const modelsFilter = < T extends AnyModelConfig > (
118+ data : T [ ] ,
119+ nameFilter : string ,
120+ filteredModelType : FilterableModelType | null
121+ ) : T [ ] => {
122+ return data . filter ( ( model ) => {
123+ const matchesFilter =
124+ model . name . toLowerCase ( ) . includes ( nameFilter . toLowerCase ( ) ) ||
125+ model . base . toLowerCase ( ) . includes ( nameFilter . toLowerCase ( ) ) ||
126+ model . type . toLowerCase ( ) . includes ( nameFilter . toLowerCase ( ) ) ||
127+ model . description ?. toLowerCase ( ) . includes ( nameFilter . toLowerCase ( ) ) ||
128+ model . format . toLowerCase ( ) . includes ( nameFilter . toLowerCase ( ) ) ;
129+
130+ const matchesType = getMatchesType ( model , filteredModelType ) ;
131+
132+ return matchesFilter && matchesType ;
133+ } ) ;
134+ } ;
135+
136+ const getMatchesType = ( modelConfig : AnyModelConfig , filteredModelType : FilterableModelType | null ) : boolean => {
137+ if ( filteredModelType === 'refiner' ) {
138+ return modelConfig . base === 'sdxl-refiner' ;
139+ }
140+
141+ if ( filteredModelType === 'main' && modelConfig . base === 'sdxl-refiner' ) {
142+ return false ;
143+ }
144+
145+ return filteredModelType ? modelConfig . type === filteredModelType : true ;
146+ } ;
0 commit comments