|
1 | 1 | import argparse |
2 | | -from asyncio import StreamReader |
3 | 2 | import atexit |
4 | 3 | import base64 |
5 | 4 | from collections import OrderedDict |
|
10 | 9 | import pprint |
11 | 10 | import socket |
12 | 11 | import sys |
13 | | -from typing import Any |
14 | 12 | from typing import Awaitable |
15 | 13 | from typing import Callable |
16 | 14 | from typing import Dict |
|
24 | 22 | from urllib.parse import urlunparse |
25 | 23 |
|
26 | 24 | from aiohttp import ClientSession |
27 | | -from aiohttp import MultipartReader |
28 | 25 | from aiohttp import web |
29 | 26 | from aiohttp.web import Request |
30 | 27 | from aiohttp.web import middleware |
|
54 | 51 | from .trace_checks import CheckTraceDDService |
55 | 52 | from .trace_checks import CheckTracePeerService |
56 | 53 | from .trace_checks import CheckTraceStallAsync |
| 54 | +from .tracerflare import TracerFlareEvent |
| 55 | +from .tracerflare import v1_decode as v1_tracerflare_decode |
57 | 56 | from .tracestats import decode_v06 as tracestats_decode_v06 |
58 | 57 | from .tracestats import v06StatsPayload |
59 | 58 |
|
@@ -381,6 +380,19 @@ async def _apmtelemetry_by_session(self, token: Optional[str]) -> List[Telemetry |
381 | 380 | # TODO: Sort the events? |
382 | 381 | return events |
383 | 382 |
|
| 383 | + async def _tracerflares_by_session(self, token: Optional[str]) -> List[TracerFlareEvent]: |
| 384 | + """Return the tracer-flare events that belong to the given session token. |
| 385 | +
|
| 386 | + If token is None or if the token was used to manually start a session |
| 387 | + with /session-start then return all tracer-flare events that were sent |
| 388 | + since the last /session-start request was made. |
| 389 | + """ |
| 390 | + events: List[TracerFlareEvent] = [] |
| 391 | + for req in self._requests_by_session(token): |
| 392 | + if req.match_info.handler == self.handle_v1_tracer_flare: |
| 393 | + events.append(await v1_tracerflare_decode(req.headers, await req.read())) |
| 394 | + return events |
| 395 | + |
384 | 396 | async def _tracestats_by_session(self, token: Optional[str]) -> List[v06StatsPayload]: |
385 | 397 | stats: List[v06StatsPayload] = [] |
386 | 398 | for req in self._requests_by_session(token): |
@@ -517,21 +529,7 @@ async def handle_v2_apmtelemetry(self, request: Request) -> web.Response: |
517 | 529 | return web.HTTPOk() |
518 | 530 |
|
519 | 531 | async def handle_v1_tracer_flare(self, request: Request) -> web.Response: |
520 | | - # reconstruct stream from previously cached bytes |
521 | | - stream = StreamReader() |
522 | | - stream.feed_data(self._request_data(request)) |
523 | | - stream.feed_eof() |
524 | | - |
525 | | - tracer_flare: Dict[str, Any] = {} |
526 | | - |
527 | | - async for part in MultipartReader(request.headers, stream): |
528 | | - if part.name is not None: |
529 | | - if part.name == "flare_file": |
530 | | - tracer_flare[part.name] = await part.read() # zipfile |
531 | | - else: |
532 | | - tracer_flare[part.name] = await part.text() |
533 | | - |
534 | | - request["_tracer_flare"] = tracer_flare |
| 532 | + tracer_flare: TracerFlareEvent = await v1_tracerflare_decode(request.headers, self._request_data(request)) |
535 | 533 |
|
536 | 534 | expectedFields = ["source", "case_id", "email", "hostname", "flare_file"] |
537 | 535 | missingFields = [k for k in expectedFields if k not in tracer_flare] |
@@ -782,6 +780,11 @@ async def handle_session_apmtelemetry(self, request: Request) -> web.Response: |
782 | 780 | events = await self._apmtelemetry_by_session(token) |
783 | 781 | return web.json_response(events) |
784 | 782 |
|
| 783 | + async def handle_session_tracerflares(self, request: Request) -> web.Response: |
| 784 | + token = request["session_token"] |
| 785 | + events = await self._tracerflares_by_session(token) |
| 786 | + return web.json_response(events) |
| 787 | + |
785 | 788 | async def handle_session_tracestats(self, request: Request) -> web.Response: |
786 | 789 | token = request["session_token"] |
787 | 790 | stats = await self._tracestats_by_session(token) |
@@ -1029,6 +1032,7 @@ def make_app( |
1029 | 1032 | web.get("/test/session/snapshot", agent.handle_snapshot), |
1030 | 1033 | web.get("/test/session/traces", agent.handle_session_traces), |
1031 | 1034 | web.get("/test/session/apmtelemetry", agent.handle_session_apmtelemetry), |
| 1035 | + web.get("/test/session/tracerflares", agent.handle_session_tracerflares), |
1032 | 1036 | web.get("/test/session/stats", agent.handle_session_tracestats), |
1033 | 1037 | web.get("/test/session/requests", agent.handle_session_requests), |
1034 | 1038 | web.post("/test/session/responses/config", agent.handle_v07_remoteconfig_create), |
|
0 commit comments