|
14 | 14 |
|
15 | 15 | import warnings
|
16 | 16 |
|
| 17 | +from typing import List |
| 18 | + |
17 | 19 | import numpy as np
|
18 | 20 | import theano.tensor as tt
|
19 | 21 |
|
@@ -565,3 +567,78 @@ def jacobian_det(self, y):
|
565 | 567 | else:
|
566 | 568 | det += det_
|
567 | 569 | return det
|
| 570 | + |
| 571 | + |
| 572 | +def _extend_axis(array, axis): |
| 573 | + n = array.shape[axis] + 1 |
| 574 | + sum_vals = array.sum(axis, keepdims=True) |
| 575 | + norm = sum_vals / (np.sqrt(n) + n) |
| 576 | + fill_val = norm - sum_vals / np.sqrt(n) |
| 577 | + |
| 578 | + out = tt.concatenate([array, fill_val], axis=axis) |
| 579 | + return out - norm |
| 580 | + |
| 581 | + |
| 582 | +def _extend_axis_rev(array, axis): |
| 583 | + if axis < 0: |
| 584 | + axis = axis % array.ndim |
| 585 | + assert axis >= 0 and axis < array.ndim |
| 586 | + |
| 587 | + n = array.shape[axis] |
| 588 | + last = tt.take(array, [-1], axis=axis) |
| 589 | + |
| 590 | + sum_vals = -last * np.sqrt(n) |
| 591 | + norm = sum_vals / (np.sqrt(n) + n) |
| 592 | + slice_before = (slice(None, None),) * axis |
| 593 | + return array[slice_before + (slice(None, -1),)] + norm |
| 594 | + |
| 595 | + |
| 596 | +def _extend_axis_val(array, axis): |
| 597 | + n = array.shape[axis] + 1 |
| 598 | + sum_vals = array.sum(axis, keepdims=True) |
| 599 | + norm = sum_vals / (np.sqrt(n) + n) |
| 600 | + fill_val = norm - sum_vals / np.sqrt(n) |
| 601 | + |
| 602 | + out = np.concatenate([array, fill_val], axis=axis) |
| 603 | + return out - norm |
| 604 | + |
| 605 | + |
| 606 | +def _extend_axis_rev_val(array, axis): |
| 607 | + n = array.shape[axis] |
| 608 | + last = np.take(array, [-1], axis=axis) |
| 609 | + |
| 610 | + sum_vals = -last * np.sqrt(n) |
| 611 | + norm = sum_vals / (np.sqrt(n) + n) |
| 612 | + slice_before = (slice(None, None),) * len(array.shape[:axis]) |
| 613 | + return array[slice_before + (slice(None, -1),)] + norm |
| 614 | + |
| 615 | + |
| 616 | +class ZeroSumTransform(Transform): |
| 617 | + name = "zerosum" |
| 618 | + |
| 619 | + _zerosum_axes: List[int] |
| 620 | + |
| 621 | + def __init__(self, zerosum_axes): |
| 622 | + self._zerosum_axes = zerosum_axes |
| 623 | + |
| 624 | + def forward(self, x): |
| 625 | + for axis in self._zerosum_axes: |
| 626 | + x = _extend_axis_rev(x, axis=axis) |
| 627 | + return floatX(x) |
| 628 | + |
| 629 | + def forward_val(self, x, point): |
| 630 | + for axis in self._zerosum_axes: |
| 631 | + x = _extend_axis_rev_val(x, axis=axis) |
| 632 | + return x |
| 633 | + |
| 634 | + def backward(self, z): |
| 635 | + z = tt.as_tensor_variable(z) |
| 636 | + for axis in self._zerosum_axes: |
| 637 | + z = _extend_axis(z, axis=axis) |
| 638 | + return floatX(z) |
| 639 | + |
| 640 | + def jacobian_det(self, x): |
| 641 | + return tt.constant(0.0) |
| 642 | + |
| 643 | + |
| 644 | +zerosum = ZeroSumTransform |
0 commit comments