diff --git a/pixi.lock b/pixi.lock index 385b275..e7d7093 100644 --- a/pixi.lock +++ b/pixi.lock @@ -5,10 +5,13 @@ environments: - url: https://conda.anaconda.org/conda-forge/ indexes: - https://pypi.org/simple + options: + pypi-prerelease-mode: if-necessary-or-explicit packages: linux-64: - conda: https://conda.anaconda.org/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/linux-64/_openmp_mutex-4.5-2_gnu.tar.bz2 + - conda: https://conda.anaconda.org/conda-forge/noarch/asttokens-3.0.1-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/basedpyright-1.36.1-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/binutils_impl_linux-64-2.45-h9d8b0ac_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/bzip2-1.0.8-hda65f42_8.conda @@ -20,9 +23,14 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/colorama-0.4.6-pyhd8ed1ab_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/compiler-rt21-21.1.5-hb700be7_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/compiler-rt21_linux-64-21.1.5-hffcefe0_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/decorator-5.2.1-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/exceptiongroup-1.3.0-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/executing-2.2.1-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/icu-75.1-he02047a_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/iniconfig-2.3.0-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/ipython-9.8.0-pyh53cf698_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/ipython_pygments_lexers-1.1.1-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/jedi-0.19.2-pyhd8ed1ab_1.conda - conda: https://conda.anaconda.org/conda-forge/noarch/kernel-headers_linux-64-4.18.0-he073ed8_8.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/keyutils-1.6.3-hb9d3cd8_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/krb5-1.21.3-h659f571_0.conda @@ -58,24 +66,33 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/libxml2-2.15.1-h26afc86_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libzlib-1.3.1-hb9d3cd8_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/llvm-openmp-21.1.5-h4922eb0_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/matplotlib-inline-0.2.1-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/ncurses-6.5-h2d0b736_3.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/nodejs-24.9.0-heeeca48_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/nodejs-wheel-24.12.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/numpy-2.3.5-py314h2b28147_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/openssl-3.6.0-h26f9b46_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/packaging-25.0-pyh29332c3_1.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/parso-0.8.5-pyhcf101f3_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/pexpect-4.9.0-pyhd8ed1ab_1.conda - conda: https://conda.anaconda.org/conda-forge/noarch/pluggy-1.6.0-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/prompt-toolkit-3.0.52-pyha770c72_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/ptyprocess-0.7.0-pyhd8ed1ab_1.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/pure_eval-0.2.3-pyhd8ed1ab_1.conda - conda: https://conda.anaconda.org/conda-forge/noarch/pygments-2.19.2-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/pytest-9.0.2-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/python-3.14.0-h32b2ec7_102_cp314.conda - conda: https://conda.anaconda.org/conda-forge/noarch/python_abi-3.14-8_cp314.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/readline-8.2-h8c095d6_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/rhash-1.4.6-hb9d3cd8_1.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/stack_data-0.6.3-pyhd8ed1ab_1.conda - conda: https://conda.anaconda.org/conda-forge/noarch/sysroot_linux-64-2.28-h4ee821c_8.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.13-noxft_ha0e22de_103.conda - conda: https://conda.anaconda.org/conda-forge/noarch/tomli-2.3.0-pyhcf101f3_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/traitlets-5.14.3-pyhd8ed1ab_1.conda - conda: https://conda.anaconda.org/conda-forge/noarch/typing_extensions-4.15.0-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/tzdata-2025b-h78e105d_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/wcwidth-0.2.14-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb8e6e7a_2.conda - pypi: ./ packages: @@ -100,6 +117,19 @@ packages: purls: [] size: 23621 timestamp: 1650670423406 +- conda: https://conda.anaconda.org/conda-forge/noarch/asttokens-3.0.1-pyhd8ed1ab_0.conda + sha256: ee4da0f3fe9d59439798ee399ef3e482791e48784873d546e706d0935f9ff010 + md5: 9673a61a297b00016442e022d689faa6 + depends: + - python >=3.10 + constrains: + - astroid >=2,<5 + license: Apache-2.0 + license_family: Apache + purls: + - pkg:pypi/asttokens?source=hash-mapping + size: 28797 + timestamp: 1763410017955 - conda: https://conda.anaconda.org/conda-forge/noarch/basedpyright-1.36.1-pyhcf101f3_0.conda sha256: 4e83264419ec97a2da8b07164ea5ad3aaf6d7842b6881b7c5807718d9fab8e9f md5: 12654a584fa0872deea98f0d6ae548a2 @@ -238,6 +268,17 @@ packages: purls: [] size: 51691276 timestamp: 1762315639532 +- conda: https://conda.anaconda.org/conda-forge/noarch/decorator-5.2.1-pyhd8ed1ab_0.conda + sha256: c17c6b9937c08ad63cb20a26f403a3234088e57d4455600974a0ce865cb14017 + md5: 9ce473d1d1be1cc3810856a48b3fab32 + depends: + - python >=3.9 + license: BSD-2-Clause + license_family: BSD + purls: + - pkg:pypi/decorator?source=hash-mapping + size: 14129 + timestamp: 1740385067843 - conda: https://conda.anaconda.org/conda-forge/noarch/exceptiongroup-1.3.0-pyhd8ed1ab_0.conda sha256: ce61f4f99401a4bd455b89909153b40b9c823276aefcbb06f2044618696009ca md5: 72e42d28960d875c7654614f8b50939a @@ -249,12 +290,22 @@ packages: - pkg:pypi/exceptiongroup?source=hash-mapping size: 21284 timestamp: 1746947398083 +- conda: https://conda.anaconda.org/conda-forge/noarch/executing-2.2.1-pyhd8ed1ab_0.conda + sha256: 210c8165a58fdbf16e626aac93cc4c14dbd551a01d1516be5ecad795d2422cad + md5: ff9efb7f7469aed3c4a8106ffa29593c + depends: + - python >=3.10 + license: MIT + license_family: MIT + purls: + - pkg:pypi/executing?source=hash-mapping + size: 30753 + timestamp: 1756729456476 - pypi: ./ name: gibbsflux version: 0.0.1 - sha256: 76f2cca0a506c495fce951044dc36a51015ce2b0aef58830a23c0f1c024e4558 + sha256: 84e089bc957b5b405345d41aba29cc24a52e27678eb5a3610f31da11f3625b67 requires_python: '>=3.13' - editable: true - conda: https://conda.anaconda.org/conda-forge/linux-64/icu-75.1-he02047a_0.conda sha256: 71e750d509f5fa3421087ba88ef9a7b9be11c53174af3aa4d06aff4c18b38e8e md5: 8b189310083baabfb622af68fd9d3ae3 @@ -278,6 +329,52 @@ packages: - pkg:pypi/iniconfig?source=compressed-mapping size: 13387 timestamp: 1760831448842 +- conda: https://conda.anaconda.org/conda-forge/noarch/ipython-9.8.0-pyh53cf698_0.conda + sha256: 8a72c9945dc4726ee639a9652b622ae6b03f3eba0e16a21d1c6e5bfb562f5a3f + md5: fd77b1039118a3e8ce1070ac8ed45bae + depends: + - __unix + - pexpect >4.3 + - decorator >=4.3.2 + - ipython_pygments_lexers >=1.0.0 + - jedi >=0.18.1 + - matplotlib-inline >=0.1.5 + - prompt-toolkit >=3.0.41,<3.1.0 + - pygments >=2.11.0 + - python >=3.11 + - stack_data >=0.6.0 + - traitlets >=5.13.0 + - typing_extensions >=4.6 + - python + license: BSD-3-Clause + license_family: BSD + purls: + - pkg:pypi/ipython?source=compressed-mapping + size: 645145 + timestamp: 1764766793792 +- conda: https://conda.anaconda.org/conda-forge/noarch/ipython_pygments_lexers-1.1.1-pyhd8ed1ab_0.conda + sha256: 894682a42a7d659ae12878dbcb274516a7031bbea9104e92f8e88c1f2765a104 + md5: bd80ba060603cc228d9d81c257093119 + depends: + - pygments + - python >=3.9 + license: BSD-3-Clause + license_family: BSD + purls: + - pkg:pypi/ipython-pygments-lexers?source=hash-mapping + size: 13993 + timestamp: 1737123723464 +- conda: https://conda.anaconda.org/conda-forge/noarch/jedi-0.19.2-pyhd8ed1ab_1.conda + sha256: 92c4d217e2dc68983f724aa983cca5464dcb929c566627b26a2511159667dba8 + md5: a4f4c5dc9b80bc50e0d3dc4e6e8f1bd9 + depends: + - parso >=0.8.3,<0.9.0 + - python >=3.9 + license: Apache-2.0 AND MIT + purls: + - pkg:pypi/jedi?source=hash-mapping + size: 843646 + timestamp: 1733300981994 - conda: https://conda.anaconda.org/conda-forge/noarch/kernel-headers_linux-64-4.18.0-he073ed8_8.conda sha256: 305c22a251db227679343fd73bfde121e555d466af86e537847f4c8b9436be0d md5: ff007ab0f0fdc53d245972bba8a6d40c @@ -726,6 +823,18 @@ packages: purls: [] size: 3226046 timestamp: 1762315432827 +- conda: https://conda.anaconda.org/conda-forge/noarch/matplotlib-inline-0.2.1-pyhd8ed1ab_0.conda + sha256: 9d690334de0cd1d22c51bc28420663f4277cfa60d34fa5cad1ce284a13f1d603 + md5: 00e120ce3e40bad7bfc78861ce3c4a25 + depends: + - python >=3.10 + - traitlets + license: BSD-3-Clause + license_family: BSD + purls: + - pkg:pypi/matplotlib-inline?source=compressed-mapping + size: 15175 + timestamp: 1761214578417 - conda: https://conda.anaconda.org/conda-forge/linux-64/ncurses-6.5-h2d0b736_3.conda sha256: 3fde293232fa3fca98635e1167de6b7c7fda83caf24b9d6c91ec9eefb4f4d586 md5: 47e340acb35de30501a76c7c799c41d7 @@ -809,6 +918,29 @@ packages: - pkg:pypi/packaging?source=hash-mapping size: 62477 timestamp: 1745345660407 +- conda: https://conda.anaconda.org/conda-forge/noarch/parso-0.8.5-pyhcf101f3_0.conda + sha256: 30de7b4d15fbe53ffe052feccde31223a236dae0495bab54ab2479de30b2990f + md5: a110716cdb11cf51482ff4000dc253d7 + depends: + - python >=3.10 + - python + license: MIT + license_family: MIT + purls: + - pkg:pypi/parso?source=hash-mapping + size: 81562 + timestamp: 1755974222274 +- conda: https://conda.anaconda.org/conda-forge/noarch/pexpect-4.9.0-pyhd8ed1ab_1.conda + sha256: 202af1de83b585d36445dc1fda94266697341994d1a3328fabde4989e1b3d07a + md5: d0d408b1f18883a944376da5cf8101ea + depends: + - ptyprocess >=0.5 + - python >=3.9 + license: ISC + purls: + - pkg:pypi/pexpect?source=hash-mapping + size: 53561 + timestamp: 1733302019362 - conda: https://conda.anaconda.org/conda-forge/noarch/pluggy-1.6.0-pyhd8ed1ab_0.conda sha256: a8eb555eef5063bbb7ba06a379fa7ea714f57d9741fe0efdb9442dbbc2cccbcc md5: 7da7ccd349dbf6487a7778579d2bb971 @@ -820,6 +952,41 @@ packages: - pkg:pypi/pluggy?source=hash-mapping size: 24246 timestamp: 1747339794916 +- conda: https://conda.anaconda.org/conda-forge/noarch/prompt-toolkit-3.0.52-pyha770c72_0.conda + sha256: 4817651a276016f3838957bfdf963386438c70761e9faec7749d411635979bae + md5: edb16f14d920fb3faf17f5ce582942d6 + depends: + - python >=3.10 + - wcwidth + constrains: + - prompt_toolkit 3.0.52 + license: BSD-3-Clause + license_family: BSD + purls: + - pkg:pypi/prompt-toolkit?source=hash-mapping + size: 273927 + timestamp: 1756321848365 +- conda: https://conda.anaconda.org/conda-forge/noarch/ptyprocess-0.7.0-pyhd8ed1ab_1.conda + sha256: a7713dfe30faf17508ec359e0bc7e0983f5d94682492469bd462cdaae9c64d83 + md5: 7d9daffbb8d8e0af0f769dbbcd173a54 + depends: + - python >=3.9 + license: ISC + purls: + - pkg:pypi/ptyprocess?source=hash-mapping + size: 19457 + timestamp: 1733302371990 +- conda: https://conda.anaconda.org/conda-forge/noarch/pure_eval-0.2.3-pyhd8ed1ab_1.conda + sha256: 71bd24600d14bb171a6321d523486f6a06f855e75e547fa0cb2a0953b02047f0 + md5: 3bfdfb8dbcdc4af1ae3f9a8eb3948f04 + depends: + - python >=3.9 + license: MIT + license_family: MIT + purls: + - pkg:pypi/pure-eval?source=hash-mapping + size: 16668 + timestamp: 1733569518868 - conda: https://conda.anaconda.org/conda-forge/noarch/pygments-2.19.2-pyhd8ed1ab_0.conda sha256: 5577623b9f6685ece2697c6eb7511b4c9ac5fb607c9babc2646c811b428fd46a md5: 6b6ece66ebcae2d5f326c77ef2c5a066 @@ -913,6 +1080,20 @@ packages: purls: [] size: 193775 timestamp: 1748644872902 +- conda: https://conda.anaconda.org/conda-forge/noarch/stack_data-0.6.3-pyhd8ed1ab_1.conda + sha256: 570da295d421661af487f1595045760526964f41471021056e993e73089e9c41 + md5: b1b505328da7a6b246787df4b5a49fbc + depends: + - asttokens + - executing + - pure_eval + - python >=3.9 + license: MIT + license_family: MIT + purls: + - pkg:pypi/stack-data?source=hash-mapping + size: 26988 + timestamp: 1733569565672 - conda: https://conda.anaconda.org/conda-forge/noarch/sysroot_linux-64-2.28-h4ee821c_8.conda sha256: 0053c17ffbd9f8af1a7f864995d70121c292e317804120be4667f37c92805426 md5: 1bad93f0aa428d618875ef3a588a889e @@ -951,6 +1132,17 @@ packages: - pkg:pypi/tomli?source=compressed-mapping size: 20973 timestamp: 1760014679845 +- conda: https://conda.anaconda.org/conda-forge/noarch/traitlets-5.14.3-pyhd8ed1ab_1.conda + sha256: f39a5620c6e8e9e98357507262a7869de2ae8cc07da8b7f84e517c9fd6c2b959 + md5: 019a7385be9af33791c989871317e1ed + depends: + - python >=3.9 + license: BSD-3-Clause + license_family: BSD + purls: + - pkg:pypi/traitlets?source=hash-mapping + size: 110051 + timestamp: 1733367480074 - conda: https://conda.anaconda.org/conda-forge/noarch/typing_extensions-4.15.0-pyhcf101f3_0.conda sha256: 032271135bca55aeb156cee361c81350c6f3fb203f57d024d7e5a1fc9ef18731 md5: 0caa1af407ecff61170c9437a808404d @@ -970,6 +1162,17 @@ packages: purls: [] size: 122968 timestamp: 1742727099393 +- conda: https://conda.anaconda.org/conda-forge/noarch/wcwidth-0.2.14-pyhd8ed1ab_0.conda + sha256: e311b64e46c6739e2a35ab8582c20fa30eb608da130625ed379f4467219d4813 + md5: 7e1e5ff31239f9cd5855714df8a3783d + depends: + - python >=3.10 + license: MIT + license_family: MIT + purls: + - pkg:pypi/wcwidth?source=hash-mapping + size: 33670 + timestamp: 1758622418893 - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb8e6e7a_2.conda sha256: a4166e3d8ff4e35932510aaff7aa90772f84b4d07e9f6f83c614cba7ceefe0eb md5: 6432cb5d4ac0046c3ac0a8a0f95842f9 diff --git a/pyproject.toml b/pyproject.toml index f4f14bc..972f9cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,3 +25,4 @@ cmake = ">=4.1.2,<5" pytest = ">=9.0.2,<10" numpy = ">=2.3.5,<3" basedpyright = ">=1.36.1,<2" +ipython = ">=9.8.0,<10" diff --git a/pyrightconfig.json b/pyrightconfig.json new file mode 100644 index 0000000..a9ad51a --- /dev/null +++ b/pyrightconfig.json @@ -0,0 +1,4 @@ +{ + "venvPath": ".pixi/envs", + "venv": "default" +} diff --git a/src/gibbsflux/conjugacy/normal_form.py b/src/gibbsflux/conjugacy/normal_form.py new file mode 100644 index 0000000..a1d1d59 --- /dev/null +++ b/src/gibbsflux/conjugacy/normal_form.py @@ -0,0 +1,21 @@ +from gibbsflux.factors.syntax import terms as t + +type KeyType = tuple[int | str | KeyType, ...] + +def key(term: t.Term, blanket_var: str) -> KeyType: # pyright: ignore + match term: + case t.Var(_, var): + return (0,) if var == blanket_var else (1, var) + case _ if t.is_base_factor_term(term): + return (2, *t.arg_positions(term, blanket_var), type(term).__name__) # pyright: ignore + case t.Mix(arg, branches): + return (3, int(arg != blanket_var), + *(key(branch, blanket_var) for branch in branches)) + + +def normalize(markov_blanket: t.Program, blanket_var: str): + # precondition: markov_blanket is well-typed and actually encodes a Markov + # blanket factor graph + return t.Program(sorted(markov_blanket.terms, + key=lambda term: key(term, blanket_var))) + diff --git a/src/gibbsflux/factors/syntax/terms.py b/src/gibbsflux/factors/syntax/terms.py index 6bd116b..4de1cc7 100644 --- a/src/gibbsflux/factors/syntax/terms.py +++ b/src/gibbsflux/factors/syntax/terms.py @@ -84,14 +84,20 @@ def __eq__(self, other): ## helpers ## +def arg_positions(term: BaseFactor, var: str) -> tuple[int, ...]: + return tuple(i for (i, arg) in enumerate(astuple(term))\ + if isinstance(arg, str) and arg == var) + +def is_base_factor_term(term: Term) -> bool: + return any(isinstance(term, type) for type in BaseFactor.__args__) + def free_vars(term: Term) -> set[str]: match term: case Var(_, name): return {name} - case Normal() | Dirichlet() | Categorical(): + case _ if is_base_factor_term(term): return set(filter(lambda arg: isinstance(arg, str), astuple(term))) case Mix(arg, branches): argvars = {arg} if isinstance(arg, str) else set() return argvars.union(*map(free_vars, branches)) - diff --git a/test/fixtures.py b/test/fixtures.py new file mode 100644 index 0000000..643e018 --- /dev/null +++ b/test/fixtures.py @@ -0,0 +1,128 @@ +"""Factory functions for constructing common probabilistic models.""" + +import numpy as np +from gibbsflux.factors.syntax import types as ft +from gibbsflux.factors.syntax.terms import ( + Var, Normal, Dirichlet, Categorical, Mix, Program +) + + +def make_gmm(data: np.ndarray, num_components: int, **hypers) -> Program: + """ + Construct a Gaussian Mixture Model (GMM) for 1D data. + + Parameters + ---------- + data : np.ndarray + 1D float array of shape (N,) containing observed data points. + num_components : int + Number of mixture components (K). + **hypers : dict + Hyperparameters with the following options: + - mean_prior_mean (float, default=0.0): Prior mean for component means + - mean_prior_var (float, default=100.0): Prior variance for component means + - component_var (float, default=1.0): Fixed variance for each component + - weight_concentration (float, default=1.0): Symmetric Dirichlet concentration + + Returns + ------- + Program + A Program object representing the unnormalized posterior distribution. + Variables include: + - mu_1, ..., mu_K: Component means + - w: Mixture weights (K-dimensional simplex) + - z_1, ..., z_N: Mixture indicators for each data point + + Example + ------- + >>> data = np.array([1.0, 2.0, 5.0, 6.0]) + >>> model = make_gmm(data, num_components=2) + """ + # Extract hyperparameters with defaults + mean_prior_mean = hypers.get('mean_prior_mean', 0.0) + mean_prior_var = hypers.get('mean_prior_var', 100.0) + component_var = hypers.get('component_var', 1.0) + weight_concentration = hypers.get('weight_concentration', 1.0) + + N = len(data) + K = num_components + + return Program([ + # declare component means + *(Var(ft.R, f'mu_{k}') for k in range(1, K + 1)), + # declare mixture weights + Var(ft.Rn(K), 'w'), + # declare mixture indicators for each data point + *(Var(ft.Nn(K), f'z_{i}') for i in range(1, N + 1)), + # add priors on component means + *(Normal(f'mu_{k}', mean_prior_mean, mean_prior_var) for k in range(1, K + 1)), + # add prior on mixture weights (symmetric Dirichlet) + Dirichlet('w', weight_concentration * np.ones(K)), + # mixture indicator: z_i ~ Categorical(w) + *(Categorical(f'z_{i}', 'w') for i in range(1, N + 1)), + # likelihood: data[i-1] | z_i, mu ~ Normal(mu_{z_i}, component_var) + *(Mix(f'z_{i}', [Normal(data[i-1], f'mu_{k}', component_var) for k in range(1, K + 1)]) + for i in range(1, N + 1)) + ]) + + + +def make_discrete_hmm( + data: np.ndarray, + prior: np.ndarray, + transition: np.ndarray, + emission: np.ndarray +) -> Program: + """ + Construct a discrete Hidden Markov Model (HMM). + + Parameters + ---------- + data : np.ndarray + Shape (T,), observed sequence with integer values in [0, m). + prior : np.ndarray + Shape (n,), initial state distribution. Should sum to 1. + transition : np.ndarray + Shape (n, n), transition matrix where transition[i, j] = P(x_t=i | x_{t-1}=j). + Each column should sum to 1. + emission : np.ndarray + Shape (m, n), emission matrix where emission[i, j] = P(obs=i | x_t=j). + Each column should sum to 1. + + Returns + ------- + Program + A Program object representing the unnormalized posterior over hidden states. + Variables include: + - x_0, x_1, ..., x_{T-1}: Hidden states at each time step + + Notes + ----- + The function constructs a model for inference over the hidden state sequence + given fixed transition and emission parameters. The matrices are interpreted as: + - transition[:, j] gives the distribution over next states given current state j + - emission[:, j] gives the distribution over observations given state j + + Example + ------- + >>> data = np.array([0, 1, 1, 0]) + >>> prior = np.array([0.5, 0.5]) + >>> transition = np.array([[0.7, 0.3], [0.3, 0.7]]) + >>> emission = np.array([[0.9, 0.2], [0.1, 0.8]]) + >>> model = make_discrete_hmm(data, prior, transition, emission) + """ + T = len(data) + n = len(prior) + + return Program([ + # declare hidden state variables + *(Var(ft.Nn(n), f'x_{t}') for t in range(T)), + # initial state distribution: x_0 ~ Categorical(prior) + Categorical('x_0', np.log(prior)), + # transition factors: x_{t+1} | x_t ~ Categorical(transition[:, x_t]) + *(Mix(f'x_{t}', [Categorical(f'x_{t+1}', np.log(transition[:, j])) for j in range(n)]) + for t in range(T - 1)), + # observation/emission factors: data[t] | x_t ~ Categorical(emission[:, x_t]) + *(Mix(f'x_{t}', [Categorical(int(data[t]), np.log(emission[:, j])) for j in range(n)]) + for t in range(T)) + ]) diff --git a/test/test_conjugacy_normal_form.py b/test/test_conjugacy_normal_form.py new file mode 100644 index 0000000..bd04091 --- /dev/null +++ b/test/test_conjugacy_normal_form.py @@ -0,0 +1,30 @@ +import numpy as np +import pytest + +from gibbsflux.factors.syntax import types as ft +from gibbsflux.factors.syntax.terms import ( + Var, Normal, Dirichlet, Categorical, Mix, Program +) +from gibbsflux.interpreters import markov_blanket +from gibbsflux.conjugacy import normal_form + +@pytest.fixture +def gmm(): + return Program([ + Var(ft.R, 'mu1'), + Var(ft.R, 'mu2'), + Var(ft.Rn(2), 'w'), + Var(ft.Nn(2), 'z'), + Normal('mu1', 0, 100), + Normal('mu2', 0, 100), + Dirichlet('w', np.ones(2)), + Categorical('z', 'w'), + Mix('z', [Normal(1, 'mu1', 1), Normal(1, 'mu2', 1)]), + ]) + + +def test_markov_blanket_smoke(gmm): + for blanket_var in 'mu1 mu2 w z'.split(): + mb = markov_blanket.markov_blanket(gmm, blanket_var) + normal_form.normalize(mb, blanket_var) + diff --git a/test/test_factors.py b/test/test_factors.py index b163e0a..a503aa1 100644 --- a/test/test_factors.py +++ b/test/test_factors.py @@ -7,6 +7,8 @@ ) from gibbsflux.interpreters import factor_typechecker +from fixtures import make_discrete_hmm, make_gmm + @pytest.fixture def gmm(): return Program([ @@ -74,3 +76,22 @@ def test_type_checking_bad_promotion(): with pytest.raises(TypeError): factor_typechecker.interpret_type(bad_gmm) +def test_type_checking_make_gmm(): + num_datapoints = 5 + num_components = 3 + gmm = make_gmm(np.ones(num_datapoints), num_components=num_components) + target_type = ft.Potential([ft.R] * num_components + + [ft.Rn(num_components)] + + [ft.Nn(num_components)] * num_datapoints) + assert factor_typechecker.interpret_type(gmm) == target_type + +def test_type_checking_make_discrete_hmm(): + num_datapoints = 5 + num_states = 3 + num_emiss = 4 + hmm = make_discrete_hmm(data=np.arange(num_datapoints) & num_emiss, + prior=np.ones(num_states)/num_states, + transition=np.ones((num_states, num_states))/num_states, + emission=np.ones((num_emiss, num_states))/num_emiss) + target_type = ft.Potential([ft.Nn(num_states)] * num_datapoints) + assert factor_typechecker.interpret_type(hmm) == target_type