Skip to content

Commit 4a6cc02

Browse files
brianwa84Copybara-Service
authored andcommitted
Adds the VonMisesFisher distribution over points on the unit hypersphere.
The sampling algorithm is due to Wood'94. The pdf currently uses a recurrence for bessel_iN which isn't great for large N (WIP on better bessel_ive). Tests are based on importance sampling the surface area of a [hyper]spherical cap. Fixes tensorflow/tensorflow#6141 PiperOrigin-RevId: 205891121
1 parent ae307bf commit 4a6cc02

File tree

4 files changed

+787
-0
lines changed

4 files changed

+787
-0
lines changed

tensorflow_probability/python/distributions/BUILD

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ py_library(
7575
":vector_laplace_linear_operator",
7676
":vector_sinh_arcsinh_diag",
7777
":vector_student_t",
78+
":von_mises_fisher",
7879
":wishart",
7980
# numpy dep,
8081
# tensorflow dep,
@@ -497,6 +498,15 @@ py_library(
497498
],
498499
)
499500

501+
py_library(
502+
name = "von_mises_fisher",
503+
srcs = ["von_mises_fisher.py"],
504+
deps = [
505+
# numpy dep,
506+
# tensorflow dep,
507+
],
508+
)
509+
500510
py_library(
501511
name = "wishart",
502512
srcs = ["wishart.py"],
@@ -1026,6 +1036,19 @@ py_test(
10261036
],
10271037
)
10281038

1039+
py_test(
1040+
name = "von_mises_fisher_test",
1041+
size = "medium",
1042+
srcs = ["von_mises_fisher_test.py"],
1043+
deps = [
1044+
":distributions",
1045+
# numpy dep,
1046+
# tensorflow dep,
1047+
"//tensorflow_probability",
1048+
"//tensorflow_probability/python/internal:test_util",
1049+
],
1050+
)
1051+
10291052
py_test(
10301053
name = "vector_student_t_test",
10311054
size = "medium",

tensorflow_probability/python/distributions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
from tensorflow_probability.python.distributions.vector_exponential_diag import VectorExponentialDiag
7575
from tensorflow_probability.python.distributions.vector_laplace_diag import VectorLaplaceDiag
7676
from tensorflow_probability.python.distributions.vector_sinh_arcsinh_diag import VectorSinhArcsinhDiag
77+
from tensorflow_probability.python.distributions.von_mises_fisher import VonMisesFisher
7778
from tensorflow_probability.python.distributions.wishart import Wishart
7879

7980
from tensorflow_probability.python.internal.distribution_util import fill_triangular
@@ -174,6 +175,7 @@
174175
'VectorDiffeomixture',
175176
'VectorLaplaceDiag',
176177
'VectorSinhArcsinhDiag',
178+
'VonMisesFisher',
177179
'Wishart',
178180
'TransformedDistribution',
179181
'QuantizedDistribution',

0 commit comments

Comments
 (0)