Skip to content

Commit 4dd5222

Browse files
authored
Merge pull request #77 from mihiarc/fix/evalid-filter-all-plt-cn-tables
Fix EVALID filtering for all tables with PLT_CN column
2 parents 2fd7dd7 + 788960f commit 4dd5222

2 files changed

Lines changed: 175 additions & 7 deletions

File tree

src/pyfia/core/fia.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,9 @@ def _apply_spatial_filter(self, df: pl.LazyFrame, table_name: str) -> pl.LazyFra
292292
return df
293293
if table_name == "PLOT":
294294
return df.filter(pl.col("CN").is_in(self._spatial_plot_cns))
295-
elif table_name in ["TREE", "COND"]:
295+
# Filter any table with a PLT_CN column by the spatial plot CNs
296+
schema = self._reader.get_table_schema(table_name)
297+
if "PLT_CN" in schema:
296298
return df.filter(pl.col("PLT_CN").is_in(self._spatial_plot_cns))
297299
return df
298300

@@ -346,9 +348,15 @@ def load_table(
346348
pl.LazyFrame
347349
Polars LazyFrame of the requested table.
348350
"""
349-
# Build base WHERE clause for state filter
351+
# Inspect table schema to determine which filters apply
352+
table_schema = self._reader.get_table_schema(table_name)
353+
table_columns = set(table_schema.keys())
354+
has_plt_cn = "PLT_CN" in table_columns
355+
has_statecd = "STATECD" in table_columns
356+
357+
# Build base WHERE clause for state filter (any table with STATECD)
350358
base_where_clause = None
351-
if self.state_filter and table_name in ["PLOT", "COND", "TREE"]:
359+
if self.state_filter and has_statecd:
352360
state_list = ", ".join(str(s) for s in self.state_filter)
353361
base_where_clause = f"STATECD IN ({state_list})"
354362

@@ -359,9 +367,9 @@ def load_table(
359367
else:
360368
base_where_clause = where
361369

362-
# EVALID filter via PLT_CN for TREE, COND tables
363-
# This is a critical optimization - it reduces data load by 90%+ for GRM estimates
364-
if self.evalid and table_name in ["TREE", "COND"]:
370+
# EVALID filter via PLT_CN for any table that has a PLT_CN column
371+
# This is a critical optimization - it reduces data load by 90%+
372+
if self.evalid and has_plt_cn:
365373
valid_plot_cns = self._get_valid_plot_cns()
366374
if valid_plot_cns:
367375
from .utils import batch_query_by_values
@@ -404,7 +412,7 @@ def query_batch(batch: list) -> pl.LazyFrame:
404412
self.tables[table_name] = result
405413
return self.tables[table_name]
406414

407-
# Default path - no EVALID filtering or not a filterable table
415+
# Default path - no EVALID filtering or table has no PLT_CN column
408416
df = self._reader.read_table(
409417
table_name,
410418
columns=columns,
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
"""Unit tests for load_table() EVALID and state filtering.
2+
3+
Verifies that load_table() applies PLT_CN-based EVALID filtering and
4+
STATECD-based state filtering to any table that has those columns,
5+
not just a hardcoded allowlist of table names.
6+
"""
7+
8+
from unittest.mock import MagicMock, patch
9+
10+
import polars as pl
11+
import pytest
12+
13+
14+
@pytest.fixture
15+
def mock_fia():
16+
"""Create a mock FIA instance with the real load_table method."""
17+
from pyfia.core.fia import FIA
18+
19+
with patch.object(FIA, "__init__", lambda self: None):
20+
db = FIA()
21+
db.tables = {}
22+
db.evalid = None
23+
db.state_filter = None
24+
db._polygon_attributes = None
25+
db._spatial_plot_cns = None
26+
db._valid_plot_cns = None
27+
db._reader = MagicMock()
28+
return db
29+
30+
31+
class TestEVALIDFilteringByColumn:
32+
"""EVALID filtering should apply to any table with PLT_CN, not just TREE/COND."""
33+
34+
def test_tree_grm_component_gets_evalid_filtered(self, mock_fia):
35+
"""TREE_GRM_COMPONENT has PLT_CN and should be EVALID-filtered."""
36+
mock_fia.evalid = [132303]
37+
mock_fia._valid_plot_cns = ["100", "200", "300"]
38+
mock_fia._reader.get_table_schema.return_value = {
39+
"CN": "VARCHAR",
40+
"TRE_CN": "VARCHAR",
41+
"PLT_CN": "VARCHAR",
42+
"COMPONENT": "VARCHAR",
43+
"TPA_UNADJ": "DOUBLE",
44+
}
45+
mock_fia._reader.read_table.return_value = pl.DataFrame(
46+
{"TRE_CN": ["1"], "PLT_CN": ["100"], "TPA_UNADJ": [1.0]}
47+
).lazy()
48+
49+
mock_fia.load_table("TREE_GRM_COMPONENT")
50+
51+
# Verify read_table was called with a PLT_CN IN (...) WHERE clause
52+
call_args = mock_fia._reader.read_table.call_args
53+
where_clause = call_args.kwargs.get("where", "") or ""
54+
assert "PLT_CN IN" in where_clause
55+
56+
def test_table_without_plt_cn_skips_evalid_filter(self, mock_fia):
57+
"""Tables without PLT_CN (e.g. POP_EVAL) should not get EVALID filtering."""
58+
mock_fia.evalid = [132303]
59+
mock_fia._valid_plot_cns = ["100", "200"]
60+
mock_fia._reader.get_table_schema.return_value = {
61+
"CN": "VARCHAR",
62+
"EVALID": "INTEGER",
63+
"EVAL_DESCR": "VARCHAR",
64+
}
65+
mock_fia._reader.read_table.return_value = pl.DataFrame(
66+
{"CN": ["1"], "EVALID": [132303], "EVAL_DESCR": ["test"]}
67+
).lazy()
68+
69+
mock_fia.load_table("POP_EVAL")
70+
71+
# Should use default path without PLT_CN filtering
72+
call_args = mock_fia._reader.read_table.call_args
73+
where_clause = call_args.kwargs.get("where", "") or ""
74+
assert "PLT_CN IN" not in where_clause
75+
76+
def test_no_evalid_set_skips_filter(self, mock_fia):
77+
"""When no EVALID is set, PLT_CN filtering should be skipped."""
78+
mock_fia.evalid = None
79+
mock_fia._reader.get_table_schema.return_value = {
80+
"CN": "VARCHAR",
81+
"PLT_CN": "VARCHAR",
82+
"TPA_UNADJ": "DOUBLE",
83+
}
84+
mock_fia._reader.read_table.return_value = pl.DataFrame(
85+
{"CN": ["1"], "PLT_CN": ["100"], "TPA_UNADJ": [1.0]}
86+
).lazy()
87+
88+
mock_fia.load_table("TREE_GRM_COMPONENT")
89+
90+
# Should use default path
91+
call_args = mock_fia._reader.read_table.call_args
92+
where_clause = call_args.kwargs.get("where", "") or ""
93+
assert "PLT_CN IN" not in where_clause
94+
95+
96+
class TestStateFilteringByColumn:
97+
"""State filtering should apply to any table with STATECD."""
98+
99+
def test_table_with_statecd_gets_filtered(self, mock_fia):
100+
"""Any table with STATECD should get state filtering."""
101+
mock_fia.state_filter = [13] # Georgia
102+
mock_fia._reader.get_table_schema.return_value = {
103+
"CN": "VARCHAR",
104+
"PLT_CN": "VARCHAR",
105+
"STATECD": "INTEGER",
106+
}
107+
mock_fia._reader.read_table.return_value = pl.DataFrame(
108+
{"CN": ["1"], "PLT_CN": ["100"], "STATECD": [13]}
109+
).lazy()
110+
111+
mock_fia.load_table("SEEDLING")
112+
113+
call_args = mock_fia._reader.read_table.call_args
114+
where_clause = call_args.kwargs.get("where", "") or ""
115+
assert "STATECD IN (13)" in where_clause
116+
117+
def test_table_without_statecd_skips_filter(self, mock_fia):
118+
"""Tables without STATECD should not get state filtering."""
119+
mock_fia.state_filter = [13]
120+
mock_fia._reader.get_table_schema.return_value = {
121+
"CN": "VARCHAR",
122+
"EVALID": "INTEGER",
123+
}
124+
mock_fia._reader.read_table.return_value = pl.DataFrame(
125+
{"CN": ["1"], "EVALID": [132303]}
126+
).lazy()
127+
128+
mock_fia.load_table("POP_EVAL")
129+
130+
call_args = mock_fia._reader.read_table.call_args
131+
where_clause = call_args.kwargs.get("where", "") or ""
132+
assert "STATECD" not in where_clause
133+
134+
135+
class TestSpatialFilteringByColumn:
136+
"""Spatial filtering should apply to any table with PLT_CN."""
137+
138+
def test_spatial_filter_applies_to_grm_table(self, mock_fia):
139+
"""Tables with PLT_CN should get spatial filtering when active."""
140+
mock_fia._spatial_plot_cns = ["100", "200"]
141+
mock_fia._reader.get_table_schema.return_value = {
142+
"CN": "VARCHAR",
143+
"PLT_CN": "VARCHAR",
144+
"TPA_UNADJ": "DOUBLE",
145+
}
146+
data = pl.DataFrame(
147+
{
148+
"CN": ["1", "2", "3"],
149+
"PLT_CN": ["100", "200", "999"],
150+
"TPA_UNADJ": [1.0, 2.0, 3.0],
151+
}
152+
).lazy()
153+
mock_fia._reader.read_table.return_value = data
154+
155+
result = mock_fia.load_table("TREE_GRM_COMPONENT")
156+
157+
# Should filter to only spatial plot CNs
158+
collected = result.collect()
159+
assert collected.shape[0] == 2
160+
assert set(collected["PLT_CN"].to_list()) == {"100", "200"}

0 commit comments

Comments
 (0)