-
Notifications
You must be signed in to change notification settings - Fork 149
Determinant of factorized matrices #1785
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| [det] = node.outputs | ||
| [x] = node.inputs | ||
|
|
||
| only_used_by_abs = all( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any Op that that maps (-1, 1) to the same value is actually fine, At the very least should include square as well
| match core_op: | ||
| case Cholesky(): | ||
| L = client.outputs[0] | ||
| new_det = matrix_diagonal_product(L) ** 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: Add the positive tag here.
Possibly also rewrite for log(x ** 2) -> log(x) * 2, when we know x is positive
| case QR(): | ||
| R = client.outputs[-1] | ||
| # if mode == "economic", R may not be square and this rewrite could hide a shape error | ||
| # That's why it's tagged as `shape_unsafe` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This rewrite isn't tagged shape_unsafe
| new_det = ones(x.shape[:-2], dtype=det.dtype) | ||
| case QR(): | ||
| # if mode == "economic", Q/R may not be square and this rewrite could hide a shape error | ||
| # That's why it's tagged as `shape_unsafe` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it worth handling this case in a separate rewrite so as to not tag the others as shape_unsafe (since they aren't)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've thought about that as well. That's almost always the case, only a subset of the matching cases is actually unsafe in a rewrite.
OTOH the tag is mostly a debug thing, if you're getting an odd result or shape error you may want to exclude to see if it goes away or the error is more obvious.
You never really want to exclude them at the end of the day
The old
local_det_cholrewrite is extended to cover more cases of a matrix that is factorized elsewhere, not just with Cholesky, but also LU, LUFactor, or SVD, QR (the latter two only if the sign isn't needed)A new rewrite is added for the determinant of a factorization itself. The logic is slightly different, for instance det(LUFactor) is non-sensical, and the determinant for some outputs of SVD/ QR can always be computed even if the determinant of the whole factorization cannot.
Also extended the rewrite of log(prod(x)) to sum(log(x)), which should increase the stability of many of these when we want the log determinant (or log(abs(determintant))).
Still missing tests
Closes #1679
Related to #573