diff --git a/magefiles/generate/endpoints.go b/magefiles/generate/endpoints.go index 236874f0..47460890 100644 --- a/magefiles/generate/endpoints.go +++ b/magefiles/generate/endpoints.go @@ -55,7 +55,7 @@ var ServiceInfoMap = map[string]ServiceGenerationInfo{ "getRegisteredModel", // "searchRegisteredModels", "getLatestVersions", - // "createModelVersion", + "createModelVersion", "updateModelVersion", "transitionModelVersionStage", "deleteModelVersion", diff --git a/magefiles/generate/validations.go b/magefiles/generate/validations.go index 22bac3bf..2005d3bb 100644 --- a/magefiles/generate/validations.go +++ b/magefiles/generate/validations.go @@ -73,4 +73,7 @@ var validations = map[string]string{ "SetModelVersionTag_Version": "stringAsInteger", "GetModelVersion_Version": "stringAsInteger", "GetModelVersionDownloadUri_Version": "stringAsInteger", + "CreateModelVersion_Name": "notEmpty,required", + "ModelVersionTag_Key": "required,max=250,validMetricParamOrTagName,pathIsUnique", + "ModelVersionTag_Value": "omitempty,max=5000,truncate=5000", } diff --git a/magefiles/tests.go b/magefiles/tests.go index de341e6b..d91948a2 100644 --- a/magefiles/tests.go +++ b/magefiles/tests.go @@ -34,6 +34,7 @@ func runPythonTests(pytestArgs []string) error { // "--log-cli-level=DEBUG", "--confcutdir=.", "-k", "not [file", + "-vv", } args = append(args, pytestArgs...) diff --git a/mlflow_go_backend/store/model_registry.py b/mlflow_go_backend/store/model_registry.py index df279532..b1f55226 100644 --- a/mlflow_go_backend/store/model_registry.py +++ b/mlflow_go_backend/store/model_registry.py @@ -3,6 +3,7 @@ from mlflow.entities.model_registry import ModelVersion, RegisteredModel from mlflow.protos.model_registry_pb2 import ( + CreateModelVersion, CreateRegisteredModel, DeleteModelVersion, DeleteModelVersionTag, @@ -176,6 +177,30 @@ def get_model_version_download_uri(self, name, version): ) return response.artifact_uri + def create_model_version( + self, + name, + source, + run_id=None, + tags=None, + run_link=None, + description=None, + local_model_path=None, + ): + request = CreateModelVersion( + name=name, + source=source, + run_id=run_id, + tags=[tag.to_proto() for tag in tags] if tags else [], + run_link=run_link, + description=description, + ) + response = self.service.call_endpoint( + get_lib().ModelRegistryServiceCreateModelVersion, request + ) + + return ModelVersion.from_proto(response.model_version) + def ModelRegistryStore(cls): return type(cls.__name__, (_ModelRegistryStore, cls), {}) diff --git a/pkg/contract/service/model_registry.g.go b/pkg/contract/service/model_registry.g.go index 7e1cc921..73842c49 100644 --- a/pkg/contract/service/model_registry.g.go +++ b/pkg/contract/service/model_registry.g.go @@ -16,6 +16,7 @@ type ModelRegistryService interface { DeleteRegisteredModel(ctx context.Context, input *protos.DeleteRegisteredModel) (*protos.DeleteRegisteredModel_Response, *contract.Error) GetRegisteredModel(ctx context.Context, input *protos.GetRegisteredModel) (*protos.GetRegisteredModel_Response, *contract.Error) GetLatestVersions(ctx context.Context, input *protos.GetLatestVersions) (*protos.GetLatestVersions_Response, *contract.Error) + CreateModelVersion(ctx context.Context, input *protos.CreateModelVersion) (*protos.CreateModelVersion_Response, *contract.Error) UpdateModelVersion(ctx context.Context, input *protos.UpdateModelVersion) (*protos.UpdateModelVersion_Response, *contract.Error) TransitionModelVersionStage(ctx context.Context, input *protos.TransitionModelVersionStage) (*protos.TransitionModelVersionStage_Response, *contract.Error) DeleteModelVersion(ctx context.Context, input *protos.DeleteModelVersion) (*protos.DeleteModelVersion_Response, *contract.Error) diff --git a/pkg/entities/model_version.go b/pkg/entities/model_version.go index 66dfbbfe..4899887b 100644 --- a/pkg/entities/model_version.go +++ b/pkg/entities/model_version.go @@ -28,16 +28,19 @@ type ModelVersion struct { func (mv ModelVersion) ToProto() *protos.ModelVersion { modelVersion := protos.ModelVersion{ Name: utils.PtrTo(mv.Name), + Tags: make([]*protos.ModelVersionTag, 0, len(mv.Tags)), + Source: utils.PtrTo(mv.Source), + Status: utils.PtrTo(protos.ModelVersionStatus(protos.ModelVersionStatus_value[mv.Status])), + RunLink: utils.PtrTo(mv.RunLink), Version: utils.PtrTo(strconv.Itoa(int(mv.Version))), + Description: utils.PtrTo(mv.Description), CurrentStage: utils.PtrTo(mv.CurrentStage), CreationTimestamp: utils.PtrTo(mv.CreationTime), LastUpdatedTimestamp: utils.PtrTo(mv.LastUpdatedTime), - Description: utils.PtrTo(mv.Description), - UserId: utils.PtrTo(mv.UserID), - Source: utils.PtrTo(mv.Source), - Status: utils.PtrTo(protos.ModelVersionStatus(protos.ModelVersionStatus_value[mv.Status])), - Tags: make([]*protos.ModelVersionTag, 0, len(mv.Tags)), - RunLink: utils.PtrTo(mv.RunLink), + } + + if mv.UserID != "" { + modelVersion.UserId = utils.PtrTo(mv.UserID) } if mv.RunID != "" { diff --git a/pkg/entities/model_version_tag.go b/pkg/entities/model_version_tag.go index 22adb93b..4e05e43e 100644 --- a/pkg/entities/model_version_tag.go +++ b/pkg/entities/model_version_tag.go @@ -16,3 +16,10 @@ func (mvt ModelVersionTag) ToProto() *protos.ModelVersionTag { Value: utils.PtrTo(mvt.Value), } } + +func NewModelVersionTag(proto *protos.ModelVersionTag) *ModelVersionTag { + return &ModelVersionTag{ + Key: proto.GetKey(), + Value: proto.GetValue(), + } +} diff --git a/pkg/lib/model_registry.g.go b/pkg/lib/model_registry.g.go index 29f248b5..146c29f0 100644 --- a/pkg/lib/model_registry.g.go +++ b/pkg/lib/model_registry.g.go @@ -55,6 +55,14 @@ func ModelRegistryServiceGetLatestVersions(serviceID int64, requestData unsafe.P } return invokeServiceMethod(service.GetLatestVersions, new(protos.GetLatestVersions), requestData, requestSize, responseSize) } +//export ModelRegistryServiceCreateModelVersion +func ModelRegistryServiceCreateModelVersion(serviceID int64, requestData unsafe.Pointer, requestSize C.int, responseSize *C.int) unsafe.Pointer { + service, err := modelRegistryServices.Get(serviceID) + if err != nil { + return makePointerFromError(err, responseSize) + } + return invokeServiceMethod(service.CreateModelVersion, new(protos.CreateModelVersion), requestData, requestSize, responseSize) +} //export ModelRegistryServiceUpdateModelVersion func ModelRegistryServiceUpdateModelVersion(serviceID int64, requestData unsafe.Pointer, requestSize C.int, responseSize *C.int) unsafe.Pointer { service, err := modelRegistryServices.Get(serviceID) diff --git a/pkg/model_registry/service/model_versions.go b/pkg/model_registry/service/model_versions.go index 591f53dd..cd048f48 100644 --- a/pkg/model_registry/service/model_versions.go +++ b/pkg/model_registry/service/model_versions.go @@ -3,10 +3,10 @@ package service import ( "context" "fmt" - "strconv" "strings" "github.com/mlflow/mlflow-go-backend/pkg/contract" + "github.com/mlflow/mlflow-go-backend/pkg/entities" "github.com/mlflow/mlflow-go-backend/pkg/model_registry/store/sql/models" "github.com/mlflow/mlflow-go-backend/pkg/protos" "github.com/mlflow/mlflow-go-backend/pkg/utils" @@ -38,16 +38,7 @@ func (m *ModelRegistryService) DeleteModelVersion( func (m *ModelRegistryService) GetModelVersion( ctx context.Context, input *protos.GetModelVersion, ) (*protos.GetModelVersion_Response, *contract.Error) { - // by some strange reason GetModelVersion.Version has a string type so we can't apply our validation, - // that's why such a custom validation exists to satisfy Python tests. - version := input.GetVersion() - if _, err := strconv.Atoi(version); err != nil { - return nil, contract.NewErrorWith( - protos.ErrorCode_INVALID_PARAMETER_VALUE, "Model version must be an integer", err, - ) - } - - modelVersion, err := m.store.GetModelVersion(ctx, input.GetName(), version, true) + modelVersion, err := m.store.GetModelVersion(ctx, input.GetName(), input.GetVersion(), true) if err != nil { return nil, err } @@ -157,3 +148,29 @@ func (m *ModelRegistryService) GetModelVersionDownloadUri( ArtifactUri: utils.PtrTo(artifactURI), }, nil } + +func (m *ModelRegistryService) CreateModelVersion( + ctx context.Context, input *protos.CreateModelVersion, +) (*protos.CreateModelVersion_Response, *contract.Error) { + tags := make([]*entities.ModelVersionTag, 0, len(input.GetTags())) + for _, tag := range input.GetTags() { + tags = append(tags, entities.NewModelVersionTag(tag)) + } + + modelVersion, err := m.store.CreateModelVersion( + ctx, + input.GetName(), + input.GetSource(), + input.GetRunId(), + tags, + input.GetRunLink(), + input.GetDescription(), + ) + if err != nil { + return nil, err + } + + return &protos.CreateModelVersion_Response{ + ModelVersion: modelVersion.ToProto(), + }, nil +} diff --git a/pkg/model_registry/store/sql/helpers.go b/pkg/model_registry/store/sql/helpers.go new file mode 100644 index 00000000..5666597d --- /dev/null +++ b/pkg/model_registry/store/sql/helpers.go @@ -0,0 +1,81 @@ +package sql + +import ( + "errors" + "net/url" + "strconv" + "strings" + + "github.com/mlflow/mlflow-go-backend/pkg/entities" +) + +// GetNextVersion returns the next version number for a given registered model. +func GetNextVersion(sqlRegisteredModel *entities.RegisteredModel) int32 { + if len(sqlRegisteredModel.Versions) > 0 { + maxVersion := sqlRegisteredModel.Versions[0].Version + for _, mv := range sqlRegisteredModel.Versions { + if mv.Version > maxVersion { + maxVersion = mv.Version + } + } + + return maxVersion + 1 + } + + return 1 +} + +type ParsedModelURI struct { + Name string + Version string + Stage string + Alias string +} + +//nolint:cyclop,err113,mnd,wrapcheck +func ParseModelURI(uri string) (*ParsedModelURI, error) { + parsed, err := url.Parse(uri) + if err != nil { + return nil, err + } + + if parsed.Scheme != "models" { + return nil, errors.New("invalid model URI scheme") + } + + path := strings.TrimLeft(parsed.Path, "/") + if path == "" { + return nil, errors.New("invalid model URI path") + } + + parts := strings.Split(path, "/") + if len(parts) > 2 { + return nil, errors.New("invalid model URI format") + } + + if len(parts) == 2 { + name, suffix := parts[0], parts[1] + if suffix == "" { + return nil, errors.New("invalid model URI suffix") + } + + if _, err := strconv.Atoi(suffix); err == nil { + // The suffix is a specific version + return &ParsedModelURI{Name: name, Version: suffix}, nil + } else if strings.EqualFold(suffix, "latest") { + // The suffix is "latest" + return &ParsedModelURI{Name: name}, nil + } + + // The suffix is a specific stage + return &ParsedModelURI{Name: name, Stage: suffix}, nil + } + + // The URI is an alias URI + aliasParts := strings.SplitN(parts[0], "@", 2) + if len(aliasParts) != 2 || aliasParts[1] == "" { + return nil, errors.New("invalid model alias format") + } + + return &ParsedModelURI{Name: aliasParts[0], Alias: aliasParts[1]}, nil +} diff --git a/pkg/model_registry/store/sql/model_versions.go b/pkg/model_registry/store/sql/model_versions.go index 5c8f5047..e3b97465 100644 --- a/pkg/model_registry/store/sql/model_versions.go +++ b/pkg/model_registry/store/sql/model_versions.go @@ -5,6 +5,7 @@ import ( "database/sql" "errors" "fmt" + "net/url" "strconv" "strings" "time" @@ -18,6 +19,10 @@ import ( "github.com/mlflow/mlflow-go-backend/pkg/protos" ) +const ( + CreateModelVersionRetries = 3 +) + func (m *ModelRegistrySQLStore) GetLatestVersions( ctx context.Context, name string, stages []string, ) ([]*protos.ModelVersion, *contract.Error) { @@ -78,6 +83,104 @@ func (m *ModelRegistrySQLStore) GetLatestVersions( return results, nil } +//nolint:funlen,cyclop,staticcheck +func (m *ModelRegistrySQLStore) CreateModelVersion( + ctx context.Context, name, source, runID string, tags []*entities.ModelVersionTag, runLink, description string, +) (*entities.ModelVersion, *contract.Error) { + storageLocation := source + + parsedSource, parsedSourceErr := url.Parse(source) + if parsedSourceErr != nil { + return nil, contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf("failed to parse source %q", source), + parsedSourceErr, + ) + } + + if parsedSource.Scheme == "models" { + parsedModelURI, err := ParseModelURI(source) + if err != nil { + return nil, contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf("Unable to fetch model from model URI source artifact location '%s'.", source), + parsedSourceErr, + ) + } + + downloadURI, contractErr := m.GetModelVersionDownloadURI(ctx, parsedModelURI.Name, parsedModelURI.Version) + if contractErr != nil { + return nil, contractErr + } + + storageLocation = downloadURI + } + + registeredModel, err := m.GetRegisteredModel(ctx, name) + if err != nil { + return nil, err + } + + creationTime := time.Now().UnixMilli() + for range CreateModelVersionRetries { + registeredModel.LastUpdatedTime = creationTime + nextVersion := GetNextVersion(registeredModel) + modelVersion := &models.ModelVersion{ + Name: name, + Version: nextVersion, + CreationTime: creationTime, + LastUpdatedTime: creationTime, + Description: sql.NullString{Valid: description != "", String: description}, + Status: models.ModelVersionStatusReady, + Source: source, + RunID: runID, + RunLink: runLink, + CurrentStage: models.ModelVersionStageNone, + StorageLocation: storageLocation, + } + + modelTags := make([]models.ModelVersionTag, 0, len(tags)) + for _, tag := range tags { + modelTags = append(modelTags, models.ModelVersionTag{ + Key: tag.Key, + Value: tag.Value, + Name: registeredModel.Name, + Version: nextVersion, + }) + } + + if err := m.db.Transaction(func(tx *gorm.DB) error { + if err := tx.Create(modelVersion).Error; err != nil { + return err + } + if len(modelTags) > 0 { + if err := tx.Create(modelTags).Error; err != nil { + return err + } + modelVersion.Tags = modelTags + } + + return nil + }); err != nil { + return nil, contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf("Model Version creation error (name=%s).", name), + err, + ) + } + + return modelVersion.ToEntity(), nil + } + + return nil, contract.NewError( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf( + "Model Version creation error (name=%s). Giving up after %d attempts.", + name, CreateModelVersionRetries, + ), + ) +} + func (m *ModelRegistrySQLStore) GetModelVersion( ctx context.Context, name, version string, eager bool, ) (*entities.ModelVersion, *contract.Error) { diff --git a/pkg/model_registry/store/sql/models/model_versions.go b/pkg/model_registry/store/sql/models/model_versions.go index 7484f778..f0a0e853 100644 --- a/pkg/model_registry/store/sql/models/model_versions.go +++ b/pkg/model_registry/store/sql/models/model_versions.go @@ -10,6 +10,8 @@ import ( const StageDeletedInternal = "Deleted_Internal" +const ModelVersionStatusReady = "READY" + // ModelVersion mapped from table . // //revive:disable:exported diff --git a/pkg/model_registry/store/store.go b/pkg/model_registry/store/store.go index f790bd63..030b6b33 100644 --- a/pkg/model_registry/store/store.go +++ b/pkg/model_registry/store/store.go @@ -17,6 +17,9 @@ type ModelRegistryStore interface { type ModelVersionStore interface { GetLatestVersions(ctx context.Context, name string, stages []string) ([]*protos.ModelVersion, *contract.Error) + CreateModelVersion( + ctx context.Context, name, source, runID string, tags []*entities.ModelVersionTag, runLink, description string, + ) (*entities.ModelVersion, *contract.Error) GetModelVersion(ctx context.Context, name, version string, eager bool) (*entities.ModelVersion, *contract.Error) DeleteModelVersion(ctx context.Context, name, version string) *contract.Error UpdateModelVersion(ctx context.Context, name, version, description string) (*entities.ModelVersion, *contract.Error) diff --git a/pkg/protos/model_registry.pb.go b/pkg/protos/model_registry.pb.go index ea2060a1..41b22723 100644 --- a/pkg/protos/model_registry.pb.go +++ b/pkg/protos/model_registry.pb.go @@ -768,7 +768,7 @@ type CreateModelVersion struct { unknownFields protoimpl.UnknownFields // Register model under this name - Name *string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty" query:"name" params:"name"` + Name *string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty" query:"name" params:"name" validate:"notEmpty,required"` // URI indicating the location of the model artifacts. Source *string `protobuf:"bytes,2,opt,name=source" json:"source,omitempty" query:"source" params:"source"` // MLflow run ID for correlation, if “source“ was generated by an experiment run in @@ -1250,9 +1250,9 @@ type ModelVersionTag struct { unknownFields protoimpl.UnknownFields // The tag key. - Key *string `protobuf:"bytes,1,opt,name=key" json:"key,omitempty" query:"key" params:"key"` + Key *string `protobuf:"bytes,1,opt,name=key" json:"key,omitempty" query:"key" params:"key" validate:"required,max=250,validMetricParamOrTagName,pathIsUnique"` // The tag value. - Value *string `protobuf:"bytes,2,opt,name=value" json:"value,omitempty" query:"value" params:"value"` + Value *string `protobuf:"bytes,2,opt,name=value" json:"value,omitempty" query:"value" params:"value" validate:"omitempty,max=5000,truncate=5000"` } func (x *ModelVersionTag) Reset() { diff --git a/pkg/server/routes/model_registry.g.go b/pkg/server/routes/model_registry.g.go index efdd6f14..44664d87 100644 --- a/pkg/server/routes/model_registry.g.go +++ b/pkg/server/routes/model_registry.g.go @@ -88,6 +88,17 @@ func RegisterModelRegistryServiceRoutes(service service.ModelRegistryService, pa } return ctx.JSON(output) }) + app.Post("/mlflow/model-versions/create", func(ctx *fiber.Ctx) error { + input := &protos.CreateModelVersion{} + if err := parser.ParseBody(ctx, input); err != nil { + return err + } + output, err := service.CreateModelVersion(utils.NewContextWithLoggerFromFiberContext(ctx), input) + if err != nil { + return err + } + return ctx.JSON(output) + }) app.Patch("/mlflow/model-versions/update", func(ctx *fiber.Ctx) error { input := &protos.UpdateModelVersion{} if err := parser.ParseBody(ctx, input); err != nil {