From 99207619fa8afa2da62b33009215937e80a7d48e Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 26 Feb 2026 19:05:33 +0000 Subject: [PATCH 1/3] Initial plan From 336dd6eb5fa56471f27c8192f011961e73d5ab7e Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 26 Feb 2026 19:13:04 +0000 Subject: [PATCH 2/3] Add pufferlib/PMLL.py from PR #405 and comprehensive test suite (58 tests) Co-authored-by: drQedwards <213266729+drQedwards@users.noreply.github.com> --- pufferlib/PMLL.py | 775 +++++++++++++++++++++++++++++++++++++++++++++ tests/test_pmll.py | 583 ++++++++++++++++++++++++++++++++++ 2 files changed, 1358 insertions(+) create mode 100644 pufferlib/PMLL.py create mode 100644 tests/test_pmll.py diff --git a/pufferlib/PMLL.py b/pufferlib/PMLL.py new file mode 100644 index 0000000000..00713990ed --- /dev/null +++ b/pufferlib/PMLL.py @@ -0,0 +1,775 @@ +""" +pmll.py (PufferLib-ready) — Persistent Memory Logic Loop (PMLL) + +Refactor goals (vs. your original draft): +- Clean layering: Backend interface + CTypes backend + pure-Python fallback +- Optional native acceleration: libpmll_backend.so (SIMD/intrinsics) if present +- Thread-safe / async-safe memory controller +- PufferLib integration: usable as a plug-in “memory module” inside RL policies +- Torch integration: PML attention block that can be inserted into a network +- No hard dependency on torch/numpy unless you use the corresponding features + +Notes: +- This file does NOT assume a specific pufferlib policy API, but provides + adapters that work with typical “nn.Module policy + forward(obs)” patterns. +- If you have a concrete PufferLib policy base class you’re using, you can + subclass/mixin the PMLLPolicyMixin below and call its helpers. + +License: MIT +""" + +from __future__ import annotations + +import os +import time +import json +import math +import ctypes +import hashlib +import threading +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Protocol, runtime_checkable + +# Optional deps (only used if available/needed) +try: + import numpy as np +except Exception: # pragma: no cover + np = None + +try: + import torch + import torch.nn as nn + import torch.nn.functional as F +except Exception: # pragma: no cover + torch = None + nn = None + F = None + + +# ============================================================================= +# Utilities: stable hashing + JSONL persistence +# ============================================================================= + +def _stable_json_dumps(obj: Any) -> str: + return json.dumps(obj, sort_keys=True, separators=(",", ":"), ensure_ascii=False) + +def deterministic_hash(payload: Any, salt: str = "") -> str: + h = hashlib.sha256() + h.update(salt.encode("utf-8")) + h.update(_stable_json_dumps(payload).encode("utf-8")) + return h.hexdigest() + +@dataclass +class MemoryBlock: + payload: Dict[str, Any] + mid: str + ts: float + meta: Optional[Dict[str, Any]] = None + +class JSONLStore: + """ + Simple append-only log + optional periodic snapshot. + Designed for long runs (RL training) without loading everything each time. + """ + def __init__(self, root: str): + self.root = root + os.makedirs(root, exist_ok=True) + self.log_path = os.path.join(root, "pmll_log.jsonl") + self.snapshot_path = os.path.join(root, "pmll_snapshot.json") + + def append(self, block: MemoryBlock) -> None: + with open(self.log_path, "a", encoding="utf-8") as f: + f.write(_stable_json_dumps(block.__dict__) + "\n") + + def save_snapshot(self, blocks: List[MemoryBlock]) -> None: + with open(self.snapshot_path, "w", encoding="utf-8") as f: + f.write(_stable_json_dumps([b.__dict__ for b in blocks])) + + def load(self) -> List[MemoryBlock]: + # Prefer snapshot for faster cold start + if os.path.exists(self.snapshot_path): + try: + with open(self.snapshot_path, "r", encoding="utf-8") as f: + arr = json.loads(f.read()) + return [MemoryBlock(**x) for x in arr] + except Exception: + pass + + blocks: List[MemoryBlock] = [] + if os.path.exists(self.log_path): + with open(self.log_path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + blocks.append(MemoryBlock(**json.loads(line))) + return blocks + + +# ============================================================================= +# Backend interface (native or python) +# ============================================================================= + +@runtime_checkable +class PMLLBackend(Protocol): + def phi(self, idx: int, n: int) -> int: + ... + + def process_promise_queue(self) -> None: + ... + + def trigger_compression(self, rho: float) -> None: + ... + + def utilization(self) -> float: + ... + + # Optional accelerated attention (batched) + def vectorized_attention( + self, + q: "torch.Tensor", + k: "torch.Tensor", + v: "torch.Tensor" + ) -> Optional["torch.Tensor"]: + ... + + +class PythonBackend: + """ + Pure-Python fallback backend: + - phi via modular arithmetic + - no native promise queue (controller handles it) + - no native compression (controller can do python compression) + - utilization computed in controller + """ + def __init__(self): + self._util = 0.0 + + def phi(self, idx: int, n: int) -> int: + # collision-minimizing slot assignment (simple modulo) + return idx % n + + def process_promise_queue(self) -> None: + return + + def trigger_compression(self, rho: float) -> None: + return + + def utilization(self) -> float: + return float(self._util) + + def _set_utilization(self, u: float) -> None: + self._util = max(0.0, min(1.0, float(u))) + + def vectorized_attention(self, q, k, v): + return None + + +class CTypesBackend: + """ + Optional native backend for SIMD/intrinsics acceleration. + + Expects a shared library exposing: + int phi(int id, int n) + void process_promise_queue(PromiseQueue*, MemoryPool*) + void trigger_compression(MemoryPool*, float rho) + void vectorized_attention(float* q, float* k, float* v, int d) + + Important: + - This backend only accelerates specific operations; the controller still + owns the Python-level logic and safety. + """ + class MemoryPool(ctypes.Structure): + _fields_ = [ + ("size", ctypes.c_int), + ("data", ctypes.POINTER(ctypes.c_void_p)), + ("utilization", ctypes.c_float), + ] + + class PromiseQueue(ctypes.Structure): + _fields_ = [ + ("capacity", ctypes.c_int), + ("head", ctypes.c_int), + ("tail", ctypes.c_int), + ("promises", ctypes.POINTER(ctypes.c_void_p)), + ] + + def __init__(self, so_path: str): + self.lib = ctypes.CDLL(so_path) + + # signatures + self.lib.phi.argtypes = [ctypes.c_int, ctypes.c_int] + self.lib.phi.restype = ctypes.c_int + + # Optional exports; allow absent symbols (keep graceful) + self._has_pq = hasattr(self.lib, "process_promise_queue") + if self._has_pq: + self.lib.process_promise_queue.argtypes = [ + ctypes.POINTER(self.PromiseQueue), + ctypes.POINTER(self.MemoryPool), + ] + self.lib.process_promise_queue.restype = ctypes.POINTER(self.MemoryPool) + + self._has_comp = hasattr(self.lib, "trigger_compression") + if self._has_comp: + self.lib.trigger_compression.argtypes = [ctypes.POINTER(self.MemoryPool), ctypes.c_float] + self.lib.trigger_compression.restype = None + + self._has_attn = hasattr(self.lib, "vectorized_attention") + if self._has_attn: + self.lib.vectorized_attention.argtypes = [ + ctypes.POINTER(ctypes.c_float), + ctypes.POINTER(ctypes.c_float), + ctypes.POINTER(ctypes.c_float), + ctypes.c_int, + ] + self.lib.vectorized_attention.restype = None + + # Allocate minimal structs; you can extend this to mirror your C layout fully. + self.c_pool = self.MemoryPool(0, None, 0.0) + self.c_queue = self.PromiseQueue(0, 0, 0, None) + + def phi(self, idx: int, n: int) -> int: + return int(self.lib.phi(int(idx), int(n))) + + def process_promise_queue(self) -> None: + if self._has_pq: + self.lib.process_promise_queue(self.c_queue, self.c_pool) + + def trigger_compression(self, rho: float) -> None: + if self._has_comp: + self.lib.trigger_compression(self.c_pool, ctypes.c_float(float(rho))) + + def utilization(self) -> float: + return float(self.c_pool.utilization) + + def vectorized_attention(self, q, k, v): + # If no torch or symbol, skip + if torch is None or not self._has_attn: + return None + # Expect float32 contiguous + qf = q.contiguous().float() + kf = k.contiguous().float() + vf = v.contiguous().float() + d = qf.shape[-1] + # Only safe for CPU tensors + if qf.is_cuda or kf.is_cuda or vf.is_cuda: + return None + + q_ptr = ctypes.cast(qf.data_ptr(), ctypes.POINTER(ctypes.c_float)) + k_ptr = ctypes.cast(kf.data_ptr(), ctypes.POINTER(ctypes.c_float)) + v_ptr = ctypes.cast(vf.data_ptr(), ctypes.POINTER(ctypes.c_float)) + self.lib.vectorized_attention(q_ptr, k_ptr, v_ptr, ctypes.c_int(int(d))) + return None # in-place or just an accelerator hint + + +def make_backend( + *, + so_path: Optional[str] = None, +) -> PMLLBackend: + """ + Factory: + - If so_path provided and loadable => CTypesBackend + - Else => PythonBackend + """ + if so_path: + try: + return CTypesBackend(so_path) + except Exception: + pass + return PythonBackend() + + +# ============================================================================= +# Core: Promise + MemoryController (PufferLib-friendly) +# ============================================================================= + +@dataclass +class Promise: + pid: int + data: Any + ttl_s: float + importance: float + created_ts: float + + def expired(self, now: float) -> bool: + return (now - self.created_ts) >= self.ttl_s + + +class MemoryController: + """ + PMLL memory controller with: + - slotting via backend.phi + - promise queue (write-behind) + - compression triggers + - thread safety + + Intended use in RL: + - At each environment step, write (state->kv) or (obs->features) as promises + - Periodically process promises into the persistent pool + - Retrieve relevant entries for attention / policy decisions + """ + def __init__( + self, + pool_size: int, + *, + backend: Optional[PMLLBackend] = None, + hash_salt: str = "", + store_dir: Optional[str] = None, + snapshot_every: int = 500, + default_ttl_s: float = 3600.0, + compression_rho: float = 0.10, + compress_when_util_gt: float = 0.80, + enable_python_compress_fallback: bool = True, + ): + self.pool_size = int(pool_size) + self.backend = backend or make_backend() + self.hash_salt = hash_salt + + self.default_ttl_s = float(default_ttl_s) + self.compression_rho = float(compression_rho) + self.compress_when_util_gt = float(compress_when_util_gt) + self.enable_python_compress_fallback = bool(enable_python_compress_fallback) + + self._lock = threading.RLock() + + # Persistent pool: slots of arbitrary python objects + self.pool: List[Optional[Any]] = [None] * self.pool_size + + # Promise queue: write-behind buffer + self.promises: List[Promise] = [] + + # Persistence (optional) + self.store = JSONLStore(store_dir) if store_dir else None + self.snapshot_every = max(1, int(snapshot_every)) + self._append_count = 0 + + # Track occupancy for python backend utilization + self._occupied = 0 + + # ------------- Promise write path ------------- + + def write( + self, + *, + pid: int, + data: Any, + ttl_s: Optional[float] = None, + importance: Optional[float] = None + ) -> None: + """ + Enqueue a promise to write into persistent memory. + """ + now = time.time() + ttl = self.default_ttl_s if ttl_s is None else float(ttl_s) + imp = float(importance) if importance is not None else self._importance_score(data) + + p = Promise(pid=int(pid), data=data, ttl_s=ttl, importance=imp, created_ts=now) + with self._lock: + self.promises.append(p) + + def process_promises(self) -> None: + """ + Flush non-expired promises into pool slots. + Native backend may also do internal housekeeping. + """ + now = time.time() + + with self._lock: + if not self.promises: + self._update_utilization_locked() + return + + keep: List[Promise] = [] + for p in self.promises: + if p.expired(now): + continue + slot = self.backend.phi(p.pid, self.pool_size) + was_empty = (self.pool[slot] is None) + self.pool[slot] = p.data + if was_empty: + self._occupied += 1 + + # Persist a compact memory block (optional) + if self.store: + blk = MemoryBlock( + payload={ + "pid": p.pid, + "slot": slot, + "importance": p.importance, + "ts": now, + }, + mid=deterministic_hash({"pid": p.pid, "slot": slot, "ts": now}, salt=self.hash_salt), + ts=now, + meta={"ttl_s": p.ttl_s}, + ) + self.store.append(blk) + self._append_count += 1 + if (self._append_count % self.snapshot_every) == 0: + # snapshot only metadata log, not full pool + # (pool entries may be tensors / non-serializable) + self.store.save_snapshot([blk]) + + # promise is consumed (write-behind); drop it + self.promises = keep + + # Let native backend optionally process its internal PQ + try: + self.backend.process_promise_queue() + except Exception: + pass + + self._update_utilization_locked() + + # Compression check outside lock is fine + if self.utilization() > self.compress_when_util_gt: + self.trigger_compression(self.compression_rho) + + # ------------- Read / retrieve path ------------- + + def read_slot(self, slot: int) -> Any: + with self._lock: + return self.pool[int(slot) % self.pool_size] + + def retrieve_relevant( + self, + query: Any, + *, + threshold: float = 0.50, + max_items: int = 64, + key_extractor: Optional[Any] = None, + similarity: Optional[Any] = None, + ) -> List[Any]: + """ + Retrieve relevant pool entries based on a query. + - Works with torch tensors if provided + - key_extractor: maps entry -> key embedding (default assumes entry is (k,v) or dict) + - similarity: function(query, key) -> float + """ + if torch is None: + return [] + + if similarity is None: + similarity = cosine_sim_torch + + if key_extractor is None: + def key_extractor(x): + # Default: expect (k, v) tuple + if isinstance(x, tuple) and len(x) >= 1: + return x[0] + if isinstance(x, dict) and "k" in x: + return x["k"] + return None + + q = query + hits: List[Tuple[float, Any]] = [] + with self._lock: + for entry in self.pool: + if entry is None: + continue + k = key_extractor(entry) + if k is None: + continue + try: + s = float(similarity(q, k)) + except Exception: + continue + if s >= threshold: + hits.append((s, entry)) + + hits.sort(key=lambda t: t[0], reverse=True) + return [e for _, e in hits[: int(max_items)]] + + # ------------- Compression ------------- + + def trigger_compression(self, rho: float = 0.10) -> None: + """ + Ask backend to compress; fall back to python compression if enabled. + """ + try: + self.backend.trigger_compression(float(rho)) + return + except Exception: + pass + + if not self.enable_python_compress_fallback: + return + + # Simple python compression: drop low-importance slots by heuristic + if np is None: + return + + with self._lock: + scores = np.random.rand(self.pool_size) # placeholder importance scores + thresh = float(np.quantile(scores, float(rho))) + for i in range(self.pool_size): + if self.pool[i] is None: + continue + if float(scores[i]) < thresh: + self.pool[i] = None + self._occupied = sum(1 for x in self.pool if x is not None) + self._update_utilization_locked() + + # ------------- Utilization ------------- + + def utilization(self) -> float: + # If native backend reports utilization, prefer it + try: + u = float(self.backend.utilization()) + if not math.isnan(u) and u > 0.0: + return max(0.0, min(1.0, u)) + except Exception: + pass + + with self._lock: + return float(self._occupied) / float(self.pool_size) + + def _update_utilization_locked(self) -> None: + u = float(self._occupied) / float(self.pool_size) + # python backend can store util for debugging + if isinstance(self.backend, PythonBackend): + self.backend._set_utilization(u) + + # ------------- Importance scoring ------------- + + def _importance_score(self, data: Any) -> float: + # Replace with ERS / novelty / recency scoring if you want + if np is not None: + return float(np.random.rand()) + return 0.5 + + +# ============================================================================= +# Torch: Hybrid PML Attention block (drop-in nn.Module) +# ============================================================================= + +def cosine_sim_torch(a: "torch.Tensor", b: "torch.Tensor") -> "torch.Tensor": + a = a.flatten() + b = b.flatten() + return torch.dot(a, b) / (torch.norm(a) * torch.norm(b) + 1e-9) + +class PMLAttention(nn.Module): + """ + Hybrid attention: + A = alpha * local_attention + (1 - alpha) * persistent_attention + + Expected shapes: + q: [B, D] or [D] + k_local/v_local: [T, D] or [B, T, D] (handled loosely) + persistent entries: list of (k, v) with k/v compatible with q dims + """ + def __init__( + self, + memory: MemoryController, + *, + persistent_threshold: float = 0.50, + persistent_max_items: int = 64, + native_attention_threshold: int = 32, + ): + super().__init__() + self.memory = memory + self.persistent_threshold = float(persistent_threshold) + self.persistent_max_items = int(persistent_max_items) + self.native_attention_threshold = int(native_attention_threshold) + + def forward(self, q: torch.Tensor, k_local: torch.Tensor, v_local: torch.Tensor) -> torch.Tensor: + # Local attention + a_local = self._attend(q, k_local, v_local) + + # Persistent retrieve + rel = self.memory.retrieve_relevant( + q, + threshold=self.persistent_threshold, + max_items=self.persistent_max_items + ) + if not rel: + return a_local + + k_p, v_p = self._extract_kv(rel) + + # Optional native acceleration hint + if k_p.shape[0] > self.native_attention_threshold: + try: + self.memory.backend.vectorized_attention(q, k_p, v_p) + except Exception: + pass + + a_p = self._attend(q, k_p, v_p) + alpha = self._alpha(q, k_local, k_p) + return alpha * a_local + (1.0 - alpha) * a_p + + def _attend(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + # Normalize shapes to (T, D) with q (D,) + if q.dim() == 2: + # [B,D] -> use first batch row (common in tiny blocks); extend if you want per-batch + qv = q[0] + else: + qv = q + + if k.dim() == 3: + kv = k[0] + vv = v[0] + else: + kv = k + vv = v + + # scores: [T] + scores = (kv @ qv) / math.sqrt(qv.shape[-1]) + w = torch.softmax(scores, dim=-1) + out = (w.unsqueeze(-1) * vv).sum(dim=0) + return out + + def _extract_kv(self, rel: List[Any]) -> Tuple[torch.Tensor, torch.Tensor]: + ks = [] + vs = [] + for e in rel: + if isinstance(e, tuple) and len(e) >= 2: + ks.append(e[0]) + vs.append(e[1]) + elif isinstance(e, dict) and "k" in e and "v" in e: + ks.append(e["k"]) + vs.append(e["v"]) + return torch.stack(ks, dim=0), torch.stack(vs, dim=0) + + def _alpha(self, q: torch.Tensor, k_local: torch.Tensor, k_p: torch.Tensor) -> torch.Tensor: + # Simple similarity-based blend (scalar) + if q.dim() == 2: + qv = q[0] + else: + qv = q + if k_local.dim() == 3: + kl = k_local[0] + else: + kl = k_local + + sim_local = torch.norm(qv - kl.mean(dim=0), p=2) + sim_p = torch.norm(qv - k_p.mean(dim=0), p=2) + return torch.sigmoid(sim_local - sim_p) + + +# ============================================================================= +# PufferLib integration helpers +# ============================================================================= + +class PMLLPolicyMixin: + """ + A light-weight mixin for PufferLib-like policies. + + Usage pattern (typical): + class MyPolicy(nn.Module, PMLLPolicyMixin): + def __init__(...): + nn.Module.__init__(self) + PMLLPolicyMixin.__init__(self, pmll=MemoryController(...)) + self.pmll_attn = PMLAttention(self.pmll) + + def forward(self, obs, state=None): + features = self.encoder(obs) + # optionally: store features/kv as promises + self.pmll_write_kv(features) + self.pmll_process() # flush promises periodically + # optionally: use PML attention + ... + """ + def __init__(self, pmll: MemoryController): + self.pmll = pmll + self._pmll_step = 0 + + def pmll_write(self, pid: int, data: Any, ttl_s: Optional[float] = None, importance: Optional[float] = None) -> None: + self.pmll.write(pid=pid, data=data, ttl_s=ttl_s, importance=importance) + + def pmll_process(self) -> None: + self.pmll.process_promises() + + def pmll_write_kv( + self, + k: torch.Tensor, + v: Optional[torch.Tensor] = None, + *, + pid: Optional[int] = None, + ttl_s: Optional[float] = None, + importance: Optional[float] = None + ) -> None: + """ + Convenience: store (k, v) tuple. If v None, store (k, k). + pid defaults to a rolling step hash. + """ + if torch is None: + return + self._pmll_step += 1 + if pid is None: + pid = hash((int(self._pmll_step), int(time.time()))) + if v is None: + v = k + self.pmll.write(pid=int(pid), data=(k.detach(), v.detach()), ttl_s=ttl_s, importance=importance) + + +# ============================================================================= +# Optional: Minimal “Transformer-like” wrapper +# ============================================================================= + +class PMLLTransformer(nn.Module): + """ + A minimal example wrapper showing how to insert PMLAttention into a model. + + This is NOT a drop-in replacement for torch.nn.Transformer. + It’s a pufferlib-friendly building block: an encoder + memory + attention. + """ + def __init__(self, d_model: int, pool_size: int = 1024, so_path: Optional[str] = None): + if nn is None: + raise ImportError("torch is required for PMLLTransformer.") + super().__init__() + self.d_model = int(d_model) + self.memory = MemoryController(pool_size, backend=make_backend(so_path=so_path)) + self.attn = PMLAttention(self.memory) + + # Small encoder (replace with your policy backbone) + self.encoder = nn.Sequential( + nn.Linear(d_model, d_model), + nn.ReLU(), + nn.Linear(d_model, d_model), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: [B, D] + feats = self.encoder(x) + + # local cache: pretend k_local/v_local are last T=1 + k_local = feats.unsqueeze(1) # [B,1,D] + v_local = feats.unsqueeze(1) + + # Use PML attention with q = feats + out = self.attn(feats, k_local, v_local) + + # Write to persistent memory + for b in range(x.shape[0]): + pid = hash((b, time.time_ns())) + self.memory.write(pid=pid, data=(feats[b].detach(), out.detach()), ttl_s=3600.0) + + self.memory.process_promises() + return out + + +# ============================================================================= +# Smoke test +# ============================================================================= + +if __name__ == "__main__": + # Backend selection + # - If you compiled a native backend, set env PMLL_SO=/path/to/libpmll_backend.so + so = os.environ.get("PMLL_SO", None) + + backend = make_backend(so_path=so) + mc = MemoryController(pool_size=128, backend=backend, store_dir=None) + + # Basic write/process/read + mc.write(pid=123, data={"hello": "world"}, ttl_s=10.0, importance=0.9) + mc.process_promises() + slot = backend.phi(123, 128) + print("slot:", slot, "value:", mc.read_slot(slot)) + + # Torch test if available + if torch is not None: + model = PMLLTransformer(d_model=32, pool_size=256, so_path=so) + x = torch.randn(8, 32) + y = model(x) + print("out:", y.shape, "util:", mc.utilization()) diff --git a/tests/test_pmll.py b/tests/test_pmll.py new file mode 100644 index 0000000000..c41687a02b --- /dev/null +++ b/tests/test_pmll.py @@ -0,0 +1,583 @@ +""" +Tests for pufferlib/PMLL.py — Persistent Memory Logic Loop (PMLL) + +Covers: utilities, JSONLStore, PythonBackend, MemoryController, + Promise, make_backend, and Torch-dependent components + (cosine_sim_torch, PMLAttention, PMLLTransformer, PMLLPolicyMixin). +""" + +import json +import math +import os +import tempfile +import time + +import pytest +import numpy as np +import torch +import torch.nn as nn + +from pufferlib.PMLL import ( + _stable_json_dumps, + deterministic_hash, + MemoryBlock, + JSONLStore, + PythonBackend, + CTypesBackend, + make_backend, + Promise, + MemoryController, + cosine_sim_torch, + PMLAttention, + PMLLTransformer, + PMLLPolicyMixin, +) + + +# ============================================================================ +# Utilities +# ============================================================================ + +class TestStableJsonDumps: + def test_sorted_keys(self): + assert _stable_json_dumps({"b": 2, "a": 1}) == '{"a":1,"b":2}' + + def test_no_extra_spaces(self): + result = _stable_json_dumps({"key": "value"}) + assert " " not in result + + def test_nested(self): + result = _stable_json_dumps({"a": {"c": 3, "b": 2}}) + assert result == '{"a":{"b":2,"c":3}}' + + +class TestDeterministicHash: + def test_same_input_same_hash(self): + h1 = deterministic_hash({"x": 1}) + h2 = deterministic_hash({"x": 1}) + assert h1 == h2 + + def test_different_input_different_hash(self): + h1 = deterministic_hash({"x": 1}) + h2 = deterministic_hash({"x": 2}) + assert h1 != h2 + + def test_salt_changes_hash(self): + h1 = deterministic_hash({"x": 1}, salt="a") + h2 = deterministic_hash({"x": 1}, salt="b") + assert h1 != h2 + + def test_returns_hex_string(self): + h = deterministic_hash("hello") + assert isinstance(h, str) + assert len(h) == 64 # SHA-256 hex digest + + +class TestMemoryBlock: + def test_creation(self): + blk = MemoryBlock(payload={"k": 1}, mid="abc", ts=1.0) + assert blk.payload == {"k": 1} + assert blk.mid == "abc" + assert blk.ts == 1.0 + assert blk.meta is None + + def test_with_meta(self): + blk = MemoryBlock(payload={}, mid="x", ts=0.0, meta={"ttl": 10}) + assert blk.meta == {"ttl": 10} + + +# ============================================================================ +# JSONLStore +# ============================================================================ + +class TestJSONLStore: + def test_append_and_load_from_log(self): + with tempfile.TemporaryDirectory() as td: + store = JSONLStore(td) + blk = MemoryBlock(payload={"a": 1}, mid="m1", ts=1.0) + store.append(blk) + + loaded = store.load() + assert len(loaded) == 1 + assert loaded[0].mid == "m1" + assert loaded[0].payload == {"a": 1} + + def test_snapshot_preferred_over_log(self): + with tempfile.TemporaryDirectory() as td: + store = JSONLStore(td) + blk_log = MemoryBlock(payload={"src": "log"}, mid="log1", ts=1.0) + store.append(blk_log) + + blk_snap = MemoryBlock(payload={"src": "snap"}, mid="snap1", ts=2.0) + store.save_snapshot([blk_snap]) + + loaded = store.load() + assert len(loaded) == 1 + assert loaded[0].mid == "snap1" + + def test_load_empty(self): + with tempfile.TemporaryDirectory() as td: + store = JSONLStore(td) + assert store.load() == [] + + def test_multiple_appends(self): + with tempfile.TemporaryDirectory() as td: + store = JSONLStore(td) + for i in range(5): + store.append(MemoryBlock(payload={"i": i}, mid=f"m{i}", ts=float(i))) + + loaded = store.load() + assert len(loaded) == 5 + assert [b.mid for b in loaded] == [f"m{i}" for i in range(5)] + + +# ============================================================================ +# PythonBackend +# ============================================================================ + +class TestPythonBackend: + def test_phi_modulo(self): + b = PythonBackend() + assert b.phi(10, 8) == 2 + assert b.phi(0, 5) == 0 + assert b.phi(7, 7) == 0 + + def test_utilization_initial(self): + b = PythonBackend() + assert b.utilization() == 0.0 + + def test_set_utilization_clamped(self): + b = PythonBackend() + b._set_utilization(0.5) + assert b.utilization() == 0.5 + b._set_utilization(1.5) + assert b.utilization() == 1.0 + b._set_utilization(-0.5) + assert b.utilization() == 0.0 + + def test_process_promise_queue_noop(self): + b = PythonBackend() + b.process_promise_queue() # should not raise + + def test_trigger_compression_noop(self): + b = PythonBackend() + b.trigger_compression(0.1) # should not raise + + def test_vectorized_attention_returns_none(self): + b = PythonBackend() + assert b.vectorized_attention(None, None, None) is None + + +# ============================================================================ +# make_backend +# ============================================================================ + +class TestMakeBackend: + def test_default_returns_python_backend(self): + b = make_backend() + assert isinstance(b, PythonBackend) + + def test_invalid_so_path_falls_back(self): + b = make_backend(so_path="/nonexistent/path.so") + assert isinstance(b, PythonBackend) + + +# ============================================================================ +# Promise +# ============================================================================ + +class TestPromise: + def test_not_expired(self): + p = Promise(pid=1, data="x", ttl_s=100.0, importance=0.5, created_ts=time.time()) + assert not p.expired(time.time()) + + def test_expired(self): + p = Promise(pid=1, data="x", ttl_s=1.0, importance=0.5, created_ts=time.time() - 2.0) + assert p.expired(time.time()) + + def test_exact_boundary(self): + now = time.time() + p = Promise(pid=1, data="x", ttl_s=5.0, importance=0.5, created_ts=now - 5.0) + assert p.expired(now) + + +# ============================================================================ +# MemoryController +# ============================================================================ + +class TestMemoryController: + def test_write_and_process(self): + mc = MemoryController(pool_size=64, store_dir=None) + mc.write(pid=42, data={"hello": "world"}, ttl_s=60.0, importance=0.9) + mc.process_promises() + + slot = mc.backend.phi(42, 64) + assert mc.read_slot(slot) == {"hello": "world"} + + def test_utilization_increases(self): + mc = MemoryController(pool_size=10, store_dir=None) + assert mc.utilization() == 0.0 + + mc.write(pid=0, data="a", ttl_s=60.0, importance=1.0) + mc.process_promises() + assert mc.utilization() > 0.0 + + def test_expired_promises_not_stored(self): + mc = MemoryController(pool_size=64, store_dir=None) + mc.write(pid=99, data="old", ttl_s=0.0, importance=0.5) + time.sleep(0.01) + mc.process_promises() + + slot = mc.backend.phi(99, 64) + assert mc.read_slot(slot) is None + + def test_read_slot_wraps(self): + mc = MemoryController(pool_size=10, store_dir=None) + # slot 100 % 10 == 0 + assert mc.read_slot(100) is None + + def test_multiple_writes_same_slot(self): + mc = MemoryController(pool_size=64, store_dir=None) + mc.write(pid=5, data="first", ttl_s=60.0, importance=1.0) + mc.process_promises() + mc.write(pid=5, data="second", ttl_s=60.0, importance=1.0) + mc.process_promises() + + slot = mc.backend.phi(5, 64) + assert mc.read_slot(slot) == "second" + + def test_default_backend(self): + mc = MemoryController(pool_size=8) + assert isinstance(mc.backend, PythonBackend) + + def test_pool_size(self): + mc = MemoryController(pool_size=32) + assert mc.pool_size == 32 + assert len(mc.pool) == 32 + + def test_with_persistence(self): + with tempfile.TemporaryDirectory() as td: + mc = MemoryController(pool_size=16, store_dir=td) + mc.write(pid=7, data="persisted", ttl_s=60.0, importance=0.8) + mc.process_promises() + + # Check that the log file was written + assert os.path.exists(mc.store.log_path) + loaded = mc.store.load() + assert len(loaded) == 1 + + def test_process_empty_queue(self): + mc = MemoryController(pool_size=8, store_dir=None) + mc.process_promises() # should not raise + + +class _FailingBackend(PythonBackend): + """Backend whose trigger_compression raises, forcing the python fallback.""" + def trigger_compression(self, rho: float) -> None: + raise RuntimeError("no native compression") + + +class TestMemoryControllerCompression: + def test_trigger_compression_python_fallback(self): + backend = _FailingBackend() + mc = MemoryController(pool_size=20, backend=backend, store_dir=None, + enable_python_compress_fallback=True, + compress_when_util_gt=1.1) # prevent auto-compress + # Fill the pool + for i in range(20): + mc.write(pid=i, data=f"data_{i}", ttl_s=3600.0, importance=1.0) + mc.process_promises() + + occupied_before = mc._occupied + assert occupied_before == 20 + + np.random.seed(42) + mc.trigger_compression(rho=0.5) + # Some slots should have been cleared by the python fallback + assert mc._occupied < occupied_before + + def test_compression_noop_with_python_backend(self): + """PythonBackend.trigger_compression is a no-op, so the controller + returns early without reaching the fallback path.""" + mc = MemoryController(pool_size=10, store_dir=None, + enable_python_compress_fallback=True) + for i in range(10): + mc.write(pid=i, data=f"d{i}", ttl_s=3600.0, importance=1.0) + mc.process_promises() + + mc.trigger_compression(rho=0.5) + # PythonBackend.trigger_compression is a no-op (doesn't raise), + # so controller returns immediately without pruning + assert mc._occupied == 10 + + def test_compression_disabled(self): + backend = _FailingBackend() + mc = MemoryController(pool_size=10, backend=backend, store_dir=None, + enable_python_compress_fallback=False) + for i in range(10): + mc.write(pid=i, data=f"d{i}", ttl_s=3600.0, importance=1.0) + mc.process_promises() + + mc.trigger_compression(rho=0.5) + # Fallback disabled, backend raises => nothing changes + assert mc._occupied == 10 + + +class TestMemoryControllerRetrieve: + def test_retrieve_with_tensor_entries(self): + mc = MemoryController(pool_size=64, store_dir=None) + k1 = torch.tensor([1.0, 0.0, 0.0]) + v1 = torch.tensor([10.0, 20.0, 30.0]) + mc.write(pid=1, data=(k1, v1), ttl_s=3600.0, importance=1.0) + mc.process_promises() + + query = torch.tensor([1.0, 0.0, 0.0]) + results = mc.retrieve_relevant(query, threshold=0.5) + assert len(results) >= 1 + + def test_retrieve_empty_pool(self): + mc = MemoryController(pool_size=8, store_dir=None) + query = torch.tensor([1.0, 0.0, 0.0]) + results = mc.retrieve_relevant(query) + assert results == [] + + def test_retrieve_with_threshold(self): + mc = MemoryController(pool_size=64, store_dir=None) + # Store an entry orthogonal to query + k = torch.tensor([0.0, 1.0, 0.0]) + v = torch.tensor([1.0, 1.0, 1.0]) + mc.write(pid=1, data=(k, v), ttl_s=3600.0, importance=1.0) + mc.process_promises() + + query = torch.tensor([1.0, 0.0, 0.0]) + # With high threshold, orthogonal vectors shouldn't match + results = mc.retrieve_relevant(query, threshold=0.9) + assert len(results) == 0 + + def test_retrieve_max_items(self): + mc = MemoryController(pool_size=64, store_dir=None) + for i in range(10): + k = torch.randn(4) + mc.write(pid=i, data=(k, k), ttl_s=3600.0, importance=1.0) + mc.process_promises() + + query = torch.randn(4) + results = mc.retrieve_relevant(query, threshold=-1.0, max_items=3) + assert len(results) <= 3 + + +# ============================================================================ +# Torch: cosine_sim_torch +# ============================================================================ + +class TestCosineSim: + def test_identical_vectors(self): + a = torch.tensor([1.0, 2.0, 3.0]) + sim = cosine_sim_torch(a, a) + assert abs(float(sim) - 1.0) < 1e-5 + + def test_orthogonal_vectors(self): + a = torch.tensor([1.0, 0.0]) + b = torch.tensor([0.0, 1.0]) + sim = cosine_sim_torch(a, b) + assert abs(float(sim)) < 1e-5 + + def test_opposite_vectors(self): + a = torch.tensor([1.0, 0.0]) + b = torch.tensor([-1.0, 0.0]) + sim = cosine_sim_torch(a, b) + assert float(sim) < -0.99 + + +# ============================================================================ +# Torch: PMLAttention +# ============================================================================ + +class TestPMLAttention: + def test_forward_no_persistent(self): + mc = MemoryController(pool_size=32, store_dir=None) + attn = PMLAttention(mc) + + q = torch.randn(2, 8) + k_local = torch.randn(2, 4, 8) + v_local = torch.randn(2, 4, 8) + + out = attn(q, k_local, v_local) + assert out.shape == (8,) + + def test_forward_with_persistent(self): + mc = MemoryController(pool_size=64, store_dir=None) + attn = PMLAttention(mc, persistent_threshold=-1.0) + + # Pre-fill memory with entries similar to the query + q_vec = torch.randn(8) + for i in range(5): + k = q_vec + torch.randn(8) * 0.01 + v = torch.randn(8) + mc.write(pid=i, data=(k, v), ttl_s=3600.0, importance=1.0) + mc.process_promises() + + q = q_vec.unsqueeze(0) # [1, 8] + k_local = torch.randn(1, 3, 8) + v_local = torch.randn(1, 3, 8) + + out = attn(q, k_local, v_local) + assert out.shape == (8,) + + def test_attend_1d_query(self): + mc = MemoryController(pool_size=8, store_dir=None) + attn = PMLAttention(mc) + + q = torch.randn(4) # 1D query + k = torch.randn(3, 4) + v = torch.randn(3, 4) + out = attn._attend(q, k, v) + assert out.shape == (4,) + + def test_extract_kv_tuples(self): + mc = MemoryController(pool_size=8, store_dir=None) + attn = PMLAttention(mc) + + entries = [ + (torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0])), + (torch.tensor([5.0, 6.0]), torch.tensor([7.0, 8.0])), + ] + ks, vs = attn._extract_kv(entries) + assert ks.shape == (2, 2) + assert vs.shape == (2, 2) + + def test_extract_kv_dicts(self): + mc = MemoryController(pool_size=8, store_dir=None) + attn = PMLAttention(mc) + + entries = [ + {"k": torch.tensor([1.0]), "v": torch.tensor([2.0])}, + ] + ks, vs = attn._extract_kv(entries) + assert ks.shape == (1, 1) + + +# ============================================================================ +# Torch: PMLLTransformer +# ============================================================================ + +class TestPMLLTransformer: + def test_forward_shape(self): + model = PMLLTransformer(d_model=16, pool_size=32) + x = torch.randn(4, 16) + out = model(x) + assert out.shape == (16,) + + def test_encoder_params(self): + model = PMLLTransformer(d_model=8, pool_size=16) + params = list(model.parameters()) + assert len(params) > 0 + + def test_multiple_forward_passes(self): + model = PMLLTransformer(d_model=8, pool_size=64) + for _ in range(3): + x = torch.randn(2, 8) + out = model(x) + assert out.shape == (8,) + + +# ============================================================================ +# PMLLPolicyMixin +# ============================================================================ + +class TestPMLLPolicyMixin: + def test_mixin_write_and_process(self): + mc = MemoryController(pool_size=32, store_dir=None) + mixin = PMLLPolicyMixin(mc) + + mixin.pmll_write(pid=10, data="test_data", ttl_s=60.0, importance=0.7) + mixin.pmll_process() + + slot = mc.backend.phi(10, 32) + assert mc.read_slot(slot) == "test_data" + + def test_mixin_write_kv(self): + mc = MemoryController(pool_size=32, store_dir=None) + mixin = PMLLPolicyMixin(mc) + + k = torch.randn(8) + v = torch.randn(8) + mixin.pmll_write_kv(k, v, pid=42, ttl_s=60.0) + mixin.pmll_process() + + slot = mc.backend.phi(42, 32) + entry = mc.read_slot(slot) + assert isinstance(entry, tuple) + assert torch.allclose(entry[0], k) + assert torch.allclose(entry[1], v) + + def test_mixin_write_kv_v_none(self): + mc = MemoryController(pool_size=32, store_dir=None) + mixin = PMLLPolicyMixin(mc) + + k = torch.randn(4) + mixin.pmll_write_kv(k, pid=7, ttl_s=60.0) + mixin.pmll_process() + + slot = mc.backend.phi(7, 32) + entry = mc.read_slot(slot) + assert isinstance(entry, tuple) + # v should default to k + assert torch.allclose(entry[0], entry[1]) + + def test_mixin_step_counter(self): + mc = MemoryController(pool_size=32, store_dir=None) + mixin = PMLLPolicyMixin(mc) + assert mixin._pmll_step == 0 + mixin.pmll_write_kv(torch.randn(4)) + assert mixin._pmll_step == 1 + mixin.pmll_write_kv(torch.randn(4)) + assert mixin._pmll_step == 2 + + def test_mixin_with_nn_module(self): + """Verify the mixin pattern works with nn.Module as documented.""" + mc = MemoryController(pool_size=16, store_dir=None) + + class SimplePolicy(nn.Module, PMLLPolicyMixin): + def __init__(self): + nn.Module.__init__(self) + PMLLPolicyMixin.__init__(self, pmll=mc) + self.linear = nn.Linear(4, 4) + + def forward(self, x): + out = self.linear(x) + self.pmll_write_kv(out) + self.pmll_process() + return out + + policy = SimplePolicy() + x = torch.randn(2, 4) + out = policy(x) + assert out.shape == (2, 4) + + +# ============================================================================ +# Smoke test (same as __main__ but as a proper test) +# ============================================================================ + +class TestSmokeTest: + def test_basic_smoke(self): + """Mirrors the __main__ smoke test.""" + backend = make_backend() + mc = MemoryController(pool_size=128, backend=backend, store_dir=None) + + mc.write(pid=123, data={"hello": "world"}, ttl_s=10.0, importance=0.9) + mc.process_promises() + slot = backend.phi(123, 128) + assert mc.read_slot(slot) == {"hello": "world"} + + def test_torch_smoke(self): + """Mirrors the torch section of __main__.""" + model = PMLLTransformer(d_model=32, pool_size=256) + x = torch.randn(8, 32) + y = model(x) + assert y.shape == (32,) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From c6c0cc6e8d2bb02ff93bd2a8876c5376ce1a8fef Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 26 Feb 2026 19:14:04 +0000 Subject: [PATCH 3/3] Address code review: make tests deterministic (avoid time.sleep, reuse timestamps) Co-authored-by: drQedwards <213266729+drQedwards@users.noreply.github.com> --- tests/test_pmll.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/test_pmll.py b/tests/test_pmll.py index c41687a02b..3e4de57606 100644 --- a/tests/test_pmll.py +++ b/tests/test_pmll.py @@ -188,8 +188,9 @@ def test_invalid_so_path_falls_back(self): class TestPromise: def test_not_expired(self): - p = Promise(pid=1, data="x", ttl_s=100.0, importance=0.5, created_ts=time.time()) - assert not p.expired(time.time()) + now = time.time() + p = Promise(pid=1, data="x", ttl_s=100.0, importance=0.5, created_ts=now) + assert not p.expired(now) def test_expired(self): p = Promise(pid=1, data="x", ttl_s=1.0, importance=0.5, created_ts=time.time() - 2.0) @@ -224,8 +225,11 @@ def test_utilization_increases(self): def test_expired_promises_not_stored(self): mc = MemoryController(pool_size=64, store_dir=None) + # Use a very short TTL; the promise will have expired by the time + # process_promises is called since created_ts < now - ttl_s. mc.write(pid=99, data="old", ttl_s=0.0, importance=0.5) - time.sleep(0.01) + # Manually age the promise so it's guaranteed expired + mc.promises[0].created_ts -= 1.0 mc.process_promises() slot = mc.backend.phi(99, 64)