|
| 1 | +import asyncio |
| 2 | +from aiohttp import web |
| 3 | +import json |
| 4 | + |
1 | 5 | from ten import (
|
2 |
| - Extension, |
3 |
| - TenEnv, |
| 6 | + AsyncExtension, |
| 7 | + AsyncTenEnv, |
4 | 8 | Cmd,
|
5 | 9 | StatusCode,
|
6 | 10 | CmdResult,
|
7 | 11 | )
|
8 |
| -from http.server import HTTPServer, BaseHTTPRequestHandler |
9 |
| -import threading |
10 |
| -from functools import partial |
11 |
| -import re |
12 |
| - |
13 |
| - |
14 |
| -class HTTPHandler(BaseHTTPRequestHandler): |
15 |
| - def __init__(self, ten: TenEnv, *args, directory=None, **kwargs): |
16 |
| - ten.log_debug(f"new handler: {directory} {args} {kwargs}") |
17 |
| - self.ten = ten |
18 |
| - super().__init__(*args, **kwargs) |
19 |
| - |
20 |
| - def do_POST(self): |
21 |
| - self.ten.log_debug(f"post request incoming {self.path}") |
22 |
| - |
23 |
| - # match path /cmd/<cmd_name> |
24 |
| - match = re.match(r"^/cmd/([^/]+)$", self.path) |
25 |
| - if match: |
26 |
| - cmd_name = match.group(1) |
27 |
| - try: |
28 |
| - content_length = int(self.headers["Content-Length"]) |
29 |
| - input = self.rfile.read(content_length).decode("utf-8") |
30 |
| - self.ten.log_info(f"incoming request {self.path} {input}") |
31 |
| - |
32 |
| - # processing by send_cmd |
33 |
| - cmd_result_event = threading.Event() |
34 |
| - cmd_result: CmdResult |
35 |
| - |
36 |
| - def cmd_callback(_, result, ten_error): |
37 |
| - nonlocal cmd_result_event |
38 |
| - nonlocal cmd_result |
39 |
| - cmd_result = result |
40 |
| - self.ten.log_info( |
41 |
| - "cmd callback result: {}".format( |
42 |
| - cmd_result.get_property_to_json("") |
43 |
| - ) |
44 |
| - ) |
45 |
| - cmd_result_event.set() |
46 |
| - |
47 |
| - cmd = Cmd.create(cmd_name) |
48 |
| - cmd.set_property_from_json("", input) |
49 |
| - self.ten.send_cmd(cmd, cmd_callback) |
50 |
| - event_got = cmd_result_event.wait(timeout=5) |
51 |
| - |
52 |
| - # return response |
53 |
| - if not event_got: # timeout |
54 |
| - self.send_response_only(504) |
55 |
| - self.end_headers() |
56 |
| - return |
57 |
| - self.send_response( |
58 |
| - 200 if cmd_result.get_status_code() == StatusCode.OK else 502 |
59 |
| - ) |
60 |
| - self.send_header("Content-Type", "application/json") |
61 |
| - self.end_headers() |
62 |
| - self.wfile.write( |
63 |
| - cmd_result.get_property_to_json("").encode(encoding="utf_8") |
64 |
| - ) |
65 |
| - except Exception as e: |
66 |
| - self.ten.log_warn("failed to handle request, err {}".format(e)) |
67 |
| - self.send_response_only(500) |
68 |
| - self.end_headers() |
69 |
| - else: |
70 |
| - self.ten.log_warn(f"invalid path: {self.path}") |
71 |
| - self.send_response_only(404) |
72 |
| - self.end_headers() |
73 |
| - |
74 |
| - |
75 |
| -class HTTPServerExtension(Extension): |
| 12 | + |
| 13 | + |
| 14 | +class HTTPServerExtension(AsyncExtension): |
76 | 15 | def __init__(self, name: str):
|
77 | 16 | super().__init__(name)
|
78 |
| - self.listen_addr = "127.0.0.1" |
79 |
| - self.listen_port = 8888 |
80 |
| - self.cmd_white_list = None |
81 |
| - self.server = None |
82 |
| - self.thread = None |
83 |
| - |
84 |
| - def on_start(self, ten: TenEnv): |
85 |
| - self.listen_addr = ten.get_property_string("listen_addr") |
86 |
| - self.listen_port = ten.get_property_int("listen_port") |
87 |
| - """ |
88 |
| - white_list = ten.get_property_string("cmd_white_list") |
89 |
| - if len(white_list) > 0: |
90 |
| - self.cmd_white_list = white_list.split(",") |
91 |
| - """ |
92 |
| - |
93 |
| - ten.log_info( |
94 |
| - f"on_start {self.listen_addr}:{self.listen_port}, {self.cmd_white_list}" |
95 |
| - ) |
96 |
| - |
97 |
| - self.server = HTTPServer( |
98 |
| - (self.listen_addr, self.listen_port), partial(HTTPHandler, ten) |
99 |
| - ) |
100 |
| - self.thread = threading.Thread(target=self.server.serve_forever) |
101 |
| - self.thread.start() |
102 |
| - |
103 |
| - ten.on_start_done() |
104 |
| - |
105 |
| - def on_stop(self, ten: TenEnv): |
106 |
| - self.server.shutdown() |
107 |
| - self.thread.join() |
108 |
| - ten.on_stop_done() |
109 |
| - |
110 |
| - def on_cmd(self, ten: TenEnv, cmd: Cmd): |
| 17 | + self.listen_addr: str = "127.0.0.1" |
| 18 | + self.listen_port: int = 8888 |
| 19 | + |
| 20 | + self.ten_env: AsyncTenEnv = None |
| 21 | + |
| 22 | + # http server instances |
| 23 | + self.app = web.Application() |
| 24 | + self.runner = None |
| 25 | + |
| 26 | + # POST /cmd/{cmd_name} |
| 27 | + async def handle_post_cmd(self, request): |
| 28 | + ten_env = self.ten_env |
| 29 | + |
| 30 | + try: |
| 31 | + cmd_name = request.match_info.get('cmd_name') |
| 32 | + |
| 33 | + req_json = await request.json() |
| 34 | + input = json.dumps(req_json, ensure_ascii=False) |
| 35 | + |
| 36 | + ten_env.log_debug( |
| 37 | + f"process incoming request {request.method} {request.path} {input}") |
| 38 | + |
| 39 | + cmd = Cmd.create(cmd_name) |
| 40 | + cmd.set_property_from_json("", input) |
| 41 | + [cmd_result, _] = await asyncio.wait_for(ten_env.send_cmd(cmd), 5.0) |
| 42 | + |
| 43 | + # return response |
| 44 | + status = 200 if cmd_result.get_status_code() == StatusCode.OK else 502 |
| 45 | + return web.json_response( |
| 46 | + cmd_result.get_property_to_json(""), status=status |
| 47 | + ) |
| 48 | + except json.JSONDecodeError: |
| 49 | + return web.Response(status=400) |
| 50 | + except asyncio.TimeoutError: |
| 51 | + return web.Response(status=504) |
| 52 | + except Exception as e: |
| 53 | + ten_env.log_warn( |
| 54 | + "failed to handle request with unknown exception, err {}".format(e)) |
| 55 | + return web.Response(status=500) |
| 56 | + |
| 57 | + async def on_start(self, ten_env: AsyncTenEnv): |
| 58 | + if await ten_env.is_property_exist("listen_addr"): |
| 59 | + self.listen_addr = await ten_env.get_property_string("listen_addr") |
| 60 | + if await ten_env.is_property_exist("listen_port"): |
| 61 | + self.listen_port = await ten_env.get_property_int("listen_port") |
| 62 | + self.ten_env = ten_env |
| 63 | + |
| 64 | + ten_env.log_info( |
| 65 | + f"http server listening on {self.listen_addr}:{self.listen_port}") |
| 66 | + |
| 67 | + self.app.router.add_post("/cmd/{cmd_name}", self.handle_post_cmd) |
| 68 | + self.runner = web.AppRunner(self.app) |
| 69 | + await self.runner.setup() |
| 70 | + site = web.TCPSite(self.runner, self.listen_addr, self.listen_port) |
| 71 | + await site.start() |
| 72 | + |
| 73 | + async def on_stop(self, ten_env: AsyncTenEnv): |
| 74 | + await self.runner.cleanup() |
| 75 | + self.ten_env = None |
| 76 | + |
| 77 | + async def on_cmd(self, ten_env: AsyncTenEnv, cmd: Cmd): |
111 | 78 | cmd_name = cmd.get_name()
|
112 |
| - ten.log_info("on_cmd {cmd_name}") |
113 |
| - cmd_result = CmdResult.create(StatusCode.OK) |
114 |
| - ten.return_result(cmd_result, cmd) |
| 79 | + ten_env.log_debug(f"on_cmd {cmd_name}") |
| 80 | + ten_env.return_result(CmdResult.create(StatusCode.OK), cmd) |
0 commit comments