From d4c2c895d9d935cc3328d8d593c0e27b6035fb03 Mon Sep 17 00:00:00 2001 From: marcorudolphflex Date: Tue, 28 Oct 2025 11:16:23 +0100 Subject: [PATCH] feat(tidy3d): FXC-3294-add-opt-in-local-cache-for-simulation-results --- CHANGELOG.md | 1 + docs/index.rst | 13 +- tests/config/conftest.py | 2 +- tests/test_cli/test_migrate.py | 2 +- .../test_components/autograd/test_autograd.py | 2 +- tests/test_web/test_local_cache.py | 383 +++++++++++++ tidy3d/config/loader.py | 5 +- tidy3d/config/sections.py | 66 +++ tidy3d/web/api/autograd/engine.py | 2 +- tidy3d/web/api/container.py | 236 ++++++-- tidy3d/web/api/run.py | 10 +- tidy3d/web/api/webapi.py | 247 +++++++-- tidy3d/web/cache.py | 518 ++++++++++++++++++ tidy3d/web/core/http_util.py | 2 +- 14 files changed, 1381 insertions(+), 108 deletions(-) create mode 100644 tests/test_web/test_local_cache.py create mode 100644 tidy3d/web/cache.py diff --git a/CHANGELOG.md b/CHANGELOG.md index b94c363ff5..1772bca59b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Validation for `run_only` field in component modelers to catch duplicate or invalid matrix indices early with clear error messages. - Introduced a profile-based configuration manager with TOML persistence and runtime overrides exposed via `tidy3d.config`. - Added support of `os.PathLike` objects as paths like `pathlib.Path` alongside `str` paths in all path-related functions. +- Added configurable local simulation result caching with checksum validation, eviction limits, and per-call overrides across `web.run`, `web.load`, and job workflows. ### Changed - Improved performance of antenna metrics calculation by utilizing cached wave amplitude calculations instead of recomputing wave amplitudes for each port excitation in the `TerminalComponentModelerData`. diff --git a/docs/index.rst b/docs/index.rst index 5cada50d0e..ec9f38f384 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -168,6 +168,18 @@ This will produce the following plot, which visualizes the electromagnetic field You can now postprocess simulation data using the same python session, or view the results of this simulation on our web-based `graphical user interface (GUI) `_. +.. admonition:: Caching for repeated simulations + :class: tip + + Repeated runs of the same simulation can reuse solver results by enabling the + local cache: ``td.config.local_cache.enabled = True``. You may configure the cache directory + with ``local_cache.directory``. If the number of entries (``local_cache.max_entries``) or the storage size + (``local_cache.max_size_gb``) is exceeded, cache entries are evicted by least-recently-used (LRU) order. + You can clear all stored artifacts with ``td.web.cache.clear()``. + Additionally, there is server-side caching controlled via ``td.config.web.enable_caching`` + (enabled by default). While this can avoid recomputation on the server, it still requires + upload/download of results which is why we recommend enabling the local cache. + .. `TODO: open example in colab `_ @@ -262,4 +274,3 @@ Contents - diff --git a/tests/config/conftest.py b/tests/config/conftest.py index f0c5f7b1a3..a0208e4769 100644 --- a/tests/config/conftest.py +++ b/tests/config/conftest.py @@ -42,7 +42,7 @@ def mock_config_dir(tmp_path, monkeypatch): base_dir = tmp_path / "config_home" monkeypatch.setenv("TIDY3D_BASE_DIR", str(base_dir)) - return base_dir / ".tidy3d" + return base_dir / "config" @pytest.fixture diff --git a/tests/test_cli/test_migrate.py b/tests/test_cli/test_migrate.py index f6ae598567..6d8ed03e9c 100644 --- a/tests/test_cli/test_migrate.py +++ b/tests/test_cli/test_migrate.py @@ -18,7 +18,7 @@ def temp_config_dir(monkeypatch, tmp_path) -> Path: original_base = os.environ.get("TIDY3D_BASE_DIR") monkeypatch.setenv("TIDY3D_BASE_DIR", str(tmp_path)) reload_config(profile="default") - config_dir = Path(tmp_path) / ".tidy3d" + config_dir = Path(tmp_path) / "config" config_dir.mkdir(parents=True, exist_ok=True) yield config_dir if original_base is None: diff --git a/tests/test_components/autograd/test_autograd.py b/tests/test_components/autograd/test_autograd.py index 898ff330a8..8621eaded7 100644 --- a/tests/test_components/autograd/test_autograd.py +++ b/tests/test_components/autograd/test_autograd.py @@ -661,7 +661,7 @@ def plot_sim(sim: td.Simulation, plot_eps: bool = True) -> None: # args = [("polyslab", "mode")] -def get_functions(structure_key: str, monitor_key: str) -> typing.Callable: +def get_functions(structure_key: str, monitor_key: str) -> dict[str, typing.Callable]: if structure_key == ALL_KEY: structure_keys = structure_keys_ else: diff --git a/tests/test_web/test_local_cache.py b/tests/test_web/test_local_cache.py new file mode 100644 index 0000000000..80c0ae99fb --- /dev/null +++ b/tests/test_web/test_local_cache.py @@ -0,0 +1,383 @@ +from __future__ import annotations + +import os +from pathlib import Path + +import pytest + +import tidy3d as td +from tests.test_components.autograd.test_autograd import ALL_KEY, get_functions, params0 +from tidy3d import config +from tidy3d.config import get_manager +from tidy3d.web import Job, common, run_async +from tidy3d.web.api import webapi as web +from tidy3d.web.api.container import WebContainer +from tidy3d.web.api.webapi import load_simulation_if_cached +from tidy3d.web.cache import CACHE_ARTIFACT_NAME, clear, resolve_local_cache + +common.CONNECTION_RETRY_TIME = 0.1 + +MOCK_TASK_ID = "task-xyz" +# --- Fake pipeline global maps / queue --- +TASK_TO_SIM: dict[str, td.Simulation] = {} # task_id -> Simulation +PATH_TO_SIM: dict[str, td.Simulation] = {} # artifact path -> Simulation + + +def _reset_fake_maps(): + TASK_TO_SIM.clear() + PATH_TO_SIM.clear() + + +class _FakeStubData: + def __init__(self, simulation: td.Simulation): + self.simulation = simulation + + +@pytest.fixture +def basic_simulation(): + pulse = td.GaussianPulse(freq0=200e12, fwidth=20e12) + pt_dipole = td.PointDipole(source_time=pulse, polarization="Ex") + return td.Simulation( + size=(1, 1, 1), + grid_spec=td.GridSpec.auto(wavelength=1.0), + run_time=1e-12, + sources=[pt_dipole], + ) + + +@pytest.fixture(autouse=True) +def fake_data(monkeypatch, basic_simulation): + """Patch postprocess to return stub data bound to the correct simulation.""" + calls = {"postprocess": 0} + + def _fake_postprocess(path: str, lazy: bool = False): + calls["postprocess"] += 1 + p = Path(path) + sim = PATH_TO_SIM.get(str(p)) + if sim is None: + # Try to recover task_id from file payload written by _fake_download + try: + txt = p.read_text() + if "payload:" in txt: + task_id = txt.split("payload:", 1)[1].strip() + sim = TASK_TO_SIM.get(task_id) + except Exception: + pass + if sim is None: + # Last-resort fallback (keeps tests from crashing even if mapping failed) + sim = basic_simulation + return _FakeStubData(sim) + + monkeypatch.setattr(web.Tidy3dStubData, "postprocess", staticmethod(_fake_postprocess)) + return calls + + +def _patch_run_pipeline(monkeypatch): + """Patch upload, start, monitor, and download to avoid network calls and map sims.""" + counters = {"upload": 0, "start": 0, "monitor": 0, "download": 0} + _reset_fake_maps() # isolate between tests + + def _extract_simulation(kwargs): + """Extract the first td.Simulation object from upload kwargs.""" + if "simulation" in kwargs and isinstance(kwargs["simulation"], td.Simulation): + return kwargs["simulation"] + if "simulations" in kwargs: + sims = kwargs["simulations"] + if isinstance(sims, dict): + for sim in sims.values(): + if isinstance(sim, td.Simulation): + return sim + elif isinstance(sims, (list, tuple)): + for sim in sims: + if isinstance(sim, td.Simulation): + return sim + return None + + def _fake_upload(**kwargs): + counters["upload"] += 1 + task_id = f"{MOCK_TASK_ID}{kwargs['simulation']._hash_self()}" + sim = _extract_simulation(kwargs) + if sim is not None: + TASK_TO_SIM[task_id] = sim + return task_id + + def _fake_start(task_id, **kwargs): + counters["start"] += 1 + + def _fake_monitor(task_id, verbose=True): + counters["monitor"] += 1 + + def _fake_download(*, task_id, path, **kwargs): + counters["download"] += 1 + # Ensure we have a simulation for this task id (even if upload wasn't called) + sim = TASK_TO_SIM.get(task_id) + Path(path).write_text(f"payload:{task_id}") + if sim is not None: + PATH_TO_SIM[str(Path(path))] = sim + + def _fake__check_folder(*args, **kwargs): + pass + + def _fake_status(self): + return "success" + + monkeypatch.setattr(WebContainer, "_check_folder", _fake__check_folder) + monkeypatch.setattr(web, "upload", _fake_upload) + monkeypatch.setattr(web, "start", _fake_start) + monkeypatch.setattr(web, "monitor", _fake_monitor) + monkeypatch.setattr(web, "download", _fake_download) + monkeypatch.setattr(web, "estimate_cost", lambda *args, **kwargs: 0.0) + monkeypatch.setattr(Job, "status", property(_fake_status)) + monkeypatch.setattr( + web, + "get_info", + lambda task_id, verbose=True: type( + "_Info", (), {"solverVersion": "solver-1", "taskType": "FDTD"} + )(), + ) + return counters + + +def _reset_counters(counters: dict[str, int]) -> None: + for key in counters: + counters[key] = 0 + + +def _test_run_cache_hit(monkeypatch, tmp_path, basic_simulation, fake_data): + counters = _patch_run_pipeline(monkeypatch) + out_path = tmp_path / "result.hdf5" + clear() + + data = web.run(basic_simulation, task_name="demo", path=str(out_path)) + assert isinstance(data, _FakeStubData) + assert counters == {"upload": 1, "start": 1, "monitor": 1, "download": 1} + + _reset_counters(counters) + data2 = web.run(basic_simulation, task_name="demo", path=str(out_path)) + assert isinstance(data2, _FakeStubData) + assert counters == {"upload": 0, "start": 0, "monitor": 0, "download": 0} + + +def _test_load_simulation_if_cached(monkeypatch, tmp_path, basic_simulation): + counters = _patch_run_pipeline(monkeypatch) + out_path = tmp_path / "result_load_simulation_if_cached.hdf5" + clear() + + data = web.run(basic_simulation, task_name="demo", path=str(out_path)) + assert isinstance(data, _FakeStubData) + assert counters == {"upload": 1, "start": 1, "monitor": 1, "download": 1} + + sim_data_from_cache = load_simulation_if_cached(basic_simulation) + assert sim_data_from_cache.simulation == basic_simulation + + out_path2 = tmp_path / "result_load_simulation_if_cached2.hdf5" + sim_data_from_cache_with_path = load_simulation_if_cached(basic_simulation, path=out_path2) + assert sim_data_from_cache_with_path.simulation == basic_simulation + + +def _test_run_cache_hit_async(monkeypatch, basic_simulation, tmp_path): + counters = _patch_run_pipeline(monkeypatch) + monkeypatch.setattr(config.local_cache, "max_entries", 128) + monkeypatch.setattr(config.local_cache, "max_size_gb", 10) + cache = resolve_local_cache(use_cache=True) + cache.clear() + _reset_fake_maps() + + _reset_counters(counters) + sim2 = basic_simulation.updated_copy(shutoff=1e-4) + sim3 = basic_simulation.updated_copy(shutoff=1e-3) + + data = run_async({"task1": basic_simulation, "task2": sim2}, path_dir=str(tmp_path)) + data_task1 = data["task1"] # access to store in cache + data_task2 = data["task2"] # access to store in cache + assert counters["download"] == 2 + assert isinstance(data_task1, _FakeStubData) + assert isinstance(data_task2, _FakeStubData) + assert len(cache) == 2 + + _reset_counters(counters) + run_async({"task1": basic_simulation, "task2": sim2}, path_dir=str(tmp_path)) + assert counters["download"] == 0 + assert isinstance(data_task1, _FakeStubData) + assert len(cache) == 2 + + _reset_counters(counters) + data = run_async({"task1": basic_simulation, "task3": sim3}, path_dir=str(tmp_path)) + + data_task1 = data["task1"] + data_task2 = data["task3"] # access to store in cache + + assert counters["download"] == 1 # sim3 is new + assert isinstance(data_task1, _FakeStubData) + assert isinstance(data_task2, _FakeStubData) + assert len(cache) == 3 + + +def _test_job_run_cache(monkeypatch, basic_simulation, tmp_path): + counters = _patch_run_pipeline(monkeypatch) + cache = resolve_local_cache(use_cache=True) + cache.clear() + job = Job(simulation=basic_simulation, task_name="test") + job.run() + + assert len(cache) == 1 + + _reset_counters(counters) + + job2 = Job(simulation=basic_simulation, task_name="test") + out1_path = str(tmp_path / "result.hdf5") + out2_path = str(tmp_path / "result2.hdf5") + job2.run(path=out1_path) + assert len(cache) == 1 + assert counters["download"] == 0 + + job.load(path=out2_path) + assert os.path.exists(out1_path) + assert os.path.exists(out2_path) + + +def _test_autograd_cache(monkeypatch): + counters = _patch_run_pipeline(monkeypatch) + cache = resolve_local_cache(use_cache=True) + cache.clear() + + functions = get_functions(ALL_KEY, "mode") + make_sim = functions["sim"] + sim = make_sim(params0) + web.run(sim) + assert counters["download"] == 1 + assert len(cache) == 1 + + _reset_counters(counters) + sim = make_sim(params0) + web.run(sim) + assert counters["download"] == 0 + assert len(cache) == 1 + + +def _test_load_cache_hit(monkeypatch, tmp_path, basic_simulation, fake_data): + clear() + counters = _patch_run_pipeline(monkeypatch) + out_path = tmp_path / "load.hdf5" + + cache = resolve_local_cache(use_cache=True) + + web.run(basic_simulation, task_name="demo", path=str(out_path)) + assert counters["download"] == 1 + assert len(cache) == 1 + + _reset_counters(counters) + data = web.load(None, path=str(out_path)) + assert isinstance(data, _FakeStubData) + assert counters["download"] == 0 # served from cache + assert len(cache) == 1 # still 1 item in cache + + +def _test_checksum_mismatch_triggers_refresh(monkeypatch, tmp_path, basic_simulation): + out_path = tmp_path / "checksum.hdf5" + clear() + + web.run(basic_simulation, task_name="demo", path=str(out_path)) + + cache = resolve_local_cache(use_cache=True) + metadata = cache.list()[0] + corrupted_path = cache.root / metadata["cache_key"] / CACHE_ARTIFACT_NAME + corrupted_path.write_text("corrupted") + + cache._fetch(metadata["cache_key"]) + assert len(cache) == 0 + + +def _test_cache_eviction_by_entries(monkeypatch, tmp_path_factory, basic_simulation): + monkeypatch.setattr(config.local_cache, "max_entries", 1) + cache = resolve_local_cache(use_cache=True) + cache.clear() + + file1 = tmp_path_factory.mktemp("art1") / CACHE_ARTIFACT_NAME + file1.write_text("a") + cache.store_result(_FakeStubData(basic_simulation), MOCK_TASK_ID, str(file1), "FDTD") + assert len(cache) == 1 + + sim2 = basic_simulation.updated_copy(shutoff=1e-4) + file2 = tmp_path_factory.mktemp("art2") / CACHE_ARTIFACT_NAME + file2.write_text("b") + cache.store_result(_FakeStubData(sim2), MOCK_TASK_ID, str(file2), "FDTD") + + entries = cache.list() + assert len(entries) == 1 + assert entries[0]["simulation_hash"] == sim2._hash_self() + + +def _test_cache_eviction_by_size(monkeypatch, tmp_path_factory, basic_simulation): + monkeypatch.setattr(config.local_cache, "max_size_gb", float(10_000 * 1e-9)) + cache = resolve_local_cache(use_cache=True) + cache.clear() + + file1 = tmp_path_factory.mktemp("art1") / CACHE_ARTIFACT_NAME + file1.write_text("a" * 8_000) + cache.store_result(_FakeStubData(basic_simulation), MOCK_TASK_ID, str(file1), "FDTD") + assert len(cache) == 1 + + sim2 = basic_simulation.updated_copy(shutoff=1e-4) + file2 = tmp_path_factory.mktemp("art2") / CACHE_ARTIFACT_NAME + file2.write_text("b" * 8_000) + cache.store_result(_FakeStubData(sim2), MOCK_TASK_ID, str(file2), "FDTD") + + entries = cache.list() + assert len(cache) == 1 + assert entries[0]["simulation_hash"] == sim2._hash_self() + + +def _test_configure_cache_roundtrip(monkeypatch, tmp_path): + monkeypatch.setattr(config.local_cache, "enabled", True) + monkeypatch.setattr(config.local_cache, "directory", tmp_path) + monkeypatch.setattr(config.local_cache, "max_size_gb", 1.23) + monkeypatch.setattr(config.local_cache, "max_entries", 5) + + local_cache = resolve_local_cache() + assert local_cache is not None + assert local_cache._root == tmp_path + assert local_cache.max_size_gb == 1.23 + assert local_cache.max_entries == 5 + + +def _test_env_var_overrides(monkeypatch, tmp_path): + cache_dir = tmp_path / "cache" + monkeypatch.setenv("TIDY3D_LOCAL_CACHE__ENABLED", "true") + monkeypatch.setenv("TIDY3D_LOCAL_CACHE__DIRECTORY", str(cache_dir)) + monkeypatch.setenv("TIDY3D_LOCAL_CACHE__MAX_SIZE_GB", "0.5") + monkeypatch.setenv("TIDY3D_LOCAL_CACHE__MAX_ENTRIES", "7") + manager = get_manager() + manager._reload() + + cache = resolve_local_cache() + assert cache is not None + assert cache._root == cache_dir + assert cache.max_size_gb == 0.5 + assert cache.max_entries == 7 + + monkeypatch.delenv("TIDY3D_LOCAL_CACHE__ENABLED", raising=False) + monkeypatch.delenv("TIDY3D_LOCAL_CACHE__DIRECTORY", raising=False) + monkeypatch.delenv("TIDY3D_LOCAL_CACHE__MAX_SIZE_GB", raising=False) + monkeypatch.delenv("TIDY3D_LOCAL_CACHE__MAX_ENTRIES", raising=False) + manager = get_manager() + manager._reload() + + +def test_cache_sequential(monkeypatch, tmp_path, tmp_path_factory, basic_simulation, fake_data): + """Run all critical cache tests in sequence to ensure stability.""" + monkeypatch.setattr(config.local_cache, "enabled", True) + + # this at first as runtime changes overrides env + _test_env_var_overrides(monkeypatch, tmp_path) + + _test_load_simulation_if_cached(monkeypatch, tmp_path, basic_simulation) + _test_run_cache_hit(monkeypatch, tmp_path, basic_simulation, fake_data) + _test_load_cache_hit(monkeypatch, tmp_path, basic_simulation, fake_data) + _test_checksum_mismatch_triggers_refresh(monkeypatch, tmp_path, basic_simulation) + _test_cache_eviction_by_entries(monkeypatch, tmp_path_factory, basic_simulation) + _test_cache_eviction_by_size(monkeypatch, tmp_path_factory, basic_simulation) + _test_run_cache_hit_async(monkeypatch, basic_simulation, tmp_path) + _test_job_run_cache(monkeypatch, basic_simulation, tmp_path) + _test_autograd_cache(monkeypatch) + _test_configure_cache_roundtrip(monkeypatch, tmp_path) diff --git a/tidy3d/config/loader.py b/tidy3d/config/loader.py index ee1c359645..53046cb5ef 100644 --- a/tidy3d/config/loader.py +++ b/tidy3d/config/loader.py @@ -247,7 +247,8 @@ def resolve_config_directory() -> Path: base_override = os.getenv("TIDY3D_BASE_DIR") if base_override: - path = Path(base_override).expanduser().resolve() / ".tidy3d" + base_path = Path(base_override).expanduser().resolve() + path = base_path / "config" if _is_writable(path.parent): return path log.warning( @@ -281,7 +282,7 @@ def _xdg_config_home() -> Path: def _temporary_config_dir() -> Path: base = Path(tempfile.gettempdir()) / "tidy3d" base.mkdir(mode=0o700, exist_ok=True) - return base / ".tidy3d" + return base / "config" def _is_writable(path: Path) -> bool: diff --git a/tidy3d/config/sections.py b/tidy3d/config/sections.py index 62e5b8a473..fd47116e4e 100644 --- a/tidy3d/config/sections.py +++ b/tidy3d/config/sections.py @@ -2,7 +2,9 @@ from __future__ import annotations +import os import ssl +from os import PathLike from pathlib import Path from typing import Any, Literal, Optional from urllib.parse import urlparse @@ -11,7 +13,10 @@ from pydantic import ( BaseModel, ConfigDict, + DirectoryPath, Field, + NonNegativeFloat, + NonNegativeInt, PositiveInt, SecretStr, field_serializer, @@ -381,6 +386,67 @@ def apply_web(config: WebConfig) -> None: manager.apply_web_env(dict(config.env_vars)) +def _default_cache_directory() -> Path: + """Determine the default on-disk cache directory respecting platform conventions.""" + + base_override = os.getenv("TIDY3D_BASE_DIR") + if base_override: + base = Path(base_override).expanduser().resolve() + return (base / "cache" / "simulations").resolve() + else: + xdg_cache = os.getenv("XDG_CACHE_HOME") + if xdg_cache: + base = Path(xdg_cache).expanduser().resolve() + else: + base = Path.home() / ".cache" + return (base / "tidy3d" / "simulations").resolve() + + +@register_section("local_cache") +class LocalCacheConfig(ConfigSection): + """Settings controlling the optional local simulation cache.""" + + enabled: bool = Field( + False, + title="Enable cache", + description="Enable or disable the local simulation cache.", + json_schema_extra={"persist": True}, + ) + + directory: DirectoryPath = Field( + default_factory=_default_cache_directory, + title="Cache directory", + description="Directory where cached artifacts are stored.", + json_schema_extra={"persist": True}, + ) + + max_size_gb: NonNegativeFloat = Field( + 10.0, + title="Maximum cache size (GB)", + description="Maximum cache size in gigabytes. Set to 0 for no size limit.", + json_schema_extra={"persist": True}, + ) + + max_entries: NonNegativeInt = Field( + 0, + title="Maximum cache entries", + description="Maximum number of cache entries. Set to 0 for no limit.", + json_schema_extra={"persist": True}, + ) + + @field_validator("directory", mode="before") + def _ensure_directory_exists(cls, v: PathLike) -> Path: + """Expand ~, resolve path, and create directory if missing before DirectoryPath validation.""" + p = Path(v).expanduser().resolve() + p.mkdir(parents=True, exist_ok=True) + return p + + @field_serializer("directory") + def _serialize_directory(self, value: Path) -> str: + """Persist directory as strings.""" + return str(value) + + @register_section("plugins") class PluginsContainer(ConfigSection): """Container that holds plugin-specific configuration sections.""" diff --git a/tidy3d/web/api/autograd/engine.py b/tidy3d/web/api/autograd/engine.py index 23e97ef2fc..4838312d85 100644 --- a/tidy3d/web/api/autograd/engine.py +++ b/tidy3d/web/api/autograd/engine.py @@ -10,7 +10,7 @@ def parse_run_kwargs(**run_kwargs): """Parse the ``run_kwargs`` to extract what should be passed to the ``Job``/``Batch`` init.""" - job_fields = [*list(Job._upload_fields), "solver_version", "pay_type", "lazy"] + job_fields = [*list(Job._upload_fields), "solver_version", "pay_type", "lazy", "use_cache"] job_init_kwargs = {k: v for k, v in run_kwargs.items() if k in job_fields} return job_init_kwargs diff --git a/tidy3d/web/api/container.py b/tidy3d/web/api/container.py index ff37886761..e9c27c6fd9 100644 --- a/tidy3d/web/api/container.py +++ b/tidy3d/web/api/container.py @@ -2,8 +2,13 @@ from __future__ import annotations +import atexit import concurrent +import os +import shutil +import tempfile import time +import uuid from abc import ABC from collections.abc import Mapping from concurrent.futures import ThreadPoolExecutor @@ -12,6 +17,7 @@ from typing import Literal, Optional, Union import pydantic.v1 as pd +from pydantic.v1 import PrivateAttr from rich.progress import BarColumn, Progress, TaskProgressColumn, TextColumn, TimeElapsedColumn from tidy3d.components.base import Tidy3dBaseModel, cached_property @@ -32,6 +38,7 @@ STATE_PROGRESS_PERCENTAGE, ) from tidy3d.web.api.tidy3d_stub import Tidy3dStub +from tidy3d.web.api.webapi import restore_simulation_if_cached from tidy3d.web.core.constants import TaskId, TaskName from tidy3d.web.core.task_core import Folder from tidy3d.web.core.task_info import RunInfo, TaskInfo @@ -241,6 +248,29 @@ class Job(WebContainer): "reduce_simulation", ) + _stash_path: Optional[str] = PrivateAttr(default=None) + + def _stash_path_for_job(self) -> str: + """Stash file which is a temporary location for the cached-restored file.""" + stash_dir = Path(tempfile.gettempdir()) / "tidy3d_stash" + stash_dir.mkdir(parents=True, exist_ok=True) + return str(Path(stash_dir / f"{self._cached_task_id}.hdf5")) + + def _materialize_from_stash(self, dst_path: os.PathLike) -> None: + """Atomic copy from stash to requested path.""" + tmp = str(dst_path) + ".part" + shutil.copy2(self._stash_path, tmp) + os.replace(tmp, dst_path) + + def clear_stash(self) -> None: + """Delete this job's stash file only.""" + if self._stash_path: + try: + if os.path.exists(self._stash_path): + os.remove(self._stash_path) + finally: + self._stash_path = None + def to_file(self, fname: PathLike) -> None: """Exports :class:`Tidy3dBaseModel` instance to .yaml, .json, or .hdf5 file @@ -258,7 +288,9 @@ def to_file(self, fname: PathLike) -> None: super(Job, self).to_file(fname=fname) # noqa: UP008 def run( - self, path: PathLike = DEFAULT_DATA_PATH, priority: Optional[int] = None + self, + path: PathLike = DEFAULT_DATA_PATH, + priority: Optional[int] = None, ) -> WorkflowDataType: """Run :class:`Job` all the way through and return data. @@ -274,17 +306,50 @@ def run( :class:`WorkflowDataType` Object containing simulation results. """ - self.upload() - if priority is None: - self.start() - else: - self.start(priority=priority) - self.monitor() - return self.load(path=path) + self._check_path_dir(path=path) + + loaded_from_cache = self.load_if_cached + if not loaded_from_cache: + self.upload() + if priority is None: + self.start() + else: + self.start(priority=priority) + self.monitor() + data = self.load(path=path) + + return data + + @cached_property + def load_if_cached(self) -> bool: + """Checks if results are cached and (if yes) restores them into our shared stash file.""" + # use temporary path as final destination is unknown + stash_path = self._stash_path_for_job() + + restored = restore_simulation_if_cached( + simulation=self.simulation, + path=stash_path, + reduce_simulation=self.reduce_simulation, + verbose=getattr(self, "verbose", True), + ) + + if restored is None: + return False + + self._stash_path = stash_path + atexit.register(self.clear_stash) + return True + + @cached_property + def _cached_task_id(self) -> TaskId: + """The task ID for jobs which are loaded from cache.""" + return "cached_" + self.task_name + "_" + str(uuid.uuid4()) @cached_property def task_id(self) -> TaskId: """The task ID for this ``Job``. Uploads the ``Job`` if it hasn't already been uploaded.""" + if self.load_if_cached: + return self._cached_task_id if self.task_id_cached: return self.task_id_cached self._check_folder(self.folder_name) @@ -298,7 +363,9 @@ def _upload(self) -> TaskId: return task_id def upload(self) -> None: - """Upload this ``Job``.""" + """Upload this ``Job`` if not already got cached results.""" + if self.load_if_cached: + return _ = self.task_id def get_info(self) -> TaskInfo: @@ -314,6 +381,8 @@ def get_info(self) -> TaskInfo: @property def status(self): """Return current status of :class:`Job`.""" + if self.load_if_cached: + return "success" if web._is_modeler_batch(self.task_id): detail = self.get_info() status = detail.totalStatus.value @@ -346,13 +415,16 @@ def start(self, priority: Optional[int] = None) -> None: Note ---- To monitor progress of the :class:`Job`, call :meth:`Job.monitor` after started. + Function has no effect if cache is enabled and data was found in cache. """ - web.start( - self.task_id, - solver_version=self.solver_version, - pay_type=self.pay_type, - priority=priority, - ) + loaded = self.load_if_cached + if not loaded: + web.start( + self.task_id, + solver_version=self.solver_version, + pay_type=self.pay_type, + priority=priority, + ) def get_run_info(self) -> RunInfo: """Return information about the running :class:`Job`. @@ -372,6 +444,8 @@ def monitor(self) -> None: To load the output of completed simulation into :class:`.SimulationData` objects, call :meth:`Job.load`. """ + if self.load_if_cached: + return web.monitor(self.task_id, verbose=self.verbose) def download(self, path: PathLike = DEFAULT_DATA_PATH) -> None: @@ -386,6 +460,9 @@ def download(self, path: PathLike = DEFAULT_DATA_PATH) -> None: ---- To load the data after download, use :meth:`Job.load`. """ + if self.load_if_cached: + self._materialize_from_stash(path) + return self._check_path_dir(path=path) web.download(task_id=self.task_id, path=path, verbose=self.verbose) @@ -403,8 +480,11 @@ def load(self, path: PathLike = DEFAULT_DATA_PATH) -> WorkflowDataType: Object containing simulation results. """ self._check_path_dir(path=path) + if self.load_if_cached: + self._materialize_from_stash(path) + data = web.load( - task_id=self.task_id, + task_id=None if self.load_if_cached else self.task_id, path=path, verbose=self.verbose, lazy=self.lazy, @@ -450,6 +530,8 @@ def estimate_cost(self, verbose: bool = True) -> float: Cost is calculated assuming the simulation runs for the full ``run_time``. If early shut-off is triggered, the cost is adjusted proportionately. """ + if self.load_if_cached: + return 0.0 return web.estimate_cost(self.task_id, verbose=verbose, solver_version=self.solver_version) def postprocess_start(self, worker_group: Optional[str] = None, verbose: bool = True) -> None: @@ -547,6 +629,11 @@ class BatchData(Tidy3dBaseModel, Mapping): verbose: bool = pd.Field( True, title="Verbose", description="Whether to print info messages and progressbars." ) + cached_tasks: Optional[dict[TaskName, bool]] = pd.Field( + None, + title="Cached Tasks", + description="Whether the data of a task came from the cache.", + ) lazy: bool = pd.Field( False, @@ -554,13 +641,27 @@ class BatchData(Tidy3dBaseModel, Mapping): description="Whether to load the actual data (lazy=False) or return a proxy that loads the data when accessed (lazy=True).", ) + is_downloaded: Optional[bool] = pd.Field( + False, + title="Is Downloaded", + description="Whether the simulation data was downloaded before.", + ) + def load_sim_data(self, task_name: str) -> WorkflowDataType: """Load a simulation data object from file by task name.""" task_data_path = Path(self.task_paths[task_name]) task_id = self.task_ids[task_name] - web.get_info(task_id) + from_cache = self.cached_tasks[task_name] if self.cached_tasks else False + if not from_cache: + web.get_info(task_id) - return web.load(task_id=task_id, path=task_data_path, verbose=False, lazy=self.lazy) + return web.load( + task_id=None if from_cache else task_id, + path=task_data_path, + verbose=self.verbose, + replace_existing=not (from_cache or self.is_downloaded), + lazy=self.lazy, + ) def __getitem__(self, task_name: TaskName) -> WorkflowDataType: """Get the simulation data object for a given ``task_name``.""" @@ -744,14 +845,20 @@ def run( rather it iterates over the task names and loads the corresponding data from file one by one. If no file exists for that task, it downloads it. """ + loaded = [job.load_if_cached for job in self.jobs.values()] self._check_path_dir(path_dir) - self.upload() - self.to_file(self._batch_path(path_dir=path_dir)) - if priority is None: - self.start() + if not all(loaded): + self.upload() + self.to_file(self._batch_path(path_dir=path_dir)) + if priority is None: + self.start() + else: + self.start(priority=priority) + self.monitor(path_dir=path_dir, download_on_success=True) else: - self.start(priority=priority) - self.monitor(path_dir=path_dir, download_on_success=True) + console = get_logging_console() + console.log("Found all simulations in cache.") + self.download(path_dir=path_dir) # moves cache files return self.load(path_dir=path_dir, skip_download=True) @cached_property @@ -1191,39 +1298,50 @@ def download( self._check_path_dir(path_dir=path_dir) self.to_file(self._batch_path(path_dir=path_dir)) - num_existing = 0 - for _, job in self.jobs.items(): - job_path = self._job_data_path(task_id=job.task_id, path_dir=path_dir) - if job_path.exists(): - num_existing += 1 - if num_existing > 0: - files_plural = "files have" if num_existing > 1 else "file has" - log.warning( - f"{num_existing} {files_plural} already been downloaded " - f"and will be skipped. To forcibly overwrite existing files, invoke " - "the load or download function with `replace_existing=True`.", - log_once=True, + # Warn about already-existing files if we won't overwrite them + if not replace_existing: + num_existing = sum( + os.path.exists(self._job_data_path(task_id=job.task_id, path_dir=path_dir)) + for job in self.jobs.values() ) + if num_existing > 0: + files_plural = "files have" if num_existing > 1 else "file has" + log.warning( + f"{num_existing} {files_plural} already been downloaded " + f"and will be skipped. To forcibly overwrite existing files, invoke " + "the load or download function with `replace_existing=True`.", + log_once=True, + ) - with ThreadPoolExecutor(max_workers=self.num_workers) as executor: - fns = [] - for task_name, job in self.jobs.items(): - job_path = self._job_data_path(task_id=job.task_id, path_dir=path_dir) - if job_path.exists(): - if replace_existing: - log.info(f"File '{job_path}' already exists. Overwriting.") - else: - log.info(f"File '{job_path}' already exists. Skipping.") - continue - if "error" in job.status: - log.warning(f"Not downloading '{task_name}' as the task errored.") + fns = [] + + for task_name, job in self.jobs.items(): + if "error" in job.status: + log.warning(f"Not downloading '{task_name}' as the task errored.") + continue + + job_path = self._job_data_path(task_id=job.task_id, path_dir=path_dir) + + if job_path.exists(): + if replace_existing: + log.info(f"File '{job_path}' already exists. Overwriting.") + else: + log.info(f"File '{job_path}' already exists. Skipping.") continue - def fn(job=job, job_path=job_path) -> None: - return job.download(path=job_path) + if job.load_if_cached: + job._materialize_from_stash(job_path) + continue + + def fn(job=job, job_path=job_path) -> None: + job.download(path=job_path) + + fns.append(fn) - fns.append(fn) + if not fns: + return + with ThreadPoolExecutor(max_workers=self.num_workers) as executor: futures = [executor.submit(fn) for fn in fns] if self.verbose: @@ -1241,6 +1359,10 @@ def fn(job=job, job_path=job_path) -> None: for _ in concurrent.futures.as_completed(futures): completed += 1 progress.update(pbar, completed=completed) + else: + # Still ensure completion if verbose is off + for _ in concurrent.futures.as_completed(futures): + pass def load( self, @@ -1283,8 +1405,18 @@ def load( task_paths[task_name] = str(self._job_data_path(task_id=job.task_id, path_dir=path_dir)) task_ids[task_name] = self.jobs[task_name].task_id + loaded = {task_name: job.load_if_cached for task_name, job in self.jobs.items()} + + if not skip_download: + self.download(path_dir=path_dir, replace_existing=replace_existing) + data = BatchData( - task_paths=task_paths, task_ids=task_ids, verbose=self.verbose, lazy=self.lazy + task_paths=task_paths, + task_ids=task_ids, + verbose=self.verbose, + cached_tasks=loaded, + lazy=self.lazy, + is_downloaded=True, ) for task_name, job in self.jobs.items(): diff --git a/tidy3d/web/api/run.py b/tidy3d/web/api/run.py index a4b3cb7666..1713ec0615 100644 --- a/tidy3d/web/api/run.py +++ b/tidy3d/web/api/run.py @@ -6,6 +6,7 @@ from tidy3d.components.types.workflow import WorkflowDataType, WorkflowType from tidy3d.config import config +from tidy3d.log import get_logging_console from tidy3d.web.api.autograd.autograd import run as run_autograd from tidy3d.web.api.autograd.autograd import run_async from tidy3d.web.api.container import DEFAULT_DATA_DIR, DEFAULT_DATA_PATH @@ -232,7 +233,14 @@ def run( key_prefix = "" if len(h2sim) == 1: - path = path if path is not None else Path(DEFAULT_DATA_PATH) + if path is not None: + # user may submit the same simulation multiple times and not specify an extension, but dir path + if not Path(path).suffixes: + path = f"{path}.hdf5" + console = get_logging_console() + console.log(f"Changed output path to {path}") + else: + path = DEFAULT_DATA_PATH h, sim = next(iter(h2sim.items())) data = { h: run_autograd( diff --git a/tidy3d/web/api/webapi.py b/tidy3d/web/api/webapi.py index 64911251e8..021f4ed806 100644 --- a/tidy3d/web/api/webapi.py +++ b/tidy3d/web/api/webapi.py @@ -26,6 +26,7 @@ POST_VALIDATE_STATES, STATE_PROGRESS_PERCENTAGE, ) +from tidy3d.web.cache import CacheEntry, resolve_local_cache from tidy3d.web.core.account import Account from tidy3d.web.core.constants import ( CM_DATA_HDF5_GZ, @@ -317,6 +318,119 @@ def _task_dict_to_url_bullet_list(data_dict: dict) -> str: return "\n".join([f"- {key}: '{value}'" for key, value in data_dict.items()]) +def _copy_simulation_data_from_cache_entry(entry: CacheEntry, path: PathLike) -> bool: + """ + Copy cached simulation data from a cache entry to a specified path. + + Parameters + ---------- + entry : CacheEntry + The cache entry containing simulation data and metadata. + path : PathLike + The target directory or file path where the cached data should be materialized. + + Returns + ------- + bool + True if the cached simulation data was successfully copied, False otherwise. + """ + if entry is not None: + try: + entry.materialize(Path(path)) + return True + except Exception: + return False + return False + + +def restore_simulation_if_cached( + simulation: WorkflowType, + path: Optional[PathLike] = None, + reduce_simulation: Literal["auto", True, False] = "auto", + verbose: bool = True, +) -> Optional[PathLike]: + """ + Attempt to restore simulation data from a local cache entry, if available. + + Parameters + ---------- + simulation : WorkflowType + The simulation or workflow object for which cached data may exist. + path : Optional[PathLike] = None + Optional path where the cached data should be copied. If not provided, + the path from the cache entry will be used. + reduce_simulation : Literal["auto", True, False] = "auto" + Whether to reduce the simulation for cache lookup. If "auto", reduction is applied + only when applicable (e.g., for mode solvers). + verbose : bool = True + If True, logs a message including a link to the cached task in the web UI. + + Returns + ------- + Optional[PathLike] + The path to the restored simulation data if found in cache, otherwise None. If no path is specified, the cache entry path is returned, otherwise the given path is returned. + """ + simulation_cache = resolve_local_cache() + retrieved_simulation_path = None + if simulation_cache is not None: + sim_for_cache = simulation + if isinstance(simulation, (ModeSolver, ModeSimulation)): + sim_for_cache = get_reduced_simulation(simulation, reduce_simulation) + entry = simulation_cache.try_fetch(simulation=sim_for_cache, verbose=verbose) + if entry is not None: + if path is not None: + copied = _copy_simulation_data_from_cache_entry(entry, path) + if copied: + retrieved_simulation_path = path + else: + retrieved_simulation_path = entry.artifact_path + cached_task_id = entry.metadata.get("task_id") + cached_workflow_type = entry.metadata.get("workflow_type") + if cached_task_id is not None and cached_workflow_type is not None and verbose: + console = get_logging_console() if verbose else None + url, _ = _get_task_urls(cached_workflow_type, simulation, cached_task_id) + console.log( + f"Loaded simulation from local cache.\nView cached task using web UI at [link={url}]'{url}'[/link]." + ) + return retrieved_simulation_path + + +def load_simulation_if_cached( + simulation: WorkflowType, + path: Optional[PathLike] = None, + reduce_simulation: Literal["auto", True, False] = "auto", +) -> Optional[WorkflowDataType]: + """ + Load simulation results directly from the local cache, if available. + + Parameters + ---------- + simulation : WorkflowType + The simulation or workflow object to check for cached results. + path : Optional[PathLike] = None + Optional path to which cached data should be restored before loading. + reduce_simulation : Literal["auto", True, False] = "auto" + Whether to use a reduced simulation when checking the cache. If "auto", + reduction is applied automatically for mode solvers. + + Returns + ------- + Optional[WorkflowDataType] + The loaded simulation data if found in cache, otherwise None. + """ + restored_path = restore_simulation_if_cached(simulation, path, reduce_simulation) + if restored_path is not None: + data = load( + task_id=None, + path=str(restored_path), + ) + if isinstance(simulation, ModeSolver): + simulation._patch_data(data=data) + return data + else: + return None + + @wait_for_connection def run( simulation: WorkflowType, @@ -420,27 +534,38 @@ def run( :meth:`tidy3d.web.api.container.Batch.monitor` Monitor progress of each of the running tasks. """ - task_id = upload( + restored_path = restore_simulation_if_cached( simulation=simulation, - task_name=task_name, - folder_name=folder_name, - callback_url=callback_url, - verbose=verbose, - progress_callback=progress_callback_upload, - simulation_type=simulation_type, - parent_tasks=parent_tasks, - solver_version=solver_version, + path=path, reduce_simulation=reduce_simulation, - ) - start( - task_id, verbose=verbose, - solver_version=solver_version, - worker_group=worker_group, - pay_type=pay_type, - priority=priority, ) - monitor(task_id, verbose=verbose) + + if not restored_path: + task_id = upload( + simulation=simulation, + task_name=task_name, + folder_name=folder_name, + callback_url=callback_url, + verbose=verbose, + progress_callback=progress_callback_upload, + simulation_type=simulation_type, + parent_tasks=parent_tasks, + solver_version=solver_version, + reduce_simulation=reduce_simulation, + ) + start( + task_id, + verbose=verbose, + solver_version=solver_version, + worker_group=worker_group, + pay_type=pay_type, + priority=priority, + ) + monitor(task_id, verbose=verbose) + else: + task_id = None + data = load( task_id=task_id, path=path, @@ -448,11 +573,34 @@ def run( progress_callback=progress_callback_download, lazy=lazy, ) + if isinstance(simulation, ModeSolver): simulation._patch_data(data=data) return data +def _get_task_urls( + task_type: str, + simulation: WorkflowType, + resource_id: str, + folder_id: Optional[str] = None, + group_id: Optional[str] = None, +) -> tuple[str, Optional[str]]: + """Log task and folder links to the web UI.""" + if (task_type in ["RF", "COMPONENT_MODELER", "TERMINAL_COMPONENT_MODELER"]) and isinstance( + simulation, TerminalComponentModeler + ): + url = _get_url_rf(group_id or resource_id) + else: + url = _get_url(resource_id) + + if folder_id is not None: + folder_url = _get_folder_url(folder_id) + else: + folder_url = None + return url, folder_url + + @wait_for_connection def upload( simulation: WorkflowType, @@ -573,16 +721,11 @@ def upload( f"Cost of {solver_name} simulations is subject to change in the future." ) if task_type in GUI_SUPPORTED_TASK_TYPES: - if (task_type == "RF") and (isinstance(simulation, TerminalComponentModeler)): - url = _get_url_rf(group_id or resource_id) - folder_url = _get_folder_url(task.folder_id) - console.log(f"View task using web UI at [link={url}]'{url}'[/link].") - console.log(f"Task folder: [link={folder_url}]'{task.folder_name}'[/link].") - else: - url = _get_url(resource_id) - folder_url = _get_folder_url(task.folder_id) - console.log(f"View task using web UI at [link={url}]'{url}'[/link].") - console.log(f"Task folder: [link={folder_url}]'{task.folder_name}'[/link].") + url, folder_url = _get_task_urls( + task_type, simulation, resource_id, task.folder_id, group_id + ) + console.log(f"View task using web UI at [link={url}]'{url}'[/link].") + console.log(f"Task folder: [link={folder_url}]'{task.folder_name}'[/link].") remote_sim_file = SIM_FILE_HDF5_GZ if task_type == "MODE_SOLVER": @@ -668,7 +811,7 @@ def get_info(task_id: TaskId, verbose: bool = True) -> TaskInfo | BatchDetail: ---------- task_id : TaskId The unique identifier for the task or batch. - verbose : bool, optional + verbose : bool = True If ``True`` (default), display progress bars and status updates. If ``False``, the function runs silently. @@ -1197,7 +1340,7 @@ def download_log( @wait_for_connection def load( - task_id: TaskId, + task_id: Optional[TaskId], path: PathLike = "simulation_data.hdf5", replace_existing: bool = True, verbose: bool = True, @@ -1222,8 +1365,8 @@ def load( Parameters ---------- - task_id : str - Unique identifier of task on server. Returned by :meth:`upload`. + task_id : Optional[str] = None + Unique identifier of task on server. Returned by :meth:`upload`. If None, file is assumed to exist already from cache. path : PathLike Download path to .hdf5 data file (including filename). replace_existing : bool = True @@ -1241,31 +1384,41 @@ def load( Union[:class:`.SimulationData`, :class:`.HeatSimulationData`, :class:`.EMESimulationData`] Object containing simulation data. """ - # For component modeler batches, default to a clearer filename if the default was used. path = Path(path) + # For component modeler batches, default to a clearer filename if the default was used. + if ( + task_id + and _is_modeler_batch(task_id) + and path.name in {"simulation_data.hdf5", "simulation_data.hdf5.gz"} + ): + path = path.with_name(path.name.replace("simulation", "cm")) - if _is_modeler_batch(task_id): - if path.name == "simulation_data.hdf5": - path = path.with_name("cm_data.hdf5") - elif path.name == "simulation_data.hdf5.gz": - path = path.with_name("cm_data.hdf5.gz") - - if not path.exists() or replace_existing: - download( - task_id=task_id, - path=path, - verbose=verbose, - progress_callback=progress_callback, - ) + if task_id is None: + if not path.exists(): + raise FileNotFoundError("Cached file not found.") + elif not path.exists() or replace_existing: + download(task_id=task_id, path=path, verbose=verbose, progress_callback=progress_callback) - if verbose: + if verbose and task_id is not None: console = get_logging_console() if _is_modeler_batch(task_id): console.log(f"loading component modeler data from {path}") else: - console.log(f"loading simulation from {path}") + console.log(f"Loading simulation from {path}") stub_data = Tidy3dStubData.postprocess(path, lazy=lazy) + + simulation_cache = resolve_local_cache() + if simulation_cache is not None and task_id is not None: + info = get_info(task_id, verbose=False) + workflow_type = getattr(info, "taskType", None) + simulation_cache.store_result( + stub_data=stub_data, + task_id=task_id, + path=path, + workflow_type=workflow_type, + ) + return stub_data diff --git a/tidy3d/web/cache.py b/tidy3d/web/cache.py new file mode 100644 index 0000000000..20e9a0a712 --- /dev/null +++ b/tidy3d/web/cache.py @@ -0,0 +1,518 @@ +"""Local simulation cache manager.""" + +from __future__ import annotations + +import hashlib +import json +import os +import shutil +import tempfile +import threading +from collections.abc import Iterable +from dataclasses import dataclass +from datetime import datetime, timezone +from enum import Enum +from pathlib import Path +from typing import Any, Optional + +from tidy3d import config +from tidy3d.components.types.workflow import WorkflowDataType, WorkflowType +from tidy3d.log import log +from tidy3d.web.api.tidy3d_stub import Tidy3dStub +from tidy3d.web.core.constants import TaskId +from tidy3d.web.core.http_util import get_version as _get_protocol_version + +CACHE_ARTIFACT_NAME = "simulation_data.hdf5" +CACHE_METADATA_NAME = "metadata.json" + +TMP_PREFIX = "tidy3d-cache-" +TMP_BATCH_PREFIX = "tmp_batch" + +_CACHE: Optional[LocalCache] = None + + +@dataclass +class CacheEntry: + """Internal representation of a cache entry.""" + + key: str + root: Path + metadata: dict[str, Any] + + @property + def path(self) -> Path: + return self.root / self.key + + @property + def artifact_path(self) -> Path: + return self.path / CACHE_ARTIFACT_NAME + + @property + def metadata_path(self) -> Path: + return self.path / CACHE_METADATA_NAME + + def exists(self) -> bool: + return self.path.exists() and self.artifact_path.exists() and self.metadata_path.exists() + + def verify(self) -> bool: + if not self.exists(): + return False + checksum = self.metadata.get("checksum") + if not checksum: + return False + try: + actual_checksum, file_size = _copy_and_hash(self.artifact_path, None) + except FileNotFoundError: + return False + if checksum != actual_checksum: + log.warning( + "Simulation cache checksum mismatch for key '%s'. Removing stale entry.", self.key + ) + return False + if int(self.metadata.get("file_size", file_size)) != file_size: + self.metadata["file_size"] = file_size + _write_metadata(self.metadata_path, self.metadata) + return True + + def materialize(self, target: Path) -> Path: + """Copy cached artifact to ``target`` and return the resulting path.""" + target = Path(target) + target.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(self.artifact_path, target) + return target + + +class LocalCache: + """Manages storing and retrieving cached simulation artifacts.""" + + def __init__(self, directory: os.PathLike, max_size_gb: float, max_entries: int) -> None: + self.max_size_gb = max_size_gb + self.max_entries = max_entries + self._root = Path(directory) + self._lock = threading.RLock() + + @property + def root(self) -> Path: + return self._root + + def list(self) -> list[dict[str, Any]]: + """Return metadata for all cache entries.""" + with self._lock: + return [entry.metadata for entry in self._iter_entries()] + + def clear(self, hard=False) -> None: + """Remove all cache contents.""" + with self._lock: + if self._root.exists(): + try: + shutil.rmtree(self._root) + if not hard: + self._root.mkdir(parents=True, exist_ok=True) + except (FileNotFoundError, OSError): + pass + + def _fetch(self, key: str) -> Optional[CacheEntry]: + """Retrieve an entry by key, verifying checksum.""" + with self._lock: + entry = self._load_entry(key) + if not entry or not entry.exists(): + return None + if not entry.verify(): + self._remove_entry(entry) + return None + self._touch(entry) + return entry + + def __len__(self) -> int: + """Return number of valid cache entries.""" + with self._lock: + return sum(1 for _ in self._iter_entries()) + + def _store(self, key: str, source_path: Path, metadata: dict[str, Any]) -> Optional[CacheEntry]: + """Store a new cache entry from ``source_path``. + + Parameters + ---------- + key : str + Cache key computed from simulation hash and runtime context. + source_path : Path + Location of the artifact to cache. + metadata : dict[str, Any] + Additional metadata to persist alongside artifact. + + Returns + ------- + CacheEntry + Representation of the stored cache entry. + """ + source_path = Path(source_path) + if not source_path.exists(): + raise FileNotFoundError(f"Cannot cache missing artifact: {source_path}") + os.makedirs(self._root, exist_ok=True) + tmp_dir = Path(tempfile.mkdtemp(prefix=TMP_PREFIX, dir=self._root)) + tmp_artifact = tmp_dir / CACHE_ARTIFACT_NAME + tmp_meta = tmp_dir / CACHE_METADATA_NAME + os.makedirs(tmp_dir, exist_ok=True) + + checksum, file_size = _copy_and_hash(source_path, tmp_artifact) + now_iso = _now() + metadata = dict(metadata) + metadata.setdefault("cache_key", key) + metadata.setdefault("created_at", now_iso) + metadata["last_used"] = now_iso + metadata["checksum"] = checksum + metadata["file_size"] = file_size + + _write_metadata(tmp_meta, metadata) + try: + with self._lock: + self._root.mkdir(parents=True, exist_ok=True) + self._ensure_limits(file_size) + final_dir = self._root / key + backup_dir: Optional[Path] = None + + try: + if final_dir.exists(): + backup_dir = final_dir.with_name( + f"{final_dir.name}.bak.{_timestamp_suffix()}" + ) + os.replace(final_dir, backup_dir) + # move tmp_dir into place + os.replace(tmp_dir, final_dir) + except Exception: + # restore backup if needed + if backup_dir and backup_dir.exists(): + os.replace(backup_dir, final_dir) + raise + else: + entry = CacheEntry(key=key, root=self._root, metadata=metadata) + if backup_dir and backup_dir.exists(): + shutil.rmtree(backup_dir, ignore_errors=True) + log.debug("Stored simulation cache entry '%s' (%d bytes).", key, file_size) + return entry + finally: + try: + if tmp_dir.exists(): + shutil.rmtree(tmp_dir, ignore_errors=True) + except FileNotFoundError: + pass + + def invalidate(self, key: str) -> None: + with self._lock: + entry = self._load_entry(key) + if entry: + self._remove_entry(entry) + + def _ensure_limits(self, incoming_size: int) -> None: + max_entries = self.max_entries + max_size_bytes = int(self.max_size_gb * (1024**3)) + + entries = list(self._iter_entries()) + if len(entries) >= max_entries > 0: + self._evict(entries, keep=max_entries - 1) + entries = list(self._iter_entries()) + + if max_size_bytes == 0: # no limit + return + + existing_size = sum(int(e.metadata.get("file_size", 0)) for e in entries) + allowed_size = max(max_size_bytes - incoming_size, 0) + if existing_size > allowed_size: + self._evict_by_size(entries, existing_size, allowed_size) + + def _evict(self, entries: Iterable[CacheEntry], keep: int) -> None: + sorted_entries = sorted(entries, key=lambda e: e.metadata.get("last_used", "")) + to_remove = sorted_entries[: max(0, len(sorted_entries) - keep)] + for entry in to_remove: + self._remove_entry(entry) + + def _evict_by_size( + self, entries: Iterable[CacheEntry], current_size: int, allowed_size: float + ) -> None: + if allowed_size < 0: + allowed_size = 0 + sorted_entries = sorted(entries, key=lambda e: e.metadata.get("last_used", "")) + reclaimed = 0 + for entry in sorted_entries: + if current_size - reclaimed <= allowed_size: + break + size = int(entry.metadata.get("file_size", 0)) + self._remove_entry(entry) + reclaimed += size + log.info(f"Simulation cache evicted entry '{entry.key}' to reclaim {size} bytes.") + + def _iter_entries(self) -> Iterable[CacheEntry]: + if not self._root.exists(): + return [] + entries: list[CacheEntry] = [] + for child in self._root.iterdir(): + if child.name.startswith(TMP_PREFIX) or child.name.startswith(TMP_BATCH_PREFIX): + continue + meta_path = child / CACHE_METADATA_NAME + if not meta_path.exists(): + continue + try: + metadata = json.loads(meta_path.read_text(encoding="utf-8")) + except Exception: + metadata = {} + entries.append(CacheEntry(key=child.name, root=self._root, metadata=metadata)) + return entries + + def _load_entry(self, key: str) -> Optional[CacheEntry]: + entry = CacheEntry(key=key, root=self._root, metadata={}) + if not entry.metadata_path.exists() or not entry.artifact_path.exists(): + return None + try: + metadata = json.loads(entry.metadata_path.read_text(encoding="utf-8")) + except Exception: + metadata = {} + entry.metadata = metadata + return entry + + def _touch(self, entry: CacheEntry) -> None: + entry.metadata["last_used"] = _now() + _write_metadata(entry.metadata_path, entry.metadata) + + def _remove_entry(self, entry: CacheEntry) -> None: + if entry.path.exists(): + shutil.rmtree(entry.path, ignore_errors=True) + + def try_fetch( + self, + simulation: WorkflowType, + verbose: bool = False, + ) -> Optional[CacheEntry]: + """ + Attempt to resolve and fetch a cached result entry for the given simulation context. + On miss or any cache error, returns None (the caller should proceed with upload/run). + + Notes + ----- + - Mirrors the exact cache key/context computation from `run`. + - Safe to call regardless of `use_cache` value; will no-op if cache is disabled. + """ + try: + simulation_hash = simulation._hash_self() + workflow_type = Tidy3dStub(simulation=simulation).get_type() + + versions = _get_protocol_version() + + cache_key = build_cache_key( + simulation_hash=simulation_hash, + version=versions, + ) + + entry = self._fetch(cache_key) + if not entry: + return None + + if verbose: + log.info( + f"Simulation cache hit for workflow '{workflow_type}'; using local results." + ) + + return entry + except Exception as e: + log.error("Failed to fetch cache results: " + str(e)) + + def store_result( + self, + stub_data: WorkflowDataType, + task_id: TaskId, + path: str, + workflow_type: str, + ) -> None: + """ + After we have the data (postprocess done), store it in the cache using the + canonical key (simulation hash + workflow type + environment + version). + Also records the task_id mapping for legacy lookups. + """ + try: + simulation_obj = getattr(stub_data, "simulation", None) + simulation_hash = simulation_obj._hash_self() if simulation_obj is not None else None + if not simulation_hash: + return + + version = _get_protocol_version() + + cache_key = build_cache_key( + simulation_hash=simulation_hash, + version=version, + ) + + metadata = build_entry_metadata( + simulation_hash=simulation_hash, + workflow_type=workflow_type, + task_id=task_id, + version=version, + path=Path(path), + ) + + self._store( + key=cache_key, + source_path=Path(path), + metadata=metadata, + ) + except Exception as e: + log.error(f"Could not store cache entry: {e}") + + +def _copy_and_hash( + source: Path, dest: Optional[Path], existing_hash: Optional[str] = None +) -> tuple[str, int]: + """Copy ``source`` to ``dest`` while computing SHA256 checksum. + + Parameters + ---------- + source : Path + Source file path. + dest : Path or None + Destination file path. If ``None``, no copy is performed. + existing_hash : str, optional + If provided alongside ``dest`` and ``dest`` already exists, skip copying when hashes match. + + Returns + ------- + tuple[str, int] + The hexadecimal digest and file size in bytes. + """ + source = Path(source) + if dest is not None: + dest = Path(dest) + sha256 = _Hasher() + size = 0 + with source.open("rb") as src: + if dest is None: + while chunk := src.read(1024 * 1024): + sha256.update(chunk) + size += len(chunk) + else: + dest.parent.mkdir(parents=True, exist_ok=True) + with dest.open("wb") as dst: + while chunk := src.read(1024 * 1024): + dst.write(chunk) + sha256.update(chunk) + size += len(chunk) + return sha256.hexdigest(), size + + +def _write_metadata(path: Path, metadata: dict[str, Any]) -> None: + tmp_path = path.with_suffix(".tmp") + with tmp_path.open("w", encoding="utf-8") as fh: + json.dump(metadata, fh, indent=2, sort_keys=True) + os.replace(tmp_path, path) + + +def _now() -> str: + return datetime.now(timezone.utc).isoformat() + + +def _timestamp_suffix() -> str: + return datetime.now(timezone.utc).strftime("%Y%m%d%H%M%S%f") + + +class _Hasher: + def __init__(self): + self._hasher = hashlib.sha256() + + def update(self, data: bytes) -> None: + self._hasher.update(data) + + def hexdigest(self) -> str: + return self._hasher.hexdigest() + + +def clear() -> None: + """Remove all cache entries.""" + cache = resolve_local_cache(use_cache=True) + if cache is not None: + cache.clear() + + +def _canonicalize(value: Any) -> Any: + """Convert value into a JSON-serializable object for hashing/metadata.""" + + if isinstance(value, dict): + return { + str(k): _canonicalize(v) + for k, v in sorted(value.items(), key=lambda item: str(item[0])) + } + if isinstance(value, (list, tuple)): + return [_canonicalize(v) for v in value] + if isinstance(value, set): + return sorted(_canonicalize(v) for v in value) + if isinstance(value, Enum): + return value.value + if isinstance(value, Path): + return str(value) + if isinstance(value, datetime): + return value.isoformat() + if isinstance(value, bytes): + return value.decode("utf-8", errors="ignore") + return value + + +def build_cache_key( + *, + simulation_hash: str, + version: str, +) -> str: + """Construct a deterministic cache key.""" + + payload = { + "simulation_hash": simulation_hash, + "versions": _canonicalize(version), + } + encoded = json.dumps(payload, sort_keys=True, separators=(",", ":")).encode("utf-8") + return hashlib.sha256(encoded).hexdigest() + + +def build_entry_metadata( + *, + simulation_hash: str, + workflow_type: str, + task_id: str, + version: str, + path: Path, +) -> dict[str, Any]: + """Create metadata dictionary for a cache entry.""" + + metadata: dict[str, Any] = { + "simulation_hash": simulation_hash, + "workflow_type": workflow_type, + "versions": _canonicalize(version), + "task_id": task_id, + "path": str(path), + } + return metadata + + +def resolve_local_cache(use_cache: Optional[bool] = None) -> Optional[LocalCache]: + """ + Returns LocalCache instance if enabled. + Returns None if use_cached=False or config-fetched 'enabled' is False. + Deletes old cache directory if existing. + """ + global _CACHE + + if use_cache is False or (use_cache is not True and not config.local_cache.enabled): + return None + + if _CACHE is not None and _CACHE._root != Path(config.local_cache.directory): + log.debug(f"Clearing old cache directory {_CACHE._root}") + _CACHE.clear(hard=True) + + _CACHE = LocalCache( + directory=config.local_cache.directory, + max_entries=config.local_cache.max_entries, + max_size_gb=config.local_cache.max_size_gb, + ) + + try: + return _CACHE + except Exception as err: + log.debug(f"Simulation cache unavailable: {err}") + return None + + +resolve_local_cache() diff --git a/tidy3d/web/core/http_util.py b/tidy3d/web/core/http_util.py index 54d297c6fb..ba5bdf05d9 100644 --- a/tidy3d/web/core/http_util.py +++ b/tidy3d/web/core/http_util.py @@ -37,7 +37,7 @@ class ResponseCodes(Enum): NOT_FOUND = 404 -def get_version() -> None: +def get_version() -> str: """Get the version for the current environment.""" return core_config.get_version()