From 6b5c7a38523a4a8014612b2d77334bd01ef3b4b4 Mon Sep 17 00:00:00 2001 From: Software Developer Date: Wed, 11 Dec 2024 16:32:06 +0000 Subject: [PATCH 1/3] move `PATCH /mlflow/model-versions/update` endpoint. Signed-off-by: Software Developer --- magefiles/generate/endpoints.go | 2 +- mlflow_go/store/model_registry.py | 5 ++++ pkg/contract/service/model_registry.g.go | 1 + pkg/lib/model_registry.g.go | 8 +++++++ pkg/model_registry/service/model_versions.go | 13 ++++++++++ .../store/sql/model_versions.go | 24 +++++++++++++++++++ pkg/model_registry/store/store.go | 1 + pkg/server/routes/model_registry.g.go | 11 +++++++++ 8 files changed, 64 insertions(+), 1 deletion(-) diff --git a/magefiles/generate/endpoints.go b/magefiles/generate/endpoints.go index 0a917ac7..c0bcb2d5 100644 --- a/magefiles/generate/endpoints.go +++ b/magefiles/generate/endpoints.go @@ -56,7 +56,7 @@ var ServiceInfoMap = map[string]ServiceGenerationInfo{ // "searchRegisteredModels", "getLatestVersions", // "createModelVersion", - // "updateModelVersion", + "updateModelVersion", // "transitionModelVersionStage", "deleteModelVersion", // "getModelVersion", diff --git a/mlflow_go/store/model_registry.py b/mlflow_go/store/model_registry.py index f42f6904..01d76e9f 100644 --- a/mlflow_go/store/model_registry.py +++ b/mlflow_go/store/model_registry.py @@ -8,6 +8,7 @@ GetLatestVersions, GetRegisteredModel, RenameRegisteredModel, + UpdateModelVersion, UpdateRegisteredModel, ) @@ -84,6 +85,10 @@ def delete_model_version(self, name, version): request = DeleteModelVersion(name=name, version=str(version)) self.service.call_endpoint(get_lib().ModelRegistryServiceDeleteModelVersion, request) + def update_model_version(self, name, version, description=None): + request = UpdateModelVersion(name=name, version=str(version), description=description) + self.service.call_endpoint(get_lib().ModelRegistryServiceUpdateModelVersion, request) + def ModelRegistryStore(cls): return type(cls.__name__, (_ModelRegistryStore, cls), {}) diff --git a/pkg/contract/service/model_registry.g.go b/pkg/contract/service/model_registry.g.go index af202b9d..573383c7 100644 --- a/pkg/contract/service/model_registry.g.go +++ b/pkg/contract/service/model_registry.g.go @@ -15,5 +15,6 @@ type ModelRegistryService interface { DeleteRegisteredModel(ctx context.Context, input *protos.DeleteRegisteredModel) (*protos.DeleteRegisteredModel_Response, *contract.Error) GetRegisteredModel(ctx context.Context, input *protos.GetRegisteredModel) (*protos.GetRegisteredModel_Response, *contract.Error) GetLatestVersions(ctx context.Context, input *protos.GetLatestVersions) (*protos.GetLatestVersions_Response, *contract.Error) + UpdateModelVersion(ctx context.Context, input *protos.UpdateModelVersion) (*protos.UpdateModelVersion_Response, *contract.Error) DeleteModelVersion(ctx context.Context, input *protos.DeleteModelVersion) (*protos.DeleteModelVersion_Response, *contract.Error) } diff --git a/pkg/lib/model_registry.g.go b/pkg/lib/model_registry.g.go index e114f057..9a5bdba8 100644 --- a/pkg/lib/model_registry.g.go +++ b/pkg/lib/model_registry.g.go @@ -47,6 +47,14 @@ func ModelRegistryServiceGetLatestVersions(serviceID int64, requestData unsafe.P } return invokeServiceMethod(service.GetLatestVersions, new(protos.GetLatestVersions), requestData, requestSize, responseSize) } +//export ModelRegistryServiceUpdateModelVersion +func ModelRegistryServiceUpdateModelVersion(serviceID int64, requestData unsafe.Pointer, requestSize C.int, responseSize *C.int) unsafe.Pointer { + service, err := modelRegistryServices.Get(serviceID) + if err != nil { + return makePointerFromError(err, responseSize) + } + return invokeServiceMethod(service.UpdateModelVersion, new(protos.UpdateModelVersion), requestData, requestSize, responseSize) +} //export ModelRegistryServiceDeleteModelVersion func ModelRegistryServiceDeleteModelVersion(serviceID int64, requestData unsafe.Pointer, requestSize C.int, responseSize *C.int) unsafe.Pointer { service, err := modelRegistryServices.Get(serviceID) diff --git a/pkg/model_registry/service/model_versions.go b/pkg/model_registry/service/model_versions.go index 725c68bc..ac53eca7 100644 --- a/pkg/model_registry/service/model_versions.go +++ b/pkg/model_registry/service/model_versions.go @@ -86,3 +86,16 @@ func (m *ModelRegistryService) DeleteModelVersion( return &protos.DeleteModelVersion_Response{}, nil } + +func (m *ModelRegistryService) UpdateModelVersion( + ctx context.Context, input *protos.UpdateModelVersion, +) (*protos.UpdateModelVersion_Response, *contract.Error) { + modelVersion, err := m.store.UpdateModelVersion(ctx, input.GetName(), input.GetVersion(), input.GetDescription()) + if err != nil { + return nil, err + } + + return &protos.UpdateModelVersion_Response{ + ModelVersion: modelVersion.ToProto(), + }, nil +} diff --git a/pkg/model_registry/store/sql/model_versions.go b/pkg/model_registry/store/sql/model_versions.go index 1ecfc586..4b18973e 100644 --- a/pkg/model_registry/store/sql/model_versions.go +++ b/pkg/model_registry/store/sql/model_versions.go @@ -353,3 +353,27 @@ func (m *ModelRegistrySQLStore) DeleteModelVersion(ctx context.Context, name, ve return nil } + +func (m *ModelRegistrySQLStore) UpdateModelVersion( + ctx context.Context, name, version, description string, +) (*entities.ModelVersion, *contract.Error) { + modelVersion, err := m.GetModelVersion(ctx, name, version) + if err != nil { + return nil, err + } + + if err := m.db.WithContext(ctx).Model( + &models.ModelVersion{}, + ).Where( + "name = ?", modelVersion.Name, + ).Where( + "version = ?", modelVersion.Version, + ).Updates(&models.ModelVersion{ + Description: sql.NullString{String: description, Valid: description != ""}, + LastUpdatedTime: time.Now().UnixMilli(), + }).Error; err != nil { + return nil, contract.NewErrorWith(protos.ErrorCode_INTERNAL_ERROR, "error updating model version", err) + } + + return modelVersion, nil +} diff --git a/pkg/model_registry/store/store.go b/pkg/model_registry/store/store.go index fbbffe29..b1245e30 100644 --- a/pkg/model_registry/store/store.go +++ b/pkg/model_registry/store/store.go @@ -16,4 +16,5 @@ type ModelRegistryStore interface { RenameRegisteredModel(ctx context.Context, name, newName string) (*entities.RegisteredModel, *contract.Error) DeleteRegisteredModel(ctx context.Context, name string) *contract.Error DeleteModelVersion(ctx context.Context, name, version string) *contract.Error + UpdateModelVersion(ctx context.Context, name, version, description string) (*entities.ModelVersion, *contract.Error) } diff --git a/pkg/server/routes/model_registry.g.go b/pkg/server/routes/model_registry.g.go index 80618f06..ca0f2489 100644 --- a/pkg/server/routes/model_registry.g.go +++ b/pkg/server/routes/model_registry.g.go @@ -77,6 +77,17 @@ func RegisterModelRegistryServiceRoutes(service service.ModelRegistryService, pa } return ctx.JSON(output) }) + app.Patch("/mlflow/model-versions/update", func(ctx *fiber.Ctx) error { + input := &protos.UpdateModelVersion{} + if err := parser.ParseBody(ctx, input); err != nil { + return err + } + output, err := service.UpdateModelVersion(utils.NewContextWithLoggerFromFiberContext(ctx), input) + if err != nil { + return err + } + return ctx.JSON(output) + }) app.Delete("/mlflow/model-versions/delete", func(ctx *fiber.Ctx) error { input := &protos.DeleteModelVersion{} if err := parser.ParseBody(ctx, input); err != nil { From 57f1d405b97f4b67288c7bcaf32dfc5b3dfd876e Mon Sep 17 00:00:00 2001 From: DSuhinin Date: Wed, 11 Dec 2024 20:15:58 +0100 Subject: [PATCH 2/3] Revert "move `PATCH /mlflow/model-versions/update` endpoint." This reverts commit 6b5c7a38523a4a8014612b2d77334bd01ef3b4b4. --- magefiles/generate/endpoints.go | 2 +- mlflow_go/store/model_registry.py | 5 ---- pkg/contract/service/model_registry.g.go | 1 - pkg/lib/model_registry.g.go | 8 ------- pkg/model_registry/service/model_versions.go | 13 ---------- .../store/sql/model_versions.go | 24 ------------------- pkg/model_registry/store/store.go | 1 - pkg/server/routes/model_registry.g.go | 11 --------- 8 files changed, 1 insertion(+), 64 deletions(-) diff --git a/magefiles/generate/endpoints.go b/magefiles/generate/endpoints.go index c0bcb2d5..0a917ac7 100644 --- a/magefiles/generate/endpoints.go +++ b/magefiles/generate/endpoints.go @@ -56,7 +56,7 @@ var ServiceInfoMap = map[string]ServiceGenerationInfo{ // "searchRegisteredModels", "getLatestVersions", // "createModelVersion", - "updateModelVersion", + // "updateModelVersion", // "transitionModelVersionStage", "deleteModelVersion", // "getModelVersion", diff --git a/mlflow_go/store/model_registry.py b/mlflow_go/store/model_registry.py index 01d76e9f..f42f6904 100644 --- a/mlflow_go/store/model_registry.py +++ b/mlflow_go/store/model_registry.py @@ -8,7 +8,6 @@ GetLatestVersions, GetRegisteredModel, RenameRegisteredModel, - UpdateModelVersion, UpdateRegisteredModel, ) @@ -85,10 +84,6 @@ def delete_model_version(self, name, version): request = DeleteModelVersion(name=name, version=str(version)) self.service.call_endpoint(get_lib().ModelRegistryServiceDeleteModelVersion, request) - def update_model_version(self, name, version, description=None): - request = UpdateModelVersion(name=name, version=str(version), description=description) - self.service.call_endpoint(get_lib().ModelRegistryServiceUpdateModelVersion, request) - def ModelRegistryStore(cls): return type(cls.__name__, (_ModelRegistryStore, cls), {}) diff --git a/pkg/contract/service/model_registry.g.go b/pkg/contract/service/model_registry.g.go index 573383c7..af202b9d 100644 --- a/pkg/contract/service/model_registry.g.go +++ b/pkg/contract/service/model_registry.g.go @@ -15,6 +15,5 @@ type ModelRegistryService interface { DeleteRegisteredModel(ctx context.Context, input *protos.DeleteRegisteredModel) (*protos.DeleteRegisteredModel_Response, *contract.Error) GetRegisteredModel(ctx context.Context, input *protos.GetRegisteredModel) (*protos.GetRegisteredModel_Response, *contract.Error) GetLatestVersions(ctx context.Context, input *protos.GetLatestVersions) (*protos.GetLatestVersions_Response, *contract.Error) - UpdateModelVersion(ctx context.Context, input *protos.UpdateModelVersion) (*protos.UpdateModelVersion_Response, *contract.Error) DeleteModelVersion(ctx context.Context, input *protos.DeleteModelVersion) (*protos.DeleteModelVersion_Response, *contract.Error) } diff --git a/pkg/lib/model_registry.g.go b/pkg/lib/model_registry.g.go index 9a5bdba8..e114f057 100644 --- a/pkg/lib/model_registry.g.go +++ b/pkg/lib/model_registry.g.go @@ -47,14 +47,6 @@ func ModelRegistryServiceGetLatestVersions(serviceID int64, requestData unsafe.P } return invokeServiceMethod(service.GetLatestVersions, new(protos.GetLatestVersions), requestData, requestSize, responseSize) } -//export ModelRegistryServiceUpdateModelVersion -func ModelRegistryServiceUpdateModelVersion(serviceID int64, requestData unsafe.Pointer, requestSize C.int, responseSize *C.int) unsafe.Pointer { - service, err := modelRegistryServices.Get(serviceID) - if err != nil { - return makePointerFromError(err, responseSize) - } - return invokeServiceMethod(service.UpdateModelVersion, new(protos.UpdateModelVersion), requestData, requestSize, responseSize) -} //export ModelRegistryServiceDeleteModelVersion func ModelRegistryServiceDeleteModelVersion(serviceID int64, requestData unsafe.Pointer, requestSize C.int, responseSize *C.int) unsafe.Pointer { service, err := modelRegistryServices.Get(serviceID) diff --git a/pkg/model_registry/service/model_versions.go b/pkg/model_registry/service/model_versions.go index ac53eca7..725c68bc 100644 --- a/pkg/model_registry/service/model_versions.go +++ b/pkg/model_registry/service/model_versions.go @@ -86,16 +86,3 @@ func (m *ModelRegistryService) DeleteModelVersion( return &protos.DeleteModelVersion_Response{}, nil } - -func (m *ModelRegistryService) UpdateModelVersion( - ctx context.Context, input *protos.UpdateModelVersion, -) (*protos.UpdateModelVersion_Response, *contract.Error) { - modelVersion, err := m.store.UpdateModelVersion(ctx, input.GetName(), input.GetVersion(), input.GetDescription()) - if err != nil { - return nil, err - } - - return &protos.UpdateModelVersion_Response{ - ModelVersion: modelVersion.ToProto(), - }, nil -} diff --git a/pkg/model_registry/store/sql/model_versions.go b/pkg/model_registry/store/sql/model_versions.go index 4b18973e..1ecfc586 100644 --- a/pkg/model_registry/store/sql/model_versions.go +++ b/pkg/model_registry/store/sql/model_versions.go @@ -353,27 +353,3 @@ func (m *ModelRegistrySQLStore) DeleteModelVersion(ctx context.Context, name, ve return nil } - -func (m *ModelRegistrySQLStore) UpdateModelVersion( - ctx context.Context, name, version, description string, -) (*entities.ModelVersion, *contract.Error) { - modelVersion, err := m.GetModelVersion(ctx, name, version) - if err != nil { - return nil, err - } - - if err := m.db.WithContext(ctx).Model( - &models.ModelVersion{}, - ).Where( - "name = ?", modelVersion.Name, - ).Where( - "version = ?", modelVersion.Version, - ).Updates(&models.ModelVersion{ - Description: sql.NullString{String: description, Valid: description != ""}, - LastUpdatedTime: time.Now().UnixMilli(), - }).Error; err != nil { - return nil, contract.NewErrorWith(protos.ErrorCode_INTERNAL_ERROR, "error updating model version", err) - } - - return modelVersion, nil -} diff --git a/pkg/model_registry/store/store.go b/pkg/model_registry/store/store.go index b1245e30..fbbffe29 100644 --- a/pkg/model_registry/store/store.go +++ b/pkg/model_registry/store/store.go @@ -16,5 +16,4 @@ type ModelRegistryStore interface { RenameRegisteredModel(ctx context.Context, name, newName string) (*entities.RegisteredModel, *contract.Error) DeleteRegisteredModel(ctx context.Context, name string) *contract.Error DeleteModelVersion(ctx context.Context, name, version string) *contract.Error - UpdateModelVersion(ctx context.Context, name, version, description string) (*entities.ModelVersion, *contract.Error) } diff --git a/pkg/server/routes/model_registry.g.go b/pkg/server/routes/model_registry.g.go index ca0f2489..80618f06 100644 --- a/pkg/server/routes/model_registry.g.go +++ b/pkg/server/routes/model_registry.g.go @@ -77,17 +77,6 @@ func RegisterModelRegistryServiceRoutes(service service.ModelRegistryService, pa } return ctx.JSON(output) }) - app.Patch("/mlflow/model-versions/update", func(ctx *fiber.Ctx) error { - input := &protos.UpdateModelVersion{} - if err := parser.ParseBody(ctx, input); err != nil { - return err - } - output, err := service.UpdateModelVersion(utils.NewContextWithLoggerFromFiberContext(ctx), input) - if err != nil { - return err - } - return ctx.JSON(output) - }) app.Delete("/mlflow/model-versions/delete", func(ctx *fiber.Ctx) error { input := &protos.DeleteModelVersion{} if err := parser.ParseBody(ctx, input); err != nil { From 4b37881fe4e16074d97c3c15538f5de56e82bf62 Mon Sep 17 00:00:00 2001 From: dsuhinin Date: Thu, 6 Feb 2025 18:42:39 +0000 Subject: [PATCH 3/3] Move `POST /mlflow/model-versions/create` endpoint Signed-off-by: dsuhinin --- magefiles/generate/endpoints.go | 2 +- magefiles/generate/validations.go | 3 + magefiles/tests.go | 1 + mlflow_go_backend/store/model_registry.py | 25 +++++ pkg/contract/service/model_registry.g.go | 1 + pkg/entities/model_version.go | 15 ++- pkg/entities/model_version_tag.go | 7 ++ pkg/lib/model_registry.g.go | 8 ++ pkg/model_registry/service/model_versions.go | 39 +++++-- pkg/model_registry/store/sql/helpers.go | 81 ++++++++++++++ .../store/sql/model_versions.go | 103 ++++++++++++++++++ .../store/sql/models/model_versions.go | 2 + pkg/model_registry/store/store.go | 3 + pkg/protos/model_registry.pb.go | 6 +- pkg/server/routes/model_registry.g.go | 11 ++ 15 files changed, 286 insertions(+), 21 deletions(-) create mode 100644 pkg/model_registry/store/sql/helpers.go diff --git a/magefiles/generate/endpoints.go b/magefiles/generate/endpoints.go index 236874f0..47460890 100644 --- a/magefiles/generate/endpoints.go +++ b/magefiles/generate/endpoints.go @@ -55,7 +55,7 @@ var ServiceInfoMap = map[string]ServiceGenerationInfo{ "getRegisteredModel", // "searchRegisteredModels", "getLatestVersions", - // "createModelVersion", + "createModelVersion", "updateModelVersion", "transitionModelVersionStage", "deleteModelVersion", diff --git a/magefiles/generate/validations.go b/magefiles/generate/validations.go index 22bac3bf..2005d3bb 100644 --- a/magefiles/generate/validations.go +++ b/magefiles/generate/validations.go @@ -73,4 +73,7 @@ var validations = map[string]string{ "SetModelVersionTag_Version": "stringAsInteger", "GetModelVersion_Version": "stringAsInteger", "GetModelVersionDownloadUri_Version": "stringAsInteger", + "CreateModelVersion_Name": "notEmpty,required", + "ModelVersionTag_Key": "required,max=250,validMetricParamOrTagName,pathIsUnique", + "ModelVersionTag_Value": "omitempty,max=5000,truncate=5000", } diff --git a/magefiles/tests.go b/magefiles/tests.go index de341e6b..d91948a2 100644 --- a/magefiles/tests.go +++ b/magefiles/tests.go @@ -34,6 +34,7 @@ func runPythonTests(pytestArgs []string) error { // "--log-cli-level=DEBUG", "--confcutdir=.", "-k", "not [file", + "-vv", } args = append(args, pytestArgs...) diff --git a/mlflow_go_backend/store/model_registry.py b/mlflow_go_backend/store/model_registry.py index df279532..b1f55226 100644 --- a/mlflow_go_backend/store/model_registry.py +++ b/mlflow_go_backend/store/model_registry.py @@ -3,6 +3,7 @@ from mlflow.entities.model_registry import ModelVersion, RegisteredModel from mlflow.protos.model_registry_pb2 import ( + CreateModelVersion, CreateRegisteredModel, DeleteModelVersion, DeleteModelVersionTag, @@ -176,6 +177,30 @@ def get_model_version_download_uri(self, name, version): ) return response.artifact_uri + def create_model_version( + self, + name, + source, + run_id=None, + tags=None, + run_link=None, + description=None, + local_model_path=None, + ): + request = CreateModelVersion( + name=name, + source=source, + run_id=run_id, + tags=[tag.to_proto() for tag in tags] if tags else [], + run_link=run_link, + description=description, + ) + response = self.service.call_endpoint( + get_lib().ModelRegistryServiceCreateModelVersion, request + ) + + return ModelVersion.from_proto(response.model_version) + def ModelRegistryStore(cls): return type(cls.__name__, (_ModelRegistryStore, cls), {}) diff --git a/pkg/contract/service/model_registry.g.go b/pkg/contract/service/model_registry.g.go index 7e1cc921..73842c49 100644 --- a/pkg/contract/service/model_registry.g.go +++ b/pkg/contract/service/model_registry.g.go @@ -16,6 +16,7 @@ type ModelRegistryService interface { DeleteRegisteredModel(ctx context.Context, input *protos.DeleteRegisteredModel) (*protos.DeleteRegisteredModel_Response, *contract.Error) GetRegisteredModel(ctx context.Context, input *protos.GetRegisteredModel) (*protos.GetRegisteredModel_Response, *contract.Error) GetLatestVersions(ctx context.Context, input *protos.GetLatestVersions) (*protos.GetLatestVersions_Response, *contract.Error) + CreateModelVersion(ctx context.Context, input *protos.CreateModelVersion) (*protos.CreateModelVersion_Response, *contract.Error) UpdateModelVersion(ctx context.Context, input *protos.UpdateModelVersion) (*protos.UpdateModelVersion_Response, *contract.Error) TransitionModelVersionStage(ctx context.Context, input *protos.TransitionModelVersionStage) (*protos.TransitionModelVersionStage_Response, *contract.Error) DeleteModelVersion(ctx context.Context, input *protos.DeleteModelVersion) (*protos.DeleteModelVersion_Response, *contract.Error) diff --git a/pkg/entities/model_version.go b/pkg/entities/model_version.go index 66dfbbfe..4899887b 100644 --- a/pkg/entities/model_version.go +++ b/pkg/entities/model_version.go @@ -28,16 +28,19 @@ type ModelVersion struct { func (mv ModelVersion) ToProto() *protos.ModelVersion { modelVersion := protos.ModelVersion{ Name: utils.PtrTo(mv.Name), + Tags: make([]*protos.ModelVersionTag, 0, len(mv.Tags)), + Source: utils.PtrTo(mv.Source), + Status: utils.PtrTo(protos.ModelVersionStatus(protos.ModelVersionStatus_value[mv.Status])), + RunLink: utils.PtrTo(mv.RunLink), Version: utils.PtrTo(strconv.Itoa(int(mv.Version))), + Description: utils.PtrTo(mv.Description), CurrentStage: utils.PtrTo(mv.CurrentStage), CreationTimestamp: utils.PtrTo(mv.CreationTime), LastUpdatedTimestamp: utils.PtrTo(mv.LastUpdatedTime), - Description: utils.PtrTo(mv.Description), - UserId: utils.PtrTo(mv.UserID), - Source: utils.PtrTo(mv.Source), - Status: utils.PtrTo(protos.ModelVersionStatus(protos.ModelVersionStatus_value[mv.Status])), - Tags: make([]*protos.ModelVersionTag, 0, len(mv.Tags)), - RunLink: utils.PtrTo(mv.RunLink), + } + + if mv.UserID != "" { + modelVersion.UserId = utils.PtrTo(mv.UserID) } if mv.RunID != "" { diff --git a/pkg/entities/model_version_tag.go b/pkg/entities/model_version_tag.go index 22adb93b..4e05e43e 100644 --- a/pkg/entities/model_version_tag.go +++ b/pkg/entities/model_version_tag.go @@ -16,3 +16,10 @@ func (mvt ModelVersionTag) ToProto() *protos.ModelVersionTag { Value: utils.PtrTo(mvt.Value), } } + +func NewModelVersionTag(proto *protos.ModelVersionTag) *ModelVersionTag { + return &ModelVersionTag{ + Key: proto.GetKey(), + Value: proto.GetValue(), + } +} diff --git a/pkg/lib/model_registry.g.go b/pkg/lib/model_registry.g.go index 29f248b5..146c29f0 100644 --- a/pkg/lib/model_registry.g.go +++ b/pkg/lib/model_registry.g.go @@ -55,6 +55,14 @@ func ModelRegistryServiceGetLatestVersions(serviceID int64, requestData unsafe.P } return invokeServiceMethod(service.GetLatestVersions, new(protos.GetLatestVersions), requestData, requestSize, responseSize) } +//export ModelRegistryServiceCreateModelVersion +func ModelRegistryServiceCreateModelVersion(serviceID int64, requestData unsafe.Pointer, requestSize C.int, responseSize *C.int) unsafe.Pointer { + service, err := modelRegistryServices.Get(serviceID) + if err != nil { + return makePointerFromError(err, responseSize) + } + return invokeServiceMethod(service.CreateModelVersion, new(protos.CreateModelVersion), requestData, requestSize, responseSize) +} //export ModelRegistryServiceUpdateModelVersion func ModelRegistryServiceUpdateModelVersion(serviceID int64, requestData unsafe.Pointer, requestSize C.int, responseSize *C.int) unsafe.Pointer { service, err := modelRegistryServices.Get(serviceID) diff --git a/pkg/model_registry/service/model_versions.go b/pkg/model_registry/service/model_versions.go index 591f53dd..cd048f48 100644 --- a/pkg/model_registry/service/model_versions.go +++ b/pkg/model_registry/service/model_versions.go @@ -3,10 +3,10 @@ package service import ( "context" "fmt" - "strconv" "strings" "github.com/mlflow/mlflow-go-backend/pkg/contract" + "github.com/mlflow/mlflow-go-backend/pkg/entities" "github.com/mlflow/mlflow-go-backend/pkg/model_registry/store/sql/models" "github.com/mlflow/mlflow-go-backend/pkg/protos" "github.com/mlflow/mlflow-go-backend/pkg/utils" @@ -38,16 +38,7 @@ func (m *ModelRegistryService) DeleteModelVersion( func (m *ModelRegistryService) GetModelVersion( ctx context.Context, input *protos.GetModelVersion, ) (*protos.GetModelVersion_Response, *contract.Error) { - // by some strange reason GetModelVersion.Version has a string type so we can't apply our validation, - // that's why such a custom validation exists to satisfy Python tests. - version := input.GetVersion() - if _, err := strconv.Atoi(version); err != nil { - return nil, contract.NewErrorWith( - protos.ErrorCode_INVALID_PARAMETER_VALUE, "Model version must be an integer", err, - ) - } - - modelVersion, err := m.store.GetModelVersion(ctx, input.GetName(), version, true) + modelVersion, err := m.store.GetModelVersion(ctx, input.GetName(), input.GetVersion(), true) if err != nil { return nil, err } @@ -157,3 +148,29 @@ func (m *ModelRegistryService) GetModelVersionDownloadUri( ArtifactUri: utils.PtrTo(artifactURI), }, nil } + +func (m *ModelRegistryService) CreateModelVersion( + ctx context.Context, input *protos.CreateModelVersion, +) (*protos.CreateModelVersion_Response, *contract.Error) { + tags := make([]*entities.ModelVersionTag, 0, len(input.GetTags())) + for _, tag := range input.GetTags() { + tags = append(tags, entities.NewModelVersionTag(tag)) + } + + modelVersion, err := m.store.CreateModelVersion( + ctx, + input.GetName(), + input.GetSource(), + input.GetRunId(), + tags, + input.GetRunLink(), + input.GetDescription(), + ) + if err != nil { + return nil, err + } + + return &protos.CreateModelVersion_Response{ + ModelVersion: modelVersion.ToProto(), + }, nil +} diff --git a/pkg/model_registry/store/sql/helpers.go b/pkg/model_registry/store/sql/helpers.go new file mode 100644 index 00000000..5666597d --- /dev/null +++ b/pkg/model_registry/store/sql/helpers.go @@ -0,0 +1,81 @@ +package sql + +import ( + "errors" + "net/url" + "strconv" + "strings" + + "github.com/mlflow/mlflow-go-backend/pkg/entities" +) + +// GetNextVersion returns the next version number for a given registered model. +func GetNextVersion(sqlRegisteredModel *entities.RegisteredModel) int32 { + if len(sqlRegisteredModel.Versions) > 0 { + maxVersion := sqlRegisteredModel.Versions[0].Version + for _, mv := range sqlRegisteredModel.Versions { + if mv.Version > maxVersion { + maxVersion = mv.Version + } + } + + return maxVersion + 1 + } + + return 1 +} + +type ParsedModelURI struct { + Name string + Version string + Stage string + Alias string +} + +//nolint:cyclop,err113,mnd,wrapcheck +func ParseModelURI(uri string) (*ParsedModelURI, error) { + parsed, err := url.Parse(uri) + if err != nil { + return nil, err + } + + if parsed.Scheme != "models" { + return nil, errors.New("invalid model URI scheme") + } + + path := strings.TrimLeft(parsed.Path, "/") + if path == "" { + return nil, errors.New("invalid model URI path") + } + + parts := strings.Split(path, "/") + if len(parts) > 2 { + return nil, errors.New("invalid model URI format") + } + + if len(parts) == 2 { + name, suffix := parts[0], parts[1] + if suffix == "" { + return nil, errors.New("invalid model URI suffix") + } + + if _, err := strconv.Atoi(suffix); err == nil { + // The suffix is a specific version + return &ParsedModelURI{Name: name, Version: suffix}, nil + } else if strings.EqualFold(suffix, "latest") { + // The suffix is "latest" + return &ParsedModelURI{Name: name}, nil + } + + // The suffix is a specific stage + return &ParsedModelURI{Name: name, Stage: suffix}, nil + } + + // The URI is an alias URI + aliasParts := strings.SplitN(parts[0], "@", 2) + if len(aliasParts) != 2 || aliasParts[1] == "" { + return nil, errors.New("invalid model alias format") + } + + return &ParsedModelURI{Name: aliasParts[0], Alias: aliasParts[1]}, nil +} diff --git a/pkg/model_registry/store/sql/model_versions.go b/pkg/model_registry/store/sql/model_versions.go index 5c8f5047..e3b97465 100644 --- a/pkg/model_registry/store/sql/model_versions.go +++ b/pkg/model_registry/store/sql/model_versions.go @@ -5,6 +5,7 @@ import ( "database/sql" "errors" "fmt" + "net/url" "strconv" "strings" "time" @@ -18,6 +19,10 @@ import ( "github.com/mlflow/mlflow-go-backend/pkg/protos" ) +const ( + CreateModelVersionRetries = 3 +) + func (m *ModelRegistrySQLStore) GetLatestVersions( ctx context.Context, name string, stages []string, ) ([]*protos.ModelVersion, *contract.Error) { @@ -78,6 +83,104 @@ func (m *ModelRegistrySQLStore) GetLatestVersions( return results, nil } +//nolint:funlen,cyclop,staticcheck +func (m *ModelRegistrySQLStore) CreateModelVersion( + ctx context.Context, name, source, runID string, tags []*entities.ModelVersionTag, runLink, description string, +) (*entities.ModelVersion, *contract.Error) { + storageLocation := source + + parsedSource, parsedSourceErr := url.Parse(source) + if parsedSourceErr != nil { + return nil, contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf("failed to parse source %q", source), + parsedSourceErr, + ) + } + + if parsedSource.Scheme == "models" { + parsedModelURI, err := ParseModelURI(source) + if err != nil { + return nil, contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf("Unable to fetch model from model URI source artifact location '%s'.", source), + parsedSourceErr, + ) + } + + downloadURI, contractErr := m.GetModelVersionDownloadURI(ctx, parsedModelURI.Name, parsedModelURI.Version) + if contractErr != nil { + return nil, contractErr + } + + storageLocation = downloadURI + } + + registeredModel, err := m.GetRegisteredModel(ctx, name) + if err != nil { + return nil, err + } + + creationTime := time.Now().UnixMilli() + for range CreateModelVersionRetries { + registeredModel.LastUpdatedTime = creationTime + nextVersion := GetNextVersion(registeredModel) + modelVersion := &models.ModelVersion{ + Name: name, + Version: nextVersion, + CreationTime: creationTime, + LastUpdatedTime: creationTime, + Description: sql.NullString{Valid: description != "", String: description}, + Status: models.ModelVersionStatusReady, + Source: source, + RunID: runID, + RunLink: runLink, + CurrentStage: models.ModelVersionStageNone, + StorageLocation: storageLocation, + } + + modelTags := make([]models.ModelVersionTag, 0, len(tags)) + for _, tag := range tags { + modelTags = append(modelTags, models.ModelVersionTag{ + Key: tag.Key, + Value: tag.Value, + Name: registeredModel.Name, + Version: nextVersion, + }) + } + + if err := m.db.Transaction(func(tx *gorm.DB) error { + if err := tx.Create(modelVersion).Error; err != nil { + return err + } + if len(modelTags) > 0 { + if err := tx.Create(modelTags).Error; err != nil { + return err + } + modelVersion.Tags = modelTags + } + + return nil + }); err != nil { + return nil, contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf("Model Version creation error (name=%s).", name), + err, + ) + } + + return modelVersion.ToEntity(), nil + } + + return nil, contract.NewError( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf( + "Model Version creation error (name=%s). Giving up after %d attempts.", + name, CreateModelVersionRetries, + ), + ) +} + func (m *ModelRegistrySQLStore) GetModelVersion( ctx context.Context, name, version string, eager bool, ) (*entities.ModelVersion, *contract.Error) { diff --git a/pkg/model_registry/store/sql/models/model_versions.go b/pkg/model_registry/store/sql/models/model_versions.go index 7484f778..f0a0e853 100644 --- a/pkg/model_registry/store/sql/models/model_versions.go +++ b/pkg/model_registry/store/sql/models/model_versions.go @@ -10,6 +10,8 @@ import ( const StageDeletedInternal = "Deleted_Internal" +const ModelVersionStatusReady = "READY" + // ModelVersion mapped from table . // //revive:disable:exported diff --git a/pkg/model_registry/store/store.go b/pkg/model_registry/store/store.go index f790bd63..030b6b33 100644 --- a/pkg/model_registry/store/store.go +++ b/pkg/model_registry/store/store.go @@ -17,6 +17,9 @@ type ModelRegistryStore interface { type ModelVersionStore interface { GetLatestVersions(ctx context.Context, name string, stages []string) ([]*protos.ModelVersion, *contract.Error) + CreateModelVersion( + ctx context.Context, name, source, runID string, tags []*entities.ModelVersionTag, runLink, description string, + ) (*entities.ModelVersion, *contract.Error) GetModelVersion(ctx context.Context, name, version string, eager bool) (*entities.ModelVersion, *contract.Error) DeleteModelVersion(ctx context.Context, name, version string) *contract.Error UpdateModelVersion(ctx context.Context, name, version, description string) (*entities.ModelVersion, *contract.Error) diff --git a/pkg/protos/model_registry.pb.go b/pkg/protos/model_registry.pb.go index ea2060a1..41b22723 100644 --- a/pkg/protos/model_registry.pb.go +++ b/pkg/protos/model_registry.pb.go @@ -768,7 +768,7 @@ type CreateModelVersion struct { unknownFields protoimpl.UnknownFields // Register model under this name - Name *string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty" query:"name" params:"name"` + Name *string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty" query:"name" params:"name" validate:"notEmpty,required"` // URI indicating the location of the model artifacts. Source *string `protobuf:"bytes,2,opt,name=source" json:"source,omitempty" query:"source" params:"source"` // MLflow run ID for correlation, if “source“ was generated by an experiment run in @@ -1250,9 +1250,9 @@ type ModelVersionTag struct { unknownFields protoimpl.UnknownFields // The tag key. - Key *string `protobuf:"bytes,1,opt,name=key" json:"key,omitempty" query:"key" params:"key"` + Key *string `protobuf:"bytes,1,opt,name=key" json:"key,omitempty" query:"key" params:"key" validate:"required,max=250,validMetricParamOrTagName,pathIsUnique"` // The tag value. - Value *string `protobuf:"bytes,2,opt,name=value" json:"value,omitempty" query:"value" params:"value"` + Value *string `protobuf:"bytes,2,opt,name=value" json:"value,omitempty" query:"value" params:"value" validate:"omitempty,max=5000,truncate=5000"` } func (x *ModelVersionTag) Reset() { diff --git a/pkg/server/routes/model_registry.g.go b/pkg/server/routes/model_registry.g.go index efdd6f14..44664d87 100644 --- a/pkg/server/routes/model_registry.g.go +++ b/pkg/server/routes/model_registry.g.go @@ -88,6 +88,17 @@ func RegisterModelRegistryServiceRoutes(service service.ModelRegistryService, pa } return ctx.JSON(output) }) + app.Post("/mlflow/model-versions/create", func(ctx *fiber.Ctx) error { + input := &protos.CreateModelVersion{} + if err := parser.ParseBody(ctx, input); err != nil { + return err + } + output, err := service.CreateModelVersion(utils.NewContextWithLoggerFromFiberContext(ctx), input) + if err != nil { + return err + } + return ctx.JSON(output) + }) app.Patch("/mlflow/model-versions/update", func(ctx *fiber.Ctx) error { input := &protos.UpdateModelVersion{} if err := parser.ParseBody(ctx, input); err != nil {