|
21 | 21 | Any, |
22 | 22 | Dict, |
23 | 23 | FrozenSet, |
| 24 | + Iterator, |
24 | 25 | List, |
25 | 26 | Literal, |
| 27 | + MutableMapping, |
26 | 28 | NotRequired, |
27 | 29 | Optional, |
28 | 30 | Set, |
29 | 31 | Tuple, |
30 | 32 | TypedDict, |
31 | 33 | Union, |
32 | 34 | cast, |
| 35 | + overload, |
33 | 36 | ) |
34 | 37 |
|
35 | 38 | import zmq |
@@ -452,6 +455,102 @@ async def execute( |
452 | 455 | raise |
453 | 456 |
|
454 | 457 |
|
| 458 | +@dataclass(frozen=True) |
| 459 | +class AgentKernelRegistryKey: |
| 460 | + agent_id: AgentId |
| 461 | + kernel_id: KernelId |
| 462 | + |
| 463 | + |
| 464 | +class KernelRegistry(MutableMapping[AgentKernelRegistryKey, AbstractKernel]): |
| 465 | + _registry: MutableMapping[AgentKernelRegistryKey, AbstractKernel] |
| 466 | + _global_registry: MutableMapping[KernelId, AbstractKernel] |
| 467 | + |
| 468 | + def __init__(self) -> None: |
| 469 | + super().__init__() |
| 470 | + |
| 471 | + self._registry = {} |
| 472 | + self._global_registry = {} |
| 473 | + |
| 474 | + def agent_view(self, agent_id: AgentId) -> "KernelRegistryAgentView": |
| 475 | + return KernelRegistryAgentView(self, agent_id) |
| 476 | + |
| 477 | + def global_view(self) -> "KernelRegistryGlobalView": |
| 478 | + return KernelRegistryGlobalView(self) |
| 479 | + |
| 480 | + @overload |
| 481 | + def __getitem__(self, key: KernelId) -> AbstractKernel: ... |
| 482 | + |
| 483 | + @overload |
| 484 | + def __getitem__(self, key: AgentKernelRegistryKey) -> AbstractKernel: ... |
| 485 | + |
| 486 | + def __getitem__(self, key: KernelId | AgentKernelRegistryKey) -> AbstractKernel: |
| 487 | + if isinstance(key, AgentKernelRegistryKey): |
| 488 | + return self._registry[key] |
| 489 | + else: |
| 490 | + return self._global_registry[key] |
| 491 | + |
| 492 | + def __setitem__(self, key: AgentKernelRegistryKey, value: AbstractKernel) -> None: |
| 493 | + self._registry[key] = value |
| 494 | + self._global_registry[key.kernel_id] = value |
| 495 | + |
| 496 | + def __delitem__(self, key: AgentKernelRegistryKey) -> None: |
| 497 | + del self._registry[key] |
| 498 | + del self._global_registry[key.kernel_id] |
| 499 | + |
| 500 | + def __iter__(self) -> Iterator[AgentKernelRegistryKey]: |
| 501 | + return iter(self._registry) |
| 502 | + |
| 503 | + def __len__(self) -> int: |
| 504 | + return len(self._registry) |
| 505 | + |
| 506 | + |
| 507 | +class KernelRegistryAgentView(MutableMapping[KernelId, AbstractKernel]): |
| 508 | + _registry: KernelRegistry |
| 509 | + _agent_id: AgentId |
| 510 | + |
| 511 | + def __init__(self, kernel_registry: KernelRegistry, agent_id: AgentId) -> None: |
| 512 | + super().__init__() |
| 513 | + |
| 514 | + self._registry = kernel_registry |
| 515 | + self._agent_id = agent_id |
| 516 | + |
| 517 | + def __getitem__(self, key: KernelId) -> AbstractKernel: |
| 518 | + return self._registry[AgentKernelRegistryKey(self._agent_id, key)] |
| 519 | + |
| 520 | + def __setitem__(self, key: KernelId, value: AbstractKernel) -> None: |
| 521 | + self._registry[AgentKernelRegistryKey(self._agent_id, key)] = value |
| 522 | + |
| 523 | + def __delitem__(self, key: KernelId) -> None: |
| 524 | + del self._registry[AgentKernelRegistryKey(self._agent_id, key)] |
| 525 | + |
| 526 | + def __iter__(self) -> Iterator[KernelId]: |
| 527 | + for registry_key in self._registry: |
| 528 | + if registry_key.agent_id == self._agent_id: |
| 529 | + yield registry_key.kernel_id |
| 530 | + |
| 531 | + def __len__(self) -> int: |
| 532 | + return sum(1 for key in self._registry if key.agent_id == self._agent_id) |
| 533 | + |
| 534 | + |
| 535 | +class KernelRegistryGlobalView(Mapping[KernelId, AbstractKernel]): |
| 536 | + _registry: KernelRegistry |
| 537 | + |
| 538 | + def __init__(self, kernel_registry: KernelRegistry) -> None: |
| 539 | + super().__init__() |
| 540 | + |
| 541 | + self._registry = kernel_registry |
| 542 | + |
| 543 | + def __getitem__(self, key: KernelId) -> AbstractKernel: |
| 544 | + return self._registry[key] |
| 545 | + |
| 546 | + def __iter__(self) -> Iterator[KernelId]: |
| 547 | + for registry_key in self._registry: |
| 548 | + yield registry_key.kernel_id |
| 549 | + |
| 550 | + def __len__(self) -> int: |
| 551 | + return len(self._registry) |
| 552 | + |
| 553 | + |
455 | 554 | _zctx = None |
456 | 555 |
|
457 | 556 |
|
|
0 commit comments