Skip to content

Commit 6a98d9d

Browse files
[BUG] Replace deprecated batched_dot with pt.sum in KroneckerNormal
- Fixes Issue #7878 - Replace pt.batched_dot(sqrt_quad.T, sqrt_quad.T) with pt.sum(sqrt_quad.T ** 2, axis=-1) - Computes squared norm per sample using modern PyTensor operations - Eliminates deprecation warnings and ensures future compatibility
1 parent 3a0186e commit 6a98d9d

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

pymc/distributions/multivariate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2124,8 +2124,8 @@ def logp(value, rng, size, mu, sigma, *covs):
21242124
sqrt_quad = sqrt_quad / pt.sqrt(eigs[:, None])
21252125
logdet = pt.sum(pt.log(eigs))
21262126

2127-
# Square each sample
2128-
quad = pt.batched_dot(sqrt_quad.T, sqrt_quad.T)
2127+
# Square each sample - compute squared norm for each sample
2128+
quad = pt.sum(sqrt_quad.T ** 2, axis=-1)
21292129
if onedim:
21302130
quad = quad[0]
21312131

0 commit comments

Comments
 (0)