Skip to content

Commit a835e87

Browse files
authored
add smi replay simulator (#19)
* add smi replay simulator * update threading in replay simulator
1 parent 24deeb6 commit a835e87

2 files changed

Lines changed: 306 additions & 0 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ all = [
7373
"pandas",
7474
"python-dotenv",
7575
"pyarrow>=14.0.1",
76+
"aiosqlite",
7677

7778
]
7879

Lines changed: 305 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,305 @@
1+
import asyncio
2+
import logging
3+
import os
4+
import aiosqlite
5+
from datetime import datetime
6+
7+
import typer
8+
import zmq
9+
import zmq.asyncio
10+
import msgpack
11+
from tiled.client import from_uri
12+
13+
from ..config import settings
14+
from ..schemas import (
15+
RawFrameEvent,
16+
SASStart,
17+
SASStop,
18+
SerializableNumpyArrayModel,
19+
)
20+
21+
"""
22+
Simulates image retrieval by reading Tiled URLs from a local SQLite database
23+
and sends the fetched images onto ZMQ.
24+
"""
25+
26+
# Get configuration from environment variables with defaults
27+
DEFAULT_DB_PATH = os.getenv("DB_PATH", "latent_vectors.db")
28+
DEFAULT_API_KEY = os.getenv("TILED_API_KEY", None)
29+
# New environment variable for selecting prod vs dev environment
30+
TILED_ENV = os.getenv("TILED_ENV", "dev").lower()
31+
32+
# Define environment-specific URLs
33+
TILED_URLS = {
34+
"dev": {
35+
"url_pattern": "http://tiled-dev.nsls2.bnl.gov/api/v1/array/full/"
36+
},
37+
"prod": {
38+
"url_pattern": "http://tiled.nsls2.bnl.gov/api/v1/array/full/"
39+
}
40+
}
41+
42+
# Setup logging
43+
logger = logging.getLogger(__name__)
44+
logging.basicConfig(level=logging.INFO)
45+
46+
app = typer.Typer()
47+
48+
49+
async def get_urls_from_db(db_path, limit=None):
50+
"""Get a list of Tiled URLs from the database asynchronously"""
51+
try:
52+
async with aiosqlite.connect(db_path) as conn:
53+
query = "SELECT id, tiled_url FROM vectors ORDER BY id"
54+
if limit:
55+
query += f" LIMIT {limit}"
56+
57+
async with conn.execute(query) as cursor:
58+
results = await cursor.fetchall()
59+
60+
if not results:
61+
logger.warning(f"No Tiled URLs found in database {db_path}")
62+
return []
63+
64+
logger.info(f"Found {len(results)} Tiled URLs in database")
65+
return results
66+
except Exception as e:
67+
logger.error(f"Error reading from database: {e}")
68+
return []
69+
70+
71+
def transform_url_for_env(tiled_url, env):
72+
"""
73+
Transform a Tiled URL to match the specified environment format.
74+
75+
Args:
76+
tiled_url: The original Tiled URL (typically from dev environment)
77+
env: The target environment ('dev' or 'prod')
78+
79+
Returns:
80+
str: Transformed URL for the target environment
81+
"""
82+
if env not in TILED_URLS:
83+
logger.warning(f"Unknown environment '{env}', falling back to 'dev'")
84+
env = "dev"
85+
86+
# If we're staying in dev, no transformation needed
87+
if env == "dev" and "tiled-dev.nsls2.bnl.gov" in tiled_url:
88+
return tiled_url
89+
90+
# Extract slice parameter if present
91+
slice_param = None
92+
if '?' in tiled_url:
93+
slice_param = tiled_url.split('?')[1]
94+
95+
# Parse the URL to extract UUID and stream path
96+
url_without_query = tiled_url.split('?')[0] # Remove query parameters
97+
98+
# Extract UUID and stream path
99+
uuid = None
100+
stream_path = None
101+
102+
if 'array/full/' in url_without_query:
103+
path_after_full = url_without_query.split('array/full/')[1]
104+
parts = path_after_full.split('/')
105+
if len(parts) >= 1:
106+
uuid = parts[0]
107+
if len(parts) > 1:
108+
stream_path = '/'.join(parts[1:])
109+
110+
if not uuid or not stream_path:
111+
logger.error(f"Could not parse Tiled URL: {tiled_url}")
112+
return tiled_url # Return original if parsing fails
113+
114+
# Transform URL based on environment
115+
if env == "prod":
116+
# Get image name from stream_path
117+
parts = stream_path.split('/')
118+
if len(parts) > 0:
119+
image_name = parts[-1]
120+
# Format: http://tiled.nsls2.bnl.gov/api/v1/array/full/smi/raw/{uuid}/primary/data/{image_name}?slice=...
121+
new_url = f"{TILED_URLS[env]['url_pattern']}smi/raw/{uuid}/primary/data/{image_name}"
122+
else:
123+
# Fallback if we can't extract image name
124+
logger.error(f"Could not extract image name from stream path: {stream_path}")
125+
return tiled_url
126+
else:
127+
# Dev URL format: http://tiled-dev.nsls2.bnl.gov/api/v1/array/full/{uuid}/{stream_path}?slice=...
128+
new_url = f"{TILED_URLS[env]['url_pattern']}{uuid}/{stream_path}"
129+
130+
# Add slice parameter if it exists
131+
if slice_param:
132+
new_url = f"{new_url}?{slice_param}"
133+
134+
logger.debug(f"Transformed URL: {tiled_url} -> {new_url}")
135+
return new_url
136+
137+
138+
def _read_image_from_tiled_url_sync(tiled_url, api_key=None):
139+
"""
140+
Read an image from a Tiled URL.
141+
142+
Args:
143+
tiled_url: The Tiled URL (already transformed for the appropriate environment)
144+
api_key: API key for Tiled authentication
145+
146+
Returns:
147+
tuple: (image_data, index)
148+
"""
149+
try:
150+
# Extract index from slice parameter
151+
index = 0 # Default index
152+
if '?' in tiled_url and 'slice=' in tiled_url:
153+
slice_param = tiled_url.split('slice=')[1].split('&')[0]
154+
if ':' in slice_param:
155+
parts = slice_param.split(',')[0].split(':')
156+
if parts[0].isdigit():
157+
index = int(parts[0])
158+
159+
# Parse the URL to extract base URL and path
160+
url_without_query = tiled_url.split('?')[0] # Remove query parameters
161+
url_parts = url_without_query.split('/api/v1/')
162+
163+
if len(url_parts) != 2:
164+
logger.error(f"Invalid Tiled URL format: {tiled_url}")
165+
return None, 0
166+
167+
# Change array/full to metadata
168+
base_uri = f"{url_parts[0]}/api/v1/metadata"
169+
170+
# Extract dataset URI - get everything after "array/full/"
171+
full_path = url_parts[1]
172+
173+
if 'array/full/' in url_without_query:
174+
# If the URL contains array/full, extract the part after it
175+
path_parts = full_path.split('array/full/')
176+
if len(path_parts) > 1:
177+
dataset_uri = path_parts[1]
178+
else:
179+
dataset_uri = full_path
180+
else:
181+
# If URL doesn't contain array/full, use the whole path
182+
dataset_uri = full_path
183+
184+
logger.debug(f"Base URI: {base_uri}, Dataset URI: {dataset_uri}, Index: {index}")
185+
186+
# Connect to the Tiled server
187+
client = from_uri(base_uri, api_key=api_key)
188+
189+
# Access the dataset
190+
tiled_data = client[dataset_uri]
191+
logger.debug(f"Dataset shape: {tiled_data.shape}, dtype: {tiled_data.dtype}")
192+
193+
# Retrieve the image at the specified index
194+
image = tiled_data[index]
195+
196+
return image, index
197+
198+
except Exception as e:
199+
logger.error(f"Error reading from Tiled URL {tiled_url}: {e}")
200+
return None, 0
201+
202+
203+
async def read_image_from_tiled_url(tiled_url, api_key=None):
204+
"""Async wrapper for _read_image_from_tiled_url_sync"""
205+
return await asyncio.to_thread(_read_image_from_tiled_url_sync, tiled_url, api_key)
206+
207+
208+
@app.command()
209+
def main(
210+
db_path: str = typer.Option(DEFAULT_DB_PATH, help="Path to the SQLite database containing Tiled URLs"),
211+
max_frames: int = typer.Option(10000, help="Maximum number of frames to process"),
212+
api_key: str = typer.Option(DEFAULT_API_KEY, help="API key for Tiled authentication"),
213+
env: str = typer.Option(TILED_ENV, help="Tiled environment to use ('dev' or 'prod')")
214+
):
215+
"""
216+
Run the image simulator that reads Tiled URLs from a database, fetches the images, and publishes them via ZMQ.
217+
218+
Configuration can be set via environment variables:
219+
- DB_PATH: Path to the SQLite database
220+
- TILED_API_KEY: API key for Tiled authentication
221+
- TILED_ENV: Environment to use ('dev' or 'prod')
222+
223+
Command-line arguments override environment variables.
224+
"""
225+
# Log the configuration
226+
logger.info(f"Starting DB Image Simulator with:")
227+
logger.info(f"- Database path: {db_path}")
228+
logger.info(f"- Max frames: {max_frames}")
229+
logger.info(f"- API key provided: {api_key is not None}")
230+
logger.info(f"- Tiled environment: {env}")
231+
232+
async def run():
233+
# Check if database exists
234+
if not os.path.exists(db_path):
235+
logger.error(f"Database file not found: {db_path}")
236+
return
237+
238+
# Setup ZMQ socket
239+
context = zmq.asyncio.Context()
240+
socket = context.socket(zmq.PUB)
241+
address = settings.tiled_poller.zmq_frame_publisher.address
242+
logger.info(f"Binding to ZMQ address: {address}")
243+
socket.bind(address)
244+
245+
# Get URLs from database
246+
urls = await get_urls_from_db(db_path, limit=max_frames)
247+
if not urls:
248+
logger.error("No URLs found in database, cannot continue")
249+
return
250+
251+
# Send start event
252+
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
253+
start = SASStart(
254+
width=1679, # Default values - these will be updated from real data
255+
height=1475,
256+
data_type="uint32",
257+
tiled_url=f"{env}://latent_vectors",
258+
run_name=f"{env}_tiled_run",
259+
run_id=str(current_time),
260+
)
261+
logger.info(f"Sending start event")
262+
await socket.send(msgpack.packb(start.model_dump()))
263+
264+
# Process each URL
265+
for db_id, tiled_url in urls:
266+
try:
267+
logger.info(f"Processing URL from DB record {db_id}: {tiled_url}")
268+
269+
# Transform the URL for the current environment before processing
270+
transformed_url = transform_url_for_env(tiled_url, env)
271+
logger.info(f"Transformed URL: {transformed_url}")
272+
273+
# Read image data from transformed Tiled URL
274+
image_data, index = await read_image_from_tiled_url(transformed_url, api_key)
275+
276+
if image_data is None:
277+
logger.error(f"Failed to read image from {transformed_url}")
278+
continue
279+
280+
# Send the frame event with transformed URL
281+
event = RawFrameEvent(
282+
image=SerializableNumpyArrayModel(array=image_data),
283+
frame_number=index,
284+
tiled_url=transformed_url,
285+
)
286+
logger.info(f"Sending frame {index}")
287+
await socket.send(msgpack.packb(event.model_dump()))
288+
289+
# Small delay between frames
290+
await asyncio.sleep(0.1)
291+
292+
except Exception as e:
293+
logger.error(f"Error processing frame from {tiled_url}: {e}")
294+
295+
# Send stop event
296+
stop = SASStop(num_frames=len(urls))
297+
logger.info(f"Sending stop event")
298+
await socket.send(msgpack.packb(stop.model_dump()))
299+
logger.info(f"Complete - sent {len(urls)} frames")
300+
301+
asyncio.run(run())
302+
303+
304+
if __name__ == "__main__":
305+
app()

0 commit comments

Comments
 (0)