diff --git a/src/llama_stack_client/lib/cli/models/models.py b/src/llama_stack_client/lib/cli/models/models.py index c724e5d5..24662a55 100644 --- a/src/llama_stack_client/lib/cli/models/models.py +++ b/src/llama_stack_client/lib/cli/models/models.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import json from typing import Optional import click @@ -93,12 +94,23 @@ def get_model(ctx, model_id: str): console.print(table) +class JSONParamType(click.ParamType): + name = "json" + + def convert(self, value, param, ctx): + try: + return json.loads(value) + except json.JSONDecodeError as e: + self.fail(f"Invalid JSON: {e}", param, ctx) + + @click.command(name="register", help="Register a new model at distribution endpoint") @click.help_option("-h", "--help") @click.argument("model_id") @click.option("--provider-id", help="Provider ID for the model", default=None) @click.option("--provider-model-id", help="Provider's model ID", default=None) -@click.option("--metadata", help="JSON metadata for the model", default=None) +@click.option("--metadata", type=JSONParamType(), help="JSON metadata for the model", default=None) +@click.option("--model-type", type=click.Choice(["llm", "embedding"]), default="llm", help="Model type: llm, embedding") @click.pass_context @handle_client_errors("register model") def register_model( @@ -107,6 +119,7 @@ def register_model( provider_id: Optional[str], provider_model_id: Optional[str], metadata: Optional[str], + model_type: Optional[str], ): """Register a new model at distribution endpoint""" client = ctx.obj["client"] @@ -117,6 +130,7 @@ def register_model( provider_id=provider_id, provider_model_id=provider_model_id, metadata=metadata, + model_type=model_type, ) if response: console.print(f"[green]Successfully registered model {model_id}[/green]")