Skip to content
72 changes: 60 additions & 12 deletions backend/app/features/simulation/api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datetime import datetime, timezone
from uuid import UUID

from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy.orm import Session, joinedload, selectinload

from app.common.dependencies import get_database_session
Expand Down Expand Up @@ -57,6 +57,36 @@ def list_cases(db: Session = Depends(get_database_session)) -> list[CaseOut]:
return resp


@case_router.get(
"/names",
response_model=list[str],
responses={
200: {"description": "List all case names."},
500: {"description": "Internal server error."},
},
)
def list_case_names(db: Session = Depends(get_database_session)) -> list[str]:
"""Return a sorted list of all case names.

This lightweight endpoint avoids loading nested simulation data,
making it suitable for populating filter dropdowns.

Parameters
----------
db : Session, optional
The database session dependency, by default provided by
`Depends(get_database_session)`.

Returns
-------
list[str]
Alphabetically sorted case names.
"""
names = db.query(Case.name).order_by(Case.name).all()

return [n[0] for n in names]


@case_router.get(
"/{case_id}",
response_model=CaseOut,
Expand Down Expand Up @@ -251,7 +281,17 @@ def create_simulation(
500: {"description": "Internal server error."},
},
)
def list_simulations(db: Session = Depends(get_database_session)):
def list_simulations(
db: Session = Depends(get_database_session),
case_name: str | None = Query(
None,
description="Filter simulations by exact case name.",
),
case_group: str | None = Query(
None,
description="Filter simulations by exact case group.",
),
):
"""
Retrieve a list of simulations from the database, ordered by creation date
in descending order.
Expand All @@ -261,24 +301,32 @@ def list_simulations(db: Session = Depends(get_database_session)):
db : Session, optional
The database session dependency, by default obtained via
`Depends(get_database_session)`.
case_name : str, optional
If provided, only simulations whose associated case name matches
exactly will be returned.
case_group : str, optional
If provided, only simulations whose associated case group matches
exactly will be returned.

Returns
-------
list
A list of `Simulation` objects, ordered by their `created_at` timestamp
in descending order.
"""
sims = (
db.query(Simulation)
.options(
joinedload(Simulation.case),
joinedload(Simulation.machine),
selectinload(Simulation.artifacts),
selectinload(Simulation.links),
)
.order_by(Simulation.created_at.desc())
.all()
query = db.query(Simulation).options(
joinedload(Simulation.case),
joinedload(Simulation.machine),
selectinload(Simulation.artifacts),
selectinload(Simulation.links),
)

if case_name is not None:
query = query.filter(Simulation.case.has(name=case_name))
if case_group is not None:
query = query.filter(Simulation.case.has(case_group=case_group))

sims = query.order_by(Simulation.created_at.desc()).all()
return [_simulation_to_out(s) for s in sims]


Expand Down
186 changes: 186 additions & 0 deletions backend/tests/features/simulation/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,25 @@ def test_endpoint_returns_cases_with_nested_simulations(
assert exec_ids["case-nested-exec-2"]["changeCount"] == 1


class TestListCaseNames:
def test_endpoint_returns_empty_list(self, client):
res = client.get(f"{API_BASE}/cases/names")
assert res.status_code == 200
assert res.json() == []

def test_endpoint_returns_case_names_sorted_alphabetically(
self, client, db: Session
):
_create_case(db, "zeta_case")
_create_case(db, "alpha_case")
_create_case(db, "beta_case")
db.commit()

res = client.get(f"{API_BASE}/cases/names")
assert res.status_code == 200
assert res.json() == ["alpha_case", "beta_case", "zeta_case"]


class TestGetCase:
def test_endpoint_returns_case_with_simulations(
self, client, db: Session, normal_user_sync, admin_user_sync
Expand Down Expand Up @@ -429,6 +448,173 @@ def test_endpoint_returns_simulations_with_data(
assert data[0]["caseName"] == "test_case_list"
assert data[0]["executionId"] == "list-test-exec-1"

def test_filter_by_case_name(
self, client, db: Session, normal_user_sync, admin_user_sync
):
machine = db.query(Machine).first()
assert machine is not None

case_a = _create_case(db, "case_alpha")
case_b = _create_case(db, "case_beta")

ingestion = Ingestion(
source_type=IngestionSourceType.BROWSER_UPLOAD,
source_reference="test_filter_case_name",
machine_id=machine.id,
triggered_by=normal_user_sync["id"],
status=IngestionStatus.SUCCESS,
created_count=2,
duplicate_count=0,
error_count=0,
)
db.add(ingestion)
db.flush()

for case, exec_id in [(case_a, "exec-a"), (case_b, "exec-b")]:
db.add(
Simulation(
case_id=case.id,
execution_id=exec_id,
compset="AQUAPLANET",
compset_alias="QPC4",
grid_name="f19_f19",
grid_resolution="1.9x2.5",
initialization_type="startup",
simulation_type="experimental",
status="created",
machine_id=machine.id,
simulation_start_date="2023-01-01T00:00:00Z",
created_by=normal_user_sync["id"],
last_updated_by=admin_user_sync["id"],
ingestion_id=ingestion.id,
)
)
db.commit()

# No filter returns both
res = client.get(f"{API_BASE}/simulations")
assert res.status_code == 200
assert len(res.json()) == 2

# Filter by case_name=case_alpha returns only one
res = client.get(f"{API_BASE}/simulations", params={"case_name": "case_alpha"})
assert res.status_code == 200
data = res.json()
assert len(data) == 1
assert data[0]["caseName"] == "case_alpha"

# Non-matching filter returns empty
res = client.get(f"{API_BASE}/simulations", params={"case_name": "nonexistent"})
assert res.status_code == 200
assert len(res.json()) == 0

def test_filter_by_case_group(
self, client, db: Session, normal_user_sync, admin_user_sync
):
machine = db.query(Machine).first()
assert machine is not None

case_g1 = Case(name="case_group1", case_group="ensemble_A")
case_g2 = Case(name="case_group2", case_group="ensemble_B")
db.add_all([case_g1, case_g2])
db.flush()

ingestion = Ingestion(
source_type=IngestionSourceType.BROWSER_UPLOAD,
source_reference="test_filter_case_group",
machine_id=machine.id,
triggered_by=normal_user_sync["id"],
status=IngestionStatus.SUCCESS,
created_count=2,
duplicate_count=0,
error_count=0,
)
db.add(ingestion)
db.flush()

for case, exec_id in [(case_g1, "exec-g1"), (case_g2, "exec-g2")]:
db.add(
Simulation(
case_id=case.id,
execution_id=exec_id,
compset="AQUAPLANET",
compset_alias="QPC4",
grid_name="f19_f19",
grid_resolution="1.9x2.5",
initialization_type="startup",
simulation_type="experimental",
status="created",
machine_id=machine.id,
simulation_start_date="2023-01-01T00:00:00Z",
created_by=normal_user_sync["id"],
last_updated_by=admin_user_sync["id"],
ingestion_id=ingestion.id,
)
)
db.commit()

res = client.get(f"{API_BASE}/simulations", params={"case_group": "ensemble_A"})
assert res.status_code == 200
data = res.json()
assert len(data) == 1
assert data[0]["caseGroup"] == "ensemble_A"

def test_filter_by_case_name_and_case_group(
self, client, db: Session, normal_user_sync, admin_user_sync
):
machine = db.query(Machine).first()
assert machine is not None

case = Case(name="combo_case", case_group="combo_group")
case_other = Case(name="other_case", case_group="combo_group")
db.add_all([case, case_other])
db.flush()

ingestion = Ingestion(
source_type=IngestionSourceType.BROWSER_UPLOAD,
source_reference="test_filter_combo",
machine_id=machine.id,
triggered_by=normal_user_sync["id"],
status=IngestionStatus.SUCCESS,
created_count=2,
duplicate_count=0,
error_count=0,
)
db.add(ingestion)
db.flush()

for c, exec_id in [(case, "exec-combo"), (case_other, "exec-other")]:
db.add(
Simulation(
case_id=c.id,
execution_id=exec_id,
compset="AQUAPLANET",
compset_alias="QPC4",
grid_name="f19_f19",
grid_resolution="1.9x2.5",
initialization_type="startup",
simulation_type="experimental",
status="created",
machine_id=machine.id,
simulation_start_date="2023-01-01T00:00:00Z",
created_by=normal_user_sync["id"],
last_updated_by=admin_user_sync["id"],
ingestion_id=ingestion.id,
)
)
db.commit()

# Both share same group, but filtering by both narrows to one
res = client.get(
f"{API_BASE}/simulations",
params={"case_name": "combo_case", "case_group": "combo_group"},
)
assert res.status_code == 200
data = res.json()
assert len(data) == 1
assert data[0]["caseName"] == "combo_case"
assert data[0]["caseGroup"] == "combo_group"


class TestGetSimulation:
def test_endpoint_succeeds_with_valid_id(
Expand Down
Loading