Add dataclass and state machine for fine grained subtask tracking#758
Add dataclass and state machine for fine grained subtask tracking#758peterd-NV wants to merge 12 commits into
Conversation
There was a problem hiding this comment.
Review: Add FineGrainedSubtask dataclass and state machine
Summary: This PR introduces a fine-grained subtask tracking system with a FineGrainedSubtask dataclass and a FineGrainedStateMachine to monitor predicate progression across parallel environments. The design is well-structured — clean separation between the data definition (FineGrainedSubtask), the runner (FineGrainedSubtaskRunner), and the top-level orchestrator (FineGrainedStateMachine). Test coverage is thorough.
Findings
1. 🐛 Performance: Element-wise reset loop will be slow at scale
File: fine_grained_state_machine.py (lines 152–156)
def reset(self, env_ids) -> None:
for group_name in self.fine_grained_subtask.group_names:
for eid in env_ids:
self.current_index[group_name][eid] = 0
self.group_score[group_name][eid] = 0.0
self.group_complete[group_name][eid] = FalseThis iterates element-by-element over env_ids for each group. With hundreds/thousands of environments, this will be significantly slower than vectorized indexing:
def reset(self, env_ids) -> None:
if not isinstance(env_ids, torch.Tensor):
env_ids = torch.as_tensor(env_ids, dtype=torch.long, device=self.device)
for group_name in self.fine_grained_subtask.group_names:
self.current_index[group_name][env_ids] = 0
self.group_score[group_name][env_ids] = 0.0
self.group_complete[group_name][env_ids] = FalseSeverity: Medium — correctness is fine but this will be a bottleneck in real training loops.
2. 🔍 State machine can advance multiple predicates in a single step (by design, but undocumented)
File: fine_grained_state_machine.py (lines 103–136)
The step() method iterates over all chain indices in a single call. If predicate p0 becomes True and p1 is already True at the same step, the state machine advances through both in one step() call. This is because after advancing past p0, the loop continues to chain_idx=1 and finds the env is now at_position.
Wait — actually re-reading more carefully: the advanced mask prevents double-advancing within a single step. Once advance_mask is set for an env, ~advanced blocks further evaluation. So each env can advance at most one predicate per step.
This is correct behavior but worth a brief docstring note on the single-advance-per-step invariant, since it's subtle and reviewers/future maintainers may miss it.
3. ⚠️ is_complete() and overall_score_per_env() recompute on every get_state() call
File: fine_grained_state_machine.py (lines 159–174, 187–228)
get_state() calls runner.is_complete() and runner.overall_score_per_env() per runner per env in the output loop. Each of these does a torch.stack + reduction. In a tight training loop where get_state() is called every step (via fine_grained_subtask_step_func), this could add up.
Consider caching is_complete and overall_score_per_env results per step (invalidated at each step() call) or computing them once at the end of step().
Severity: Low-Medium — depends on number of FGS runners and training throughput requirements.
4. 📝 _compute_composite_task_gating_mask assumes _current_subtask_idx is list-like or tensor
File: fine_grained_state_machine.py (lines 58–78)
if torch.is_tensor(current_idx):
ci = current_idx.to(self.device)
else:
ci = torch.as_tensor(current_idx, device=self.device)
return ci == int(self.fine_grained_subtask.parent_subtask_idx)If _current_subtask_idx is a scalar int (not a list or tensor), torch.as_tensor(5) produces a 0-d tensor and ci == N returns a 0-d bool tensor, not (num_envs,). The comparison against composite_task_gating_mask downstream would then broadcast unexpectedly. This probably never happens in practice (composite tasks use per-env tensors), but a shape assertion or explicit broadcast would make it defensive.
Severity: Low — unlikely path but a potential silent logic error.
5. 💡 task_base.py imports are unconditional — consider lazy imports
File: task_base.py (lines 14–18)
from isaaclab_arena.tasks.fine_grained_state_machine import (
make_fine_grained_subtask_events_cfg,
make_fine_grained_subtask_termination_cfg,
)
from isaaclab_arena.tasks.fine_grained_subtask import FineGrainedSubtaskEvery task (including those that don't use fine-grained tracking) now imports the full state machine module at class-load time. This pulls in torch and the state machine code even when unused. For a base class imported everywhere, lazy imports inside get_fine_grained_subtask_events_cfg() / get_fine_grained_subtask_termination_cfg() would keep startup lean.
Severity: Low — minor import overhead, not a correctness issue.
6. 🧪 Missing test: predicate raising an exception mid-chain
The test suite covers happy paths thoroughly (sequential advance, logical modes, gating, reset). However, there's no test for what happens if a predicate callable raises an exception. Currently, it would propagate unhandled and potentially leave the state machine in a partially-advanced state for that step.
Consider either:
- Adding a test that documents this as expected behavior (exception propagates, caller handles it)
- Or wrapping predicate evaluation in a try/except that logs and skips
Severity: Low — documentation/robustness concern.
7. 📝 Minor: Typo in docstrings — "parallelenvironments"
File: fine_grained_state_machine.py (lines 36, 177)
"across all parallelenvironments" → "across all parallel environments"
Positives
- Clean architecture: The dataclass → runner → state machine layering is well-designed and extensible.
- Thorough validation in
FineGrainedSubtask.__post_init__catches misconfiguration early. - Flexible input formats for
predicate_groupswith good normalization logic. - Composite task integration via
dataclasses.replacewith namespaced names is elegant. - Comprehensive test coverage (17 test cases) covering sequential advancement, out-of-order rejection, all logical modes, gating, reset isolation, and integration hooks.
Verdict
Good design, solid implementation, excellent tests. The vectorized reset (finding #1) is the most actionable improvement. The rest are minor hardening suggestions. 👍
There was a problem hiding this comment.
Follow-up note: Noticed the class rename in the latest commit (FineGrainedStateMachine → FineGrainedSubtaskTrackingStateMachine). The new name better reflects the class's purpose — good clarification. The previous review findings still apply unchanged.
No new review generated — naming-only refactor.
Update (2026-06-05): Reviewed incremental changes 4c9aa2b..83f0ca3.
Previous findings status:
- ✅ Fixed — Finding #7 (typo "parallelenvironments" → "parallel environments")
- Findings #1–#6 — original comments stand (not addressed in this push, which is fine given the scope of changes)
Summary of new changes:
Large push with several new features alongside the rename refactor:
- Rename:
FineGrainedSubtask→FineGrainedProgressObjective,FineGrainedSubtaskTrackingStateMachine→FineGrainedProgressTracker— cleaner naming, comprehensive rename across all references and tests. 👍 NotNextTorelation — new loss strategy with proper gradient (no flat plateau) and solver integration tests. Well-designed.- Variations system —
VariationBase/BuildTimeVariationBase/RunTimeVariationBasewith cfg-driven sampler wiring. Clean architecture with enable/disable toggle andapply_cfg. Good test coverage for camera extrinsics, HDR image, build-time, and run-time variations. - CLI override specs — graph YAML can declare swappable
--flagarguments that remap node asset names. Properly validated (duplicates, unknown targets). - Graph spec Pydantic migration —
ArenaEnvGraphSpecnow uses Pydantic validation instead of manual asserts. Spatial constraints use string-based registry lookup instead of enum. Removes special-case handling forat_pose/in. - Chunked eval runner —
--chunk_sizesplits long sweeps across subprocesses to reclaim leaked memory. Appropriate TODO for metrics aggregation. - Pooled placer reproducibility — per-env RNGs via
get_rngs()replacerandom.choice/random.choices, makingsample_with_replacementreproducible underplacement_seed.
No new issues found in this push. Clean, well-tested changes. 👍
Update (2026-06-05): Reviewed incremental changes 83f0ca3..de73e3d.
Key change: Termination term → RecorderTerm 🎉
The fine-grained progress tracking has been moved from a TerminationTermCfg (that always returned False) to a proper RecorderTerm (FineGrainedProgressRecorder.record_post_step). This directly addresses the architectural concern discussed in this PR — using a termination term as a per-step hook was an anti-pattern.
Benefits:
- Cleaner intent expression (this is observation/logging, not termination)
record_post_stepreturns(None, None)— records nothing, just ticks the state machine- No more zero-return termination term cluttering
TerminationsCfg - Tests properly updated to verify the new interface
Previous findings status:
- ✅ Fixed — Mid-function dead docstring (original finding about
fine_grained_subtask_step_funcplacing docs mid-function). Now properly in theFineGrainedProgressRecorderclass docstring. - ✅ Addressed — "Termination term always returns False" architectural smell. Replaced with RecorderTerm — the correct hook point.
- Findings #1 (vectorized reset), #2 (split calls), #3–#6 — original comments still stand (not in scope of this refactor).
Minor new finding:
- 📝 Typo in
FineGrainedProgressRecorderdocstring (line 289): "state nachine" → "state machine"
Verdict:
Clean, well-scoped refactor. The move to RecorderTerm is the right architectural choice. 👍
Update (2026-06-05): Reviewed incremental changes de73e3d..680fa6d.
Previous findings status:
- ✅ Fixed — Finding #2 (undocumented single-advance-per-step invariant). Now explicitly documented in
step()and_step_group()docstrings. - ✅ Fixed — Finding #3 (
is_complete()/overall_score_per_env()recomputed per env inget_state()). Now pre-computed once as(num_envs,)tensors outside the env loop and passed intoget_state_for_env(). - 📝 Typo "state nachine" (line 317 of
fine_grained_progress_tracker.py) — still present from previous push, not addressed here.
Summary of new changes:
step()refactored — Inner group-iteration logic extracted into_step_group(). Identical behavior, better readability and separation of concerns.get_state()optimized —completenessandscorestensors computed once per call, passed to newrunner.get_state_for_env(env_idx, ...)method. Eliminates redundanttorch.stack+ reduction per env.- Line-length fixes in
fine_grained_progress_objective.py— cosmetic only.
No new issues found. Clean refactor addressing prior performance feedback. 👍
Update (2026-06-05): Reviewed incremental changes 680fa6d..24807a4.
Previous findings status:
- ✅ Fixed — Finding #1 (vectorized reset). The element-wise
for eid in env_idsloop is now replaced withtorch.as_tensor+ direct tensor indexing (self.current_index[group_name][env_ids] = 0). Exactly the optimization suggested. 👍 - ✅ Fixed — Typo "state nachine" → "state machine" (line 317 of
fine_grained_progress_tracker.py). Now corrected.
New change: Objective caching in task_base.py
- Added
_resolve_fine_grained_progress_objectives()— resolvesget_fine_grained_progress_objectives()once and caches the result so events cfg and recorder cfg share the same objective instances. Prevents duplicate instantiation and ensures consistency between reset and step paths. Clean, correct pattern.
All original findings now addressed:
| # | Finding | Status |
|---|---|---|
| 1 | Vectorized reset | ✅ Fixed (this push) |
| 2 | Single-advance-per-step docstring | ✅ Fixed (prior push) |
| 3 | Cached is_complete/overall_score_per_env |
✅ Fixed (prior push) |
| 4 | Scalar _current_subtask_idx shape guard |
⚪ Low-risk, not addressed (acceptable) |
| 5 | Lazy imports in task_base.py |
⚪ Low priority, not addressed (acceptable) |
| 6 | Exception handling in predicates | ⚪ Low priority, not addressed (acceptable) |
| 7 | Typo "parallelenvironments" | ✅ Fixed (earlier push) |
No new issues found. PR is in good shape. 🚀
Update (00693fb): Reviewed incremental changes. New commits add the agentic environment generation pipeline, env-graph spec refactoring (parent/child → subject/reference, type → kind), task transitions, and @agent_ready decorator. These changes are orthogonal to the fine-grained subtask tracking discussed in earlier comments. No new issues found — the architecture is clean with good test coverage.
Greptile SummaryThis PR introduces foundational machinery for fine-grained subtask progress tracking in Arena: a
Confidence Score: 4/5Safe to merge for single-level composite tasks; nested CompositeTaskBase usage will silently track all inner FGPOs under the outer subtask index, losing inner gating. The gating rewrite in CompositeTaskBase.get_fine_grained_progress_objectives() unconditionally sets parent_subtask_idx=i (the outer index) on every child FGPO via dataclasses.replace. When the child is itself a CompositeTaskBase, the inner per-subtask indices it already populated are overwritten — all inner FGPOs collapse to the outer index and fire together as soon as the outer composite reaches that subtask, bypassing any inner-subtask ordering. This is a current defect on the nested composite code path, not a future concern. isaaclab_arena/tasks/composite_task_base.py — the parent_subtask_idx override in get_fine_grained_progress_objectives needs a guard for nested composite children. Important Files Changed
Sequence DiagramsequenceDiagram
participant Env as IsaacLab Env
participant Rec as FineGrainedProgressRecorder
participant Tracker as FineGrainedProgressTracker
participant Runner as FineGrainedProgressObjectiveRunner
participant Extras as env.extras
Env->>Rec: record_post_step()
Rec->>Tracker: _ensure_progress_tracker(env, objectives)
Note over Tracker: lazy-init on first call, cached on env
Rec->>Tracker: step(env, step_index)
loop for each runner
Tracker->>Runner: step(env, step_index)
Runner->>Runner: _compute_composite_task_gating_mask(env)
loop for each group
Runner->>Runner: _step_group(env, group, chain, gating_mask)
Note over Runner: vectorised torch.where advance
Runner-->>Tracker: transition events
end
end
Tracker-->>Rec: events accumulated
Rec->>Tracker: get_state()
Rec->>Tracker: get_events()
Rec->>Extras: "extras[fine_grained_progress] = {states, events}"
Rec-->>Env: return (None, None)
Env->>Env: reset(env_ids)
Env->>Tracker: fine_grained_progress_reset_func(env, env_ids, objectives)
Tracker->>Runner: runner.reset(env_ids)
Note over Runner: vectorised index assignment
Tracker->>Tracker: "_events[eid] = [] for eid in env_ids"
Reviews (6): Last reviewed commit: "Merge branch 'main' of github.com:isaac-..." | Re-trigger Greptile |
|
Thanks for adding this! Going to be super useful once we have the full pipeline in.
So when reading "subtask" anywhere its not clear which one is meant if not already familiar. If I read it correctly we've got two systems now that do opposite jobs (but read like siblings in the code):
Currently "FineGrained" names the "resolution" of the Task, not the part that it has control over (Can it end the episode or not?) My proposal is to name by that question Can it end the episode or not?:
So the renames could be roughly:
Two side benefits:
|
| """ | ||
|
|
||
| env.extras["fine_grained_subtask"] = { | ||
| "states": sm.get_state(), |
There was a problem hiding this comment.
As this is a termination event, this runs on every step. so get_state() and get_events() rebuild the full per-env view each step even when nothing is reading those extras. Could we move this to a recorder term so it only runs when recording is active?
There was a problem hiding this comment.
Touches on the renaming and suggestion to move this state tracking out of termination and into a recorder
There was a problem hiding this comment.
Good call — using a termination term that always returns False just to get a per-step hook is definitely an architectural smell. Moving the state tracking to RecorderTerm.record_post_step would:
- Skip the overhead when not recording — the
get_state()/get_events()rebuilds only happen when actually needed - Better express intent — this is observation/logging, not termination logic
- Simplify task configs — no zero-return termination term cluttering
TerminationsCfg
The only consideration is whether any real-time consumers (e.g., a live dashboard or curriculum logic) need that per-step state outside of recording. If so, might need a toggle or lazy evaluation pattern. But for pure data collection, the recorder approach is cleaner.
Thanks for the concrete suggestion — this is actionable feedback for @peterd-NV to consider.
There was a problem hiding this comment.
Refactored to use RecorderManager instead of TerminationManager for driving state machine step. TerminationTermCfg has been replaced with a recorder term that runs record_post_step. Since we are using the mechanism to "drive the state machine step" and not actually record data, the recorder term returns (None, None).
One note: The progress tracking should be run at every step, regardless if recording is active or not. The events are complex dictionaries with many fields and thus returned to the user via env.extras, and not via a RecorderTerm.
| } | ||
|
|
||
| # Return all-False so it does not contribute to termination. | ||
| return torch.zeros(env.num_envs, dtype=torch.bool, device=env.device) |
There was a problem hiding this comment.
Since this always returns all-False/zeros it never contributes to termination, so it's really just being used as a "runs every step" hook. RecorderTerm.record_post_step could potentially give the same per-step call without adding a zero-return statement to every progress task's terminations config.
There was a problem hiding this comment.
Updated to user recorder manager/term.
| ci = torch.as_tensor(current_idx, device=self.device) | ||
| return ci == int(self.fine_grained_subtask.parent_subtask_idx) | ||
|
|
||
| def step(self, env, step_index: torch.Tensor | None) -> list[dict]: |
There was a problem hiding this comment.
This function is quite long and handles gating, groups and chain positions. Could we move parts out into functions and split this up into something like
def step(self, env, step_index) -> list[dict]:
gating_mask = self._compute_composite_task_gating_mask(env)
if not bool(gating_mask.any().item()):
return []
events = []
for group_name, predicate_chain in self.fine_grained_subtask.canonical_predicate_groups.items():
events += self._step_group(env, group_name, predicate_chain, gating_mask, step_index)
return events
There was a problem hiding this comment.
Updated to simply step and move _step_group to a separate function.
There was a problem hiding this comment.
This looks good! The step method is now clean and focused on orchestration, while _step_group handles the detailed per-group logic. Nice refactoring 👍
| for eid in env_ids: | ||
| self._events[eid] = [] | ||
|
|
||
| def get_state(self) -> list[dict]: |
There was a problem hiding this comment.
Same here. Claude had a remark and a suggestion (obviously untested :) ). But would be nice to split this up as well into digestible junks.
Ai suggestion: "Per-env loop currently recomputes whole tensors per env (runner.is_complete()[env_idx], runner.overall_score_per_env()[env_idx] are called once per env_idx, but each call builds the full (num_envs,) tensor — so it's O(num_envs²)). Extracting fixes both. Add a method on the runner that owns its own state:
# on FineGrainedSubtaskRunner
def state_for_env(self, env_idx, is_complete, score) -> dict:
"""Per-env view of this runner. is_complete/score passed in so the
full-tensor computations happen once, outside the env loop."""
fgs = self.fine_grained_subtask
active_predicates = {}
completed_groups = 0
for group_name in fgs.group_names:
chain = fgs.canonical_predicate_groups[group_name]
cur = int(self.current_index[group_name][env_idx].item())
if cur >= len(chain):
active_predicates[group_name] = None
completed_groups += 1
else:
active_predicates[group_name] = _predicate_repr(chain[cur][0])
return {
"completed_groups": completed_groups,
"total_groups": len(fgs.group_names),
"score": float(score),
"is_complete": bool(is_complete),
"active_predicates": active_predicates,
}
Then get_state hoists the tensor computations out of the loop and just assembles:
def get_state(self) -> list[dict]:
# Compute the per-runner tensors once, not once per env.
completeness = [r.is_complete() for r in self.runners]
scores = [r.overall_score_per_env() for r in self.runners]
output = []
for env_idx in range(self.num_envs):
fgs_states = {
r.fine_grained_subtask.name: r.state_for_env(env_idx, completeness[i][env_idx], scores[i][env_idx])
for i, r in enumerate(self.runners)
}
overall_score = sum(
r.fine_grained_subtask.score * fgs_states[r.fine_grained_subtask.name]["score"]
for r in self.runners
)
all_complete = all(s["is_complete"] for s in fgs_states.values())
output.append({
"fine_grained_subtasks": fgs_states,
"overall_score": overall_score,
"all_complete": all_complete,
})
return output
"
There was a problem hiding this comment.
Applied the change to separate this into getting state per_env and split up into different chunks
There was a problem hiding this comment.
Nice, thanks for the refactor! 👍
|
Thanks @cvolkcvolk for the review. I've gone ahead renamed from "FineGrainedSubtasks" to "FineGrainedProgressObjective" and "FineGrainedProgressTracker". I kept the "FineGrained" terminology since I wanted to make it clear that the predicate chains we give tasks here can be made to be more detailed and include this notion of multiple "subtasks" in the progress (i.e. we can have multiple predicates or predicate groups for a single TaskBase). |
…d/fine_grained_subtask_tracking
Summary
This PR adds foundational machinery required for fine grained subtask tracking in Arena.