Skip to content

Implement batch evaluation of RV shapes in get_plates#8033

Open
RohanAich wants to merge 3 commits intopymc-devs:mainfrom
RohanAich:plate-shape-eval
Open

Implement batch evaluation of RV shapes in get_plates#8033
RohanAich wants to merge 3 commits intopymc-devs:mainfrom
RohanAich:plate-shape-eval

Conversation

@RohanAich
Copy link
Copy Markdown

@RohanAich RohanAich commented Jan 4, 2026

Implement batch evaluation of RV shapes in ModelGraph.get_plates

Description

Implements the TODO in ModelGraph.get_plates to evaluate all random variables (RV) shapes at once instead of indiviually.

Related Issue

  • Closes #
  • Related to TODO in ModelGraph.get_plates

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@codecov
Copy link
Copy Markdown

codecov Bot commented Jan 4, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 91.44%. Comparing base (056e80c) to head (21b603c).

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main    #8033   +/-   ##
=======================================
  Coverage   91.44%   91.44%           
=======================================
  Files         116      116           
  Lines       19002    19005    +3     
=======================================
+ Hits        17377    17380    +3     
  Misses       1625     1625           
Files with missing lines Coverage Δ
pymc/model_graph.py 84.69% <100.00%> (+0.15%) ⬆️
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ricardoV94
Copy link
Copy Markdown
Member

ricardoV94 commented Jan 5, 2026

Good chance to fix #8024?

We need something better than _cheap_eval_mode, because we need it to do shape inference first to handle flat RV. A possible strategy is to add this helper to pytensorf, and call it before we eval the chapes with the cheap mode:

def get_symbolic_rv_shapes(
    rvs: Sequence[Variable], raise_if_rvs_in_graph: bool = True
) -> tuple[TensorVariable]:
    # TODO: docstrings
    
    rv_shapes = [rv.shape for rv in rvs]
    shape_fg = FunctionGraph(outputs=rv_shapes, features=[ShapeFeature()], clone=True)
    with config.change_flags(optdb__max_use_ratio=10, cxx=""):
        infer_shape_db.default_query.rewrite(shape_fg)
    rv_shapes = shape_fg.outputs

    if raise_if_rvs_in_graph and (overlap := (set(rvs) & set(ancestors(rv_shapes)))):
        raise ValueError(f"rv_shapes still depend the following rvs {overlap}")

    return tuple(rv_shapes)

I would call it with raise_if_rvs_in_graph=False

@RohanAich
Copy link
Copy Markdown
Author

Thank you! This may take me a bit but I will try to get it done within the next few days.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants