Skip to content

Stop zmq sockets #1377

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ipykernel/kernelbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,8 @@ async def shell_main(self, subshell_id: str | None):
await to_thread.run_sync(self.shell_stop.wait)
tg.cancel_scope.cancel()

await socket.stop()

async def process_shell(self, socket=None):
# socket=None is valid if kernel subshells are not supported.
try:
Expand Down
9 changes: 1 addition & 8 deletions ipykernel/subshell.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,4 @@ async def create_pair_socket(
self._pair_socket = zmq_anyio.Socket(context, zmq.PAIR)
self._pair_socket.connect(address)
self.start_soon(self._pair_socket.start)

def run(self) -> None:
try:
super().run()
finally:
if self._pair_socket is not None:
self._pair_socket.close()
self._pair_socket = None
self.add_teardown_callback(self._pair_socket.stop)
4 changes: 4 additions & 0 deletions ipykernel/subshell_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,13 @@ def close(self) -> None:
async def get_control_other_socket(self, thread: BaseThread) -> zmq_anyio.Socket:
if not self._control_other_socket.started.is_set():
await thread.task_group.start(self._control_other_socket.start)
thread.add_teardown_callback(self._control_other_socket.stop)
return self._control_other_socket

async def get_control_shell_channel_socket(self, thread: BaseThread) -> zmq_anyio.Socket:
if not self._control_shell_channel_socket.started.is_set():
await thread.task_group.start(self._control_shell_channel_socket.start)
thread.add_teardown_callback(self._control_shell_channel_socket.stop)
return self._control_shell_channel_socket

def get_other_socket(self, subshell_id: str | None) -> zmq_anyio.Socket:
Expand Down Expand Up @@ -281,6 +283,8 @@ async def _listen_for_subshell_reply(
# Subshell no longer exists so exit gracefully
return
raise
finally:
await shell_channel_socket.stop()

async def _process_control_request(
self, request: dict[str, t.Any], subshell_task: t.Any
Expand Down
54 changes: 36 additions & 18 deletions ipykernel/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

from collections.abc import Awaitable
from inspect import isawaitable
from queue import Queue
from threading import Event, Thread
from typing import Any, Callable
Expand All @@ -26,6 +27,7 @@ def __init__(self, **kwargs):
self.is_pydev_daemon_thread = True
self._tasks: Queue[tuple[str, Callable[[], Awaitable[Any]]] | None] = Queue()
self._result: Queue[Any] = Queue()
self._teardown_callbacks: list[Callable[[], Any] | Callable[[], Awaitable[Any]]] = []
self._exception: Exception | None = None

@property
Expand All @@ -47,6 +49,9 @@ def run_sync(self, func: Callable[..., Any]) -> Any:
self._tasks.put(("run_sync", func))
return self._result.get()

def add_teardown_callback(self, func: Callable[[], Any] | Callable[[], Awaitable[Any]]) -> None:
self._teardown_callbacks.append(func)

def run(self) -> None:
"""Run the thread."""
try:
Expand All @@ -55,24 +60,37 @@ def run(self) -> None:
self._exception = exc

async def _main(self) -> None:
async with create_task_group() as tg:
self._task_group = tg
self.started.set()
while True:
task = await to_thread.run_sync(self._tasks.get)
if task is None:
break
func, arg = task
if func == "start_soon":
tg.start_soon(arg)
elif func == "run_async":
res = await arg
self._result.put(res)
else: # func == "run_sync"
res = arg()
self._result.put(res)

tg.cancel_scope.cancel()
try:
async with create_task_group() as tg:
self._task_group = tg
self.started.set()
while True:
task = await to_thread.run_sync(self._tasks.get)
if task is None:
break
func, arg = task
if func == "start_soon":
tg.start_soon(arg)
elif func == "run_async":
res = await arg
self._result.put(res)
else: # func == "run_sync"
res = arg()
self._result.put(res)

tg.cancel_scope.cancel()
finally:
exception = None
for teardown_callback in self._teardown_callbacks[::-1]:
try:
res = teardown_callback()
if isawaitable(res):
await res
except Exception as exc:
if exception is None:
exception = exc
if exception is not None:
raise exception

def stop(self) -> None:
"""Stop the thread.
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ dependencies = [
"psutil>=5.7",
"packaging>=22",
"anyio>=4.8.0,<5.0.0",
"zmq-anyio >=0.3.6",
"zmq-anyio >=0.3.9",
]

[project.urls]
Expand Down
Loading