diff --git a/swagger_py_codegen/command.py b/swagger_py_codegen/command.py index 6564475..08d7d8a 100644 --- a/swagger_py_codegen/command.py +++ b/swagger_py_codegen/command.py @@ -162,9 +162,11 @@ def print_version(ctx, param, value): @click.option('--version', is_flag=True, callback=print_version, expose_value=False, is_eager=True, help='Show current version.') +@click.option('-a', '--use-async', default=False, is_flag=True, + help='Generate async request handlers (tornado)') def generate(destination, swagger_doc, force=False, package=None, template_dir=None, templates='flask', - specification=False, ui=False, validate=False): + specification=False, ui=False, validate=False, use_async=False): package = package or destination.replace('-', '_') data = spec_load(swagger_doc) if validate: @@ -198,6 +200,7 @@ def generate(destination, swagger_doc, force=False, package=None, click.secho('%-12s%s' % (status, ui_dest)) for code in generator.generate(): + code.data['use_async'] = use_async source = template.render_code(code) dest = join(destination, code.dest(env)) dest_exists = exists(dest) diff --git a/swagger_py_codegen/templates/tornado/validators.tpl b/swagger_py_codegen/templates/tornado/validators.tpl index 08b6592..fb171b1 100644 --- a/swagger_py_codegen/templates/tornado/validators.tpl +++ b/swagger_py_codegen/templates/tornado/validators.tpl @@ -71,7 +71,11 @@ class ValidatorAdaptor(object): def request_validate(obj): def _request_validate(view): @wraps(view) + {% if not use_async -%} def wrapper(*args, **kwargs): + {%- else -%} + async def wrapper(*args, **kwargs): + {%- endif %} request = obj.request endpoint = obj.endpoint user_info = obj.current_user @@ -88,7 +92,8 @@ def request_validate(obj): if location == 'json': value = getattr(request, 'body', MultiDict()) elif location == 'args': - value = getattr(request, 'query_arguments', MultiDict()) + value = {key: list(map(obj.decode_argument, value)) for key, value in + request.query_arguments.items()} for k,v in six.iteritems(value): if isinstance(v, list) and len(v) == 1: value[k] = v[0] @@ -101,7 +106,11 @@ def request_validate(obj): raise tornado.web.HTTPError(422, message='Unprocessable Entity', reason=json.dumps(reasons)) setattr(obj, location, result) + {% if not use_async -%} return view(*args, **kwargs) + {%- else -%} + return await view(*args, **kwargs) + {%- endif %} return wrapper return _request_validate @@ -109,8 +118,13 @@ def request_validate(obj): def response_filter(obj): def _response_filter(view): @wraps(view) + {% if not use_async -%} def wrapper(*args, **kwargs): resp = view(*args, **kwargs) + {%- else -%} + async def wrapper(*args, **kwargs): + resp = await view(*args, **kwargs) + {%- endif %} request = obj.request endpoint = obj.endpoint method = request.method @@ -145,7 +159,8 @@ def response_filter(obj): reason=json.dumps(errors)) obj.set_status(status) obj.set_headers(headers) - obj.write(json.dumps(resp)) + if resp: + obj.write(json.dumps(resp)) return return wrapper return _response_filter diff --git a/swagger_py_codegen/templates/tornado/view.tpl b/swagger_py_codegen/templates/tornado/view.tpl index 4971394..d649b01 100644 --- a/swagger_py_codegen/templates/tornado/view.tpl +++ b/swagger_py_codegen/templates/tornado/view.tpl @@ -9,7 +9,11 @@ class {{ name }}(ApiHandler): {%- for method, ins in methods.items() %} + {% if use_async -%} + async def {{ method.lower() }}(self{{ params.__len__() and ', ' or '' }}{{ params | join(', ') }}): + {%- else -%} def {{ method.lower() }}(self{{ params.__len__() and ', ' or '' }}{{ params | join(', ') }}): + {%- endif %} {%- for request in ins.requests %} print(self.{{ request }}) {%- endfor %}