diff --git a/scheduler/cmd/scheduler/main.go b/scheduler/cmd/scheduler/main.go index 9676b8c7a5..79ab7bf1d8 100644 --- a/scheduler/cmd/scheduler/main.go +++ b/scheduler/cmd/scheduler/main.go @@ -262,7 +262,7 @@ func main() { } // Create stores - ss := store.NewMemoryStore(logger, store.NewLocalSchedulerStore(), eventHub) + ss := store.NewModelServerService(logger, store.NewLocalSchedulerStore(), eventHub) ps := pipeline.NewPipelineStore(logger, eventHub, ss) es := experiment.NewExperimentServer(logger, eventHub, ss, ps) cleaner := cleaner.NewVersionCleaner(ss, logger) diff --git a/scheduler/pkg/envoy/processor/incremental_benchmark_test.go b/scheduler/pkg/envoy/processor/incremental_benchmark_test.go index 653195ecb6..33f473091c 100644 --- a/scheduler/pkg/envoy/processor/incremental_benchmark_test.go +++ b/scheduler/pkg/envoy/processor/incremental_benchmark_test.go @@ -122,7 +122,7 @@ func benchmarkModelUpdate( eventHub, err := coordinator.NewEventHub(logger) require.NoError(b, err) - memoryStore := store.NewMemoryStore(logger, store.NewLocalSchedulerStore(), eventHub) + memoryStore := store.NewModelServerService(logger, store.NewLocalSchedulerStore(), eventHub) pipelineStore := pipeline.NewPipelineStore(logger, eventHub, memoryStore) ip, err := NewIncrementalProcessor( "some node", diff --git a/scheduler/pkg/envoy/processor/incremental_test.go b/scheduler/pkg/envoy/processor/incremental_test.go index c2b1989265..c48e7bb27f 100644 --- a/scheduler/pkg/envoy/processor/incremental_test.go +++ b/scheduler/pkg/envoy/processor/incremental_test.go @@ -342,7 +342,7 @@ func TestRollingUpdate(t *testing.T) { } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - modelStore := store.NewMemoryStore(log.New(), store.NewLocalSchedulerStore(), nil) + modelStore := store.NewModelServerService(log.New(), store.NewLocalSchedulerStore(), nil) xdsCache, err := xdscache.NewSeldonXDSCache(log.New(), &xdscache.PipelineGatewayDetails{Host: "pipeline", GrpcPort: 1, HttpPort: 2}, nil) g.Expect(err).To(BeNil()) inc := &IncrementalProcessor{ @@ -420,7 +420,7 @@ func TestDraining(t *testing.T) { } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - modelStore := store.NewMemoryStore(log.New(), store.NewLocalSchedulerStore(), nil) + modelStore := store.NewModelServerService(log.New(), store.NewLocalSchedulerStore(), nil) xdsCache, err := xdscache.NewSeldonXDSCache(log.New(), &xdscache.PipelineGatewayDetails{Host: "pipeline", GrpcPort: 1, HttpPort: 2}, nil) g.Expect(err).To(BeNil()) inc := &IncrementalProcessor{ @@ -564,7 +564,7 @@ func TestModelSync(t *testing.T) { } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - modelStore := store.NewMemoryStore(log.New(), store.NewLocalSchedulerStore(), nil) + modelStore := store.NewModelServerService(log.New(), store.NewLocalSchedulerStore(), nil) xdsCache, err := xdscache.NewSeldonXDSCache(log.New(), &xdscache.PipelineGatewayDetails{Host: "pipeline", GrpcPort: 1, HttpPort: 2}, nil) g.Expect(err).To(BeNil()) inc := &IncrementalProcessor{ @@ -808,7 +808,7 @@ func TestEnvoySettings(t *testing.T) { t.Run(test.name, func(t *testing.T) { logger := log.New() eventHub, _ := coordinator.NewEventHub(logger) - memoryStore := store.NewMemoryStore(log.New(), store.NewLocalSchedulerStore(), eventHub) + memoryStore := store.NewModelServerService(log.New(), store.NewLocalSchedulerStore(), eventHub) xdsCache, err := xdscache.NewSeldonXDSCache(log.New(), &xdscache.PipelineGatewayDetails{Host: "pipeline", GrpcPort: 1, HttpPort: 2}, nil) g.Expect(err).To(BeNil()) inc := &IncrementalProcessor{ diff --git a/scheduler/pkg/envoy/processor/server_test.go b/scheduler/pkg/envoy/processor/server_test.go index 0de114fd43..450ebfb746 100644 --- a/scheduler/pkg/envoy/processor/server_test.go +++ b/scheduler/pkg/envoy/processor/server_test.go @@ -43,7 +43,7 @@ func TestFetch(t *testing.T) { logger := log.New() - memoryStore := store.NewMemoryStore(logger, store.NewLocalSchedulerStore(), nil) + memoryStore := store.NewModelServerService(logger, store.NewLocalSchedulerStore(), nil) pipelineHandler := pipeline.NewPipelineStore(logger, nil, memoryStore) xdsCache, err := xdscache.NewSeldonXDSCache(log.New(), &xdscache.PipelineGatewayDetails{}, nil) diff --git a/scheduler/pkg/kafka/dataflow/server_test.go b/scheduler/pkg/kafka/dataflow/server_test.go index 58e9d72076..cca28851ef 100644 --- a/scheduler/pkg/kafka/dataflow/server_test.go +++ b/scheduler/pkg/kafka/dataflow/server_test.go @@ -853,7 +853,7 @@ func createTestScheduler(t *testing.T, serverName string) (*ChainerServer, *coor eventHub, _ := coordinator.NewEventHub(logger) - schedulerStore := store.NewMemoryStore(logger, store.NewLocalSchedulerStore(), eventHub) + schedulerStore := store.NewModelServerService(logger, store.NewLocalSchedulerStore(), eventHub) pipelineServer := pipeline.NewPipelineStore(logger, eventHub, schedulerStore) data := diff --git a/scheduler/pkg/server/control_plane_test.go b/scheduler/pkg/server/control_plane_test.go index b51132a4b3..5b6dd12310 100644 --- a/scheduler/pkg/server/control_plane_test.go +++ b/scheduler/pkg/server/control_plane_test.go @@ -42,7 +42,7 @@ func TestStartServerStream(t *testing.T) { { name: "ok", server: &SchedulerServer{ - modelStore: store.NewMemoryStore(log.New(), store.NewLocalSchedulerStore(), nil), + modelStore: store.NewModelServerService(log.New(), store.NewLocalSchedulerStore(), nil), logger: log.New(), timeout: 10 * time.Millisecond, }, @@ -50,7 +50,7 @@ func TestStartServerStream(t *testing.T) { { name: "timeout", server: &SchedulerServer{ - modelStore: store.NewMemoryStore(log.New(), store.NewLocalSchedulerStore(), nil), + modelStore: store.NewModelServerService(log.New(), store.NewLocalSchedulerStore(), nil), logger: log.New(), timeout: 1 * time.Millisecond, }, diff --git a/scheduler/pkg/server/server_status_test.go b/scheduler/pkg/server/server_status_test.go index 19f0d240c2..4ab5d92d99 100644 --- a/scheduler/pkg/server/server_status_test.go +++ b/scheduler/pkg/server/server_status_test.go @@ -144,7 +144,7 @@ func TestModelsStatusStream(t *testing.T) { }, }, server: &SchedulerServer{ - modelStore: store.NewMemoryStore(log.New(), store.NewLocalSchedulerStore(), nil), + modelStore: store.NewModelServerService(log.New(), store.NewLocalSchedulerStore(), nil), logger: log.New(), timeout: 10 * time.Millisecond, }, @@ -157,7 +157,7 @@ func TestModelsStatusStream(t *testing.T) { }, }, server: &SchedulerServer{ - modelStore: store.NewMemoryStore(log.New(), store.NewLocalSchedulerStore(), nil), + modelStore: store.NewModelServerService(log.New(), store.NewLocalSchedulerStore(), nil), logger: log.New(), timeout: 1 * time.Millisecond, }, @@ -763,7 +763,7 @@ func TestServersStatusStream(t *testing.T) { }, }, server: &SchedulerServer{ - modelStore: store.NewMemoryStore(log.New(), store.NewLocalSchedulerStore(), nil), + modelStore: store.NewModelServerService(log.New(), store.NewLocalSchedulerStore(), nil), logger: log.New(), timeout: 10 * time.Millisecond, }, @@ -799,7 +799,7 @@ func TestServersStatusStream(t *testing.T) { }, }, server: &SchedulerServer{ - modelStore: store.NewMemoryStore(log.New(), store.NewLocalSchedulerStore(), nil), + modelStore: store.NewModelServerService(log.New(), store.NewLocalSchedulerStore(), nil), logger: log.New(), timeout: 10 * time.Millisecond, }, @@ -836,7 +836,7 @@ func TestServersStatusStream(t *testing.T) { }, }, server: &SchedulerServer{ - modelStore: store.NewMemoryStore(log.New(), store.NewLocalSchedulerStore(), nil), + modelStore: store.NewModelServerService(log.New(), store.NewLocalSchedulerStore(), nil), logger: log.New(), timeout: 10 * time.Millisecond, }, @@ -851,7 +851,7 @@ func TestServersStatusStream(t *testing.T) { }, }, server: &SchedulerServer{ - modelStore: store.NewMemoryStore(log.New(), store.NewLocalSchedulerStore(), nil), + modelStore: store.NewModelServerService(log.New(), store.NewLocalSchedulerStore(), nil), logger: log.New(), timeout: 1 * time.Millisecond, }, diff --git a/scheduler/pkg/server/server_test.go b/scheduler/pkg/server/server_test.go index 97506d011c..40b965a876 100644 --- a/scheduler/pkg/server/server_test.go +++ b/scheduler/pkg/server/server_test.go @@ -58,7 +58,7 @@ func TestLoadModel(t *testing.T) { eventHub, err := coordinator.NewEventHub(logger) g.Expect(err).To(BeNil()) - schedulerStore := store.NewMemoryStore(logger, store.NewLocalSchedulerStore(), eventHub) + schedulerStore := store.NewModelServerService(logger, store.NewLocalSchedulerStore(), eventHub) experimentServer := experiment.NewExperimentServer(logger, eventHub, nil, nil) pipelineServer := pipeline.NewPipelineStore(logger, eventHub, schedulerStore) sync := synchroniser.NewSimpleSynchroniser(time.Duration(10 * time.Millisecond)) @@ -364,7 +364,7 @@ func TestUnloadModel(t *testing.T) { log.SetLevel(log.DebugLevel) eventHub, err := coordinator.NewEventHub(logger) g.Expect(err).To(BeNil()) - schedulerStore := store.NewMemoryStore(logger, store.NewLocalSchedulerStore(), eventHub) + schedulerStore := store.NewModelServerService(logger, store.NewLocalSchedulerStore(), eventHub) experimentServer := experiment.NewExperimentServer(logger, eventHub, nil, nil) pipelineServer := pipeline.NewPipelineStore(logger, eventHub, schedulerStore) mockAgent := &mockAgentHandler{} @@ -707,7 +707,7 @@ func TestServerNotify(t *testing.T) { log.SetLevel(log.DebugLevel) eventHub, err := coordinator.NewEventHub(logger) g.Expect(err).To(BeNil()) - schedulerStore := store.NewMemoryStore(logger, store.NewLocalSchedulerStore(), eventHub) + schedulerStore := store.NewModelServerService(logger, store.NewLocalSchedulerStore(), eventHub) sync := synchroniser.NewSimpleSynchroniser(time.Duration(10 * time.Millisecond)) scheduler := scheduler2.NewSimpleScheduler(logger, schedulerStore, diff --git a/scheduler/pkg/store/memory.go b/scheduler/pkg/store/memory.go index 4a0611c92c..e60991aaa3 100644 --- a/scheduler/pkg/store/memory.go +++ b/scheduler/pkg/store/memory.go @@ -23,42 +23,49 @@ import ( "github.com/seldonio/seldon-core/scheduler/v2/pkg/store/utils" ) -type MemoryStore struct { +// todo implement interface +type todo interface { + GetModel() + PutModel() +} + +type ModelServerService struct { mu sync.RWMutex opLocks sync.Map - store *LocalSchedulerStore + cache *LocalSchedulerStore + db todo logger log.FieldLogger eventHub *coordinator.EventHub } -func NewMemoryStore( +func NewModelServerService( logger log.FieldLogger, store *LocalSchedulerStore, eventHub *coordinator.EventHub, -) *MemoryStore { - return &MemoryStore{ - store: store, - logger: logger.WithField("source", "MemoryStore"), +) *ModelServerService { + return &ModelServerService{ + cache: store, + logger: logger.WithField("source", "ModelServerService"), eventHub: eventHub, } } -func (m *MemoryStore) GetAllModels() []string { +func (m *ModelServerService) GetAllModels() []string { m.mu.RLock() defer m.mu.RUnlock() var modelNames []string - for modelName := range m.store.models { + for modelName := range m.cache.models { modelNames = append(modelNames, modelName) } return modelNames } -func (m *MemoryStore) GetModels() ([]*ModelSnapshot, error) { +func (m *ModelServerService) GetModels() ([]*ModelSnapshot, error) { m.mu.RLock() defer m.mu.RUnlock() foundModels := []*ModelSnapshot{} - for name, model := range m.store.models { + for name, model := range m.cache.models { snapshot := &ModelSnapshot{ Name: name, Deleted: model.IsDeleted(), @@ -69,12 +76,12 @@ func (m *MemoryStore) GetModels() ([]*ModelSnapshot, error) { return foundModels, nil } -func (m *MemoryStore) addModelVersionIfNotExists(req *agent.ModelVersion) (*Model, *ModelVersion) { +func (m *ModelServerService) addModelVersionIfNotExists(req *agent.ModelVersion) (*Model, *ModelVersion) { modelName := req.GetModel().GetMeta().GetName() - model, ok := m.store.models[modelName] + model, ok := m.cache.models[modelName] if !ok { model = &Model{} - m.store.models[modelName] = model + m.cache.models[modelName] = model } if existingModelVersion := model.GetVersion(req.GetVersion()); existingModelVersion == nil { modelVersion := NewDefaultModelVersion(req.GetModel(), req.GetVersion()) @@ -88,7 +95,7 @@ func (m *MemoryStore) addModelVersionIfNotExists(req *agent.ModelVersion) (*Mode } } -func (m *MemoryStore) addNextModelVersion(model *Model, pbmodel *pb.Model) { +func (m *ModelServerService) addNextModelVersion(model *Model, pbmodel *pb.Model) { // if we start from a clean state, lets use the generation id as the starting version // this is to ensure that we have monotonic increasing version numbers // and we never reset back to 1 @@ -105,7 +112,7 @@ func (m *MemoryStore) addNextModelVersion(model *Model, pbmodel *pb.Model) { }) } -func (m *MemoryStore) UpdateModel(req *pb.LoadModelRequest) error { +func (m *ModelServerService) UpdateModel(req *pb.LoadModelRequest) error { logger := m.logger.WithField("func", "UpdateModel") m.mu.Lock() defer m.mu.Unlock() @@ -117,16 +124,23 @@ func (m *MemoryStore) UpdateModel(req *pb.LoadModelRequest) error { modelName, ) } - model, ok := m.store.models[modelName] - if !ok { + + model := m.cache.GetModel(modelName) + if model == nil { + // the update would first be done to the etcd to prevent missmatch of updates between cache and real db + model = &Model{} - m.store.models[modelName] = model m.addNextModelVersion(model, req.GetModel()) + m.db.PutModel() // put the etcd first instead of the cached data + m.cache.PutModel(modelName, model) + } else if model.IsDeleted() { if model.Inactive() { model = &Model{} - m.store.models[modelName] = model m.addNextModelVersion(model, req.GetModel()) + m.db.PutModel() + m.cache.PutModel(modelName, model) + } else { return fmt.Errorf( "Model %s is in process of deletion - new model can not be created", @@ -157,19 +171,23 @@ func (m *MemoryStore) UpdateModel(req *pb.LoadModelRequest) error { return nil } -func (m *MemoryStore) getModelImpl(key string) *ModelSnapshot { - model, ok := m.store.models[key] - if ok { - return m.deepCopy(model, key) +func (m *ModelServerService) getModelImpl(key string) (*ModelSnapshot, error) { + model, err := m.cache.GetModel(key) + if err != nil { + return nil, err } - - return &ModelSnapshot{ - Name: key, - Versions: nil, + if model == nil { + return &ModelSnapshot{ + Name: key, + Versions: nil, + }, nil } + + return m.deepCopy(model, key), nil + } -func (m *MemoryStore) deepCopy(model *Model, key string) *ModelSnapshot { +func (m *ModelServerService) deepCopy(model *Model, key string) *ModelSnapshot { snapshot := &ModelSnapshot{ Name: key, Deleted: model.IsDeleted(), @@ -182,13 +200,13 @@ func (m *MemoryStore) deepCopy(model *Model, key string) *ModelSnapshot { return snapshot } -func (m *MemoryStore) LockModel(modelId string) { +func (m *ModelServerService) LockModel(modelId string) { var lock sync.RWMutex existingLock, _ := m.opLocks.LoadOrStore(modelId, &lock) existingLock.(*sync.RWMutex).Lock() } -func (m *MemoryStore) UnlockModel(modelId string) { +func (m *ModelServerService) UnlockModel(modelId string) { logger := m.logger.WithField("func", "UnlockModel") lock, loaded := m.opLocks.Load(modelId) if loaded { @@ -198,13 +216,13 @@ func (m *MemoryStore) UnlockModel(modelId string) { } } -func (m *MemoryStore) GetModel(key string) (*ModelSnapshot, error) { +func (m *ModelServerService) GetModel(key string) (*ModelSnapshot, error) { m.mu.RLock() defer m.mu.RUnlock() - return m.getModelImpl(key), nil + return m.getModelImpl(key) } -func (m *MemoryStore) RemoveModel(req *pb.UnloadModelRequest) error { +func (m *ModelServerService) RemoveModel(req *pb.UnloadModelRequest) error { err := m.removeModelImpl(req) if err != nil { return err @@ -212,11 +230,11 @@ func (m *MemoryStore) RemoveModel(req *pb.UnloadModelRequest) error { return nil } -func (m *MemoryStore) removeModelImpl(req *pb.UnloadModelRequest) error { +func (m *ModelServerService) removeModelImpl(req *pb.UnloadModelRequest) error { m.mu.Lock() defer m.mu.Unlock() modelName := req.GetModel().GetName() - model, ok := m.store.models[modelName] + model, ok := m.cache.models[modelName] if ok { // Updating the k8s meta is required to be updated so status updates back (to manager) // will match latest generation value. Previous generation values might be ignored by manager. @@ -233,20 +251,20 @@ func (m *MemoryStore) removeModelImpl(req *pb.UnloadModelRequest) error { } } -func (m *MemoryStore) GetServers(shallow bool, modelDetails bool) ([]*ServerSnapshot, error) { +func (m *ModelServerService) GetServers(shallow bool, modelDetails bool) ([]*ServerSnapshot, error) { m.mu.RLock() defer m.mu.RUnlock() var servers []*ServerSnapshot - for _, server := range m.store.servers { + for _, server := range m.cache.servers { servers = append(servers, server.CreateSnapshot(shallow, modelDetails)) } return servers, nil } -func (m *MemoryStore) GetServer(serverKey string, shallow bool, modelDetails bool) (*ServerSnapshot, error) { +func (m *ModelServerService) GetServer(serverKey string, shallow bool, modelDetails bool) (*ServerSnapshot, error) { m.mu.RLock() defer m.mu.RUnlock() - server := m.store.servers[serverKey] + server := m.cache.servers[serverKey] if server == nil { return nil, fmt.Errorf("Server [%s] not found", serverKey) } else { @@ -260,20 +278,20 @@ func (m *MemoryStore) GetServer(serverKey string, shallow bool, modelDetails boo } } -func (m *MemoryStore) getServerStats(serverKey string) *ServerStats { +func (m *ModelServerService) getServerStats(serverKey string) *ServerStats { return &ServerStats{ NumEmptyReplicas: m.numEmptyServerReplicas(serverKey), MaxNumReplicaHostedModels: m.maxNumModelReplicasForServer(serverKey), } } -func (m *MemoryStore) getModelServer( +func (m *ModelServerService) getModelServer( modelKey string, version uint32, serverKey string, ) (*Model, *ModelVersion, *Server, error) { // Validate - model, ok := m.store.models[modelKey] + model, ok := m.cache.models[modelKey] if !ok { return nil, nil, nil, fmt.Errorf("failed to find model %s", modelKey) } @@ -281,14 +299,14 @@ func (m *MemoryStore) getModelServer( if modelVersion == nil { return nil, nil, nil, fmt.Errorf("Version not found for model %s, version %d", modelKey, version) } - server, ok := m.store.servers[serverKey] + server, ok := m.cache.servers[serverKey] if !ok { return nil, nil, nil, fmt.Errorf("failed to find server %s", serverKey) } return model, modelVersion, server, nil } -func (m *MemoryStore) UpdateLoadedModels( +func (m *ModelServerService) UpdateLoadedModels( modelKey string, version uint32, serverKey string, @@ -309,7 +327,7 @@ func (m *MemoryStore) UpdateLoadedModels( return nil } -func (m *MemoryStore) updateLoadedModelsImpl( +func (m *ModelServerService) updateLoadedModelsImpl( modelKey string, version uint32, serverKey string, @@ -318,7 +336,7 @@ func (m *MemoryStore) updateLoadedModelsImpl( logger := m.logger.WithField("func", "updateLoadedModelsImpl") // Validate - model, ok := m.store.models[modelKey] + model, ok := m.cache.models[modelKey] if !ok { return nil, fmt.Errorf("failed to find model %s", modelKey) } @@ -339,7 +357,7 @@ func (m *MemoryStore) updateLoadedModelsImpl( }, nil } - server, ok := m.store.servers[serverKey] + server, ok := m.cache.servers[serverKey] if !ok { return nil, fmt.Errorf("failed to find server %s", serverKey) } @@ -426,7 +444,7 @@ func (m *MemoryStore) updateLoadedModelsImpl( } } -func (m *MemoryStore) UnloadVersionModels(modelKey string, version uint32) (bool, error) { +func (m *ModelServerService) UnloadVersionModels(modelKey string, version uint32) (bool, error) { evt, updated, err := m.unloadVersionModelsImpl(modelKey, version) if err != nil { return updated, err @@ -440,13 +458,13 @@ func (m *MemoryStore) UnloadVersionModels(modelKey string, version uint32) (bool return updated, nil } -func (m *MemoryStore) unloadVersionModelsImpl(modelKey string, version uint32) (*coordinator.ModelEventMsg, bool, error) { +func (m *ModelServerService) unloadVersionModelsImpl(modelKey string, version uint32) (*coordinator.ModelEventMsg, bool, error) { logger := m.logger.WithField("func", "UnloadVersionModels") m.mu.Lock() defer m.mu.Unlock() // Validate - model, ok := m.store.models[modelKey] + model, ok := m.cache.models[modelKey] if !ok { return nil, false, fmt.Errorf("failed to find model %s", modelKey) } @@ -486,7 +504,7 @@ func (m *MemoryStore) unloadVersionModelsImpl(modelKey string, version uint32) ( return nil, false, nil } -func (m *MemoryStore) UpdateModelState( +func (m *ModelServerService) UpdateModelState( modelKey string, version uint32, serverKey string, @@ -516,7 +534,7 @@ func (m *MemoryStore) UpdateModelState( return nil } -func (m *MemoryStore) updateModelStateImpl( +func (m *ModelServerService) updateModelStateImpl( modelKey string, version uint32, serverKey string, @@ -563,7 +581,7 @@ func (m *MemoryStore) updateModelStateImpl( // Update models loaded onto replica for relevant state if desiredState == Loaded || desiredState == Loading || desiredState == Unloaded || desiredState == LoadFailed { - server, ok := m.store.servers[serverKey] + server, ok := m.cache.servers[serverKey] if ok { replica, ok := server.replicas[replicaIdx] if ok { @@ -611,12 +629,12 @@ func (m *MemoryStore) updateModelStateImpl( return nil, nil, nil } -func (m *MemoryStore) updateReservedMemory( +func (m *ModelServerService) updateReservedMemory( modelReplicaState ModelReplicaState, serverKey string, replicaIdx int, memBytes uint64, ) { // update reserved memory that is being used for sorting replicas // do we need to lock replica update? - server, ok := m.store.servers[serverKey] + server, ok := m.cache.servers[serverKey] if ok { replica, okReplica := server.replicas[replicaIdx] if okReplica { @@ -629,7 +647,7 @@ func (m *MemoryStore) updateReservedMemory( } } -func (m *MemoryStore) AddServerReplica(request *agent.AgentSubscribeRequest) error { +func (m *ModelServerService) AddServerReplica(request *agent.AgentSubscribeRequest) error { evts, serverEvt, err := m.addServerReplicaImpl(request) if err != nil { return err @@ -650,14 +668,14 @@ func (m *MemoryStore) AddServerReplica(request *agent.AgentSubscribeRequest) err return nil } -func (m *MemoryStore) addServerReplicaImpl(request *agent.AgentSubscribeRequest) ([]coordinator.ModelEventMsg, coordinator.ServerEventMsg, error) { +func (m *ModelServerService) addServerReplicaImpl(request *agent.AgentSubscribeRequest) ([]coordinator.ModelEventMsg, coordinator.ServerEventMsg, error) { m.mu.Lock() defer m.mu.Unlock() - server, ok := m.store.servers[request.ServerName] + server, ok := m.cache.servers[request.ServerName] if !ok { server = NewServer(request.ServerName, request.Shared) - m.store.servers[request.ServerName] = server + m.cache.servers[request.ServerName] = server } server.shared = request.Shared @@ -693,7 +711,7 @@ func (m *MemoryStore) addServerReplicaImpl(request *agent.AgentSubscribeRequest) return evts, serverEvt, nil } -func (m *MemoryStore) RemoveServerReplica(serverName string, replicaIdx int) ([]string, error) { +func (m *ModelServerService) RemoveServerReplica(serverName string, replicaIdx int) ([]string, error) { models, evts, err := m.removeServerReplicaImpl(serverName, replicaIdx) if err != nil { return nil, err @@ -709,11 +727,11 @@ func (m *MemoryStore) RemoveServerReplica(serverName string, replicaIdx int) ([] return models, nil } -func (m *MemoryStore) removeServerReplicaImpl(serverName string, replicaIdx int) ([]string, []coordinator.ModelEventMsg, error) { +func (m *ModelServerService) removeServerReplicaImpl(serverName string, replicaIdx int) ([]string, []coordinator.ModelEventMsg, error) { m.mu.Lock() defer m.mu.Unlock() - server, ok := m.store.servers[serverName] + server, ok := m.cache.servers[serverName] if !ok { return nil, nil, fmt.Errorf("Failed to find server %s", serverName) } @@ -724,7 +742,7 @@ func (m *MemoryStore) removeServerReplicaImpl(serverName string, replicaIdx int) delete(server.replicas, replicaIdx) // TODO we should not reschedule models on servers with dedicated models, e.g. non shareable servers if len(server.replicas) == 0 { - delete(m.store.servers, serverName) + delete(m.cache.servers, serverName) } loadedModelsRemoved, loadedEvts := m.removeModelfromServerReplica(serverReplica.loadedModels, replicaIdx) loadingModelsRemoved, loadingEtvs := m.removeModelfromServerReplica(serverReplica.loadingModels, replicaIdx) @@ -735,13 +753,13 @@ func (m *MemoryStore) removeServerReplicaImpl(serverName string, replicaIdx int) return modelsRemoved, evts, nil } -func (m *MemoryStore) removeModelfromServerReplica(lModels map[ModelVersionID]bool, replicaIdx int) ([]string, []coordinator.ModelEventMsg) { +func (m *ModelServerService) removeModelfromServerReplica(lModels map[ModelVersionID]bool, replicaIdx int) ([]string, []coordinator.ModelEventMsg) { logger := m.logger.WithField("func", "RemoveServerReplica") var modelNames []string var evts []coordinator.ModelEventMsg // Find models to reschedule due to this server replica being removed for modelVersionID := range lModels { - model, ok := m.store.models[modelVersionID.Name] + model, ok := m.cache.models[modelVersionID.Name] if ok { modelVersion := model.GetVersion(modelVersionID.Version) if modelVersion != nil { @@ -778,16 +796,16 @@ func (m *MemoryStore) removeModelfromServerReplica(lModels map[ModelVersionID]bo return modelNames, evts } -func (m *MemoryStore) DrainServerReplica(serverName string, replicaIdx int) ([]string, error) { +func (m *ModelServerService) DrainServerReplica(serverName string, replicaIdx int) ([]string, error) { m.mu.Lock() defer m.mu.Unlock() return m.drainServerReplicaImpl(serverName, replicaIdx) } -func (m *MemoryStore) drainServerReplicaImpl(serverName string, replicaIdx int) ([]string, error) { +func (m *ModelServerService) drainServerReplicaImpl(serverName string, replicaIdx int) ([]string, error) { logger := m.logger.WithField("func", "DrainServerReplica") - server, ok := m.store.servers[serverName] + server, ok := m.cache.servers[serverName] if !ok { return nil, fmt.Errorf("Failed to find server %s", serverName) } @@ -811,12 +829,12 @@ func (m *MemoryStore) drainServerReplicaImpl(serverName string, replicaIdx int) return append(loadedModels, loadingModels...), nil } -func (m *MemoryStore) findModelsToReSchedule(models map[ModelVersionID]bool, replicaIdx int) []string { +func (m *ModelServerService) findModelsToReSchedule(models map[ModelVersionID]bool, replicaIdx int) []string { logger := m.logger.WithField("func", "DrainServerReplica") modelsReSchedule := make([]string, 0) for modelVersionID := range models { - model, ok := m.store.models[modelVersionID.Name] + model, ok := m.cache.models[modelVersionID.Name] if ok { modelVersion := model.GetVersion(modelVersionID.Version) if modelVersion != nil { @@ -831,17 +849,17 @@ func (m *MemoryStore) findModelsToReSchedule(models map[ModelVersionID]bool, rep return modelsReSchedule } -func (m *MemoryStore) ServerNotify(request *pb.ServerNotify) error { +func (m *ModelServerService) ServerNotify(request *pb.ServerNotify) error { logger := m.logger.WithField("func", "MemoryServerNotify") m.mu.Lock() defer m.mu.Unlock() logger.Debugf("ServerNotify %v", request) - server, ok := m.store.servers[request.Name] + server, ok := m.cache.servers[request.Name] if !ok { server = NewServer(request.Name, request.Shared) - m.store.servers[request.Name] = server + m.cache.servers[request.Name] = server } server.SetExpectedReplicas(int(request.ExpectedReplicas)) server.SetMinReplicas(int(request.MinReplicas)) @@ -850,9 +868,9 @@ func (m *MemoryStore) ServerNotify(request *pb.ServerNotify) error { return nil } -func (m *MemoryStore) numEmptyServerReplicas(serverName string) uint32 { +func (m *ModelServerService) numEmptyServerReplicas(serverName string) uint32 { emptyReplicas := uint32(0) - server, ok := m.store.servers[serverName] + server, ok := m.cache.servers[serverName] if !ok { return emptyReplicas } @@ -864,9 +882,9 @@ func (m *MemoryStore) numEmptyServerReplicas(serverName string) uint32 { return emptyReplicas } -func (m *MemoryStore) maxNumModelReplicasForServer(serverName string) uint32 { +func (m *ModelServerService) maxNumModelReplicasForServer(serverName string) uint32 { maxNumModels := uint32(0) - for _, model := range m.store.models { + for _, model := range m.cache.models { latest := model.Latest() if latest != nil && latest.Server() == serverName { maxNumModels = max(maxNumModels, uint32(latest.DesiredReplicas())) @@ -887,7 +905,7 @@ func toSchedulerLoadedModels(agentLoadedModels []*agent.ModelVersion) map[ModelV return loadedModels } -func (m *MemoryStore) SetModelGwModelState(name string, versionNumber uint32, status ModelState, reason string, source string) error { +func (m *ModelServerService) SetModelGwModelState(name string, versionNumber uint32, status ModelState, reason string, source string) error { logger := m.logger.WithField("func", "SetModelGwModelState") logger.Debugf("Attempt to set model-gw state on model %s:%d status:%s", name, versionNumber, status.String()) @@ -905,13 +923,13 @@ func (m *MemoryStore) SetModelGwModelState(name string, versionNumber uint32, st return nil } -func (m *MemoryStore) setModelGwModelStateImpl(name string, versionNumber uint32, status ModelState, reason, source string) ([]*coordinator.ModelEventMsg, error) { +func (m *ModelServerService) setModelGwModelStateImpl(name string, versionNumber uint32, status ModelState, reason, source string) ([]*coordinator.ModelEventMsg, error) { var evts []*coordinator.ModelEventMsg m.mu.Lock() defer m.mu.Unlock() - model, ok := m.store.models[name] + model, ok := m.cache.models[name] if !ok { return nil, fmt.Errorf("failed to find model %s", name) } diff --git a/scheduler/pkg/store/memory_status.go b/scheduler/pkg/store/memory_status.go index 13ec647bb6..89e5a258d3 100644 --- a/scheduler/pkg/store/memory_status.go +++ b/scheduler/pkg/store/memory_status.go @@ -111,11 +111,11 @@ func updateModelState(isLatest bool, modelVersion *ModelVersion, prevModelVersio } } -func (m *MemoryStore) FailedScheduling(modelID string, version uint32, reason string, reset bool) error { +func (m *ModelServerService) FailedScheduling(modelID string, version uint32, reason string, reset bool) error { m.mu.Lock() defer m.mu.Unlock() - model, ok := m.store.models[modelID] + model, ok := m.cache.models[modelID] if !ok { return fmt.Errorf("model %s not found", modelID) } @@ -156,7 +156,7 @@ func (m *MemoryStore) FailedScheduling(modelID string, version uint32, reason st return fmt.Errorf("model %s found, version %d not found", modelID, version) } -func (m *MemoryStore) updateModelStatus(isLatest bool, deleted bool, modelVersion *ModelVersion, prevModelVersion *ModelVersion) { +func (m *ModelServerService) updateModelStatus(isLatest bool, deleted bool, modelVersion *ModelVersion, prevModelVersion *ModelVersion) { logger := m.logger.WithField("func", "updateModelStatus") stats := calcModelVersionStatistics(modelVersion, deleted) logger.Debugf("Stats %+v modelVersion %+v prev model %+v", stats, modelVersion, prevModelVersion) @@ -164,7 +164,7 @@ func (m *MemoryStore) updateModelStatus(isLatest bool, deleted bool, modelVersio updateModelState(isLatest, modelVersion, prevModelVersion, stats, deleted) } -func (m *MemoryStore) setModelGwStatusToTerminate(isLatest bool, modelVersion *ModelVersion) { +func (m *ModelServerService) setModelGwStatusToTerminate(isLatest bool, modelVersion *ModelVersion) { if !isLatest { modelVersion.state.ModelGwState = ModelTerminated modelVersion.state.ModelGwReason = "Not latest version" @@ -174,12 +174,12 @@ func (m *MemoryStore) setModelGwStatusToTerminate(isLatest bool, modelVersion *M } } -func (m *MemoryStore) UnloadModelGwVersionModels(modelKey string, version uint32) (bool, error) { +func (m *ModelServerService) UnloadModelGwVersionModels(modelKey string, version uint32) (bool, error) { m.mu.Lock() defer m.mu.Unlock() fmt.Println("UnloadModelGwVersionModels called for ", modelKey, " version ", version) - model, ok := m.store.models[modelKey] + model, ok := m.cache.models[modelKey] if !ok { return false, fmt.Errorf("failed to find model %s", modelKey) } diff --git a/scheduler/pkg/store/memory_status_test.go b/scheduler/pkg/store/memory_status_test.go index 07d3d194ab..3d7e62fe8a 100644 --- a/scheduler/pkg/store/memory_status_test.go +++ b/scheduler/pkg/store/memory_status_test.go @@ -228,7 +228,7 @@ func TestUpdateStatus(t *testing.T) { logger := log.New() eventHub, err := coordinator.NewEventHub(logger) g.Expect(err).To(BeNil()) - ms := NewMemoryStore(logger, test.store, eventHub) + ms := NewModelServerService(logger, test.store, eventHub) model, modelVersion, _, err := ms.getModelServer(test.modelName, test.version, test.serverName) var prevModelVersion *ModelVersion if test.prevVersion != nil { diff --git a/scheduler/pkg/store/memory_test.go b/scheduler/pkg/store/memory_test.go index 31d259f84a..b43aaed6d0 100644 --- a/scheduler/pkg/store/memory_test.go +++ b/scheduler/pkg/store/memory_test.go @@ -219,7 +219,7 @@ func TestUpdateModel(t *testing.T) { logger := log.New() eventHub, err := coordinator.NewEventHub(logger) g.Expect(err).To(BeNil()) - ms := NewMemoryStore(logger, test.store, eventHub) + ms := NewModelServerService(logger, test.store, eventHub) err = ms.UpdateModel(test.loadModelReq) if test.err != nil { g.Expect(err.Error()).To(BeIdenticalTo(test.err.Error())) @@ -281,7 +281,7 @@ func TestGetModel(t *testing.T) { logger := log.New() eventHub, err := coordinator.NewEventHub(logger) g.Expect(err).To(BeNil()) - ms := NewMemoryStore(logger, test.store, eventHub) + ms := NewModelServerService(logger, test.store, eventHub) model, err := ms.GetModel(test.key) if test.err == nil { g.Expect(err).To(BeNil()) @@ -457,7 +457,7 @@ func TestGetServer(t *testing.T) { logger := log.New() eventHub, err := coordinator.NewEventHub(logger) g.Expect(err).To(BeNil()) - ms := NewMemoryStore(logger, test.store, eventHub) + ms := NewModelServerService(logger, test.store, eventHub) server, err := ms.GetServer(test.key, false, true) if !test.isErr { g.Expect(err).To(BeNil()) @@ -524,7 +524,7 @@ func TestRemoveModel(t *testing.T) { logger := log.New() eventHub, err := coordinator.NewEventHub(logger) g.Expect(err).To(BeNil()) - ms := NewMemoryStore(logger, test.store, eventHub) + ms := NewModelServerService(logger, test.store, eventHub) err = ms.RemoveModel(&pb.UnloadModelRequest{Model: &pb.ModelReference{Name: test.key}}) if !test.err { g.Expect(err).To(BeNil()) @@ -959,7 +959,7 @@ func TestUpdateLoadedModels(t *testing.T) { if test.isModelDeleted { test.store.models[test.modelKey].SetDeleted() } - ms := NewMemoryStore(logger, test.store, eventHub) + ms := NewModelServerService(logger, test.store, eventHub) msg, err := ms.updateLoadedModelsImpl(test.modelKey, test.version, test.serverKey, test.replicas) if !test.err { g.Expect(err).To(BeNil()) @@ -1308,7 +1308,7 @@ func TestUpdateModelState(t *testing.T) { }, ) - ms := NewMemoryStore(logger, test.store, eventHub) + ms := NewModelServerService(logger, test.store, eventHub) err = ms.UpdateModelState(test.modelKey, test.version, test.serverKey, test.replicaIdx, &test.availableMemory, test.expectedState, test.desiredState, "", test.modelRuntimeInfo) if !test.err { g.Expect(err).To(BeNil()) @@ -1617,7 +1617,7 @@ func TestUpdateModelStatus(t *testing.T) { logger := log.New() eventHub, err := coordinator.NewEventHub(logger) g.Expect(err).To(BeNil()) - ms := NewMemoryStore(logger, &LocalSchedulerStore{}, eventHub) + ms := NewModelServerService(logger, &LocalSchedulerStore{}, eventHub) ms.updateModelStatus(true, test.deleted, test.modelVersion, test.prevAvailableModelVersion) g.Expect(test.modelVersion.state.State).To(Equal(test.expectedState)) g.Expect(test.modelVersion.state.Reason).To(Equal(test.expectedReason)) @@ -1774,7 +1774,7 @@ func TestAddModelVersionIfNotExists(t *testing.T) { logger := log.New() eventHub, err := coordinator.NewEventHub(logger) g.Expect(err).To(BeNil()) - ms := NewMemoryStore(logger, test.store, eventHub) + ms := NewModelServerService(logger, test.store, eventHub) ms.addModelVersionIfNotExists(test.modelVersion) modelName := test.modelVersion.GetModel().GetMeta().GetName() g.Expect(test.store.models[modelName].GetVersions()).To(Equal(test.expected)) @@ -1901,7 +1901,7 @@ func TestAddServerReplica(t *testing.T) { logger := log.New() eventHub, err := coordinator.NewEventHub(logger) g.Expect(err).To(BeNil()) - ms := NewMemoryStore(logger, test.store, eventHub) + ms := NewModelServerService(logger, test.store, eventHub) // register a callback to check if the event is triggered serverEvents := int64(0) @@ -1969,7 +1969,7 @@ func TestRemoveServerReplica(t *testing.T) { serverName: "server1", replicaIdx: 0, serverExists: true, - modelsReturned: 0, // no models really defined in store + modelsReturned: 0, // no models really defined in cache }, { name: "ReplicaRemovedAndDeleted", @@ -2029,7 +2029,7 @@ func TestRemoveServerReplica(t *testing.T) { modelsReturned: 1, }, { - name: "ReplicaRemovedAndServerDeleted but no model version in store", + name: "ReplicaRemovedAndServerDeleted but no model version in cache", store: &LocalSchedulerStore{ servers: map[string]*Server{ "server1": { @@ -2128,7 +2128,7 @@ func TestRemoveServerReplica(t *testing.T) { logger := log.New() eventHub, err := coordinator.NewEventHub(logger) g.Expect(err).To(BeNil()) - ms := NewMemoryStore(logger, test.store, eventHub) + ms := NewModelServerService(logger, test.store, eventHub) models, err := ms.RemoveServerReplica(test.serverName, test.replicaIdx) g.Expect(err).To(BeNil()) g.Expect(test.modelsReturned).To(Equal(len(models))) @@ -2257,7 +2257,7 @@ func TestDrainServerReplica(t *testing.T) { logger := log.New() eventHub, err := coordinator.NewEventHub(logger) g.Expect(err).To(BeNil()) - ms := NewMemoryStore(logger, test.store, eventHub) + ms := NewModelServerService(logger, test.store, eventHub) models, err := ms.DrainServerReplica(test.serverName, test.replicaIdx) g.Expect(err).To(BeNil()) g.Expect(test.modelsReturned).To(Equal(models)) diff --git a/scheduler/pkg/store/mesh.go b/scheduler/pkg/store/mesh.go index 3566f5f8af..1c84691b7b 100644 --- a/scheduler/pkg/store/mesh.go +++ b/scheduler/pkg/store/mesh.go @@ -22,20 +22,35 @@ import ( pb "github.com/seldonio/seldon-core/apis/go/v2/mlops/scheduler" ) +type Cache interface { + GetModel(key string) *Model + GetModelKeys() []string + PutModel(key string, model *Model) + GetServer(key string) *Server + PutServer(key string, server *Server) +} + type LocalSchedulerStore struct { - servers map[string]*Server - models map[string]*Model - failedToScheduleModels map[string]bool + servers map[string]*Server + models map[string]*Model } func NewLocalSchedulerStore() *LocalSchedulerStore { m := LocalSchedulerStore{} m.servers = make(map[string]*Server) m.models = make(map[string]*Model) - m.failedToScheduleModels = make(map[string]bool) return &m } +func (lss *LocalSchedulerStore) GetModel(key string) *Model { + model, _ := lss.models[key] + return model +} + +func (lss *LocalSchedulerStore) PutModel(key string, model *Model) { + lss.models[key] = model +} + type Model struct { versions []*ModelVersion deleted atomic.Bool diff --git a/scheduler/pkg/store/test_memory_hack.go b/scheduler/pkg/store/test_memory_hack.go index 3a3921120e..39015880e7 100644 --- a/scheduler/pkg/store/test_memory_hack.go +++ b/scheduler/pkg/store/test_memory_hack.go @@ -19,7 +19,7 @@ import ( ) type TestMemoryStore struct { - *MemoryStore + *ModelServerService } type ModelID struct { @@ -28,8 +28,8 @@ type ModelID struct { } // NewTestMemory DO NOT USE for non-test code. This is purely meant for using in tests where an integration test is -// wanted where the real memory store is needed, but the test needs the ability to directly manipulate the model -// statuses, which can't be achieved with MemoryStore. TestMemoryStore embeds MemoryStore and adds DirectlyUpdateModelStatus +// wanted where the real memory cache is needed, but the test needs the ability to directly manipulate the model +// statuses, which can't be achieved with ModelServerService. TestMemoryStore embeds ModelServerService and adds DirectlyUpdateModelStatus // to modify the statuses. func NewTestMemory( t *testing.T, @@ -39,7 +39,7 @@ func NewTestMemory( if t == nil { panic("testing.T is required, must only be run via tests") } - m := NewMemoryStore(logger, store, eventHub) + m := NewModelServerService(logger, store, eventHub) return &TestMemoryStore{m} } @@ -47,7 +47,7 @@ func (t *TestMemoryStore) DirectlyUpdateModelStatus(model ModelID, state ModelSt t.mu.Lock() defer t.mu.Unlock() - found, ok := t.store.models[model.Name] + found, ok := t.cache.models[model.Name] if !ok { return errors.New("model not found") }