Skip to content

Add tornado WebSocketHandler instrumentation support #3498

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
([#3464](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3464))
- `opentelemetry-instrumentation-redis` Add support for redis client-specific instrumentation.
([#3143](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3143))
- `opentelemetry-instrumentation-tornado` Add support for `WebSocketHandler` instrumentation
([#3448](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2761))

### Fixed

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def client_response_hook(span, future):
from typing import Collection, Dict

import tornado.web
import tornado.websocket
import wrapt
from wrapt import wrap_function_wrapper

Expand Down Expand Up @@ -351,12 +352,20 @@ def patch_handler_class(tracer, server_histograms, cls, request_hook=None):
"prepare",
partial(_prepare, tracer, server_histograms, request_hook),
)
_wrap(cls, "on_finish", partial(_on_finish, tracer, server_histograms))
_wrap(
cls,
"log_exception",
partial(_log_exception, tracer, server_histograms),
)

if issubclass(cls, tornado.websocket.WebSocketHandler):
_wrap(
cls,
"on_close",
partial(_websockethandler_on_close, tracer, server_histograms),
)
else:
_wrap(cls, "on_finish", partial(_on_finish, tracer, server_histograms))
return True


Expand All @@ -365,8 +374,11 @@ def unpatch_handler_class(cls):
return

unwrap(cls, "prepare")
unwrap(cls, "on_finish")
unwrap(cls, "log_exception")
if issubclass(cls, tornado.websocket.WebSocketHandler):
unwrap(cls, "on_close")
else:
unwrap(cls, "on_finish")
delattr(cls, _OTEL_PATCHED_KEY)


Expand Down Expand Up @@ -394,13 +406,21 @@ def _prepare(


def _on_finish(tracer, server_histograms, func, handler, args, kwargs):
response = func(*args, **kwargs)

_record_on_finish_metrics(server_histograms, handler)
try:
Copy link
Contributor

@xrmx xrmx May 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure we should catch user code exception and change the code behavior, or if we do we should re-raise later.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its just a finally block to ensure that we finish the trace span and metrics, even if the user code raises an exception. It doesn't catch, so it shouldn't need to re-raise?

return func(*args, **kwargs)
finally:
_record_on_finish_metrics(server_histograms, handler)
_finish_span(tracer, handler)

_finish_span(tracer, handler)

return response
def _websockethandler_on_close(
tracer, server_histograms, func, handler, args, kwargs
):
try:
func()
finally:
_record_on_finish_metrics(server_histograms, handler)
_finish_span(tracer, handler)


def _log_exception(tracer, server_histograms, func, handler, args, kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
# limitations under the License.


import asyncio
from unittest.mock import Mock, patch

import tornado.websocket
from http_server_mock import HttpServerMock
from tornado.httpclient import HTTPClientError
from tornado.testing import AsyncHTTPTestCase
Expand Down Expand Up @@ -450,6 +452,54 @@ def test_handler_on_finish(self):

self.assertEqual(auditor.kind, SpanKind.INTERNAL)

@tornado.testing.gen_test()
async def test_websockethandler(self):
ws_client = await tornado.websocket.websocket_connect(
f"ws://127.0.0.1:{self.get_http_port()}/echo_socket"
)

await ws_client.write_message("world")
resp = await ws_client.read_message()
self.assertEqual(resp, "hello world")

ws_client.close()
await asyncio.sleep(0.5)

spans = self.sorted_spans(self.memory_exporter.get_finished_spans())
self.assertEqual(len(spans), 3)
close_span, msg_span, req_span = spans

self.assertEqual(req_span.name, "GET /echo_socket")
self.assertEqual(req_span.context.trace_id, msg_span.context.trace_id)
self.assertIsNone(req_span.parent)
self.assertEqual(req_span.kind, SpanKind.SERVER)
self.assertSpanHasAttributes(
req_span,
{
SpanAttributes.HTTP_METHOD: "GET",
SpanAttributes.HTTP_SCHEME: "http",
SpanAttributes.HTTP_HOST: "127.0.0.1:"
+ str(self.get_http_port()),
SpanAttributes.HTTP_TARGET: "/echo_socket",
SpanAttributes.HTTP_CLIENT_IP: "127.0.0.1",
SpanAttributes.HTTP_STATUS_CODE: 101,
"tornado.handler": "tests.tornado_test_app.EchoWebSocketHandler",
},
)

self.assertEqual(msg_span.name, "audit_message")
self.assertFalse(msg_span.context.is_remote)
self.assertEqual(msg_span.kind, SpanKind.INTERNAL)
self.assertEqual(msg_span.parent.span_id, req_span.context.span_id)

self.assertEqual(close_span.name, "audit_on_close")
self.assertFalse(close_span.context.is_remote)
self.assertEqual(close_span.parent.span_id, req_span.context.span_id)
self.assertEqual(
close_span.context.trace_id, msg_span.context.trace_id
)
self.assertEqual(close_span.kind, SpanKind.INTERNAL)

def test_exclude_lists(self):
def test_excluded(path):
self.fetch(path)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import time

import tornado.web
import tornado.websocket
from tornado import gen


Expand Down Expand Up @@ -110,6 +111,16 @@ def get(self):
raise tornado.web.HTTPError(403)


class EchoWebSocketHandler(tornado.websocket.WebSocketHandler):
async def on_message(self, message):
with self.application.tracer.start_as_current_span("audit_message"):
self.write_message(f"hello {message}")

def on_close(self):
with self.application.tracer.start_as_current_span("audit_on_close"):
time.sleep(0.05)


def make_app(tracer):
app = tornado.web.Application(
[
Expand All @@ -122,6 +133,7 @@ def make_app(tracer):
(r"/ping", HealthCheckHandler),
(r"/test_custom_response_headers", CustomResponseHeaderHandler),
(r"/raise_403", RaiseHTTPErrorHandler),
(r"/echo_socket", EchoWebSocketHandler),
]
)
app.tracer = tracer
Expand Down