Skip to content

Commit 52cd20e

Browse files
authored
feat: celery integration (#527)
1 parent 894f419 commit 52cd20e

8 files changed

Lines changed: 1703 additions & 5 deletions

File tree

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
pypi/posthog: minor
3+
---
4+
5+
feat: add Celery integration and improve PostHog client fork safety

examples/celery_integration.py

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
"""
2+
Celery integration example for the PostHog Python SDK.
3+
4+
Demonstrates how to use ``PosthogCeleryIntegration`` with:
5+
- producer-side and worker-side instrumentation (publishing events and context propagation)
6+
- context propagation (distinct ID, session ID, tags) from producer to worker
7+
- task lifecycle events (published, started, success, failure, retry)
8+
- exception capture from failed tasks
9+
- ``task_filter`` customization hook
10+
11+
Setup:
12+
1. Set ``POSTHOG_PROJECT_API_KEY`` and ``POSTHOG_HOST`` in your environment
13+
2. Install dependencies: pip install posthog celery redis
14+
3. Start Redis: redis-server
15+
4. Start the worker: celery -A examples.celery_integration worker --loglevel=info
16+
5. Run the producer: python -m examples.celery_integration
17+
"""
18+
19+
import os
20+
import time
21+
from typing import Any, Optional
22+
23+
from celery import Celery
24+
from celery.signals import worker_process_shutdown
25+
26+
import posthog
27+
from posthog.integrations.celery import PosthogCeleryIntegration
28+
29+
30+
# --- Configuration ---
31+
32+
POSTHOG_PROJECT_API_KEY = os.getenv("POSTHOG_PROJECT_API_KEY", "phc_...")
33+
POSTHOG_HOST = os.getenv("POSTHOG_HOST", "http://localhost:8000")
34+
35+
app = Celery(
36+
"examples.celery_integration",
37+
broker="redis://localhost:6379/0",
38+
)
39+
40+
41+
# --- Integration wiring ---
42+
43+
44+
def configure_posthog() -> None:
45+
posthog.api_key = POSTHOG_PROJECT_API_KEY
46+
posthog.host = POSTHOG_HOST
47+
posthog.enable_local_evaluation = (
48+
False # to not require personal_api_key for this example
49+
)
50+
posthog.setup()
51+
52+
53+
def task_filter(task_name: Optional[str], task_properties: dict[str, Any]) -> bool:
54+
if task_name is not None and task_name.endswith(".health_check"):
55+
return False
56+
return True
57+
58+
59+
def create_integration() -> PosthogCeleryIntegration:
60+
return PosthogCeleryIntegration(
61+
capture_exceptions=True,
62+
capture_task_lifecycle_events=True,
63+
propagate_context=True,
64+
task_filter=task_filter,
65+
)
66+
67+
68+
configure_posthog()
69+
integration = create_integration()
70+
integration.instrument()
71+
72+
73+
# --- Worker process setup ---
74+
# On a single host the forked child inherits the PostHog client and
75+
# integration, so nothing extra is needed. If workers run on different
76+
# hosts, uncomment the signal and handler below to initialise a fresh
77+
# client and integration in each worker process. If using a custom flag
78+
# definition cache provider, reinitialize your client in each worker with
79+
# the custom provider, and reinstrument the integration with that new client
80+
# instance.
81+
# @worker_process_init.connect
82+
# def on_worker_process_init(**kwargs) -> None:
83+
# global integration
84+
# configure_posthog()
85+
# integration = create_integration()
86+
# integration.instrument()
87+
# return
88+
89+
90+
# Use this signal to shutdown the integration and PostHog client
91+
# in the worker processes. Calling shutdown() is important to flush
92+
# any pending events and is required even if the workers are running
93+
# on the same host as the producer.
94+
@worker_process_shutdown.connect
95+
def on_worker_process_shutdown(**kwargs) -> None:
96+
integration.shutdown()
97+
posthog.shutdown()
98+
99+
100+
# --- Example tasks ---
101+
102+
103+
@app.task
104+
def health_check() -> dict[str, str]:
105+
return {"status": "ok"}
106+
107+
108+
@app.task(max_retries=3)
109+
def process_order(order_id: str) -> dict:
110+
"""A task that processes an order successfully."""
111+
112+
# simulate work
113+
time.sleep(0.1)
114+
115+
# Custom event inside the task - context tags propagated from the
116+
# producer (e.g. "source", "release") should appear on this event
117+
# and this should be attributed to the correct distinct ID and session.
118+
posthog.capture(
119+
"celery example order processed",
120+
properties={"order_id": order_id, "amount": 99.99},
121+
)
122+
123+
return {"order_id": order_id, "status": "completed"}
124+
125+
126+
@app.task(bind=True, max_retries=3)
127+
def send_notification(self, user_id: str, message: str) -> None:
128+
"""A task that may fail and retry."""
129+
if self.request.retries < 2:
130+
raise self.retry(
131+
exc=ConnectionError("notification service unavailable"),
132+
countdown=120,
133+
)
134+
return None
135+
136+
137+
@app.task
138+
def failing_task() -> None:
139+
"""A task that always fails."""
140+
raise ValueError("something went wrong")
141+
142+
143+
# --- Producer code ---
144+
145+
if __name__ == "__main__":
146+
print("PostHog Celery Integration Example")
147+
print("=" * 40)
148+
print()
149+
150+
# Set up PostHog context before dispatching tasks.
151+
# The integration propagates this context to workers via task headers.
152+
with posthog.new_context(fresh=True):
153+
posthog.identify_context("user-123")
154+
posthog.set_context_session("session-user-123-abc")
155+
posthog.tag("source", "celery_integration_example_script")
156+
posthog.tag("release", "v1.2.3")
157+
158+
print("Dispatching tasks...")
159+
160+
# This task is intentionally filtered and should not emit task events.
161+
result = health_check.delay()
162+
print(f" health_check dispatched (filtered): {result.id}")
163+
164+
# This task will produce events:
165+
# celery task published (sender side)
166+
# celery task started (worker side)
167+
# order processed (custom event, should carry propagated context tags)
168+
# celery task success (worker side, includes duration)
169+
result = process_order.delay("order-456")
170+
print(f" process_order dispatched: {result.id}")
171+
172+
# This task will produce events:
173+
# celery task published
174+
# celery task started
175+
# celery task retry (with reason)
176+
# celery task started (retry attempt)
177+
# celery task success
178+
result = send_notification.delay("user-123", "Hello!")
179+
print(f" send_notification dispatched: {result.id}")
180+
181+
# This task will produce events:
182+
# celery task published
183+
# celery task started
184+
# celery task failure (with error_type and error_message)
185+
result = failing_task.delay()
186+
print(f" failing_task dispatched: {result.id}")
187+
188+
print()
189+
print("Tasks dispatched. Check your Celery worker logs and PostHog for events.")
190+
print()
191+
192+
# Shut down the integration and client in producer process
193+
integration.shutdown()
194+
posthog.shutdown()

mypy-baseline.txt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,6 @@ posthog/client.py:0: error: Incompatible types in assignment (expression has typ
1616
posthog/client.py:0: error: Incompatible types in assignment (expression has type "dict[Any, Any]", variable has type "None") [assignment]
1717
posthog/client.py:0: error: "None" has no attribute "__iter__" (not iterable) [attr-defined]
1818
posthog/client.py:0: error: Statement is unreachable [unreachable]
19-
posthog/client.py:0: error: Right operand of "and" is never evaluated [unreachable]
20-
posthog/client.py:0: error: Incompatible types in assignment (expression has type "Poller", variable has type "None") [assignment]
21-
posthog/client.py:0: error: "None" has no attribute "start" [attr-defined]
22-
posthog/client.py:0: error: Statement is unreachable [unreachable]
2319
posthog/client.py:0: error: Statement is unreachable [unreachable]
2420
posthog/client.py:0: error: Name "parse_qs" already defined (possibly by an import) [no-redef]
2521
posthog/client.py:0: error: Name "urlparse" already defined (possibly by an import) [no-redef]

posthog/client.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
import sys
66
import warnings
7+
import weakref
78
from datetime import datetime, timedelta, timezone
89
from typing import Any, Dict, List, Optional, Union
910
from uuid import uuid4
@@ -60,6 +61,7 @@
6061
get,
6162
normalize_host,
6263
remote_config,
64+
reset_sessions,
6365
)
6466
from posthog.types import (
6567
FeatureFlag,
@@ -223,6 +225,7 @@ def __init__(
223225
Category:
224226
Initialization
225227
"""
228+
self._max_queue_size = max_queue_size
226229
self.queue = queue.Queue(max_queue_size)
227230

228231
# api_key: This should be the Team API Key (token), public
@@ -245,8 +248,9 @@ def __init__(
245248
self.feature_flags_request_timeout_seconds = (
246249
feature_flags_request_timeout_seconds
247250
)
248-
self.poller = None
251+
self.poller: Optional[Poller] = None
249252
self.distinct_ids_feature_flags_reported = SizeLimitedDict(MAX_DICT_SIZE, set)
253+
self.flag_fallback_cache_url = flag_fallback_cache_url
250254
self.flag_cache = self._initialize_flag_cache(flag_fallback_cache_url)
251255
self.flag_definition_version = 0
252256
self._flags_etag: Optional[str] = None
@@ -338,6 +342,12 @@ def __init__(
338342
if send:
339343
consumer.start()
340344

345+
if hasattr(os, "register_at_fork"):
346+
weak_self = weakref.ref(self)
347+
os.register_at_fork(
348+
after_in_child=lambda: Client._reinit_after_fork_weak(weak_self)
349+
)
350+
341351
def _set_before_send(self, before_send):
342352
if before_send is not None:
343353
if callable(before_send):
@@ -1125,6 +1135,69 @@ def capture_exception(
11251135
except Exception as e:
11261136
self.log.exception(f"Failed to capture exception: {e}")
11271137

1138+
@staticmethod
1139+
def _reinit_after_fork_weak(weak_self):
1140+
"""
1141+
Reinitialize the client after a fork.
1142+
Garbage collected if the client is deleted.
1143+
"""
1144+
self = weak_self()
1145+
if self is None:
1146+
return
1147+
self._reinit_after_fork()
1148+
1149+
def _reinit_after_fork(self):
1150+
"""Reinitialize fork-unsafe client state in a forked child process.
1151+
1152+
Registered via os.register_at_fork(after_in_child=...) so it runs
1153+
exactly once in each child, before any user code, covering all code
1154+
paths (capture, flush, join, etc.).
1155+
1156+
Python threads do not survive fork() and queue.Queue internal locks
1157+
may be in an inconsistent state, so the event queue, consumer threads
1158+
and other state are replaced. Inherited queue items are not retained
1159+
as they'll be handled by the parent process's consumers.
1160+
"""
1161+
if self.consumers:
1162+
self.queue = queue.Queue(self._max_queue_size)
1163+
1164+
new_consumers = []
1165+
for old in self.consumers:
1166+
consumer = Consumer(
1167+
self.queue,
1168+
old.api_key,
1169+
flush_at=old.flush_at,
1170+
host=old.host,
1171+
on_error=old.on_error,
1172+
flush_interval=old.flush_interval,
1173+
gzip=old.gzip,
1174+
retries=old.retries,
1175+
timeout=old.timeout,
1176+
historical_migration=old.historical_migration,
1177+
)
1178+
new_consumers.append(consumer)
1179+
1180+
if self.send:
1181+
consumer.start()
1182+
1183+
self.consumers = new_consumers
1184+
1185+
if self.enable_local_evaluation:
1186+
self.poller = Poller(
1187+
interval=timedelta(seconds=self.poll_interval),
1188+
execute=self._load_feature_flags,
1189+
)
1190+
self.poller.start()
1191+
else:
1192+
self.poller = None
1193+
1194+
# If using Redis cache, we must reinitialize to get a fresh connection (fork-safe).
1195+
# If using Memory cache, we keep it as-is to benefit from the inherited warm cache.
1196+
if isinstance(self.flag_cache, RedisFlagCache):
1197+
self.flag_cache = self._initialize_flag_cache(self.flag_fallback_cache_url)
1198+
1199+
reset_sessions()
1200+
11281201
def _enqueue(self, msg, disable_geoip):
11291202
# type: (...) -> Optional[str]
11301203
"""Push a new `msg` onto the queue, return `(success, msg)`"""

0 commit comments

Comments
 (0)