diff --git a/aiohttp_apischema/generator.py b/aiohttp_apischema/generator.py index 7663443..b8a28f4 100644 --- a/aiohttp_apischema/generator.py +++ b/aiohttp_apischema/generator.py @@ -352,17 +352,10 @@ async def _on_startup(self, app: web.Application) -> None: # Strip qualifiers (Required/NotRequired) from param_type. # TODO(PY311): (remove tuple) Annotated[insp.type, *insp.metadata] param_type = Annotated[(insp.type, *insp.metadata)] if insp.metadata else insp.type - extracted_type = insp.type - while get_origin(extracted_type) is Literal: - extracted_type = get_args(extracted_type)[0] - try: - is_str = issubclass(extracted_type, str) # type: ignore[arg-type] - except TypeError: - is_str = isinstance(extracted_type, str) # Literal + models.append((key, "validation", TypeAdapter(param_type))) # We also need to convert values to Json for runtime checking. - ann_type = param_type if is_str else Json[param_type] # type: ignore[misc,valid-type] - models.append((key, "validation", TypeAdapter(ann_type))) + ann_type = param_type | Json[param_type] td[param_name] = Required[ann_type] if required else NotRequired[ann_type] endpoints["query"] = TypeAdapter(TypedDict(query.__name__, td)) # type: ignore[attr-defined,operator] for code, model in endpoints["resps"].items(): diff --git a/tests/test_generator.py b/tests/test_generator.py index 005c80d..86308ce 100644 --- a/tests/test_generator.py +++ b/tests/test_generator.py @@ -314,16 +314,9 @@ async def handler(request: web.Request, *, query: QueryArgs) -> APIResponse[int] "prefixItems": [{"type": "string"}, {"type": "integer"}, {"type": "number"}]} paths = {"/foo": {"get": { "operationId": "handler", - "parameters": [{"name": "foo", "in": "query", "required": True, "schema": { - "contentMediaType": "application/json", - "contentSchema": {"type": "integer"}, "type": "string"}}, - {"name": "bar", "in": "query", "required": False, "schema": { - "contentMediaType": "application/json", - "contentSchema": bar, "type": "string"}}, - {"name": "baz", "in": "query", "required": True, "schema": { - "contentMediaType": "application/json", - "contentSchema": {"$ref": "#/components/schemas/Baz"}, - "type": "string"}}, + "parameters": [{"name": "foo", "in": "query", "required": True, "schema": {"type": "integer"}}, + {"name": "bar", "in": "query", "required": False, "schema": bar}, + {"name": "baz", "in": "query", "required": True, "schema": {"$ref": "#/components/schemas/Baz"}}, {"name": "spam", "in": "query", "required": False, "schema": { "type": "string", "const": "eggz"}}], "responses": { @@ -354,6 +347,32 @@ async def handler(request: web.Request, *, query: QueryArgs) -> APIResponse[int] assert result[1]["type"] == "string_type" +async def test_query_literal(aiohttp_client: AiohttpClient) -> None: + schema_gen = SchemaGenerator() + + class QueryArgs(TypedDict): + foo: Literal[42, "spam"] + + @schema_gen.api() + async def handler(request: web.Request, *, query: QueryArgs) -> APIResponse[int | str]: + return APIResponse(query["foo"]) + + app = web.Application() + schema_gen.setup(app) + app.router.add_get("/foo", handler) + + client = await aiohttp_client(app) + async with client.get("/foo", params={"foo": 42}) as resp: + assert resp.status == 200 + result = await resp.json() + assert result == 42 + + async with client.get("/foo", params={"foo": "spam"}) as resp: + assert resp.status == 200 + result = await resp.json() + assert result == "spam" + + async def test_query_pydantic_annotations(aiohttp_client: AiohttpClient) -> None: schema_gen = SchemaGenerator() @@ -374,7 +393,7 @@ async def handler(request: web.Request, *, query: QueryArgs) -> APIResponse[int] schema = await resp.json() param = schema["paths"]["/foo"]["get"]["parameters"][0] - assert param["schema"]["contentSchema"]["description"] == "Some description" + assert param["schema"]["description"] == "Some description" assert param["schema"]["default"] == 42 async with client.get("/foo") as resp: