diff --git a/notebooks/source/covtype.ipynb b/notebooks/source/covtype.ipynb new file mode 100644 index 000000000..45699f0c1 --- /dev/null +++ b/notebooks/source/covtype.ipynb @@ -0,0 +1,699 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from functools import partial\n", + "import os\n", + "import time\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as onp\n", + "\n", + "from jax import jit, lax, random, vmap\n", + "import jax.numpy as np\n", + "\n", + "import numpyro; numpyro.set_platform(\"cpu\"); numpyro.set_host_device_count(4)\n", + "from numpyro.contrib.autoguide import AutoBNAFNormal, AutoContinuousELBO\n", + "import numpyro.distributions as dist\n", + "from numpyro.examples.datasets import COVTYPE, load_dataset\n", + "from numpyro.handlers import scale\n", + "from numpyro.infer.hmc_util import consensus, parametric_draws\n", + "from numpyro.infer.util import initialize_model, transformed_potential_energy\n", + "from numpyro.infer import MCMC, NUTS, SVI, predictive\n", + "import numpyro.optim as optim" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### load data" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Data shape: (581012, 55)\n", + "Label distribution: 211840 has label 1, 369172 has label 0\n" + ] + } + ], + "source": [ + "def _load_dataset():\n", + " _, fetch = load_dataset(COVTYPE, shuffle=False)\n", + " features, labels = fetch()\n", + "\n", + " # normalize features and add intercept\n", + " features = (features - features.mean(0)) / features.std(0)\n", + " features = np.hstack([features, np.ones((features.shape[0], 1))])\n", + "\n", + " # make binary feature\n", + " _, counts = onp.unique(labels, return_counts=True)\n", + " specific_category = np.argmax(counts)\n", + " labels = (labels == specific_category)\n", + "\n", + " N, dim = features.shape\n", + " print(\"Data shape:\", features.shape)\n", + " print(\"Label distribution: {} has label 1, {} has label 0\"\n", + " .format(labels.sum(), N - labels.sum()))\n", + " return features, labels\n", + "\n", + "X_full, y_full = _load_dataset()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train set contains 400000 (68.85%) data points.\n", + "Test set contains 181012 (31.15%) data points.\n" + ] + } + ], + "source": [ + "def get_train_shards_and_test_data(X, y, K, N, rng=None):\n", + " if rng is not None:\n", + " idxs = random.shuffle(rng, np.arange(X.shape[0]))\n", + " X = X[idxs]\n", + " y = y[idxs]\n", + " shards = []\n", + " for i in range(K):\n", + " shards.append((X[i * N: (i + 1) * N], y[i * N: (i + 1) * N]))\n", + " train_data = (X[:K * N], y[:K * N])\n", + " test_data = (X[K * N:], y[K * N:])\n", + " return shards, train_data, test_data\n", + "\n", + "K, N = 40, 10000\n", + "shards, (X_train, y_train), (X_test, y_test) = get_train_shards_and_test_data(\n", + " X_full, y_full, K, N, random.PRNGKey(0))\n", + "print(\"Train set contains {} ({}%) data points.\".format(\n", + " K * N, round(K * N / X_full.shape[0] * 100, 2)))\n", + "print(\"Test set contains {} ({}%) data points.\".format(\n", + " X_full.shape[0] - K * N, round(100 - K * N / X_full.shape[0] * 100, 2)))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### model" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "def model(X, y=None, prior_scale=1, likelihood_scale=1):\n", + " with scale(scale_factor=prior_scale):\n", + " coefs = numpyro.sample('coefs', dist.Normal(0, 1), sample_shape=X.shape[-1:])\n", + " with scale(scale_factor=likelihood_scale):\n", + " numpyro.sample('y', dist.Bernoulli(logits=np.dot(X, coefs)), obs=y)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### sampling" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "def get_subposterior(rng, shard, K):\n", + " X, y = shard\n", + " mcmc = MCMC(NUTS(model), num_warmup=1000, num_samples=2500, num_chains=4,\n", + " chain_method=\"parallel\", progress_bar=False)\n", + " mcmc.run(rng, X, y, prior_scale=1 / K)\n", + " return mcmc.get_samples()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " =============================== SUBPOSTERIOR 00 ===============================took 71.48939394950867\n", + "\n", + " =============================== SUBPOSTERIOR 01 ===============================" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexists\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'.results'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmakedirs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'.results'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msave\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'.results/subposterior_{:02d}.npy'\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msamples\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 13\u001b[0m \u001b[0msubposteriors\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msamples\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0mend\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/pydata/lib/python3.6/site-packages/numpy/lib/npyio.py\u001b[0m in \u001b[0;36msave\u001b[0;34m(file, arr, allow_pickle, fix_imports)\u001b[0m\n\u001b[1;32m 534\u001b[0m \u001b[0marr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masanyarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 535\u001b[0m format.write_array(fid, arr, allow_pickle=allow_pickle,\n\u001b[0;32m--> 536\u001b[0;31m pickle_kwargs=pickle_kwargs)\n\u001b[0m\u001b[1;32m 537\u001b[0m \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 538\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mown_fid\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/pydata/lib/python3.6/site-packages/numpy/lib/format.py\u001b[0m in \u001b[0;36mwrite_array\u001b[0;34m(fp, array, version, allow_pickle, pickle_kwargs)\u001b[0m\n\u001b[1;32m 631\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mpickle_kwargs\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 632\u001b[0m \u001b[0mpickle_kwargs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 633\u001b[0;31m \u001b[0mpickle\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdump\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprotocol\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mpickle_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 634\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0marray\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mflags\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mf_contiguous\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0marray\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mflags\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mc_contiguous\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 635\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misfileobj\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfp\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/jax/jax/interpreters/xla.py\u001b[0m in \u001b[0;36m_forward_method\u001b[0;34m(attrname, self, fun, *args)\u001b[0m\n\u001b[1;32m 587\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 588\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_forward_method\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mattrname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 589\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mattrname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 590\u001b[0m \u001b[0m_forward_to_value\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpartial\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_forward_method\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"_value\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 591\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/jax/jax/interpreters/pxla.py\u001b[0m in \u001b[0;36m_value\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 388\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_npy_value\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 389\u001b[0m \u001b[0mids\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_ids\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 390\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcopy_to_host_async\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 391\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_npy_value\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_collect\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice_buffers\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto_py\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mids\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 392\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_npy_value\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/jax/jax/interpreters/pxla.py\u001b[0m in \u001b[0;36mcopy_to_host_async\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 376\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_npy_value\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 377\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mbuf\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice_buffers\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 378\u001b[0;31m \u001b[0mbuf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcopy_to_host_async\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 379\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 380\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mdelete\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "rngs = random.split(random.PRNGKey(0), K)\n", + "subposteriors = []\n", + "for i, (rng, shard) in enumerate(zip(rngs, shards)):\n", + " start = time.time()\n", + " sep = '=' * 31\n", + " if i > 5:\n", + " break\n", + " print('\\n ' + sep + ' SUBPOSTERIOR {:02d} '.format(i) + sep, end='')\n", + " samples = get_subposterior(rng, shard, K)\n", + " if not os.path.exists('.results'):\n", + " os.makedirs('.results')\n", + " np.save('.results/subposterior_{:02d}.npy'.format(i), samples)\n", + " subposteriors.append(samples)\n", + " end = time.time()\n", + " print(\"Elapsed time:\", end - start)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### merge subposteriors" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "consensus_samples = consensus(subposteriors, 10000)\n", + "parametric_samples = parametric_draws(subposteriors, 10000)\n", + "np.save('.results/consensus_samples.npy', consensus_samples)\n", + "np.save('.results/parametric_samples.npy', parametric_samples)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Consensus accuaracy: 0.771\n", + "Parametric accuaracy: 0.771\n" + ] + } + ], + "source": [ + "y_consensus = predictive(random.PRNGKey(0), model, consensus_samples, X_test)\n", + "y_consensus = (y_consensus.sum(axis=0) / y_consensus.shape[0]) >= 0.5\n", + "acc = (y_consensus == y_test).sum() / y_test.shape[0]\n", + "print('Consensus accuaracy: {}'.format(round(acc.item(), 4)))\n", + "\n", + "y_parametric = predictive(random.PRNGKey(1), model, parametric_samples, X_test)\n", + "y_parametric = (y_parametric.sum(axis=0) / y_parametric.shape[0]) >= 0.5\n", + "acc = (y_parametric == y_test).sum() / y_test.shape[0]\n", + "print('Parametric accuaracy: {}'.format(round(acc.item(), 4)))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### train bnaf" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "batch_size = 1000\n", + "iters_per_epoch = X_train.shape[0] // batch_size\n", + "guide = AutoBNAFNormal(model, num_flows=1, hidden_factors=[8])\n", + "svi = SVI(model, guide, optim.Adam(0.01), AutoContinuousELBO(), likelihood_scale=iters_per_epoch)\n", + "svi_state = svi.init(random.PRNGKey(0), X_train[:1], y_train[:1])\n", + "\n", + "def epoch_train(epoch, svi_state):\n", + " idx = random.shuffle(random.fold_in(random.PRNGKey(1), epoch), np.arange(X_train.shape[0]))\n", + " X, y = X_train[idx], y_train[idx]\n", + "\n", + " def body_fn(state, i):\n", + " X_batch = lax.dynamic_slice_in_dim(X, i * batch_size, batch_size)\n", + " y_batch = lax.dynamic_slice_in_dim(y, i * batch_size, batch_size)\n", + " return svi.update(state, X_batch, y_batch)\n", + "\n", + " svi_state, losses = lax.scan(body_fn, svi_state, np.arange(iters_per_epoch))\n", + " return svi_state, losses" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 01 - loss 231555.4531 - acc 0.7705 - time 71.38712668418884\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0macc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0my_iaf\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0my_test\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0my_test\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m print(\"Epoch {:02d} - loss {:.4f} - acc {} - time {}\".format(\n\u001b[0;32m---> 12\u001b[0;31m epoch, np.mean(epoch_loss), round(acc.item(), 4), time.time() - tic))\n\u001b[0m\u001b[1;32m 13\u001b[0m \u001b[0mlosses\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mlosses\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepoch_loss\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/jax/jax/interpreters/xla.py\u001b[0m in \u001b[0;36mitem\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 662\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mcomplex\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 663\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0monp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0missubdtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0monp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloating\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 664\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfloat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 665\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0monp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0missubdtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0monp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minteger\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 666\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/jax/jax/interpreters/xla.py\u001b[0m in \u001b[0;36m_forward_method\u001b[0;34m(attrname, self, fun, *args)\u001b[0m\n\u001b[1;32m 587\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 588\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_forward_method\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mattrname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 589\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mattrname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 590\u001b[0m \u001b[0m_forward_to_value\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpartial\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_forward_method\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"_value\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 591\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/jax/jax/interpreters/xla.py\u001b[0m in \u001b[0;36m_value\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 610\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_check_if_deleted\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 611\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_npy_value\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 612\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_npy_value\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice_buffer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto_py\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 613\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_npy_value\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mflags\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwriteable\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 614\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_npy_value\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/pydata/lib/python3.6/site-packages/numpy/core/_internal.py\u001b[0m in \u001b[0;36m_dtype_from_pep3118\u001b[0;34m(spec)\u001b[0m\n\u001b[1;32m 597\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 598\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 599\u001b[0;31m \u001b[0;32mdef\u001b[0m \u001b[0m_dtype_from_pep3118\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mspec\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 600\u001b[0m \u001b[0mstream\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_Stream\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mspec\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 601\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0malign\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m__dtype_from_pep3118\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstream\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mis_subdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "losses = np.array([])\n", + "num_epochs = 2\n", + "for epoch in range(1, num_epochs + 1):\n", + " tic = time.time()\n", + " svi_state, epoch_loss = epoch_train(epoch, svi_state)\n", + " params = svi.get_params(svi_state)\n", + " posterior = guide.sample_posterior(random.PRNGKey(2 * epoch), params, sample_shape=(10000,))\n", + " y_iaf = predictive(random.PRNGKey(2 * epoch + 1), model, posterior, X_test)[\"y\"]\n", + " y_iaf = (y_iaf.sum(axis=0) / y_iaf.shape[0]) >= 0.5\n", + " acc = (y_iaf == y_test).sum() / y_test.shape[0]\n", + " print(\"Epoch {:02d} - loss {:.4f} - acc {} - time {}\".format(\n", + " epoch, np.mean(epoch_loss), round(acc.item(), 4), time.time() - tic))\n", + " losses = np.concatenate([losses, epoch_loss])" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYkAAAD4CAYAAAAZ1BptAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3de3xU9Z3/8deHhKuKoICLgAQtVdFWhahUa8UrF21xt2vXdl2pdaVV221ru4K/ttpqbenV1mp1bbGCbVXUdmW5iBRES71AUK5yi4AQiRAICYTck8/vj/kSJsOcyQSTGZD38/GYx5zzOd9zvp85k5zPnNuMuTsiIiLJdMh2AiIicuhSkRARkUgqEiIiEklFQkREIqlIiIhIpNxsJ9DWevXq5Xl5edlOQ0TksLJkyZId7t47MZ5WkTCzrwM3Awb8zt1/ZWbHAU8DecAm4HPuvsvMDPg1MAaoBL7o7m+G5YwDvhsW+0N3nxLiw4DHga7ALODr7u5RfaTKNS8vj4KCgnReloiIBGb2brJ4i4ebzOxMYgXiPOAs4GozGwxMBOa5+2BgXhgHGA0MDo/xwMNhOccBdwPnh2XdbWY9wzwPh7b75hsV4lF9iIhIBqRzTuJ04HV3r3T3euBl4J+BscCU0GYKcE0YHgtM9ZjXgR5m1hcYCcx199KwNzAXGBWmdXf31zx2Z9/UhGUl60NERDIgnSKxEviUmR1vZt2IHUYaAJzg7sUA4blPaN8P2BI3f1GIpYoXJYmTog8REcmAFs9JuPtqM/sJsU/+FcAyoD7FLJZsMQcRT5uZjSd2uIqTTjqpNbOKiEgKaV0C6+6T3X2ou38KKAXWA9vCoSLC8/bQvIjYnsY+/YGtLcT7J4mToo/E/B5193x3z+/d+4CT8yIicpDSKhJm1ic8nwT8C/AkMB0YF5qMA54Pw9OBGyxmOFAeDhXNAa40s57hhPWVwJwwbY+ZDQ9XRt2QsKxkfYiISAake5/Ec2Z2PFAH3BYudZ0ETDOzm4DNwLWh7Sxi5y0KiV0CeyOAu5ea2b3A4tDuHncvDcO3sP8S2NnhARDVh4iIZIB92L4qPD8/3w/mPomiXZUUbq9gxKk6Ny4iRx4zW+Lu+YnxD90d1wfr8l++THVdI5smXZXtVEREDhn67qaguq4x2ymIiBxyVCRERCSSioSIiERSkRARkUgqEiIiEklFQkREIqlIiIhIJBUJERGJpCIhIiKRVCRERCSSioSIiERSkRARkUgqEiIiEklFQkREIqlIiIhIJBWJBMXlVdlOQUTkkKEikWBnRW22UxAROWSoSIiISKS0ioSZfdPMVpnZSjN70sy6mNkgM3vDzNab2dNm1im07RzGC8P0vLjl3Bnia81sZFx8VIgVmtnEuHjSPtqTWXv3ICJy+GixSJhZP+C/gHx3PxPIAa4DfgLc7+6DgV3ATWGWm4Bd7v4R4P7QDjMbEuY7AxgF/NbMcswsB3gIGA0MAT4f2pKiDxERyYB0DzflAl3NLBfoBhQDlwLPhulTgGvC8NgwTph+mZlZiD/l7jXuvhEoBM4Lj0J33+DutcBTwNgwT1QfIiKSAS0WCXd/D/g5sJlYcSgHlgBl7l4fmhUB/cJwP2BLmLc+tD8+Pp4wT1T8+BR9NGNm482swMwKSkpKWnpJIiKSpnQON/UkthcwCDgROIrYoaFEvm+WiGltFT8w6P6ou+e7e37v3r2TNUmbJe1WROTIlM7hpsuBje5e4u51wF+AC4Ae4fATQH9gaxguAgYAhOnHAqXx8YR5ouI7UvQhIiIZkE6R2AwMN7Nu4TzBZcDbwEvAv4Y244Dnw/D0ME6YPt/dPcSvC1c/DQIGA4uAxcDgcCVTJ2Int6eHeaL6EBGRDEjnnMQbxE4evwmsCPM8CkwAbjezQmLnDyaHWSYDx4f47cDEsJxVwDRiBeYF4DZ3bwjnHL4KzAFWA9NCW1L0ISIiGZDbchNw97uBuxPCG4hdmZTYthq4NmI59wH3JYnPAmYliSftQ0REMkN3XIuISCQViQS641pEZD8VCRERiaQiISIikVQkREQkkopEAp2TEBHZT0VCREQiqUiIiEgkFQkREYmkIpFA3wIrIrKfioSIiERSkRARkUgqEiIiEklFQkREIqlIiIhIJBWJBFV1DdlOQUTkkKEikeDWPy7JdgoiIocMFYkEW8urs52CiMghQ0UiicZGz3YKIiKHhBaLhJmdamZL4x67zewbZnacmc01s/XhuWdob2b2gJkVmtlyMxsat6xxof16MxsXFx9mZivCPA+Yxb6LNaqP9vZ0wZZMdCMicshrsUi4+1p3P9vdzwaGAZXAX4GJwDx3HwzMC+MAo4HB4TEeeBhiG3zgbuB84Dzg7riN/sOh7b75RoV4VB/tqnRvbSa6ERE55LX2cNNlwDvu/i4wFpgS4lOAa8LwWGCqx7wO9DCzvsBIYK67l7r7LmAuMCpM6+7ur7m7A1MTlpWsDxERyYDWFonrgCfD8AnuXgwQnvuEeD8g/nhNUYilihcliafqoxkzG29mBWZWUFJS0sqXJCIiUdIuEmbWCfgM8ExLTZPE/CDiaXP3R909393ze/fu3ZpZRUQkhdbsSYwG3nT3bWF8WzhURHjeHuJFwIC4+foDW1uI908ST9WHiIhkQGuKxOfZf6gJYDqw7wqlccDzcfEbwlVOw4HycKhoDnClmfUMJ6yvBOaEaXvMbHi4qumGhGUl60NERDIgN51GZtYNuAL4clx4EjDNzG4CNgPXhvgsYAxQSOxKqBsB3L3UzO4FFod297h7aRi+BXgc6ArMDo9UfYiISAakVSTcvRI4PiG2k9jVToltHbgtYjmPAY8liRcAZyaJJ+1DREQyQ3dcB/16dG0ajtU5ERFRkQhyc/Tb1iIiiVQkghxTkRARSaQiISIikVQkkvi/ZcVs2rE322mIiGSdikQSa7ftYfSv/57tNEREsk5FIoJ+xlREREVCRERSUJEIdGeEiMiBVCRERCSSioSIiERSkRARkUgqEiIiEklFQkREIqlIiIhIJBWJQF8PLiJyIBUJERGJpCIhIiKRVCRERCRSWkXCzHqY2bNmtsbMVpvZJ8zsODOba2brw3PP0NbM7AEzKzSz5WY2NG4540L79WY2Li4+zMxWhHkeMIv9AlBUHyIikhnp7kn8GnjB3U8DzgJWAxOBee4+GJgXxgFGA4PDYzzwMMQ2+MDdwPnAecDdcRv9h0PbffONCvGoPtqcTluLiByoxSJhZt2BTwGTAdy91t3LgLHAlNBsCnBNGB4LTPWY14EeZtYXGAnMdfdSd98FzAVGhWnd3f01j11iNDVhWcn6EBGRDEhnT+JkoAT4g5m9ZWa/N7OjgBPcvRggPPcJ7fsBW+LmLwqxVPGiJHFS9NGMmY03swIzKygpKUnjJYmISDrSKRK5wFDgYXc/B9hL6sM+liTmBxFPm7s/6u757p7fu3fv1swqIiIppFMkioAid38jjD9LrGhsC4eKCM/b49oPiJu/P7C1hXj/JHFS9CEiIhnQYpFw9/eBLWZ2aghdBrwNTAf2XaE0Dng+DE8HbghXOQ0HysOhojnAlWbWM5ywvhKYE6btMbPh4aqmGxKWlayPNqcbrkVEDpSbZruvAX8ys07ABuBGYgVmmpndBGwGrg1tZwFjgEKgMrTF3UvN7F5gcWh3j7uXhuFbgMeBrsDs8ACYFNGHiIhkQFpFwt2XAvlJJl2WpK0Dt0Us5zHgsSTxAuDMJPGdyfpoD5bszIiIyBFOd1yLiEgkFQkREYmkIhHoxLWIyIFUJEREJJKKRNDzqE7ZTkFE5JCjIhHcOfq0bKcgInLIUZEIunXKyXYKIiKHHBUJERGJpCIhIiKRVCRERCSSikSQ00HfyyEikkhFIhjSt3u2UxAROeSoSASmb/gTETmAioSIiERSkRARkUgqEiIiEklFQkREIqlIiIhIpLSKhJltMrMVZrbUzApC7Dgzm2tm68NzzxA3M3vAzArNbLmZDY1bzrjQfr2ZjYuLDwvLLwzzWqo+REQkM1qzJ3GJu5/t7vt+63oiMM/dBwPzwjjAaGBweIwHHobYBh+4GzgfOA+4O26j/3Bou2++US30ISIiGfBBDjeNBaaE4SnANXHxqR7zOtDDzPoCI4G57l7q7ruAucCoMK27u7/m7g5MTVhWsj5ERCQD0i0SDrxoZkvMbHyIneDuxQDhuU+I9wO2xM1bFGKp4kVJ4qn6aMbMxptZgZkVlJSUpPmSRESkJblptrvQ3beaWR9grpmtSdE22a3LfhDxtLn7o8CjAPn5+fq1ahGRNpLWnoS7bw3P24G/EjunsC0cKiI8bw/Ni4ABcbP3B7a2EO+fJE6KPkREJANaLBJmdpSZHbNvGLgSWAlMB/ZdoTQOeD4MTwduCFc5DQfKw6GiOcCVZtYznLC+EpgTpu0xs+HhqqYbEpaVrI92Mecbn2rPxYuIHHbSOdx0AvDXcFVqLvBnd3/BzBYD08zsJmAzcG1oPwsYAxQClcCNAO5eamb3AotDu3vcvTQM3wI8DnQFZocHwKSIPtpFl466bUREJF6LRcLdNwBnJYnvBC5LEnfgtohlPQY8liReAJyZbh8iIpIZ+ugcx5KeQxcROXKpSIiISCQViTj63SERkeZUJEREJJKKhIiIRFKREBGRSCoScXROQkSkORUJERGJpCIRx7QrISLSjIpEnIrq+mynICJySFGRiFNb35jtFEREDikqEnFO6N452ymIiBxSVCTi9OneJdspiIgcUlQkREQkkoqEiIhEUpEQEZFIKhIiIhJJRUJERCKpSKQw9+1t2U5BRCSr0i4SZpZjZm+Z2YwwPsjM3jCz9Wb2tJl1CvHOYbwwTM+LW8adIb7WzEbGxUeFWKGZTYyLJ+0jUyYv3JDJ7kREDjmt2ZP4OrA6bvwnwP3uPhjYBdwU4jcBu9z9I8D9oR1mNgS4DjgDGAX8NhSeHOAhYDQwBPh8aJuqj4xo1A3YInKES6tImFl/4Crg92HcgEuBZ0OTKcA1YXhsGCdMvyy0Hws85e417r4RKATOC49Cd9/g7rXAU8DYFvrIiAb3THYnInLISXdP4lfAHcC+z9bHA2Xuvu8b8YqAfmG4H7AFIEwvD+2b4gnzRMVT9dGMmY03swIzKygpKUnzJSXXIe6LYBsaVSRE5MjWYpEws6uB7e6+JD6cpKm3MK2t4gcG3R9193x3z+/du3eyJmnrEPd14SoRInKky02jzYXAZ8xsDNAF6E5sz6KHmeWGT/r9ga2hfREwACgys1zgWKA0Lr5P/DzJ4jtS9NFuYkUiVh6WbSmjuq6BLh1z2rtbEZFDUot7Eu5+p7v3d/c8Yiee57v7vwMvAf8amo0Dng/D08M4Yfp8d/cQvy5c/TQIGAwsAhYDg8OVTJ1CH9PDPFF9tJsOCWvkoZcK27tLEZFD1ge5T2ICcLuZFRI7fzA5xCcDx4f47cBEAHdfBUwD3gZeAG5z94awl/BVYA6xq6emhbap+mg3lnCUa3dVXXt3KSJyyErncFMTd18ALAjDG4hdmZTYphq4NmL++4D7ksRnAbOSxJP20Z466BdMRUSa6I7rBA9+YWizcZ28FpEjmYpEgjNO7N5svFH3SojIEUxFIlHC4SbVCBE5kqlIiIhIJBWJFhzMjkR1XQPbdle3eS4iIpmmItECP4jjTTdPLeD8H81rh2xERDJLRaIFB3NO4u/rd7R9IiIiWaAikSDxZjoRkSOZikSCzh2br5LXN+zMUiYiItmnIpGge5eOzcY37azMUiYiItmnIpHEv+Xv/1La0/t2T9FSROTDTUUiibxeRzUNd87VKhKRI5e2gEncfNGgpuGlW8qymImISHapSCSRm6PVIiICKhJp2VN9cL8pcTA34omIHEpUJNLwse+/eFAbfNUIETncqUikqfwgfqFONUJEDncqEmk6mDuxdbhJRA53KhJpOuueF3llXUmr5lGJEJHDXYtFwsy6mNkiM1tmZqvM7AchPsjM3jCz9Wb2tJl1CvHOYbwwTM+LW9adIb7WzEbGxUeFWKGZTYyLJ+0jW55ZUtSq9tqREJHDXTp7EjXApe5+FnA2MMrMhgM/Ae5398HALuCm0P4mYJe7fwS4P7TDzIYA1wFnAKOA35pZjpnlAA8Bo4EhwOdDW1L0kRWNja3b6rv2JUTkMNdikfCYijDaMTwcuBR4NsSnANeE4bFhnDD9MjOzEH/K3WvcfSNQCJwXHoXuvsHda4GngLFhnqg+sqKqriGb3YuIZFxa5yTCJ/6lwHZgLvAOUObu9aFJEdAvDPcDtgCE6eXA8fHxhHmi4sen6CMxv/FmVmBmBSUlrTtv0BoNrd2T0I6EiBzm0ioS7t7g7mcD/Yl98j89WbPwnOwyIG/DeLL8HnX3fHfP7927d7ImbaJRW30ROcK06uomdy8DFgDDgR5mlhsm9Qe2huEiYABAmH4sUBofT5gnKr4jRR9ZoT0JETnSpHN1U28z6xGGuwKXA6uBl4B/Dc3GAc+H4elhnDB9vsduGJgOXBeufhoEDAYWAYuBweFKpk7ETm5PD/NE9ZEVrS4SOnEtIoe53Jab0BeYEq5C6gBMc/cZZvY28JSZ/RB4C5gc2k8GnjCzQmJ7ENcBuPsqM5sGvA3UA7e5ewOAmX0VmAPkAI+5+6qwrAkRfbS7DgaJNaG1m3ztSYjI4a7FIuHuy4FzksQ3EDs/kRivBq6NWNZ9wH1J4rOAWen2kQnv/GgMg+5snlJrvzZcNUJEDne64zpC7Arc5mrrGynaVcmFk+Zz38y3W1yGvpZDRA53KhKttLuqnvfKqvjd3ze22FYlQkQOdyoSKVx6Wp8DYq05Ga0dCRE53KlIpPDYF889IPbDGauzkImISHaoSLTSaxt2pt9YexIicphTkWhHuk9CRA53KhLtSOckRORwpyLRjlQjRORwpyLRjnSfhIgc7lQkWtCjW8eU0xsbnZr65L8zoRIhIoc7FYkW3HX1kJTTfzRrNad+9wVq6xsPmKYdCRE53KlItKBbp5yU059ctBmAd3fuJW/iTKYt3sK+b/TQ1U1tY/6abfxolu5PEckGFYkWDD2pZ+S0h14qZG9t7FDTFfe/AsCTizfvb6Aa0Sa+9HgBj76yIdtpiByRVCRa0Kd7FzZNuirptJ/NWZvhbEREMktFIk1RhSJR/HkI7UiIyOFORaIVfn3d2S22WbqlrKlQ1Dc6eRNncs49L1LfcOCJbRGRQ52KRCsku4IplX2FYVdlHU+8/m6zae+VVbFrb22b5SYi0h5UJFph9Mf6tqp9/M+f7qmuZ97qbeRNnMmD89dz4aT5XDBpfhtnKCLStlosEmY2wMxeMrPVZrbKzL4e4seZ2VwzWx+ee4a4mdkDZlZoZsvNbGjcssaF9uvNbFxcfJiZrQjzPGDhZ+Gi+siWozvnMuu/Lkq7ffyv1/1y7jpumlIAwM9fXAdAVV0D75RUsLOipqlddV0D33x6Ke+VVaVc9sYdeymvrEsrj4qaerbvrk4774P1uf95jfvnrkvZZktpJd9+Zhl1DY1U1zXwy7nrqK5LfjNiIt3BfniYVrCF8qr0/jbT9fbW3eRNnEnh9ooDpu3aW8u6bXvatD/ZL509iXrgW+5+OjAcuM3MhgATgXnuPhiYF8YBRgODw2M88DDENvjA3cD5xH63+u64jf7Doe2++UaFeFQfWTPkxO784DNnpNX2b6u3t9jmsl+8zLAf/o2Pfmc2VbUNzFu9nb++9R4XTppPdV0DNz2+mLueX8kX/7CI2SuKAaipb+CSny/grHtebPZPU1ZZy96a+gP6OPPuOZz3o3nNYivfK+eSny9gd3Xb/TMv2ljKr+etT9nm288s49klRSzeVMqUVzfxwLz1PPaP6F/5W/LurqbhxoOoEXUNjTy7pIjGNGduaHRG/OwlZi4vblU/7t6siG3fXZ12EY9XuH0P65Ns8KrrGtia8MHhlXUlTH1tU7PYGxt28tBLhXzliSUHvObGRmdPC+93RU09lbUH/g2lsqW0ko079gKwung3dzy7nG8/syxp20UbS2mIy2vX3tq0PiRMX7YVgDmr3m+K7aio4fd/38DVv1nIleES9H2eeG0TL61t+f+vvW3bXc3L60oipz+3pIjnl76XwYxaL7elBu5eDBSH4T1mthroB4wFRoRmU4AFwIQQn+qx/5jXzayHmfUNbee6eymAmc0FRpnZAqC7u78W4lOBa4DZKfrIqo+ecEybL7O2oZHT73qBY7vu/xqQj3//RWrjTngvWFvCpklXUV27P/Z28W465hgX/2xBU+yWEafw9csGc/PUAnp069QU31Ndx8zlxVz18b5cP/kNyirrmLZ4C2M+1pcLJs3ndzfkc8WQEwCorK2ngxmN7tw7YzVVtfVccEovPnfugKblnXPPi3zzio8ya0Vxs7xXvlfO5IUb+f6nz6BDBzimS0feKamgorqeNzaWAtA5N4c/vvFuU/u8iTP5xMnH84cbz6VLx9gNjIs3lXLtI681Lbeh0cnpcOBvj0d55OV3eGvzLuas2gZAxxzj6o+fGLmMuoZGdlTUsGlnJV978k1+OqcbF5zSix//y8ea2qx9fw9/W72Ns/r3YOmWXXz10sFs3lnJVQ/8nT019Wz88RjMjPN+NI+OOcb6+8YAsY28O3TtlENFTT1vb93NeYOOA2Ibu4ZG54TuXbj8l7GN3cYfx+abvHAjQ/p251vPLKO4vLrZVXY3PLYo9vyJvKbYvz36etNwZV0D3TrmMG/Ndi4/vQ/3zVrN5IUbWXPvqKZ1XNfQSHFZNX9etJlhA3ty89QCunTswJp7R1Nd18Bp33uBW0acwoRRp1G0q5LC7RWMOHX/LzY+vXgzE55bAcDfbr+46RN9cfmBe8KvFu7gC79/gztGncqtIz7Cnuo6zrl3LgB3jj6NL198StL3ZeV75Tzy8jvA/r3JsQ8uZFlRedL2AN97fhVw4FWJ5VV1HNUph9ycDkxftpVjOudySfgFyryJM+lzTGcWfefyyOXuqKhhd1UdJ/c+GoDZK4pxYMzH+vLqOztoaHQ2l1aSP/A4tpZVcePjiwG4d+wZfP//3mb9D0fToYM1/X99KxTTvsd2ZeDx3Tihexcg9vdSXlXXNA6w5v3dfPo3C/nG5R/ltks+0hTfWVHDT19YywnHduH2Kz4amfvBstbswptZHvAKcCaw2d17xE3b5e49zWwGMMndF4b4PGIb9hFAF3f/YYh/D6gituGf5O6Xh/hFwAR3v9rMypL1kSSv8cT2RDjppJOGvfvuu4lN2lzh9gou/+XL7d5PojX3jmLu29v42pNvtXreEaf2ZsHaAz/VXHpaH+aviX3quvvTQ7h++EAGf2d20mU8d8sFfPbhV1vV70NfGMptf36zWezmiwal/J3wp8YP57q4DR7EXntuB2PKa+9yyam92VFRy+f+5zVuvDCPk3sdxZATj2XYwJ78fM5aJi/cSFWST6gdc4zX77wMM2P+mu2MOvOfePTld3hgfmFkLm9+7wqO6hzbqJ763ReaTevXo2uzQ4Ofy+/PuAvyuOqBhQD8/Y5LWLV1N1/54xIA3vreFfz3s8v52+pt/OrfzqbX0Z25fvIbAPzzOf3461uxT5W/+fw5HNU5hy89XtCsv1U/GElFTT0vryvhjmeXA/D7G/L5zv+uYNvummZtZ3ztk9z6pzfZXFrJrSNOYcqrm9hb28Btl5zCyDP+ic88+I/I1zx5XD47K2q547lYH8new//85CB+v3D/e3hM51z2xO3JFt43mkmz1/DRE47h2vz+3PX8qgMu4Ejla5d+hAE9uzXlANDr6M784Yvn8ukHFx7QPr4g5E2cCcT+Xitr69lbU89H+hzdVITjLZxwCT26deLMu+c0xe7/t7P4xMm9qK5roG+PLnTOzaGuobHp/+IL55/EScd1Y9LsNQBccmpvXkryv5XMsV07Rh6OW/79K/nRzNU8tXgLEPtaoH49u/LlJ5Y0a/fFC/J4/NVNvH3PSIbctT/vdC/VT8bMlrh7/gHxdIuEmR0NvAzc5+5/idqAm9lM4McJReIO4FKgc0KRqCRWdH6cUCTucPdPp1sk4uXn53tBQUGqJm2itr6Rj343+YZUpC30OrozOypqWm7YTq45+0T+d+nWNlnWZ4f257k3i9pkWbLf5af3aXZYe+UPRnJ05xYPECUVVSTSurrJzDoCzwF/cve/hPC2cBiJ8Lwv0yJgQNzs/YGtLcT7J4mn6iPrcuMOWXTK0UVi0vayWSCANisQgApEO0k877mtHS5QSefqJgMmA6vd/Zdxk6YD+65QGgc8Hxe/IVzlNBwoD+c15gBXmlnPcML6SmBOmLbHzIaHvm5IWFayPrKuQ1yRmPKl87KYiYhITI6lf84uXensl1wI/AewwsyWhtj/AyYB08zsJmAzcG2YNgsYAxQSO5x0I4C7l5rZvcDi0O6efSexgVuAx4GuxE5Y7zuOE9XHIWHf8b9F4WTs2QN6sHRLWTZTEpEjWF6vo9p8melc3bQQiCpPlyVp78BtEct6DHgsSbyA2MnwxPjOZH0cavbtVOR0sKbCUVZZyzUP/YNNOyt5/MZzefSVDTz4haEMDVdzfPeq0/nhTH39tYi0jX49urbLcnUwvQ2Ee/9ojLsIoEe3Tjw5fji/uPYsRpzahz/fPJzjjurES98ewTcuH8xNnxzEicd2Sbq8c06Knas/vW93XvzmpwA4L+84Lhrcq8VcZv3XReQP7Mn1w0+KbNM5t3Vv+yv/fQlnDejRckPgkeuHtWrZyZzSu+0/DbXGHaNOzWr/h4P/GD6wafjfzz+Jn3724+3e5zVnn9hsvIPRdMn2wbrxwjy+M+b0D7SMlnzmrBNbbgQ8+IVz6N/z4Db08751Mf+YeOlBzduSVl0CezjI1NVN8d4vr2b4j+fx3yNPbXb9cjpu+eMSZq+M3SB0/fCTuPmik+l1dGc27dzLGSceC8DL60o4u38Pju3Wkdkrinltw07O6t+j6Rrrt753BQsLd3B0l1wuibuGva6hkTXFe/j0gwvp1imHX193Drur6vjssP5s3119wA120796Id065TZd2vvXWy/gnPB7Gt98emnT5Zn5A3J076YAAArNSURBVHtSEHeT28gzTmDJu7vYUVHLxh+P4bF/bKK2vpHPDuvHeffN48YL87h++EAu+8X+S4b/qXsX5n3rYs4Ilx0+cv1QLhrcm6M657Jxx14u+fkCBh7fjW9cPphvPt38xqxjOufy2v+7jNkrijEzHnqpkI079tIptwM9unZk+57mJ3zPOLE7q7bujnwPeh3dmUtO7c0zS4r48qdO5s4xpzflsM+/DO3HX95sftPTXVcPYdwFefz3s8uapvU9tgvF5bGTh6/deSl/e3tb0zX7ie4ZewYdzFhdvJstu6p4Jdx09af/PJ+fvrCm2X0A1507oOmyyGRuGXEKDy94p1mszzGd+fbIU5sulQW46mN9mbmi5RsFJ4/LD3fGOxVxl7V+5qwTObn3UXz9ssGs317BMV1y6XtsV2rqG1i4fgeXnNqHb05byvPhpPfi71xOcXkVt09bxs0XDWq6p+LxG8+lZ7dOTR8+8ibOpGOOcUyXjpz2T8dw16eHsHxLOZ87dwBzVr3Pl59YwlcuPoXPDu3X9NstX7n4FCaOPo0XV73Prspa1r5fkfLGzH3yju/Glz45iGEDezb9jz3y8jtMmr2Gt+8ZSdeOOcxe+T5r3t9D59wO/McnBlK2t44XVhVz7bABnHPvXH75ubM4unMu4+MuTf2/r36Sn724tul9hNgHsqV3XUnX8ONliZfOn9X/2Kb3eeOPx7Bzby35P/xbi69h06SrqG9o5OtPL+XWEac0vY4PIurqpqY7RT8sj2HDhnk2lFbUeENDY6vnq66r961llQfV56KNO/35pe+12O7Vwh1eU9dwQLyxMZbvg/PX+6Mvv9MUHzhhhg+cMKNZ24rqOv/tS4U+f802b2ho9F++uNaLdlX6N556y4t2RedfV9/Q1I+7+1OL3vWBE2b4j2a97e7uK98r850VNQf0NXDCDH960WZvbGz0le+VeUNDo89fvc3HT13sb75bekA/28qrvKq23jfv3NuU/61/XNI0fdfeGp+/eps/U7DF/3PKYn9n+x6fs7LYSytqfG9N7LUNnDDDF6zd3jRPVW2919bvX2/v7ar0mroGr29o9Lc272p6XTV1Df7g/PW+pXSv766q9Vv/uKRp/TU2NvrM5Vt9dXF5U153PLPsgPVbVlnrs5ZvPeB17dpb4/8oLHF391Xvlfu3py31gRNm+BOvbfKq2npf8m6pr3qv3N3dVxSV+YqiMq9vaPTfzFvnxWVVvmlHhQ+cMMN/9sIaX/f+bnd3f2FlsY/59SteXVfvl/78JR84YYaP+FnseXdVrW8rr2rq/7z75vrACTN86qsb/f24eEseWVDozy3Z0ixWW9+Q9G/L3b1sb61X1dYnXVZjY6M/W7DFq+ti08+550UfOGFGs/dmn4JNO5v6uOWPBf7TF1Z7cVmVPzh/ve+tqfOyytqm5bSF63//uj+yoLBpfPvuav/3373uv3vlHS+vqk06z+W/WNBsHawoKvPHFm5oeq3//rvX/c6/LG96HQMnzPCF60t8Z0VN5PprC0CBJ9mmZn2j3taPbBWJD5OSPdW+bXf6G4TWqK1v8F/NXeeVNW33j5po5vKtPnDCDL/96aVpz1Pf0OgL15e0Sf+19Q1emlD43N3fL6/yrWWV3tDQmHQDl47dVbU+afbqpEU/yrs79np9xAeYv75Z5AMnzPAtpXuTTp+1fKsPvefFNtuw/mN9iT/wt3VtsqwoX3mioN02pG1hT3Vdyg9W7u5Fuyp94IQZPvbBhc3iAyfM8PFTF7dLXlFFQoeb5EOnvqGRn7+4jq9cfHKzryUROZz86Y13ueL0E+gT99Ucu6vr6Noxh47tcG/WB77j+nChIiEi0nof6I5rERE5MqlIiIhIJBUJERGJpCIhIiKRVCRERCSSioSIiERSkRARkUgqEiIiEulDdzOdmZUAB/sj172AHW2YTltRXq2jvFpHebXOhzWvge7eOzH4oSsSH4SZFSS74zDblFfrKK/WUV6tc6TlpcNNIiISSUVCREQiqUg092i2E4igvFpHebWO8mqdIyovnZMQEZFI2pMQEZFIKhIiIhJJRSIws1FmttbMCs1sYob73mRmK8xsqZkVhNhxZjbXzNaH554hbmb2QMhzuZkNbeNcHjOz7Wa2Mi7W6lzMbFxov97MxrVTXt83s/fCeltqZmPipt0Z8lprZiPj4m32PpvZADN7ycxWm9kqM/t6iGd1faXIK6vrKyyvi5ktMrNlIbcfhPggM3sjvP6nzaxTiHcO44Vhel5LObdxXo+b2ca4dXZ2iGfybz/HzN4ysxlhPLPrKtlvmh5pDyAHeAc4GegELAOGZLD/TUCvhNhPgYlheCLwkzA8BpgNGDAceKONc/kUMBRYebC5AMcBG8JzzzDcsx3y+j7w7SRth4T3sDMwKLy3OW39PgN9gaFh+BhgXeg7q+srRV5ZXV+hLwOODsMdgTfCupgGXBfijwC3hOFbgUfC8HXA06lyboe8Hgf+NUn7TP7t3w78GZgRxjO6rrQnEXMeUOjuG9y9FngKGJvlnMYCU8LwFOCauPhUj3kd6GFmfduqU3d/BSj9gLmMBOa6e6m77wLmAqPaIa8oY4Gn3L3G3TcChcTe4zZ9n9292N3fDMN7gNVAP7K8vlLkFSUj6yvk4+5eEUY7hocDlwLPhnjiOtu3Lp8FLjMzS5FzW+cVJSPvpZn1B64Cfh/GjQyvKxWJmH7AlrjxIlL/U7U1B140syVmNj7ETnD3Yoj90wN9QjwbubY2l0zm+NWwu//YvsM62cgr7NqfQ+wT6CGzvhLygkNgfYXDJ0uB7cQ2ou8AZe5en6SfphzC9HLg+PbILTEvd9+3zu4L6+x+M+ucmFdC/22d16+AO4DGMH48GV5XKhIxliSWyWuDL3T3ocBo4DYz+1SKttnONV5ULpnK8WHgFOBsoBj4RTbyMrOjgeeAb7j77lRNs5zXIbG+3L3B3c8G+hP7RHt6in4ylltiXmZ2JnAncBpwLrFDSBMylZeZXQ1sd/cl8eEUy2+XnFQkYoqAAXHj/YGtmerc3beG5+3AX4n942zbdxgpPG/PYq6tzSUjObr7tvCP3Qj8jv270BnLy8w6EtsQ/8nd/xLCWV9fyfI6FNZXPHcvAxYQO6bfw8xyk/TTlEOYfiyxw47tlltcXqPCoTt39xrgD2R2nV0IfMbMNhE71HcpsT2LzK6rD3JC5cPyAHKJnWAaxP4TdGdkqO+jgGPihl8ldgzzZzQ/+fnTMHwVzU+YLWqHnPJofoK4VbkQ+8S1kdiJu55h+Lh2yKtv3PA3iR13BTiD5ifqNhA7Cdum73N43VOBXyXEs7q+UuSV1fUV+uoN9AjDXYG/A1cDz9D8ZOytYfg2mp+MnZYq53bIq2/cOv0VMClLf/sj2H/iOqPrqk03Lofzg9jVCuuIHR/9Tgb7PTm8gcuAVfv6JnYscR6wPjwfF/fH+lDIcwWQ38b5PEnsUEQdsU8gNx1MLsCXiJ0gKwRubKe8ngj9Lgem03wj+J2Q11pgdHu8z8Anie22LweWhseYbK+vFHlldX2F5X0ceCvksBK4K+7/YFF4/c8AnUO8SxgvDNNPbinnNs5rflhnK4E/sv8KqIz97YdljmB/kcjoutLXcoiISCSdkxARkUgqEiIiEklFQkREIqlIiIhIJBUJERGJpCIhIiKRVCRERCTS/wdxDdS1nGMMowAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plt.plot(losses)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Flow HMC" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'guide' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtransform\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mguide\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_transform\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mopt_state\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0munpack_fn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mguide\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munpack_latent\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mlatent_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mguide\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlatent_size\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mNameError\u001b[0m: name 'guide' is not defined" + ] + } + ], + "source": [ + "transform = guide.get_transform(opt_state)\n", + "unpack_fn = guide.unpack_latent\n", + "latent_size = guide.latent_size" + ] + }, + { + "cell_type": "code", + "execution_count": 88, + "metadata": {}, + "outputs": [], + "source": [ + "def make_transformed_pe(potential_fn, transform, unpack_fn, prior_scale):\n", + " def transformed_potential_fn(z):\n", + " u, intermediates = transform.call_with_intermediates(z)\n", + " logdet = transform.log_abs_det_jacobian(z, u, intermediates=intermediates) * prior_scale\n", + " return potential_fn(unpack_fn(u)) + logdet\n", + "\n", + " return transformed_potential_fn" + ] + }, + { + "cell_type": "code", + "execution_count": 79, + "metadata": {}, + "outputs": [], + "source": [ + "def get_flow_subposterior(rng, shard, K):\n", + " X, y = shard\n", + " init_params = random.normal(rng, (4, latent_size))\n", + " _, potential_fn, _ = initialize_model(rng, model, X, y, prior_scale=1 / K)\n", + " transformed_potential_fn = make_transformed_pe(potential_fn, transform, unpack_fn, 1 / K)\n", + " samples = mcmc(num_warmup=1000, num_samples=2500, init_params=init_params,\n", + " num_chains=4, potential_fn=transformed_potential_fn)\n", + " return samples" + ] + }, + { + "cell_type": "code", + "execution_count": 89, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "warmup: 100%|██████████| 1000/1000 [03:07<00:00, 5.63it/s, 1023 steps of size 6.92e-05. acc. prob=0.78]\n", + "sample: 100%|██████████| 2500/2500 [07:29<00:00, 5.58it/s, 1023 steps of size 6.92e-05. acc. prob=0.88]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + " mean std median 5.0% 95.0% n_eff r_hat\n", + " Param:0[0] -13.94 0.25 -14.01 -14.26 -13.50 6.02 1.01\n", + " Param:0[1] 3.60 0.10 3.58 3.46 3.77 4.80 1.81\n", + " Param:0[2] -1.00 0.36 -1.06 -1.56 -0.42 4.99 1.35\n", + " Param:0[3] -5.20 0.33 -5.32 -5.57 -4.54 3.37 1.68\n", + " Param:0[4] -2.45 0.67 -2.45 -3.60 -1.51 2.59 2.69\n", + " Param:0[5] 6.63 0.44 6.74 5.81 7.18 3.79 1.65\n", + " Param:0[6] 0.34 0.57 0.30 -0.56 1.29 7.26 1.01\n", + " Param:0[7] -1.72 0.30 -1.72 -2.22 -1.24 4.35 1.86\n", + " Param:0[8] 10.02 0.59 9.89 9.13 11.04 4.90 1.15\n", + " Param:0[9] -6.03 0.49 -5.89 -6.84 -5.33 3.51 1.69\n", + "Param:0[10] 11.40 0.53 11.44 10.59 12.21 3.57 2.19\n", + "Param:0[11] 15.31 0.38 15.24 14.73 15.97 12.57 1.01\n", + "Param:0[12] 30.34 0.88 30.54 28.70 31.54 5.96 1.01\n", + "Param:0[13] -0.24 0.57 -0.25 -1.15 0.54 4.01 1.41\n", + "Param:0[14] -8.89 0.47 -8.82 -9.67 -8.16 4.70 1.05\n", + "Param:0[15] 0.82 0.64 0.97 -0.55 1.62 4.95 1.11\n", + "Param:0[16] 0.62 0.32 0.54 0.20 1.18 4.59 1.12\n", + "Param:0[17] -10.90 1.00 -10.76 -12.52 -9.27 5.20 1.45\n", + "Param:0[18] 0.62 0.16 0.59 0.37 0.86 4.88 1.16\n", + "Param:0[19] -6.58 0.50 -6.55 -7.44 -5.74 8.45 1.00\n", + "Param:0[20] -14.14 0.68 -14.13 -15.11 -13.16 2.59 2.48\n", + "Param:0[21] 5.05 0.25 4.98 4.72 5.46 3.38 1.72\n", + "Param:0[22] 1.46 0.41 1.49 0.84 2.22 3.58 1.89\n", + "Param:0[23] -4.04 0.42 -4.07 -4.73 -3.47 3.97 1.82\n", + "Param:0[24] -7.94 1.89 -7.75 -11.41 -5.26 2.97 2.36\n", + "Param:0[25] -0.50 0.55 -0.55 -1.22 0.30 2.99 2.22\n", + "Param:0[26] -2.05 0.53 -2.09 -2.85 -1.06 9.19 1.39\n", + "Param:0[27] 5.06 1.25 5.22 3.21 7.47 4.65 1.36\n", + "Param:0[28] 8.44 0.82 8.39 7.00 9.66 4.58 1.64\n", + "Param:0[29] 0.19 0.28 0.26 -0.38 0.54 3.76 1.61\n", + "Param:0[30] 14.60 1.78 15.01 11.66 17.62 3.66 1.63\n", + "Param:0[31] 20.82 0.90 20.76 19.62 22.35 5.42 1.29\n", + "Param:0[32] -31.48 3.68 -31.31 -37.17 -25.47 5.27 1.00\n", + "Param:0[33] -42.52 1.82 -42.42 -45.33 -39.70 26.26 1.01\n", + "Param:0[34] -49.77 3.01 -50.29 -54.42 -46.00 2.74 2.70\n", + "Param:0[35] 17.45 2.33 17.69 13.41 21.19 4.13 1.76\n", + "Param:0[36] 4.63 6.34 5.31 -5.43 13.28 3.42 1.61\n", + "Param:0[37] 13.86 4.16 12.31 8.42 20.54 4.86 1.02\n", + "Param:0[38] 0.78 1.44 0.71 -1.81 2.61 13.06 1.09\n", + "Param:0[39] 28.69 2.32 27.99 25.57 32.94 4.43 1.47\n", + "Param:0[40] -20.61 2.50 -20.76 -23.80 -16.05 7.42 1.47\n", + "Param:0[41] -53.15 9.93 -51.85 -68.62 -35.56 7.21 1.08\n", + "Param:0[42] 2.50 2.45 2.10 -1.07 5.54 10.09 1.02\n", + "Param:0[43] 0.90 0.26 0.88 0.53 1.41 3.39 2.19\n", + "Param:0[44] -99.67 4.76 -98.73 -107.37 -92.83 6.72 1.09\n", + "Param:0[45] -115.56 11.62 -116.65 -134.54 -93.48 8.19 1.00\n", + "Param:0[46] -1.78 0.41 -1.73 -2.51 -1.21 4.37 1.60\n", + "Param:0[47] 106.94 5.34 107.00 98.59 115.20 6.45 1.00\n", + "Param:0[48] 9.21 1.05 9.52 7.30 10.58 3.92 1.52\n", + "Param:0[49] 23.49 11.20 26.47 1.60 37.24 3.97 1.60\n", + "Param:0[50] -3.87 0.64 -3.84 -5.00 -2.86 11.39 1.01\n", + "Param:0[51] -30.30 4.26 -29.63 -39.68 -24.57 5.93 1.08\n", + "Param:0[52] 9.28 0.64 9.41 7.95 10.14 6.93 1.01\n", + "Param:0[53] 22.34 4.76 20.76 16.82 31.59 3.29 1.54\n", + "Param:0[54] 262.73 36.32 270.82 191.77 306.70 8.09 1.05\n", + "\n", + "\n" + ] + } + ], + "source": [ + "init_params = random.normal(random.PRNGKey(1), (latent_size,))\n", + "_, potential_fn, _ = initialize_model(rng, model, X, y, prior_scale=1 / K)\n", + "transformed_potential_fn = make_transformed_pe(potential_fn, transform, unpack_fn, 1 / K)\n", + "samples = mcmc(num_warmup=1000, num_samples=2500, init_params=init_params,\n", + " num_chains=1, potential_fn=transformed_potential_fn)" + ] + }, + { + "cell_type": "code", + "execution_count": 94, + "metadata": {}, + "outputs": [], + "source": [ + "flow_samples = vmap(lambda x: unpack_fn(transform(x)))(samples)" + ] + }, + { + "cell_type": "code", + "execution_count": 82, + "metadata": {}, + "outputs": [], + "source": [ + "from numpyro.diagnostics import summary" + ] + }, + { + "cell_type": "code", + "execution_count": 97, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "DeviceArray(0.7682474, dtype=float32)" + ] + }, + "execution_count": 97, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rngs = random.split(random.PRNGKey(2), 2500)\n", + "y_flow = vmap(partial(predict, model, X_test))(rngs, flow_samples)\n", + "y_flow = (y_flow.sum(axis=0) / y_flow.shape[0]) >= 0.5\n", + "acc = (y_flow == y_test).sum() / y_test.shape[0]\n", + "acc" + ] + }, + { + "cell_type": "code", + "execution_count": 92, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + " mean std median 5.0% 95.0% n_eff r_hat\n", + " coefs[0] 1.97 0.07 1.97 1.86 2.09 17.99 1.01\n", + " coefs[1] -0.02 0.02 -0.02 -0.05 0.02 21.80 1.02\n", + " coefs[2] -0.12 0.02 -0.12 -0.14 -0.08 8.34 1.05\n", + " coefs[3] -0.31 0.02 -0.31 -0.34 -0.28 7.72 1.17\n", + " coefs[4] -0.11 0.02 -0.11 -0.14 -0.08 10.66 1.01\n", + " coefs[5] -0.10 0.02 -0.09 -0.12 -0.07 65.96 1.00\n", + " coefs[6] 0.01 0.02 0.01 -0.03 0.05 7.04 1.15\n", + " coefs[7] -0.49 0.03 -0.49 -0.53 -0.44 14.74 1.00\n", + " coefs[8] 0.24 0.02 0.23 0.20 0.27 5.82 1.47\n", + " coefs[9] -0.01 0.01 -0.01 -0.03 0.02 6.66 1.48\n", + " coefs[10] 1.67 0.04 1.68 1.61 1.74 6.33 1.03\n", + " coefs[11] 0.53 0.02 0.53 0.49 0.57 11.40 1.22\n", + " coefs[12] 1.44 0.03 1.44 1.39 1.50 7.69 1.05\n", + " coefs[13] -0.80 0.22 -0.81 -1.13 -0.50 3.51 1.65\n", + " coefs[14] -2.65 0.16 -2.65 -2.92 -2.38 10.49 1.06\n", + " coefs[15] -0.95 0.16 -0.92 -1.23 -0.68 5.82 1.14\n", + " coefs[16] -0.74 0.12 -0.78 -0.90 -0.53 4.76 1.15\n", + " coefs[17] -0.21 0.02 -0.20 -0.25 -0.17 17.49 1.02\n", + " coefs[18] -0.06 0.11 -0.07 -0.23 0.12 5.46 1.10\n", + " coefs[19] -1.63 0.07 -1.63 -1.73 -1.50 8.42 1.22\n", + " coefs[20] -2.43 0.15 -2.45 -2.66 -2.20 2.83 2.25\n", + " coefs[21] 0.00 0.02 -0.00 -0.03 0.03 6.51 1.23\n", + " coefs[22] -0.03 0.02 -0.03 -0.07 0.00 3.49 2.16\n", + " coefs[23] -0.14 0.04 -0.14 -0.22 -0.09 3.17 2.54\n", + " coefs[24] -0.17 0.04 -0.17 -0.25 -0.11 4.80 1.64\n", + " coefs[25] -0.02 0.03 -0.02 -0.05 0.03 4.50 1.61\n", + " coefs[26] -0.04 0.01 -0.04 -0.06 -0.02 10.97 1.32\n", + " coefs[27] -1.01 0.05 -1.00 -1.10 -0.94 6.59 1.02\n", + " coefs[28] 0.36 0.36 0.35 -0.23 0.93 2.95 2.12\n", + " coefs[29] 0.09 0.02 0.09 0.06 0.12 5.30 1.13\n", + " coefs[30] 0.10 0.03 0.10 0.05 0.15 32.94 1.14\n", + " coefs[31] 0.05 0.04 0.05 -0.02 0.11 9.05 1.01\n", + " coefs[32] 0.12 0.01 0.12 0.10 0.13 4.93 1.31\n", + " coefs[33] 0.11 0.01 0.11 0.09 0.14 9.51 1.19\n", + " coefs[34] 0.07 0.03 0.07 0.02 0.12 40.27 1.01\n", + " coefs[35] 0.35 0.02 0.35 0.32 0.38 4.23 1.46\n", + " coefs[36] 0.36 0.02 0.36 0.33 0.38 10.62 1.04\n", + " coefs[37] 0.16 0.01 0.16 0.14 0.18 11.52 1.00\n", + " coefs[38] -0.02 0.02 -0.02 -0.05 0.01 27.97 1.04\n", + " coefs[39] -0.01 0.03 -0.01 -0.05 0.03 17.64 1.04\n", + " coefs[40] 0.00 0.02 0.00 -0.04 0.04 16.42 1.15\n", + " coefs[41] -1.13 0.11 -1.12 -1.31 -0.99 8.21 1.09\n", + " coefs[42] 0.09 0.02 0.09 0.07 0.12 9.71 1.28\n", + " coefs[43] -0.08 0.02 -0.07 -0.12 -0.04 8.79 1.26\n", + " coefs[44] 0.14 0.02 0.15 0.10 0.18 5.45 1.52\n", + " coefs[45] 0.09 0.02 0.08 0.06 0.12 15.35 1.09\n", + " coefs[46] 0.16 0.02 0.16 0.14 0.19 20.08 1.02\n", + " coefs[47] -0.01 0.03 -0.01 -0.06 0.05 22.08 1.16\n", + " coefs[48] -0.09 0.02 -0.09 -0.11 -0.05 27.91 1.19\n", + " coefs[49] 1.08 0.15 1.07 0.83 1.31 6.21 1.15\n", + " coefs[50] -1.99 0.29 -1.95 -2.48 -1.56 15.78 1.00\n", + " coefs[51] -0.11 0.02 -0.11 -0.14 -0.07 11.80 1.24\n", + " coefs[52] -0.13 0.02 -0.12 -0.16 -0.10 18.72 1.00\n", + " coefs[53] -0.15 0.02 -0.15 -0.19 -0.12 48.88 1.03\n", + " coefs[54] -2.20 0.06 -2.21 -2.28 -2.09 4.78 1.54\n", + "\n", + "\n" + ] + } + ], + "source": [ + "summary({'coefs': real_samples['coefs'][None, ...]})" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": {}, + "outputs": [], + "source": [ + "X, y = shards[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " =============================== SUBPOSTERIOR 00 ===============================\n", + "\n", + " mean std median 5.0% 95.0% n_eff r_hat\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0msep\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'='\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;36m31\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'\\n '\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0msep\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m' SUBPOSTERIOR {:02d} '\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0msep\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mend\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m''\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 8\u001b[0;31m \u001b[0msamples\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_flow_subposterior\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrng\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mshard\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mK\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 9\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexists\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'.results/flow'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmakedirs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'.results/flow'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m\u001b[0m in \u001b[0;36mget_flow_subposterior\u001b[0;34m(rng, shard, K)\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mtransformed_potential_fn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmake_transformed_pe\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpotential_fn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtransform\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0munpack_fn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m samples = mcmc(num_warmup=1000, num_samples=2500, init_params=init_params,\n\u001b[0;32m----> 7\u001b[0;31m num_chains=4, potential_fn=transformed_potential_fn)\n\u001b[0m\u001b[1;32m 8\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0msamples\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/numpyro/numpyro/mcmc.py\u001b[0m in \u001b[0;36mmcmc\u001b[0;34m(num_warmup, num_samples, init_params, num_chains, sampler, constrain_fn, print_summary, **sampler_kwargs)\u001b[0m\n\u001b[1;32m 425\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 426\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mprint_summary\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 427\u001b[0;31m \u001b[0msummary\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msamples\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 428\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0msamples_flat\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 429\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/numpyro/numpyro/diagnostics.py\u001b[0m in \u001b[0;36msummary\u001b[0;34m(samples, prob)\u001b[0m\n\u001b[1;32m 223\u001b[0m \u001b[0mrow_format\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mname_format\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m' {:>9.2f} {:>9.2f} {:>9.2f} {:>9.2f} {:>9.2f} {:>9.2f} {:>9.2f}'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 224\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m \u001b[0;32min\u001b[0m \u001b[0msamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 225\u001b[0;31m \u001b[0mvalue\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdevice_get\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 226\u001b[0m \u001b[0mvalue_flat\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0monp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 227\u001b[0m \u001b[0mmean\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalue_flat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/pydata/lib/python3.6/site-packages/jax/api.py\u001b[0m in \u001b[0;36mdevice_get\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 1109\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0my\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtree_leaves\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1110\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1111\u001b[0;31m \u001b[0my\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcopy_to_host_async\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1112\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mAttributeError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1113\u001b[0m \u001b[0;32mpass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/pydata/lib/python3.6/site-packages/jax/interpreters/pxla.py\u001b[0m in \u001b[0;36mcopy_to_host_async\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 448\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_npy_value\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 449\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mbuf\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice_buffers\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 450\u001b[0;31m \u001b[0mbuf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcopy_to_host_async\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 451\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 452\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mdelete\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "rngs = random.split(random.PRNGKey(1), K)\n", + "subposteriors = []\n", + "for i, (rng, shard) in enumerate(zip(rngs, shards)):\n", + " if i > 0:\n", + " break\n", + " sep = '=' * 31\n", + " print('\\n ' + sep + ' SUBPOSTERIOR {:02d} '.format(i) + sep, end='')\n", + " samples = get_flow_subposterior(rng, shard, K)\n", + " if not os.path.exists('.results/flow'):\n", + " os.makedirs('.results/flow')\n", + " np.save('.results/flow/subposterior_{:02d}.npy'.format(i), samples)\n", + " subposteriors.append(samples)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python (pydata)", + "language": "python", + "name": "pydata" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}