diff --git a/CHANGELOG.md b/CHANGELOG.md index efb766bda4..3bfe2cd793 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ([#3464](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3464)) - `opentelemetry-instrumentation-redis` Add support for redis client-specific instrumentation. ([#3143](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3143)) +- `opentelemetry-instrumentation-tornado` Add support for `WebSocketHandler` instrumentation + ([#3448](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2761)) ### Fixed diff --git a/instrumentation/opentelemetry-instrumentation-tornado/src/opentelemetry/instrumentation/tornado/__init__.py b/instrumentation/opentelemetry-instrumentation-tornado/src/opentelemetry/instrumentation/tornado/__init__.py index 9fbf88a74d..2737e68e67 100644 --- a/instrumentation/opentelemetry-instrumentation-tornado/src/opentelemetry/instrumentation/tornado/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-tornado/src/opentelemetry/instrumentation/tornado/__init__.py @@ -162,6 +162,7 @@ def client_response_hook(span, future): from typing import Collection, Dict import tornado.web +import tornado.websocket import wrapt from wrapt import wrap_function_wrapper @@ -351,12 +352,20 @@ def patch_handler_class(tracer, server_histograms, cls, request_hook=None): "prepare", partial(_prepare, tracer, server_histograms, request_hook), ) - _wrap(cls, "on_finish", partial(_on_finish, tracer, server_histograms)) _wrap( cls, "log_exception", partial(_log_exception, tracer, server_histograms), ) + + if issubclass(cls, tornado.websocket.WebSocketHandler): + _wrap( + cls, + "on_close", + partial(_websockethandler_on_close, tracer, server_histograms), + ) + else: + _wrap(cls, "on_finish", partial(_on_finish, tracer, server_histograms)) return True @@ -365,8 +374,11 @@ def unpatch_handler_class(cls): return unwrap(cls, "prepare") - unwrap(cls, "on_finish") unwrap(cls, "log_exception") + if issubclass(cls, tornado.websocket.WebSocketHandler): + unwrap(cls, "on_close") + else: + unwrap(cls, "on_finish") delattr(cls, _OTEL_PATCHED_KEY) @@ -394,13 +406,21 @@ def _prepare( def _on_finish(tracer, server_histograms, func, handler, args, kwargs): - response = func(*args, **kwargs) - - _record_on_finish_metrics(server_histograms, handler) + try: + return func(*args, **kwargs) + finally: + _record_on_finish_metrics(server_histograms, handler) + _finish_span(tracer, handler) - _finish_span(tracer, handler) - return response +def _websockethandler_on_close( + tracer, server_histograms, func, handler, args, kwargs +): + try: + func() + finally: + _record_on_finish_metrics(server_histograms, handler) + _finish_span(tracer, handler) def _log_exception(tracer, server_histograms, func, handler, args, kwargs): diff --git a/instrumentation/opentelemetry-instrumentation-tornado/tests/test_instrumentation.py b/instrumentation/opentelemetry-instrumentation-tornado/tests/test_instrumentation.py index daf2ddd846..ea09c9b1a7 100644 --- a/instrumentation/opentelemetry-instrumentation-tornado/tests/test_instrumentation.py +++ b/instrumentation/opentelemetry-instrumentation-tornado/tests/test_instrumentation.py @@ -13,8 +13,10 @@ # limitations under the License. +import asyncio from unittest.mock import Mock, patch +import tornado.websocket from http_server_mock import HttpServerMock from tornado.httpclient import HTTPClientError from tornado.testing import AsyncHTTPTestCase @@ -450,6 +452,54 @@ def test_handler_on_finish(self): self.assertEqual(auditor.kind, SpanKind.INTERNAL) + @tornado.testing.gen_test() + async def test_websockethandler(self): + ws_client = await tornado.websocket.websocket_connect( + f"ws://127.0.0.1:{self.get_http_port()}/echo_socket" + ) + + await ws_client.write_message("world") + resp = await ws_client.read_message() + self.assertEqual(resp, "hello world") + + ws_client.close() + await asyncio.sleep(0.5) + + spans = self.sorted_spans(self.memory_exporter.get_finished_spans()) + self.assertEqual(len(spans), 3) + close_span, msg_span, req_span = spans + + self.assertEqual(req_span.name, "GET /echo_socket") + self.assertEqual(req_span.context.trace_id, msg_span.context.trace_id) + self.assertIsNone(req_span.parent) + self.assertEqual(req_span.kind, SpanKind.SERVER) + self.assertSpanHasAttributes( + req_span, + { + SpanAttributes.HTTP_METHOD: "GET", + SpanAttributes.HTTP_SCHEME: "http", + SpanAttributes.HTTP_HOST: "127.0.0.1:" + + str(self.get_http_port()), + SpanAttributes.HTTP_TARGET: "/echo_socket", + SpanAttributes.HTTP_CLIENT_IP: "127.0.0.1", + SpanAttributes.HTTP_STATUS_CODE: 101, + "tornado.handler": "tests.tornado_test_app.EchoWebSocketHandler", + }, + ) + + self.assertEqual(msg_span.name, "audit_message") + self.assertFalse(msg_span.context.is_remote) + self.assertEqual(msg_span.kind, SpanKind.INTERNAL) + self.assertEqual(msg_span.parent.span_id, req_span.context.span_id) + + self.assertEqual(close_span.name, "audit_on_close") + self.assertFalse(close_span.context.is_remote) + self.assertEqual(close_span.parent.span_id, req_span.context.span_id) + self.assertEqual( + close_span.context.trace_id, msg_span.context.trace_id + ) + self.assertEqual(close_span.kind, SpanKind.INTERNAL) + def test_exclude_lists(self): def test_excluded(path): self.fetch(path) diff --git a/instrumentation/opentelemetry-instrumentation-tornado/tests/tornado_test_app.py b/instrumentation/opentelemetry-instrumentation-tornado/tests/tornado_test_app.py index 9e84c74aca..1523375212 100644 --- a/instrumentation/opentelemetry-instrumentation-tornado/tests/tornado_test_app.py +++ b/instrumentation/opentelemetry-instrumentation-tornado/tests/tornado_test_app.py @@ -2,6 +2,7 @@ import time import tornado.web +import tornado.websocket from tornado import gen @@ -110,6 +111,16 @@ def get(self): raise tornado.web.HTTPError(403) +class EchoWebSocketHandler(tornado.websocket.WebSocketHandler): + async def on_message(self, message): + with self.application.tracer.start_as_current_span("audit_message"): + self.write_message(f"hello {message}") + + def on_close(self): + with self.application.tracer.start_as_current_span("audit_on_close"): + time.sleep(0.05) + + def make_app(tracer): app = tornado.web.Application( [ @@ -122,6 +133,7 @@ def make_app(tracer): (r"/ping", HealthCheckHandler), (r"/test_custom_response_headers", CustomResponseHeaderHandler), (r"/raise_403", RaiseHTTPErrorHandler), + (r"/echo_socket", EchoWebSocketHandler), ] ) app.tracer = tracer