diff --git a/lab3/solutions/Lab3_Bias_And_Uncertainty.ipynb b/lab3/solutions/Lab3_Bias_And_Uncertainty.ipynb index 33cbff92..a210f0f0 100644 --- a/lab3/solutions/Lab3_Bias_And_Uncertainty.ipynb +++ b/lab3/solutions/Lab3_Bias_And_Uncertainty.ipynb @@ -2,142 +2,73 @@ "cells": [ { "cell_type": "markdown", - "metadata": { - "id": "IgYKebt871EK" - }, "source": [ - "# Laboratory 3: Detecting and mitigating bias and uncertainty in Facial Detection Systems\n", - "In this lab, we'll continue to explore how to mitigate algorithmic bias in facial recognition systems. In addition, we'll explore the notion of *uncertainty* in datasets, and learn how to reduce both data-based and model-based uncertainty.\n", - "\n", - "As we've seen in lecture 5, bias and uncertainty underlie many common issues with machine learning models today, and these are not just limited to classification tasks. Automatically detecting and mitigating uncertainty is crucial to deploying fair and safe models. \n", - "\n", - "In this lab, we'll be using [CAPSA](https://github.com/themis-ai/capsa/), a software package developed by [Themis AI](https://themisai.io/), which automatically *wraps* models to make them risk-aware and plugs into training workflows. We'll explore how we can use CAPSA to diagnose uncertainties, and then develop methods for automatically mitigating them.\n", - "\n", - "\n", - "Run the next code block for a short video from Google that explores how and why it's important to consider bias when thinking about machine learning:" - ] + "<table align=\"center\">\n", + " <td align=\"center\"><a target=\"_blank\" href=\"http://introtodeeplearning.com\">\n", + " <img src=\"https://i.ibb.co/Jr88sn2/mit.png\" style=\"padding-bottom:5px;\" />\n", + " Visit MIT Deep Learning</a></td>\n", + " <td align=\"center\"><a target=\"_blank\" href=\"https://colab.research.google.com/github/aamini/introtodeeplearning/blob/2023/lab3/solutions/Lab3_Bias_And_Uncertainty.ipynb\">\n", + " <img src=\"https://i.ibb.co/2P3SLwK/colab.png\" style=\"padding-bottom:5px;\" />Run in Google Colab</a></td>\n", + " <td align=\"center\"><a target=\"_blank\" href=\"https://github.com/aamini/introtodeeplearning/blob/2023/lab3/solutions/Lab3_Bias_And_Uncertainty.ipynb\">\n", + " <img src=\"https://i.ibb.co/xfJbPmL/github.png\" height=\"70px\" style=\"padding-bottom:5px;\" />View Source on GitHub</a></td>\n", + "</table>\n", + "\n", + "# Copyright Information" + ], + "metadata": { + "id": "Kxl9-zNYhxlQ" + } }, { "cell_type": "code", "source": [ - "!git clone https://github.com/slolla/capsa-intro-deep-learning.git\n", - "!cd capsa-intro-deep-learning/ && git checkout HistogramVAEWrapper\n" + "# Copyright 2023 MIT Introduction to Deep Learning. All Rights Reserved.\n", + "# \n", + "# Licensed under the MIT License. You may not use this file except in compliance\n", + "# with the License. Use and/or modification of this code outside of MIT Introduction\n", + "# to Deep Learning must reference:\n", + "#\n", + "# © MIT Introduction to Deep Learning\n", + "# http://introtodeeplearning.com\n", + "#" ], "metadata": { - "id": "5Ll7uZ8q72hm", - "outputId": "56b3117b-e344-481b-a9fc-2798b76d7a60", - "colab": { - "base_uri": "https://localhost:8080/" - } + "id": "aAcJJN3Xh3S1" }, - "execution_count": 1, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "fatal: destination path 'capsa-intro-deep-learning' already exists and is not an empty directory.\n", - "Already on 'HistogramVAEWrapper'\n", - "Your branch is up to date with 'origin/HistogramVAEWrapper'.\n" - ] - } - ] + "execution_count": null, + "outputs": [] }, { "cell_type": "markdown", "metadata": { - "id": "6JTRoM7E71EU" + "id": "IgYKebt871EK" }, "source": [ - "Let's get started by installing the relevant dependencies:" + "# Laboratory 3: Debiasing, Uncertainty, and Robustness\n", + "\n", + "# Part 2: Mitigating Bias and Uncertainty in Facial Detection Systems\n", + "\n", + "In Lab 2, we defined a semi-supervised VAE (SS-VAE) to diagnose feature representation disparities and biases in facial detection systems. In Lab 3 Part 1, we gained experience with [Capsa](https://github.com/themis-ai/capsa/) and its ability to build risk-aware models automatically through wrapping. Now in this lab, we will put these two together: using Capsa to build systems that can *automatically* uncover and mitigate bias and uncertainty in facial detection systems.\n", + "\n", + "As we have seen, automatically detecting and mitigating bias and uncertainty is crucial to deploying fair and safe models. Building off our foundation with Capsa, developed by [Themis AI](https://themisai.io/), we will now use Capsa for the facial detection problem, in order to diagnose risks in facial detection models. You will then design and create strategies to mitigate these risks, with goal of improving model performance across the entire facial detection dataset.\n", + "\n", + "**Your goal in this lab -- and the associated competition -- is to design a strategic solution for bias and uncertainty mitigation, using Capsa.** The approaches and solutions with oustanding performance will be recognized with outstanding prizes! Details on the submission process are at the end of this lab.\n", + "\n", + "" ] }, { - "cell_type": "code", - "source": [ - "%cd capsa-intro-deep-learning/\n", - "%pip install -e .\n", - "%cd .." - ], + "cell_type": "markdown", "metadata": { - "id": "SjAn-WZK9lOv", - "outputId": "35e24600-85b4-4320-c436-061856e56861", - "colab": { - "base_uri": "https://localhost:8080/" - } + "id": "6JTRoM7E71EU" }, - "execution_count": 2, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "/content/capsa-intro-deep-learning\n", - "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", - "Obtaining file:///content/capsa-intro-deep-learning\n", - " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - "Installing collected packages: capsa\n", - " Attempting uninstall: capsa\n", - " Found existing installation: capsa 0.1.2\n", - " Can't uninstall 'capsa'. No files were found to uninstall.\n", - " Running setup.py develop for capsa\n", - "Successfully installed capsa-0.1.2\n", - "/content\n" - ] - } - ] - }, - { - "cell_type": "code", "source": [ - "!git clone https://github.com/aamini/introtodeeplearning.git\n", - "!cd introtodeeplearning/ && git checkout 2023\n", - "%cd introtodeeplearning/\n", - "%pip install -e .\n", - "%cd .." - ], - "metadata": { - "id": "3pzGVPrh-4LQ", - "outputId": "f4588f12-d290-4746-d819-501a0e3ba390", - "colab": { - "base_uri": "https://localhost:8080/" - } - }, - "execution_count": 3, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "fatal: destination path 'introtodeeplearning' already exists and is not an empty directory.\n", - "Already on '2023'\n", - "Your branch is up to date with 'origin/2023'.\n", - "/content/introtodeeplearning\n", - "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", - "Obtaining file:///content/introtodeeplearning\n", - " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - "Requirement already satisfied: numpy in /usr/local/lib/python3.8/dist-packages (from mitdeeplearning==0.3.0) (1.21.6)\n", - "Requirement already satisfied: regex in /usr/local/lib/python3.8/dist-packages (from mitdeeplearning==0.3.0) (2022.6.2)\n", - "Requirement already satisfied: tqdm in /usr/local/lib/python3.8/dist-packages (from mitdeeplearning==0.3.0) (4.64.1)\n", - "Requirement already satisfied: gym in /usr/local/lib/python3.8/dist-packages (from mitdeeplearning==0.3.0) (0.25.2)\n", - "Requirement already satisfied: importlib-metadata>=4.8.0 in /usr/local/lib/python3.8/dist-packages (from gym->mitdeeplearning==0.3.0) (5.2.0)\n", - "Requirement already satisfied: gym-notices>=0.0.4 in /usr/local/lib/python3.8/dist-packages (from gym->mitdeeplearning==0.3.0) (0.0.8)\n", - "Requirement already satisfied: cloudpickle>=1.2.0 in /usr/local/lib/python3.8/dist-packages (from gym->mitdeeplearning==0.3.0) (1.5.0)\n", - "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.8/dist-packages (from importlib-metadata>=4.8.0->gym->mitdeeplearning==0.3.0) (3.11.0)\n", - "Installing collected packages: mitdeeplearning\n", - " Attempting uninstall: mitdeeplearning\n", - " Found existing installation: mitdeeplearning 0.3.0\n", - " Can't uninstall 'mitdeeplearning'. No files were found to uninstall.\n", - " Running setup.py develop for mitdeeplearning\n", - "Successfully installed mitdeeplearning-0.3.0\n", - "/content\n" - ] - } + "Let's get started by installing the necessary dependencies:" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": { "id": "2PdAhs1371EU" }, @@ -152,11 +83,14 @@ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "from tqdm import tqdm\n", - "from capsa import *\n", + "\n", "# Download and import the MIT 6.S191 package\n", - "from mitdeeplearning import lab3 \n", + "!pip install git+https://github.com/aamini/introtodeeplearning.git@2023\n", + "import mitdeeplearning as mdl\n", + "\n", "# Download and import capsa\n", - "#!pip install capsa\n" + "!pip install capsa\n", + "import capsa" ] }, { @@ -165,49 +99,33 @@ "id": "6VKVqLb371EV" }, "source": [ - "## 3.1 Datasets\n", + "# 3.1 Datasets\n", "\n", - "We'll be using the same datasets from lab 2 in this lab. Note that in this dataset, we've intentionally perturbed some of the samples in some ways (it's up to you to figure out how!) that are not necessarily present in the actual dataset. \n", + "Since we are again focusing on the facial detection problem, we will use the same datasets from Lab 2. To remind you, we have a dataset of positive examples (i.e., of faces) and a dataset of negative examples (i.e., of things that are not faces).\n", "\n", - "1. **Positive training data**: [CelebA Dataset](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html). A large-scale (over 200K images) of celebrity faces. \n", - "2. **Negative training data**: [ImageNet](http://www.image-net.org/). Many images across many different categories. We'll take negative examples from a variety of non-human categories. \n", - "[Fitzpatrick Scale](https://en.wikipedia.org/wiki/Fitzpatrick_scale) skin type classification system, with each image labeled as \"Lighter'' or \"Darker''.\n", + "1. **Positive training data**: [CelebA Dataset](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html). A large-scale dataset (over 200K images) of celebrity faces. \n", + "2. **Negative training data**: [ImageNet](http://www.image-net.org/). A large-scale dataset with many images across many different categories. We will take negative examples from a variety of non-human categories.\n", "\n", - "Like before, let's begin by importing these datasets. We've written a class that does a bit of data pre-processing to import the training data in a usable format.\n", + "We will evaluate trained models on an independent test dataset of face images to diagnose and mitigate potential issues with *bias, fairness, and confidence*. This will be a larger test dataset for evaluation purposes.\n", "\n", - "Also note that in this lab, we'll be using a much larger test dataset for evaluation purposes." + "We begin by importing these datasets. We have defined a `DatasetLoader` class that does a bit of data pre-processing to import the training data in a usable format." ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": { - "id": "HIA6EA1D71EW", - "outputId": "df98738c-00d5-4987-bd58-938dd17c8ef4", - "colab": { - "base_uri": "https://localhost:8080/" - } + "id": "HIA6EA1D71EW" }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Opening /root/.keras/datasets/train_face_2023_v2.h5\n", - "Loading data into memory...\n", - "Opening /root/.keras/datasets/train_face_2023_v2.h5\n", - "Loading data into memory...\n" - ] - } - ], + "outputs": [], "source": [ "batch_size = 32\n", "\n", "# Get the training data: both images from CelebA and ImageNet\n", "path_to_training_data = tf.keras.utils.get_file('train_face_2023_v2.h5', 'https://www.dropbox.com/s/b5z1cd317y5u1tr/train_face_2023_v2.h5?dl=1')\n", "# Instantiate a DatasetLoader using the downloaded dataset\n", - "train_loader = lab3.DatasetLoader(path_to_training_data, training=True, batch_size= batch_size)\n", - "test_loader = lab3.DatasetLoader(path_to_training_data, training=False, batch_size = batch_size)" + "train_loader = mdl.lab3.DatasetLoader(path_to_training_data, training=True, batch_size=batch_size)\n", + "test_loader = mdl.lab3.DatasetLoader(path_to_training_data, training=False, batch_size=batch_size)" ] }, { @@ -216,13 +134,11 @@ "id": "cREmhMWJ71EX" }, "source": [ - "### Recap: Thinking about bias and uncertainty\n", + "### Building robustness to bias and uncertainty\n", "\n", - "Remember that we'll be training our facial detection classifiers on the large, well-curated CelebA dataset (and ImageNet), and then evaluating their accuracy by testing them on an independent test dataset. Our goal is to build a model that trains on CelebA *and* achieves high classification accuracy on the the test dataset across all demographics, and to thus show that this model does not suffer from any hidden bias. \n", + "Remember that we'll be training our facial detection classifiers on the large, well-curated CelebA dataset (and ImageNet), and then evaluating their accuracy by testing them on an independent test dataset. We want to mitigate the effects of unwanted bias and uncertainty on the model's predictions and performance. Your goal is to build the best-performing, most robust model, one that achieves high classification accuracy across the entire test dataset.\n", "\n", - "In addition to thinking about bias, we want to detect areas of high *aleatoric* uncertainty in the dataset, which is defined as data noise: in the context of facial detection, this means that we may have very similar inputs with different labels-- think about the scenario where one face is labeled correctly as a positive, and another face is labeled incorrectly as a negative. \n", - "\n", - "Finally, we want to look at samples with high *epistemic*, or predictive, uncertainty. These may be samples that are anomalous or out of distribution, samples that contain adversarial noise, or samples that are \"harder\" to learn in some way. Importantly, epistemic uncertainty is not the same as bias! We may have well-represented samples that still have high epistemic uncertainty. " + "To achieve this, you may want to consider the three metrics introduced with Capsa: (1) representation bias, (2) data or aleatoric uncertainty, and (3) model or epistemic uncertainty. Note that all three of these metrics are different! For example, we can have well-represented examples that still have high epistemic uncertainty. Think about how you may use these metrics to improve the performance of your model." ] }, { @@ -231,34 +147,27 @@ "id": "1NhotGiT71EY" }, "source": [ - "# 3.2 Bias\n", + "# 3.2 Risk-aware facial detection with Capsa\n", "\n", - "In the previous lab, we used a variational autoencoder (VAE) to automatically learn the latent structure of our database, and we developed a scoring mechanism for samples to determine their bias. In this lab, we'll show that we can use CAPSA to do the same thing in one line! Then, our goal will be to continue our implementation of the DB-VAE and use the latent variables learned via a VAE to adaptively re-sample the CelebA data during training. Specifically, we will alter the probability that a given image is used during training based on how often its latent features appear in the dataset. So, faces with rarer features (like dark skin, sunglasses, or hats) should become more likely to be sampled during training, while the sampling probability for faces with features that are over-represented in the training dataset should decrease (relative to uniform random sampling across the training data)." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "niy4he0m71EZ" - }, - "source": [ - "Just like the last lab, let's define a standard classifier that we'll use as the base encoder of our network." + "In Lab 2, we built a semi-supervised variational autoencoder (SS-VAE) to learn the latent structure of our database and to uncover feature representation disparities, inspired by the approach of [uncover hidden biases](http://introtodeeplearning.com/AAAI_MitigatingAlgorithmicBias.pdf). In this lab, we'll show that we can use Capsa to build the same VAE in one line!\n", + "\n", + "This sets the foundation for quantifying a key risk metric -- representation bias -- for the facial detection problem. In working to improve your model's performance, you will want to consider representation bias carefully and think about how you could mitigate the effect of representation bias.\n", + "\n", + "Just like in Lab 2, we begin by defining a standard CNN-based classifier. We will then use Capsa to wrap the model and build the risk-aware VAE variant." ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": { "id": "5hQb75Vm71EZ" }, "outputs": [], "source": [ - "### Define the CNN model ###\n", - "\n", - "n_filters = 12 # base number of convolutional filters\n", + "### Define the CNN classifier model ###\n", "\n", "'''Function to define a standard CNN model'''\n", - "def make_standard_classifier(n_outputs=1):\n", + "def make_standard_classifier(n_outputs=1, n_filters=12):\n", " Conv2D = functools.partial(tf.keras.layers.Conv2D, padding='same', activation='relu')\n", " BatchNormalization = tf.keras.layers.BatchNormalization\n", " Flatten = tf.keras.layers.Flatten\n", @@ -278,6 +187,9 @@ " Conv2D(filters=6*n_filters, kernel_size=3, strides=2),\n", " BatchNormalization(),\n", "\n", + " Conv2D(filters=8*n_filters, kernel_size=3, strides=2),\n", + " BatchNormalization(),\n", + "\n", " Flatten(),\n", " Dense(512),\n", " Dense(n_outputs, activation=None),\n", @@ -291,29 +203,35 @@ "id": "LgTG6buf71Ea" }, "source": [ - "Let's use CAPSA's `HistogramVAEWrapper` to analyze the latent space distribution as we did previously. The `HistogramVAEWrapper` constructs a histogram with `num_bins` bins across every dimension of the latent space, and then calculates the joint probability of every sample according to the histograms. The samples with the lowest joint probability have the lowest bias, and we want to oversample these. Conversely, we want to undersample the areas of the dataset with the highest bias." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "FivHOdGE71Ea" - }, - "source": [ - "The `HistogramVAEWrapper` class takes in a number of arguments: namely, the number of bins we want to discretize our distribution into, the number of samples we want to track at any given point, and whether we're using the output of a hidden layer (good for higher-dimensional data) or the input data itself (good for lower-dimensional data). Since this is a variational autoencoder, we need to also pass in a decoder. Let's define the same decoder as the previous lab:" + "### Capsa's `HistogramVAEWrapper`\n", + "\n", + "With our base classifier Capsa allows us to automatically define a VAE implementing that base classifier. Capsa's [`HistogramVAEWrapper`](https://themisai.io/capsa/api_documentation/HistogramVAEWrapper.html) builds this VAE to analyze the latent space distribution, just as we did in Lab 2. \n", + "\n", + "Specifically, `capsa.HistogramVAEWrapper` constructs a histogram with `num_bins` bins across every dimension of the latent space, and then calculates the joint probability of every sample according to the constructed histograms. The samples with the lowest joint probability have the lowest representation; the samples with the highest joint probability have the highest representation.\n", + "\n", + "`capsa.HistogramVAEWrapper` takes in a number of arguments including:\n", + "1. `base_model`: the model to be transformed into the risk-aware variant.\n", + "2. `num_bins`: the number of bins we want to discretize our distribution into. \n", + "2. `queue_size`: the number of samples we want to track at any given point.\n", + "3. `decoder`: the decoder architecture for the VAE.\n", + "\n", + "We define the same decoder as in Lab 2:" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": { "id": "zTat3K8E71Eb" }, "outputs": [], "source": [ + "### Define the decoder architecture for the facial detection VAE ###\n", + "\n", "def make_face_decoder_network(n_filters=12):\n", " # Functionally define the different layer types we will use\n", - " Conv2DTranspose = functools.partial(tf.keras.layers.Conv2DTranspose, padding='same', activation='relu')\n", + " Conv2DTranspose = functools.partial(tf.keras.layers.Conv2DTranspose, \n", + " padding='same', activation='relu')\n", " BatchNormalization = tf.keras.layers.BatchNormalization\n", " Flatten = tf.keras.layers.Flatten\n", " Dense = functools.partial(tf.keras.layers.Dense, activation='relu')\n", @@ -322,10 +240,11 @@ " # Build the decoder network using the Sequential API\n", " decoder = tf.keras.Sequential([\n", " # Transform to pre-convolutional generation\n", - " Dense(units=4*4*6*n_filters), # 4x4 feature maps (with 6N occurances)\n", - " Reshape(target_shape=(4, 4, 6*n_filters)),\n", + " Dense(units=2*2*8*n_filters), # 4x4 feature maps (with 6N occurances)\n", + " Reshape(target_shape=(2, 2, 8*n_filters)),\n", "\n", " # Upscaling convolutions (inverse of encoder)\n", + " Conv2DTranspose(filters=6*n_filters, kernel_size=3, strides=2),\n", " Conv2DTranspose(filters=4*n_filters, kernel_size=3, strides=2),\n", " Conv2DTranspose(filters=2*n_filters, kernel_size=3, strides=2),\n", " Conv2DTranspose(filters=1*n_filters, kernel_size=5, strides=2),\n", @@ -335,162 +254,93 @@ " return decoder" ] }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "id": "i4JmvmMA71Ec" - }, - "outputs": [], - "source": [ - "standard_classifier = make_standard_classifier()\n", - "wrapped_classifier = HistogramVAEWrapper(standard_classifier, num_bins=5, queue_size=20000, latent_dim = 100, decoder=make_face_decoder_network())" - ] - }, { "cell_type": "markdown", "metadata": { - "id": "valYm5LH71Ec" - }, - "source": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "A527wdyV71Ec" + "id": "SzFGcrhv71Ed" }, "source": [ - "Now, let's train the wrapped classifier! As we did in the previous lab, in addition to updating the weights of the model, the wrapped classifier also tracks feature distributions. We can use the joint probabilities of these feature distributions to determine the bias of a given sample in this dataset. We'll make use of the `Model.fit` API here, but note that we can achieve the same behavior with a custom training loop as well." + "We are ready to create the wrapped model using `capsa.HistogramVAEWrapper` by passing in the relevant arguments!\n", + "\n", + "Just like in the wrappers in the Introduction to Capsa lab, we can take our standard CNN classifier, wrap it with `capsa.HistogramVAEWrapper`, build the wrapped model. The wrapper then enablings semi-supervised training for the facial detection task. As the wrapped model trains, the classifier weights are updated, and the VAE-wrapped model learns to track feature distributions over the latent space. More details of the `HistogramVAEWrapper` and how it can be used are [available here](https://themisai.io/capsa/api_documentation/HistogramVAEWrapper.html).\n", + "\n", + "We can then evaluate the representation bias of the classifier on the test dataset. By calling the `wrapped_model` on our test data, we can automatically generate representation bias and uncertainty scores that are normally manually calculated. Let's wrap our base CNN classifier using Capsa, train and build the resulting model, and start to process the test data: " ] }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "NmshVdLM71Ed", - "outputId": "48155283-4767-46e7-e84b-dfd3ac8c1917", - "colab": { - "base_uri": "https://localhost:8080/" - } - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Epoch 1/6\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "WARNING:tensorflow:Gradients do not exist for variables ['dense_1/kernel:0', 'dense_1/bias:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss`argument?\n", - "WARNING:tensorflow:Gradients do not exist for variables ['dense_1/kernel:0', 'dense_1/bias:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss`argument?\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - " 102/2404 [>.............................] - ETA: 5:58 - vae_compiled_loss: 0.8147 - vae_compiled_binary_accuracy: 0.4792 - vae_wrapper_loss: 3385.2124" - ] - } - ], "source": [ - "learning_rate = 1e-5\n", + "### Estimating representation bias with Capsa HistogramVAEWrapper ###\n", "\n", - "# compile model using desired optimizers and losses\n", - "wrapped_classifier.compile(\n", - " optimizer=tf.keras.optimizers.Adam(learning_rate),\n", - " loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),\n", - " metrics=[tf.keras.metrics.BinaryAccuracy()],\n", + "model = make_standard_classifier()\n", + "# Wrap the CNN classifier for latent encoding with a VAE wrapper\n", + "wrapped_model = capsa.HistogramVAEWrapper(model, num_bins=5, queue_size=20000, \n", + " latent_dim = 32, decoder=make_face_decoder_network())\n", + "\n", + "# Build the model for classification, defining the loss function, optimizer, and metrics\n", + "wrapped_model.compile(\n", + " optimizer=tf.keras.optimizers.Adam(learning_rate=5e-4),\n", + " loss=tf.keras.losses.BinaryCrossentropy(from_logits=True), # for classification\n", + " metrics=[tf.keras.metrics.BinaryAccuracy()], # for classification\n", " run_eagerly=True\n", ")\n", "\n", - "# fit the model to our training data\n", - "history = wrapped_classifier.fit(\n", + "# Train the wrapped model for 6 epochs by fitting to the training data\n", + "history = wrapped_model.fit(\n", " train_loader,\n", " epochs=6,\n", " batch_size=batch_size,\n", - " )" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "SzFGcrhv71Ed" - }, - "source": [ - "Let's see what the bias looks like on our test dataset! Note that in this lab, we're using a much larger test dataset than the one in Lab 2. By calling the `wrapped_classifier` on our test set, we can automatically generate the same bias scores that we manually calculated in the last lab. " - ] - }, - { - "cell_type": "code", - "execution_count": null, + " )\n", + "\n", + "## Evaluation\n", + "\n", + "# Get all faces from the testing dataset\n", + "test_imgs = test_loader.get_all_faces()\n", + "\n", + "# Call the Capsa-wrapped classifier to generate outputs: predictions, uncertainty, and bias!\n", + "predictions, uncertainty, bias = wrapped_model.predict(test_imgs, batch_size=512)" + ], "metadata": { - "id": "1dCqvPFH71Ed" + "id": "YqsBHBf3yUlm" }, - "outputs": [], - "source": [ - "test_imgs = test_loader.get_all_faces() # Get all faces from the testing dataset\n", - "predictions, _, bias = wrapped_classifier.predict(test_imgs) # use CAPSA-wrapped classifier to obtain estimates for bias and the output" - ] - }, - { - "cell_type": "code", "execution_count": null, - "metadata": { - "id": "Pt7_FlRW71Ee" - }, - "outputs": [], - "source": [ - "tf.config.list_physical_devices('GPU')" - ] + "outputs": [] }, { "cell_type": "markdown", - "metadata": { - "id": "Xtc0kjE471Ee" - }, - "source": [ - "Now, we have an estimate for the bias score! Let's visualize what the samples with the highest bias and those with the lowest bias look like. Before you run the next code block, which faces would you expect to be underrepresented in the dataset? Which ones do you think will be overrepresented?" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "OYMRqq5E71Ee" - }, - "outputs": [], "source": [ - "indices = np.argsort(bias, axis=None) \n", - "sorted_images = test_imgs[indices] # sort images from lowest to highest bias\n", - "sorted_biases = bias[indices]\n", - "sorted_preds = predictions[indices]" - ] - }, - { - "cell_type": "code", - "execution_count": null, + "# 3.3 Analyzing representation bias with Capsa\n", + "\n", + "From the above output, we have an estimate for the representation bias score! We can analyze the representation scores to start to think about manifestations of bias in the facial detection dataset. Before you run the next code block, which faces would you expect to be underrepresented in the dataset? Which ones do you think will be overrepresented?" + ], "metadata": { - "id": "UAYaFUj-71Ee" - }, - "outputs": [], - "source": [ - "lab3.plot_k(sorted_images[:20]) # These are the samples with the lowest representation (least bias) in our test dataset" - ] + "id": "629ng-_H6WOk" + } }, { "cell_type": "code", "execution_count": null, "metadata": { - "id": "CnbR3qAF71Ef" + "id": "OYMRqq5E71Ee" }, "outputs": [], "source": [ - "lab3.plot_k(sorted_images[-20:]) # These are the samples with the highest representation (most bias) in our test dataset" + "### Analyzing representation bias scores ###\n", + "\n", + "# Sort according to lowest to highest representation scores\n", + "indices = np.argsort(bias, axis=None) # sort the score values themselves\n", + "sorted_images = test_imgs[indices] # sort images from lowest to highest representations\n", + "sorted_biases = bias[indices] # order the representation bias scores\n", + "sorted_preds = predictions[indices] # order the prediction values\n", + "\n", + "\n", + "# Visualize the 20 images with the lowest and highest representation in the test dataset\n", + "fig, ax = plt.subplots(1, 2, figsize=(16, 8))\n", + "ax[0].imshow(mdl.util.create_grid_of_images(sorted_images[-20:], (4, 5)))\n", + "ax[0].set_title(\"Over-represented\")\n", + "\n", + "ax[1].imshow(mdl.util.create_grid_of_images(sorted_images[:20], (4, 5)))\n", + "ax[1].set_title(\"Under-represented\");" ] }, { @@ -499,7 +349,7 @@ "id": "-JYmGMJF71Ef" }, "source": [ - "Now, we'll spend some time looking at the bias by *percentile* in our dataset. First, let's plot the accuracy as the bias increases. Remember that we use bias to quantify the level of representation in our dataset, so increasing bias means increasing representation. How do you expect the accuracy to change?" + "We can also quantify how the representation density relates to the classification accuracy by plotting the two against each other:" ] }, { @@ -510,7 +360,10 @@ }, "outputs": [], "source": [ - "averaged_imgs = lab3.plot_accuracy_vs_risk(sorted_images, sorted_biases, sorted_preds, \"Bias vs. Accuracy\")" + "# Plot the representation density vs. the accuracy\n", + "plt.xlabel(\"Density (Representation)\")\n", + "plt.ylabel(\"Accuracy\")\n", + "averaged_imgs = mdl.lab3.plot_accuracy_vs_risk(sorted_images, sorted_biases, sorted_preds, \"Bias vs. Accuracy\")" ] }, { @@ -519,19 +372,20 @@ "id": "i8ERzg2-71Ef" }, "source": [ - "Now, for a super interesting visualization, let's look at the *percentiles* of bias: what does the average face in the 10th percentile of bias look like? What about the 90th percentile? What changes across these faces?" + "These representations scores relate back to data examples, so we can visualize what the average face looks like for a given *percentile* of representation density:" ] }, { "cell_type": "code", - "execution_count": null, + "source": [ + "fig, ax = plt.subplots(figsize=(15,5))\n", + "ax.imshow(mdl.util.create_grid_of_images(averaged_imgs, (1,10)))" + ], "metadata": { - "id": "1cd590UP71Ef" + "id": "kn9IpPKYSECg" }, - "outputs": [], - "source": [ - "lab3.plot_percentile(averaged_imgs)" - ] + "execution_count": null, + "outputs": [] }, { "cell_type": "markdown", @@ -539,7 +393,13 @@ "id": "cRNV-3SU71Eg" }, "source": [ - "Now that we know what the bias in our dataset looks like, let's adaptively resample from our dataset! Since we can calculate this score on-the-fly *during training*, we can adjust the probability of samples being chosen. But first, let's also take a look at the *epistemic* uncertainty of this dataset" + "#### **TODO: Scoring representation densities with Capsa**\n", + "\n", + "Write short answers to the questions below to complete the `TODO`s:\n", + "\n", + "1. How does accuracy relate to the representation score? From this relationship, what can you determine about the bias underlying the dataset?\n", + "2. What does the average face in the 10th percentile of representation density look like (i.e., the face for which 10% of the data have lower probability of occuring)? What about the 90th percentile? What changes across these faces?\n", + "3. What could be potential limitations of the `HistogramVAEWrapper` approach as it is implemented now?" ] }, { @@ -548,51 +408,52 @@ "id": "ww5lx7ue71Eg" }, "source": [ - "# 3.3 Epistemic Uncertainty\n", + "# 3.4 Analyzing epistemic uncertainty with Capsa\n", "\n", - "Recall from lecture that *epistemic* uncertainty, or a model's uncertainty in its prediction, can arise from out of distribution data, or samples that are harder to learn. This does not necessarily correlate with bias! Imagine the scenario of training an object detector for self-driving cars: even if the model is presented with many cluttered scenes, these samples still may be harder to learn than scenes with very few objects in them. In this part of the lab, we'll analyze the epistemic uncertainty of the VAE that we've trained on this dataset. \n", + "Recall that *epistemic* uncertainty, or a model's uncertainty in its prediction, can arise from out-of-distribution data, missing data, or samples that are harder to learn. This does not necessarily correlate with representation bias! Imagine the scenario of training an object detector for self-driving cars: even if the model is presented with many cluttered scenes, these samples still may be harder to learn than scenes with very few objects in them.\n", "\n", - "From lecture 6, we saw that most methods of estimating epistemic uncertainty are *sampling-based*, but we can also use *reconstruction-based* methods. If a model is unable to provide a good reconstruction for a given data point, it has not learned that area of the underlying data distribution well, and therefore has high epistemic uncertainty. \n", + "We will now use our VAE-wrapped facial detection classifier to analyze and estimate the epistemic uncertainty of the model trained on the facial detection task.\n", "\n", - "Since we've already used a VAE to calculate the histograms for bias quantification, we can use the same VAE to shed insight into epistemic uncertainty! CAPSA helps us do exactly that: when call the model, we get the bias, reconstruction loss, and prediction for every sample." + "While most methods of estimating epistemic uncertainty are *sampling-based*, we can also use ***reconstruction-based*** methods -- like using VAEs -- to estimate epistemic uncertainty. If a model is unable to provide a good reconstruction for a given data point, it has not learned that area of the underlying data distribution well, and therefore has high epistemic uncertainty.\n", + "\n" ] }, { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "AwGPvdZm71Eg" - }, - "outputs": [], + "cell_type": "markdown", "source": [ - "predictions, reconstruction_loss, bias = wrapped_classifier.predict(test_imgs) # note that we're estimating both bias and uncertainty in a single shot!\n", + "Since we've already used the `HistogramVAEWrapper` to calculate the histograms for representation bias quantification, we can use the exact same VAE wrapper to shed insight into epistemic uncertainty! Capsa helps us do exactly that. When we called the model, we returned the classification prediction, uncertainty, and bias for every sample:\n", + "`predictions, uncertainty, bias = wrapped_model.predict(test_imgs, batch_size=512)`.\n", "\n", - "epistemic_indices = np.argsort(reconstruction_loss, axis=None) \n", - "epistemic_images = test_imgs[epistemic_indices] # sort images by reconstruction loss this time!\n", - "sorted_epistemic = reconstruction_loss[epistemic_indices]\n", - "sorted_epistemic_preds = predictions[epistemic_indices]" - ] - }, - { - "cell_type": "code", - "execution_count": null, + "Let's analyze these estimated uncertainties:" + ], "metadata": { - "id": "kB8Iqrfb71Eg" - }, - "outputs": [], - "source": [ - "lab3.plot_k(epistemic_images[:20]) # samples with the LEAST epistemic uncertainty" - ] + "id": "NEfeWo2p7wKm" + } }, { "cell_type": "code", "execution_count": null, "metadata": { - "id": "miu5h2Pc71Eh" + "id": "AwGPvdZm71Eg" }, "outputs": [], "source": [ - "lab3.plot_k(epistemic_images[-20:]) # samples with the MOST epistemic uncertainty" + "### Analyzing epistemic uncertainty estimates ###\n", + "\n", + "# Sort according to epistemic uncertainty estimates\n", + "epistemic_indices = np.argsort(uncertainty, axis=None) # sort the uncertainty values\n", + "epistemic_images = test_imgs[epistemic_indices] # sort images from lowest to highest uncertainty\n", + "sorted_epistemic = uncertainty[epistemic_indices] # order the uncertainty scores\n", + "sorted_epistemic_preds = predictions[epistemic_indices] # order the prediction values\n", + "\n", + "\n", + "# Visualize the 20 images with the LEAST and MOST epistemic uncertainty\n", + "fig, ax = plt.subplots(1, 2, figsize=(16, 8))\n", + "ax[0].imshow(mdl.util.create_grid_of_images(epistemic_images[:20], (4, 5)))\n", + "ax[0].set_title(\"Least Uncertain\");\n", + "\n", + "ax[1].imshow(mdl.util.create_grid_of_images(epistemic_images[-20:], (4, 5)))\n", + "ax[1].set_title(\"Most Uncertain\");" ] }, { @@ -601,7 +462,7 @@ "id": "L0dA8EyX71Eh" }, "source": [ - "Let's run the same analysis: check how the accuracy varies with epistemic uncertainty!" + "We quantify how the epistemic uncertainty relates to the classification accuracy by plotting the two against each other:" ] }, { @@ -612,7 +473,10 @@ }, "outputs": [], "source": [ - "_ = lab3.plot_accuracy_vs_risk(epistemic_images, sorted_epistemic, sorted_epistemic_preds, \"Epistemic Uncertainty vs. Accuracy\")" + "# Plot epistemic uncertainty vs. classification accuracy\n", + "plt.xlabel(\"Epistemic Uncertainty\")\n", + "plt.ylabel(\"Accuracy\")\n", + "_ = mdl.lab3.plot_accuracy_vs_risk(epistemic_images, sorted_epistemic, sorted_epistemic_preds, \"Epistemic Uncertainty vs. Accuracy\")" ] }, { @@ -621,7 +485,13 @@ "id": "iyn0IE6x71Eh" }, "source": [ - "How do these compare to the bias plots? Was this expected or unexpected?" + "#### **TODO: Estimating epistemic uncertainties with Capsa**\n", + "\n", + "Write short answers to the questions below to complete the `TODO`s:\n", + "\n", + "1. How does accuracy relate to the epistemic uncertainty?\n", + "2. How do the results for epistemic uncertainty compare to the results for representation bias? Was this expected or unexpted? Why?\n", + "3. What may be instances in the facial detection task that could have high representation density but also high uncertainty? " ] }, { @@ -632,67 +502,15 @@ "source": [ "# 3.4 Resampling based on risk metrics\n", "\n", - "Finally, let's use both the bias score and the reconstruction loss to adaptively resample from our dataset. Since we can calculate this score on-the-fly *during training*, we can adjust the probability of samples being chosen. \n", + "Finally, we will use the risk metrics just computed to actually *mitigate* the issues of bias and uncertainty in the facial detection classifier.\n", "\n", - "Note that we want to debias and amplify only the *positive* samples in the dataset, so we're going to only adjust probabilities and calculate scores for these samples. \n", + "Specifically, we will use the latent variables learned via the VAE to adaptively re-sample the face (CelebA) data during training, following the approach of [recent work](http://introtodeeplearning.com/AAAI_MitigatingAlgorithmicBias.pdf). We will alter the probability that a given image is used during training based on how often its latent features appear in the dataset. So, faces with rarer features (like dark skin, sunglasses, or hats) should become more likely to be sampled during training, while the sampling probability for faces with features that are over-represented in the training dataset should decrease (relative to uniform random sampling across the training data).\n", "\n", - "We want to *amplify*, or increase the probability of sampling, of images with high epistemic uncertainty, since these data points come from areas of the latent distribution that the model hasn't learned very well yet. We also want to amplify images with very low representation bias, since otherwise, the model won't see enough of these samples during training. Let's define two functions below to do this:" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hRL5nUBs71Ei" - }, - "source": [ - "First, let's do this for the bias. We have a smoothing parameter `alpha` that we can tune: as `alpha` increases, the probabilities will tend towards a uniform distribution, and as `alpha` decreases, the probabilities will correlate more directly with the bias. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "0wR2bMw571Ei" - }, - "outputs": [], - "source": [ - "def score_to_probability_bias(score, alpha):\n", - " score = score + alpha\n", - " probabilities = 1/score\n", - " probabilities = probabilities/sum(probabilities)\n", - " return probabilities" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "TUs-0O_v71Ei" - }, - "source": [ - "Let's now define a similar function for the epistemic probabilities: note that in this case, we want high epistemic uncertainty to correlate with a higher probability!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "hLWGKvc971Ei" - }, - "outputs": [], - "source": [ - "def score_to_probability_epistemic(score, beta):\n", - " score = score + beta\n", - " probabilities = score/sum(score)\n", - " return probabilities" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "meZxtxFS71Ei" - }, - "source": [ - "Now, let's redefine and re-train our debiasing model!" + "Note that we want to debias and amplify only the *positive* samples in the dataset -- the faces -- so we are going to only adjust probabilities and calculate scores for these samples. We focus on using the representation bias scores to implement this adaptive resampling to achieve model debiasing.\n", + "\n", + "We re-define the wrapped model with `HistogramVAEWrapper`, and then define the adaptive resampling operation for training. At each training epoch, we compute the predictions, uncertainties, and representation bias scores, then recompute the data sampling probabilities according to the *inverse* of the representation bias score. That is, samples with higher representation densities will end up with lower re-sampling probabilities; samples with lower representations will end up with higher re-sampling probabilities.\n", + "\n", + "Let's do all this below!" ] }, { @@ -703,11 +521,19 @@ }, "outputs": [], "source": [ - "standard_classifier = make_standard_classifier()\n", - "dbvae = HistogramVAEWrapper(standard_classifier, latent_dim=100, num_bins=5, queue_size=2000, decoder=make_face_decoder_network())\n", - "dbvae.compile(optimizer=tf.keras.optimizers.Adam(1e-4),\n", - " loss=tf.keras.losses.BinaryCrossentropy(),\n", - " metrics=[tf.keras.metrics.BinaryAccuracy()])\n", + "### Define the standard CNN classifier and wrap with HistogramVAE ###\n", + "\n", + "classifier = make_standard_classifier()\n", + "# Wrap with HistogramVAE\n", + "wrapper = capsa.HistogramVAEWrapper(classifier, latent_dim=32, num_bins=5, \n", + " queue_size=2000, decoder=make_face_decoder_network())\n", + "\n", + "# Build the wrapped model for the classification task\n", + "wrapper.compile(optimizer=tf.keras.optimizers.Adam(5e-4),\n", + " loss=tf.keras.losses.BinaryCrossentropy(),\n", + " metrics=[tf.keras.metrics.BinaryAccuracy()])\n", + "\n", + "# Load training data\n", "train_imgs = train_loader.get_all_faces()" ] }, @@ -719,26 +545,32 @@ }, "outputs": [], "source": [ - "# The training loop -- outer loop iterates over the number of epochs\n", - "for i in range(6):\n", + "### Debiasing via resampling based on risk metrics ###\n", "\n", - " print(\"Starting epoch {}/{}\".format(i+1, 6))\n", + "# The training loop -- outer loop iterates over the number of epochs\n", + "num_epochs = 6\n", + "for i in range(num_epochs):\n", + " print(\"Starting epoch {}/{}\".format(i+1, num_epochs))\n", " \n", - " # get a batch of training data and compute the training step\n", + " # Get a batch of training data and compute the training step\n", " for step, data in enumerate(train_loader):\n", - " metrics = dbvae.train_step(data)\n", + " metrics = wrapper.train_step(data)\n", " if step % 100 == 0:\n", " print(step)\n", - " _, recon_loss, bias_scores = dbvae(train_imgs)\n", - " recon_loss = np.squeeze(recon_loss)\n", - "\n", - " # Recompute data sampling proabilities\n", - " p_faces = score_to_probability_bias(bias_scores.numpy(), 1e-7)\n", - " p_recon = score_to_probability_epistemic(recon_loss, 1e-7)\n", - " p_final = (p_faces + p_recon)/2\n", - " p_final /= sum(p_final)\n", - " \n", - " train_loader.p_pos = p_final" + "\n", + " # After the epoch is done, recompute data sampling proabilities \n", + " # according to the inverse of the bias\n", + " pred, unc, bias = wrapper(train_imgs)\n", + "\n", + " # Increase the probability of sampling under-represented datapoints by setting \n", + " # the probability to the **inverse** of the biases\n", + " inverse_bias = 1.0 / (bias.numpy() + 1e-7)\n", + "\n", + " # Normalize the inverse biases in order to convert them to probabilities\n", + " p_faces = inverse_bias / np.sum(inverse_bias)\n", + "\n", + " # Update the training data loader to sample according to this new distribution\n", + " train_loader.p_pos = p_faces" ] }, { @@ -747,18 +579,11 @@ "id": "SwXrAeBo71Ej" }, "source": [ - "Now, we should have a debiased model that also mitigates some forms of uncertainty! Let's see how well our model does:" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "MXiB-DMH71Ej" - }, - "source": [ - "# 3.5 Evaluation\n", + "That's it! We should have a debiased model (we hope!). Let's see how the model does.\n", "\n", - "Let's run the same analyses as before, and plot the accuracy vs. the bias and accuracy vs. epistemic uncertainty. We want the model to do better on less biased and more uncertain samples than it did previously\n" + "### Evaluation\n", + "\n", + "Let's run the same analyses as before, and plot the classification accuracy vs. the representation bias and classification accuracy vs. epistemic uncertainty. We want the model to do better across the data samples, achieving higher accuracies on the under-represented and more uncertain samples compared to previously.\n" ] }, { @@ -769,37 +594,21 @@ }, "outputs": [], "source": [ - "predictions, reconstruction_loss, bias = dbvae.predict(test_imgs)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "zCXVIsaJ71Ej" - }, - "outputs": [], - "source": [ + "### Evaluation of debiased model ###\n", + "\n", + "# Get classification predictions, uncertainties, and representation bias scores\n", + "pred, unc, bias = wrapper.predict(test_imgs)\n", + "\n", + "# Sort according to lowest to highest representation scores\n", "indices = np.argsort(bias, axis=None)\n", - "bias_images = test_imgs[indices]\n", - "sorted_bias = bias[indices]\n", - "sorted_bias_preds = predictions[indices]\n", - "_ = lab3.plot_accuracy_vs_risk(bias_images, sorted_bias, sorted_bias_preds, \"Bias vs. Accuracy\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "P6p2j_xa71Ej" - }, - "outputs": [], - "source": [ - "indices = np.argsort(reconstruction_loss, axis=None)\n", - "epistemic_images = test_imgs[indices]\n", - "sorted_epistemic = bias[indices]\n", - "sorted_epistemic_preds = predictions[indices]\n", - "_ = lab3.plot_accuracy_vs_risk(epistemic_images, sorted_epistemic, sorted_epistemic_preds, \"Epistemic Uncertainty vs. Accuracy\")" + "bias_images = test_imgs[indices] # sort the images\n", + "sorted_bias = bias[indices] # sort the representation bias scores\n", + "sorted_bias_preds = pred[indices] # sort the predictions\n", + "\n", + "# Plot the representation bias vs. the accuracy\n", + "plt.xlabel(\"Density (Representation)\")\n", + "plt.ylabel(\"Accuracy\")\n", + "_ = mdl.lab3.plot_accuracy_vs_risk(bias_images, sorted_bias, sorted_bias_preds, \"Bias vs. Accuracy\")" ] }, { @@ -808,26 +617,45 @@ "id": "d1cEEnII71Ej" }, "source": [ - "# 3.6 Conclusion\n", + "# 3.5 Competition!\n", + "\n", + "Now, you are well equipped to submit to the competition to dig in deeper into deep learning models, uncover their deficiencies with Capsa, address those deficiencies, and submit your findings!\n", + "\n", + "**Below are some potential areas to start investigating -- the goal of the competition is to develop creative and innovative solutions to address bias and uncertainty, and to improve the overall performance of deep learning models.**\n", + "\n", + "We encourage you to identify other questions that could be solved with Capsa and use those as the basis of your submission. But, to help get you started, here are some interesting questions that you might look into solving with these new tools and knowledge that you've built up: \n", + "\n", + "1. In this lab, you learned how to build a wrapper that can estimate the bias within the training data, and take the results from this wrapper to adaptively re-sample during training to encourage learning on under-represented data. \n", + " * Can we apply a similar approach to mitigate epistemic uncertainty in the model? \n", + " * Can this approach be combined with your original bias mitigation approach to achieve robustness across both bias *and* uncertainty? \n", + "\n", + "2. In this lab, you focused on the `HistogramVAEWrapper`. \n", + " * How can you use other methods of uncertainty in Capsa to strengthen your uncertainty estimates? Checkout [Capsa documentation](https://themisai.io/capsa/api_documentation/index.html) for a list of all wrappers, and ask for help if you run into trouble applying them to your model!\n", + " * Can you combine uncertainty estimates from different wrappers to achieve greater robustness in your estimates? \n", + "\n", + "3. So far in this part of the lab, we focused only on bias and epistemic uncertainty. What about aleatoric uncetainty? \n", + " * We've curated a dataset (available at [this URL](https://www.dropbox.com/s/wsdyma8a340k8lw/train_face_2023_perturbed_large.h5?dl=0)) of faces with greater amounts of aleatoric uncertainty -- can you use Capsa to wrap your model, estimate aleatoric uncertainty, and remove it from the dataset? \n", + " * Does removing aleatoric uncertainty help improve your training accuracy on this new dataset? \n", + " * Can you develop an approach to incorporate this aleatoric uncertainty estimation into the predictive training pipeline in order to improve accuracy? You may find some surprising results!!\n", "\n", - "We encourage you to think about and maybe even address some questions raised by the approach and results outlined here:\n", + "4. How can the performance of the classifier above be improved even further? We purposely did not optimize hyperparameters to leave this up to you!\n", "\n", - "* We did not analyze the *aleatoric* uncertainty of the above dataset. Try to develop a similar approach (assigning probabilities based on aleatoric uncertainty) and incorporate this as well! You may find some surprising results :)\n", + "5. Are there other applications that you think Capsa and bias/uncertainty estimation would be helpful in? \n", + " * Try integrating Capsa into another domain or dataset and submit your findings!\n", + " * Are there applications where you may *not* want to debias your model? \n", "\n", - "* How can the performance of the classifier above be improved even further? We purposely did not optimize hyperparameters to leave this up to you!\n", "\n", - "* How can you use other methods of uncertainty in CAPSA to strengthen your uncertainty estimates?\n", + "**To enter the competition, please upload the following to the [lab submission site](https://www.dropbox.com/request/TTYz3Ikx5wIgOITmm5i2):**\n", "\n", - "* In which applications (either related to facial detection or not!) would debiasing in this way be desired? Are there applications where you may not want to debias your model?\n", + "* Written short-answer responses to `TODO`s from Lab 2, Part 2 on Facial Detection.\n", + "* Description of the wrappers, algorithms, and approach you used. What was your strategy? What wrappers did you implement? What debiasing or mitigation strategies did you try? How and why did these modifications affect performance? Describe *any* modifications or implementations you made to the template code, and what their effects were. Written text, visual diagram, and plots welcome!\n", + "* Jupyter notebook with the code you used to generate your results (along with all plots/visuals generated).\n", "\n", - "* Try to optimize your model to achieve improved performance. MIT students and affiliates will be eligible for prizes during the IAP offering. To enter the competition, MIT students and affiliates should upload the following to the course Canvas:\n", + "**Name your file in the following format: `[FirstName]_[LastName]_Face`, followed by the file format (.zip, .ipynb, .pdf, etc).** ZIP files are preferred over individual files. If you submit individual files, you must name the individual files according to the above nomenclature (e.g., `[FirstName]_[LastName]_Face_TODO.pdf`, `[FirstName]_[LastName]_Face_Report.pdf`, etc.). **Submit your files [here](https://www.dropbox.com/request/TTYz3Ikx5wIgOITmm5i2).**\n", "\n", - "* Jupyter notebook with the code you used to generate your results;\n", - "copy of the line plots from section 3.5 showing the performance of your model;\n", - "* a description and/or diagram of the architecture and hyperparameters you used -- if there are any additional or interesting modifications you made to the template code, please include these in your description;\n", - "* discussion of why these modifications helped improve performance.\n", + "We encourage you to think about and maybe even address some questions raised by this lab and dig into any questions that you may have about the risks inherrent to neural networks and their data. \n", "\n", - "Hopefully this lab has shed some light on a few concepts, from vision based tasks, to VAEs, to algorithmic bias. We like to think it has, but we're biased ;)." + "<img src=\"https://i.ibb.co/BjLSRMM/ezgif-2-253dfd3f9097.gif\" />" ] } ], diff --git a/lab3/solutions/Lab3_Part_1_Introduction_to_CAPSA.ipynb b/lab3/solutions/Lab3_Part_1_Introduction_to_CAPSA.ipynb index c3d6d3fd..09b73c8b 100644 --- a/lab3/solutions/Lab3_Part_1_Introduction_to_CAPSA.ipynb +++ b/lab3/solutions/Lab3_Part_1_Introduction_to_CAPSA.ipynb @@ -2,37 +2,77 @@ "cells": [ { "cell_type": "markdown", + "source": [ + "<table align=\"center\">\n", + " <td align=\"center\"><a target=\"_blank\" href=\"http://introtodeeplearning.com\">\n", + " <img src=\"https://i.ibb.co/Jr88sn2/mit.png\" style=\"padding-bottom:5px;\" />\n", + " Visit MIT Deep Learning</a></td>\n", + " <td align=\"center\"><a target=\"_blank\" href=\"https://colab.research.google.com/github/aamini/introtodeeplearning/blob/2023/lab3/solutions/Lab3_Part_1_Introduction_to_CAPSA.ipynb\">\n", + " <img src=\"https://i.ibb.co/2P3SLwK/colab.png\" style=\"padding-bottom:5px;\" />Run in Google Colab</a></td>\n", + " <td align=\"center\"><a target=\"_blank\" href=\"https://github.com/aamini/introtodeeplearning/blob/2023/lab3/solutions/Lab3_Part_1_Introduction_to_CAPSA.ipynb\">\n", + " <img src=\"https://i.ibb.co/xfJbPmL/github.png\" height=\"70px\" style=\"padding-bottom:5px;\" />View Source on GitHub</a></td>\n", + "</table>\n", + "\n", + "# Copyright Information" + ], "metadata": { - "id": "ckzz5Hus-hJB" - }, + "id": "SWa-rLfIlTaf" + } + }, + { + "cell_type": "code", "source": [ - "## Part 1: Introduction to CAPSA" - ] + "# Copyright 2023 MIT Introduction to Deep Learning. All Rights Reserved.\n", + "# \n", + "# Licensed under the MIT License. You may not use this file except in compliance\n", + "# with the License. Use and/or modification of this code outside of MIT Introduction\n", + "# to Deep Learning must reference:\n", + "#\n", + "# © MIT Introduction to Deep Learning\n", + "# http://introtodeeplearning.com\n", + "#" + ], + "metadata": { + "id": "-LohleBMlahL" + }, + "execution_count": null, + "outputs": [] }, { "cell_type": "markdown", "metadata": { - "id": "gTpt_Hj5j-FZ" + "id": "ckzz5Hus-hJB" }, "source": [ - "As we saw in lecture 6, it is critical to be able to estimate bias and uncertainty robustly: we need benchmarks that uniformly measure how uncertain a given model is, and we need principled ways of measuring bias and uncertainty. To that end, in this lab, we'll utilize [CAPSA](https://github.com/themis-ai/capsa), a risk-estimation wrapping library developed by [Themis AI](https://themisai.io/). CAPSA supports the estimation of three different types of *risk*, defined as measures of how trustworthy our model is. These are:\n", - "1. Representation bias: using a histogram estimation approach, CAPSA calculates how likely combinations of features are to appear in a given dataset. Often, certain combinations of features are severely underrepresented in datasets, which means models learn them less well. Since evaluation metrics are often also biased in the same manner, these biases are not caught through traditional validation pipelines.\n", - "2. Aleatoric uncertainty: we can estimate the uncertainty in *data* by learning a layer that predicts a standard deviation for every input. This is useful to determine when sensors have noise, classes in datasets have low separations, and generally when very similar inputs lead to drastically different outputs.\n", - "3. Epistemic uncertainty: also known as predictive or model uncertainty, epistemic uncertainty captures the areas of our underlying data distribution that the model has not yet learned. Areas of high epistemic uncertainty can be due to out of distribution (OOD) samples or data that is harder to learn.\n" + "# Laboratory 3: Debiasing, Uncertainty, and Robustness\n", + "\n", + "# Part 1: Introduction to Capsa\n", + "\n", + "In this lab, we'll explore different ways to make deep learning models more **robust** and **trustworthy**.\n", + "\n", + "To achieve this it is critical to be able to identify and diagnose issues of bias and uncertainty in deep learning models, as we explored in the Facial Detection Lab 2. We need benchmarks that uniformly measure how uncertain a given model is, and we need principled ways of measuring bias and uncertainty. To that end, in this lab, we'll utilize [Capsa](https://github.com/themis-ai/capsa), a risk-estimation wrapping library developed by [Themis AI](https://themisai.io/). Capsa supports the estimation of three different types of ***risk***, defined as measures of how robust and trustworthy our model is. These are:\n", + "1. **Representation bias**: reflects how likely combinations of features are to appear in a given dataset. Often, certain combinations of features are severely under-represented in datasets, which means models learn them less well and can thus lead to unwanted bias.\n", + "2. **Data uncertainty**: reflects noise in the data, for example when sensors have noisy measurements, classes in datasets have low separations, and generally when very similar inputs lead to drastically different outputs. Also known as *aleatoric* uncertainty. \n", + "3. **Model uncertainty**: captures the areas of our underlying data distribution that the model has not yet learned or has difficulty learning. Areas of high model uncertainty can be due to out-of-distribution (OOD) samples or data that is harder to learn. Also known as *epistemic* uncertainty." ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { "id": "o02MyoDrnNqP" }, "source": [ - "The core ideology behind CAPSA is that models can be *wrapped* in a way that makes them *risk-aware*. \n", + "## CAPSA overview\n", + "\n", + "This lab introduces Capsa and its functionalities, to next build automated tools that use Capsa to mitigate the underlying issues of bias and uncertainty.\n", + "\n", + "The core idea behind [Capsa](https://themisai.io/capsa/) is that any deep learning model of interest can be ***wrapped*** -- just like wrapping a gift -- to be made ***aware of its own risks***. Risk is captured in representation bias, data uncertainty, and model uncertainty.\n", "\n", "\n", "\n", - "This means that CAPSA augments or modifies the user's original model minimally to create a risk-aware variant while preserving the model's underlying structure and training pipeline. CAPSA is a one-line addition to any training workflow in Tensorflow. In this part of the lab, we'll apply CAPSA's risk estimation methods to a toy regression task to further explore the notions of bias and uncertainty. " + "This means that Capsa takes the user's original model as input, and modifies it minimally to create a risk-aware variant while preserving the model's underlying structure and training pipeline. Capsa is a one-line addition to any training workflow in TensorFlow. In this part of the lab, we'll apply Capsa's risk estimation methods to a simple regression problem to further explore the notions of bias and uncertainty. \n", + "\n", + "Please refer to [Capsa's documentation](https://themisai.io/capsa/) for additional details." ] }, { @@ -41,51 +81,34 @@ "id": "hF0uSqk-nwmA" }, "source": [ - "Let's first install necessary dependencies:" + "Let's get started by installing the necessary dependencies:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "NdXF4Reyj6yy", - "outputId": "e21a92b6-cb80-4da3-9b25-f447bf28482b" + "id": "NdXF4Reyj6yy" }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", - "Requirement already satisfied: capsa in /usr/local/lib/python3.8/dist-packages (0.1.2)\n", - "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", - "Requirement already satisfied: mitdeeplearning in /usr/local/lib/python3.8/dist-packages (0.2.0)\n", - "Requirement already satisfied: gym in /usr/local/lib/python3.8/dist-packages (from mitdeeplearning) (0.25.2)\n", - "Requirement already satisfied: numpy in /usr/local/lib/python3.8/dist-packages (from mitdeeplearning) (1.21.6)\n", - "Requirement already satisfied: regex in /usr/local/lib/python3.8/dist-packages (from mitdeeplearning) (2022.6.2)\n", - "Requirement already satisfied: tqdm in /usr/local/lib/python3.8/dist-packages (from mitdeeplearning) (4.64.1)\n", - "Requirement already satisfied: importlib-metadata>=4.8.0 in /usr/local/lib/python3.8/dist-packages (from gym->mitdeeplearning) (5.2.0)\n", - "Requirement already satisfied: gym-notices>=0.0.4 in /usr/local/lib/python3.8/dist-packages (from gym->mitdeeplearning) (0.0.8)\n", - "Requirement already satisfied: cloudpickle>=1.2.0 in /usr/local/lib/python3.8/dist-packages (from gym->mitdeeplearning) (1.5.0)\n", - "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.8/dist-packages (from importlib-metadata>=4.8.0->gym->mitdeeplearning) (3.11.0)\n" - ] - } - ], + "outputs": [], "source": [ + "# Import Tensorflow 2.0\n", + "%tensorflow_version 2.x\n", "import tensorflow as tf\n", - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "!pip install capsa\n", "\n", - "from capsa import *\n", - "from helper import gen_data_regression\n", + "import IPython\n", + "import functools\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "from tqdm import tqdm\n", "\n", + "# Download and import the MIT Introduction to Deep Learning package\n", "!pip install mitdeeplearning\n", "import mitdeeplearning as mdl\n", - "import tqdm" + "\n", + "# Download and import Capsa\n", + "!pip install capsa\n", + "import capsa" ] }, { @@ -94,61 +117,70 @@ "id": "xzEcxjKHn8gc" }, "source": [ - "### 1.1 Datasets \n", - "Next, let's construct a dataset that we'll analyze. As shown in lecture, we'll look at the curve `y = x^3` with epistemic and aleatoric noise added to certain parts of the dataset. The blue points below are the test data: note that there are regions where we have no train data but we have test data! Do you expect these areas to have higher or lower uncertainty? What type of uncertainty?" + "## 1.1 Dataset\n", + "\n", + "We will build understanding of bias and uncertainty by training a neural network for a simple 2D regression task: modeling the function $y = x^3$. We will use Capsa to analyze this dataset and the performance of the model. Noise and missing-ness will be injected into the dataset.\n", + "\n", + "Let's generate the dataset and visualize it:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 265 - }, - "id": "fH40EhC1j9dH", - "outputId": "c6936767-2162-4b6c-b430-e5717c70bb75" + "id": "fH40EhC1j9dH" }, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAD4CAYAAAAJmJb0AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3deXxU9b3/8ddnJpOFhBCWsEiAEGSXPaUqchUXpIJLqVpptdpWuN6fcrtYFepWve3Vq62ttbW30FqXWtTritRaUNHSUqSAIAICIWzBQCYICdkzmc/vjzMThpCQhEwyk5nP8/HIIzPnnDnnO4G855vP+Z7vEVXFGGNMbHJFugHGGGPaj4W8McbEMAt5Y4yJYRbyxhgTwyzkjTEmhiVEugGhevXqpdnZ2ZFuhjHGdCrr168vVtXMxtZFVchnZ2ezbt26SDfDGGM6FRHZ29Q6K9cYY0wMs5A3xpgYZiFvjDExLKpq8o2pra2loKCAqqqqSDclJiQnJ5OVlYXH44l0U4wxHaDNIS8iycDfgKTA/l5W1ftFZDDwAtATWA/coKo1rd1/QUEBXbt2JTs7GxFpa3Pjmqpy+PBhCgoKGDx4cKSbY4zpAOEo11QDF6rqOGA8MENEzgb+B/i5qp4JHAG+fTo7r6qqomfPnhbwYSAi9OzZ0/4qMiaOtDnk1VEWeOoJfClwIfByYPkzwFWnewwL+PCxn6Ux8SUsJ15FxC0iG4EiYAWwCziqqr7AJgVA/yZeO09E1onIOq/XG47mGGNMp5LvLWPhqx+T7y1rfuNWCkvIq2qdqo4HsoDJwIhWvHaRquaqam5mZqMXbEXU0aNHefLJJ0/79b/4xS+oqKhodrv333+fWbNmnXKbjRs38tZbb512W4wx0WnxqnyWrN3P4lX5Yd93WIdQqupRYCVwDpAhIsETu1nAgXAeq6N0VMi3hIW8MbFp7tQc5kwewNypOWHfd5tDXkQyRSQj8DgFuATYhhP2Vwc2uxF4o63HioQFCxawa9cuxo8fzx133AHAo48+yhe+8AXGjh3L/fffD0B5eTkzZ85k3LhxnHXWWbz44ov88pe/5LPPPmPatGlMmzbtpH2//fbbjBgxgokTJ/Lqq6/WL1+7di3nnHMOEyZM4Nxzz2X79u3U1NRw33338eKLLzJ+/HhefPHFRrczxnQ+OZlpPDR7LDmZaeHfuaq26QsYC3wEfAx8AtwXWJ4DrAXygP8Dkprb16RJk7ShrVu3nrSsI+3evVtHjx5d//yvf/2rzp07V/1+v9bV1enMmTP1gw8+0Jdffllvvvnm+u2OHj2qqqqDBg1Sr9d70n4rKys1KytLd+zYoX6/X6+55hqdOXOmqqqWlJRobW2tqqquWLFCZ8+eraqqf/jDH/TWW2+t30dT2zUn0j9TY0x4Aeu0iVxt8zh5Vf0YmNDI8nyc+nyHy/eWsXhVPnOn5oT9k3H58uUsX76cCROct1xWVsbOnTuZOnUqt99+O3fddRezZs1i6tSpp9zPp59+yuDBgxk6dCgA119/PYsWLQKgpKSEG2+8kZ07dyIi1NbWNrqPlm5njIlfMTmtQXuexFBVFi5cyMaNG9m4cSN5eXl8+9vfZtiwYWzYsIExY8Zwzz338OCDD572Me69916mTZvGJ598wptvvtnkuPaWbmeMiV8xGfLhPInRtWtXjh07Vv/80ksv5amnnqKszBnqdODAAYqKivjss8/o0qUL119/PXfccQcbNmxo9PVBI0aMYM+ePezatQuAJUuW1K8rKSmhf39nxOnTTz/dZFua2s4YY4JiMuTDeRKjZ8+eTJkyhbPOOos77riD6dOn87WvfY1zzjmHMWPGcPXVV3Ps2DE2b97M5MmTGT9+PA888AD33HMPAPPmzWPGjBknnXhNTk5m0aJFzJw5k4kTJ9K7d+/6dXfeeScLFy5kwoQJ+Hy++uXTpk1j69at9Sdem9rOGGOCxKnZR4fc3FxteNOQbdu2MXLkyAi1KDbZz9SY2CIi61U1t7F1MdmTN8YY47CQN8aYDtKe0xc0xULeGGM6SHuO/GtK1N80xBhjYkVwxF97TF/QFAt5Y4zpIMGRfx3JyjXGGNPOIlGLD7KQb0ZbZqG87LLLOHr06GkfOy3t1OP82zpDpjGm/eV7y5j77LoOr8UHWcg341RB2twFSG+99RYZGRnt0SzAQt6YzmDxqnx2ecsZkpnaobX4IAv5ZjScavj9999n6tSpXHHFFYwaNQqAq666ikmTJjF69Oj6ScYAsrOzKS4uZs+ePYwcOZK5c+cyevRopk+fTmVl5UnH2r17d/2VtMErZsGZBO2iiy5i4sSJjBkzhjfeeKPRtjW1nTEmMj7YXsTfdxZzwbBMFn8jt32mEm5OU9NTRuKrM0w1vHLlSu3SpYvm5+fXLzt8+LCqqlZUVOjo0aO1uLhYVY9PM7x79251u9360UcfqarqNddco88999xJx7r88sv1mWeeUVXVX/3qV5qamqqqqrW1tVpSUqKqql6vV4cMGaJ+v/+ktjW1XUOR/pkaEy8u/OlKHXTXMr3wpyvb9TicYqrh2OzJF+fB0u8439vB5MmTGTx4cP3zX/7yl4wbN46zzz6b/fv3s3PnzpNeM3jwYMaPHw/ApEmT2LNnz0nb/OMf/2DOnDkA3HDDDfXLVZUf/vCHjB07losvvpgDBw5w6NChk17f0u2MMR3jvlmjGJKZyn2zRkWsDbE5hHL1E7DhaefxFY+Hffepqan1j99//33eeecd/vnPf9KlSxcuuOCCRqf8TUpKqn/sdrsbLdcAiMhJy55//nm8Xi/r16/H4/GQnZ3d6DFaup0xpmOcP7w37w7v3fyG7Sg2e/LnzoeJNznf26ipqYKDSkpK6N69O126dOHTTz9lzZo1p32sKVOm8MILLwBOYIceo3fv3ng8HlauXMnevXsbbVtT2xlj4ldshnyvM50efK8z27yrhlMNNzRjxgx8Ph8jR45kwYIFnH322ad9rMcff5xf//rXjBkzhgMHjt/3/Otf/zrr1q1jzJgxPPvss4wYMaLRtjW1nTEmftlUw3HIfqbGtI/2vPXoqdhUw8YY0wEiMQFZc2LzxKsxxkRAJCYga06nCHlVbXTUiWm9aCrPGRNrIjEBWXOivlyTnJzM4cOHLZzCQFU5fPgwycnJkW6KMaaDRH1PPisri4KCArxeb6SbEhOSk5PJysqKdDOMMR0k6kPe4/GccHWpMcaYlov6co0xxpjT1+aQF5EBIrJSRLaKyBYR+U5geQ8RWSEiOwPfu7e9ucYYEx0ieSOQ1ghHT94H3K6qo4CzgVtFZBSwAHhXVYcC7waeG2NMTPjZ8u0sWbufny3fHummnFKbQ15VC1V1Q+DxMWAb0B+4EngmsNkzwFVtPZYxxkSDfG8ZmwtKABCie3h3WGvyIpINTAA+BPqoamFg1UGgTxOvmSci60RknY2gMcZ0BotX5bPvSCVDMlP5/vRhkW7OKYVtdI2IpAGvAN9V1dLQi5dUVUWk0YHuqroIWATO3DXhao8xxrSX0CtbI3K3p1YIS09eRDw4Af+8qr4aWHxIRPoF1vcDisJxLGOMiZTgyVaAh2aPjfqAh/CMrhHg98A2VX0sZNVS4MbA4xsBu+GoMaZTi8YJyJoTjnLNFOAGYLOIbAws+yHwMPCSiHwb2AtcG4ZjGWNMROR7yyitrGXWmH5RNQFZc9oc8qr6d2jy9PJFbd2/McZEg8Wr8vnz5oPMmTygU5RpgqJ+WgNjjIkG0TiNcEvYtAbGGHMKnfFkaygLeWOMOYXOeLI1lJVrjDGmEcH7tc4Y3RfofGWaIAt5Y4xpRLAHD0Td3Z5aw0LeGGMa0VlPtDZkNXljjGkgWKrpDNMWNMdC3hhjGujsJ1tDWbnGGGMaiJVSDVjIG2PMSXIy0zr1ydZQVq4xxpgYZj15Y0xcy/eW8bPl2xGE708f1ulPtDZkIW+MiWs/W76dP28+CEDXlISYKdMEWbnGGBO3Qu/VOrB7SkycaG3IevLGmLgVeq/Wxd/IjblSDVjIG2PiWGe6V+vpsnKNMSZuBKcNzveWAceHSsZqwIOFvDEmjsTSlawtZeUaY0zciKUrWVvKevLGmLgQS5OOtYb15I0xMeuD7UXc8/onDMlMIzXJXT8ePtbGwp+K9eSNMTHrwWVb2X+kkvd3eCmvrmPO5AFxVaoBC3ljTAybNzWHhEDKpSa5Y34kTWMs5I0xMemD7UX85K1t+PwwJDOV26cPj3STIsJq8saYmBF68+35Sz6itMpHenJCzF7N2hJhCXkReQqYBRSp6lmBZT2AF4FsYA9wraoeCcfxjDGmoXxvGTc9tZZ9RypZuvEzymvqSE9O4Ik5E+I24CF85ZqngRkNli0A3lXVocC7gefGGNMugvPQAJTX1DEkM5XXb53C+cN7R7hlkRWWkFfVvwGfN1h8JfBM4PEzwFXhOJYxxjRmxui+pHicSOubnhTXJZpQ7VmT76OqhYHHB4E+jW0kIvOAeQADBw5sx+YYY2LR/p2bKHrjPjzlpTxCIo/JNYweNNECPqBDTryqqoqINrFuEbAIIDc3t9FtjDHmBMV5sPoJCvtfQuKb/8kkDjvL3dC9ey/OmH5dZNsXRdoz5A+JSD9VLRSRfkBROx7LGBNPVj8BG56mYv1yhogT8JUkUXfmdM6b8SPoZb34oPYcJ78UuDHw+EbgjXY8ljEmnvTPpZRUXq49j/fqxlGU0I+Ury8h7fo/Qq8zI926qBKWkBeRJcA/geEiUiAi3wYeBi4RkZ3AxYHnxhhz2vK9Zdz6/HoOv3kv6ZQzP/E13s76DmX/vg6GXhTp5kWlsJRrVHVOE6vsp26MCYt8bxkPPP0GXyl5lnT3URDoQi2P9PsbZH4l0s2LWnbFqzEm+hXnUfv7m1hcuYXEBD8ANeriSI8J9Dl3foQbF90s5I0x0a04D/74FYZX7QEJLMvIJvH6V+hj9fdmWcgbY6JW4fplZCy7mRStrF9WTSJJ179iJ1hbyELeGBNVgpOMnZdxlPM/+CYpUgPAUVd3fO4Uamc8Sj8L+BazkDfGRIfABU6v7B/OFQefYZLrUxID5ZkKkjg25w0GDB0X2TZ2QhbyxpjoELjAaZ52oZu7on5xLW5KLn/KAv40WcgbYyKrOA+Wzse3bw1uoJtUUKtQJ24SUnrimf2/9LMx8KfNQt4YExnFebDyx7DnH1BedEIYeQT8GYNJ+O76iDUvVljIG2Mi4tDz8+hzpPEQr8VN0sxHOrhFscnu8WqM6VBFHyym+ke96P35yQGvwGF3JsWXP2vTFISJ9eSNMR2jOA/+che9dr3j9C7l+CrnGlYXrml30/P8H0SkebHKQt4Y0+52vfIA2Zsfw82J5QNVOCrpVF3+G/pNmhWp5sU0C3ljTPgV51Hy3mNs3F3EeZUryOGEjjt1Cn6E17rdRO4NP7a7OLUjC3ljTPgERszU7FlDt/JC/k1BQtJdFQ5pN+533crX5nyTa+P8JtsdwULeGBM25Ut/QOq+lSQGnjcM+D+lfYPVZ9zEXdOHW++9g1jIG2NOT2AaAs6dT9GWlWSsXEgXahvdVIGj4+by9dk/5esd28q4ZyFvjGm94jxYMgcO78C/4WkyObHmDk7PXcQZOeOadi/dbdRMRFjIm+gW0lu0qWWjQHEeNb+/DE/lIQQnyF0N0l0VVODg4NnsL/Vxxpd+YPPORJCFvIleIb1FAK54/OT19gHQIQ7/5b/p9uH/4IL6ejucWHMP2p02Dm56i5zMNM7oqAaaJlnIm+i1+gkn4HsOc4K8sfUbnoaqEkjuZmEfTjvfhaXzoboUf80xenByOSZUnYJfBE+/CeR8ZTH0spOq0cJC3kSvkVfA3tUw4+HGwzsY/NUlTtjDyb39lrK/Chwf/JS6lT9B8NdftHSquU8UqPIncHfSQm6bd4uNmIlCFvImuoSG7balTk9+43PO42Cor37C+QDYttT5vvE5GHIxlB6Al26CC+85dVA3DPTmykKxKDCeveLzz6gsyqdQ+pDmK2YQB3E381INGfteNuRy/jv1Tm6bmmMBH6Us5E10ee/HsPU1pwQz4QanJ191DLa8BkBZ6WHS8t7kyMY36O4/AtvfgvIiSO/vhDxAzTFIzzr+ofDejymr8fFb93VcmzuAAW/ffGKgN1UWipHeffB2ev85aB/9PrgTX8UR3L4KBOgS+OqJt0X7UgIBn5EN/SfQddo9PNSJfzbxwELeRF5omAYLv2UH4eWboLoU6mrgzIvBu52E/RsASKorAwF/eTEuoK68+HgPdO9qqK2A0gI4sg8O7yAN+Ir/Qz4/eBYDShoEeuj3QGDle8vwLnmAL36+1FkX7N13puAvzoNX5jKw8CPu8ieRvqkKpPW/9H7AlZQBSanI+QvhwLrO8f4NYCFvOkIjwbj2Xx9yaPlP6TP9B0wu/NPxmvr4G+DgFji6zwl4lweO7oHKz6G6lGSgXFIpSsxicM12XPjxK7jrqp3XuxKcgAcqd60mRStAEkB9ZLuK2FxbRk36QBIzBrL/SDkvrVjGze636Hbh95227XwX3pxPRnkN/1c5jcFd+tG79ICzfNtSp/6/5TVKqmp5OOEW5kayTBH6cwVYeht1BzZQpYl4/JV48CE4v+QZrqoW7VID3wUXpHSHMybg+tL/nBjok24I57sw7azdQ15EZgCPA27gd6r6cHsf00SZlT92yi3VJXDN0wD4317A5XUb2PiXAsg5wym3lBTAh79xSicDzwVPKvTPhY+XOIGf0gMqj5KanMzgVIFiZ/euwAU3ihu330dtaj98FSWkaAUVrjS6+Mso9pzBnqRhlBw9wpiEfVC6jz673udLdX3p5i6AvSucEk+h85dCD+AHnpdw+4C8QijcBOVF+HChKZm8UjmBJdv2k1Kym0uPveyMBe+eCn+5C7zbIGMQpPVxSk4f/gaKd8LMx6D7IEree4zf1V3Gly++gJzMNPbv3MTeNx/hve7X8s3zshnw0c+hpgxQqg7vJ+FIHkXZV5Ja8D5pdaW41YcfF4LfGase+IAUnF+yVKpb/U+kgD8hFfdXn7N53GOMqGrzW53uzkXcwA7gEqAA+BcwR1W3NrZ9bm6urlu3rt3aYyKj/KmrSN23krqEFNxffR6GXkTFo6PpUl5AbUIqHl95/bY1XQeSeGyf0/u+YCF88DD4a6klAVIz8ZQXOhuKG7TupGNVq5vd6V9gxLE1lJLKS93/nRF1O7jXO40pnu38yLWYBNQpQRCoMTfYR1PLCFmuQB1OqArBDxnBzYm/T9WSRJI6oVtJAh63h4S6SsrVQ40ng5o+48k8sMJpS+CljY09V218+elQhVoJGe+e0hNmL7Zw78REZL2q5ja2rr178pOBPFXNDzTkBeBKoNGQN51YoHSwf+S3eHKz8P/GKAM++jllNT42HyjhHMDtq4QlX4WZP6dLv5FQnEDZwEtI+fhp3FrLTn8WKdXJDAJc6kNX/hcC1CE8VjObOcmbGIgT8rUJqdT5anCL4vE7IVqjkCR1ZGWkcLS8Oxn+I0wqfp3klFT+nLSAJK2tr9tLg++hhONB7wekkYANlkGCnGGGelJQBwMeIAUf1PkASJVaUn1etGBF/banCvHTDfjg+wh+SJSTyJO9H+C66VMYsO0pq63HgfbuyV8NzFDVmwPPbwC+qKq3hWwzD5gHMHDgwEl79+5tt/aYdvR/N8GW1ziS0JuPqvsxKXE/3eo+B+DNurPp7y5lAludUE1IBl8VnHkJvrz3SMDpkQd7xhAy74kk4FIf69MuYHzVh7h9lfW98IZ8KiSIUi1JiNaRiO+kHrAfJ/Qk5Hm5P4murtaXOKLJCR9KgN+djDv7PBh1FdWrfsGbaV/hkvSC4+ceTEyJZE++Waq6CFgETrkmws0xrVGc5wx5FKD0IADdfUVc6C4ikNv4ga11A0lLL0PKnT/g/L4qKgdOI7V4Z33AgxPw9SWQQApXSDJ+v49+NfucvwQ4OeCDAZcgzn+fJK0O3E7O2U9o+UUbvN4FpAUCvrEyTSQ01Y7gB1YdUO7qhsvvowvVuPAjX7wFaitwBXrmoWPdkybdwNUd03QThdo75A8AA0KeZwWWmc4ocAENinNC8e0Fx8ebJyTXb1ZNIpKeRWJpPi7gDs9L+Mup74G7gMrCT0md+GX0wydPKI9UpfSnS+WB+mVpfmeoZFqt8xdeffmBxkougqIIx4O8YWC6G1kWup+m/koIas0HQTCUm6v7q4JfnNb5JIGKMd+g+74VkJzh/HxrK6jzpPJan9vqe+Pp1hs3LdTeIf8vYKiIDMYJ9+uAr7XzMU17Wf1E/UVJfPYRHN3jnEz1VTrll4Aqv5v00nwg0HOWk4Mz3VUFG54+KaiTKgudsG0wu6HL7aHWBx7q8OEmoe8YKiWZhMK1eOr77XriLeZCToT63Mkk1FXVt6mpoG7YTj/gw0MitSBuJORkrx+oIIkkquvvXepTSBBAXEivYVD8KdJ3PBwrhKqj7E7IZnN5d/YO/DLzu6/hWJWPRQnX1Y+0cQNJoQ0InOtwnzufqy3YzWlo15BXVZ+I3Ab8FacT9ZSqbmnPY5p2sPNdp9d+7n9SXryXysLtuD2ZdGcPZXUe0rQSd0hqdnNV1j8ODdPQXnJi9RHwnDi+XBUSxAnsClcaXbQcV3C0iq8KD3AoaRBpfXJI2LeScs8Z9MJPLYKHRip9rgTw11LnTiIhuRuUV53UjlB+cePSOuo8qbhrnRE/PncqiV0znbH6AyZDaWFgKOfnuGY+RlpwRMrOd6n+850U1KYypHwTjLzSmV6hwfUB6i3jn6vymTs1BzJvoStw+6l+9r3OjI9pFky7adcTr61lQyijRMhFNvnal9TfnUuf6r1UShdKXen0qTuIlwwytARPoA5+QnB6UtHa8vryRw1J+FT5e91YLnWvO34i1J3kjCc/82L8H/7meKCHqA2UPHw9RpDcPQuSu0L1Mch7h8qUvvyzdgjDM4T+xX93eujdh8CxA85fFik9ofLwiVMeAFXqJlnq8LuScPUZCaWfAbCLfmwtTeHMbjCybE3gvXRxLq7qOQzmLGn+pGVnuiLWxIxTnXg9VfnRxKvAFL51v/03iv73ct44Noo6hRStoE/dQSrVQyZH6wMeTvyP5PP763vwLiCZatKkhmnuj04cCpiYWl9zfjXrTmr0xP+OpZrMKv94EoBNDINuWU65KKkr9BxGSuVBCquS+GvXqyAp3Tnm4Kkw7EsA1LgS2dRtGoXnPwIDp0BCMnXuJJLFKbm4RlwG/cY7c98kZzCkfBO9M3uTduUjzoRnGdkw4xGYeFPLAh6O97wt4E2UiPjoGhNZwcmr5k7NoXjPFvxvL2CifwuJgLu2nLPZwCSPnFCOSZTG7+NZg4fEBDelvgR6AFV4SA7c87NMk3g97atcVfYiaRIYrthrOFQcprD/JfTc/jyJ4qecRFKpoS4hhXRfJaMTD7I+5QLO+NIPoHuq87rAZfwl7z3GwbrL+Pe6F5wrYrv2d66qnXADHNpC4uEdfOIbySd7B/JQr6Gw7x/OqJOMbDhjglNOCQrMavnFYA986CvH19ll/KYTs5CPc39c9i6X5P+MtI/30lfL6dJIgAd77HUKbjk+lt2HkMDxck0iteCrpQdVlJJKXeYokr3/AmBXxhTOvf4hjhy9mV3L7mdIOqQd2w1H97DvHy/y4OcX8Zm7ikndKxhxbI0zxvvIXvoc3kGfYTMgcPu4/HN+wuK/OR9KOdc+6dSzX3rZaZDbE+jpd4M5S+o/BOZOzQGZ78xsKcC0BlMRB2vedsWniUEW8nFs/85NzNv7Pfq6DzsLGhlyUocLd2D0Sol0pQfHnBWJaSTUlPF5Qm+6SXn9iUrS+4PLQ/rRPZCZBb2zQGHchfc4dwvKHMeA770OS78D+1dCz2GcMeMHjF5bx2om8G+TEyB4JSacOAEXsHhVPkvW7gfgodljnYUX3uPcGSo4t3xVCQDdrn2S24vzYPXdzj6ufTqcPz5jOgUL+ThUuH4ZSW99lz51n5PIyfO/1EtMw33pw5D/LihUD7maTX/7PUMyU0n74jdh21J6BAP47QUnTMLV1MnHYHno/435lnMBxbnzGdDrTH41NGSjoSGjSRqMLJk7NeeE78CJI1C2LXVmtEzudnyu+LbeNcqYTsxCPk7ke8t47Z33ub7iebrvX1FfKw866u7BtpreDE86TI8zhkFa7+N3WArUpPsB/SbNOv6i0PJGen/IW+GE7BWPNxmox3viA3hodhOhe4oRKjmZacd78I1t39gc8aHfjYkzFvLxYOe7ZLx4C/9Re+SEmnv9vC4pPTg2+3WWbhbmTs2hx+nMj97CMG20J95Qw953c8MSG24f+gFj48xNnLOQj1H7d27i82X30zc9BT77iD51RSfV3Etd6TDycrpd+H0G9DqTh4Y2vq8WaWGYNtoTb6jhB0ZzJRfrrRvTJAv5GJX3+sNMK18JJVCpCVTiJiUwPpyUHpCUTreZj0XniJKGHxjNhbj11o1pkoV8DMn3lvHcsnc4t+gFVpQMYIrHRaL4SREfpZpMSvAka8759Xdo6hQsxI05bRbyMeS5Ze9w4+47yHYV8UVPMonix5nxxc3Knl/jwu5euiYnOOPEjTFxwUI+FhTnwdL5/HDfh3hcTm89Xao4kDCA9EHj6brrTa7M9sMVf4xwQ40xHc1CPha8vQD2rcYD+HDzaZdJDOrXh/5f+pGzfnVPOylpTJyykO+kgvX3mcV/YELVmvqpBhIuf5yzGs61YvVsY+KWhXwnVLh+Ga437+I8X09yEzYdX3HmJTaZljHmBBbynUVxHmVv/4hPC4/Rt3wb2RwCVw3vuKZwTlYiqWkZJ86qaIwxWMh3HqufIC3vTXKBI/5Uylwp/L3vNzj3mu+RejpXqBpj4oKFfLQL3jy79GD97em6u5wZH6/PKgYLeGPMKVjIR6l8bxm/fXU58wsXkMUhIHD3JU+qczejtN42YsYY0ywL+SiU7y3jgaff4IFjPyLLVeQs7NofBk4++YYXxhhzChbyUWjxqnxuLP0t2W4n4A+5+9LnxqUW7saYVrMbeUeTne/CE19gTvedDE84CECldKHmum7fjGMAAArNSURBVJcs4I0xp8VCPpr8+ftweAdnrL6HBdXf5FDSIFK+9kcGBO5vaowxrWXlmiixf+cm0soq6A6k9BtOVsZMyqfOt9Ezxpg2sZCPtMBdjz7ftY8BviIOJQ2izxU/5SErzxhjwsBCPkI+2F7Eg8u28qeM39Jn/1uMSu3Hpm4X0mPWj6z+bowJmzbV5EXkGhHZIiJ+EcltsG6hiOSJyHYRubRtzYw9Dy7byi5vOVsKSwDwlBcybshAq78bY8KqrT35T4DZwG9DF4rIKOA6YDRwBvCOiAxT1bo2Hi82FOfx24zn+N6x86g8dwEUP+Xcf9UubjLGhFmbQl5VtwGISMNVVwIvqGo1sFtE8oDJwD/bcryYsfoJztz/Mtf5PufvRx9g5rVPR7pFxpgY1V41+f7AmpDnBYFlJxGRecA8gIEDB7ZTc6JDvreMx5bvILNmGnOHVHIw4QrmTs2JdLOMMTGs2ZAXkXeAvo2sultV32hrA1R1EbAIIDc3V9u6v2i2eFU+yzYXAh6qJt/GQ7PHRrpJxpgY12zIq+rFp7HfA8CAkOdZgWVxKd9bxuJV+cwY3ZdjlT4UtR68MaZDtFe5ZinwJxF5DOfE61BgbTsdK+otXpXPkrX7AfjV1ydGuDXGmHjSppAXkS8DTwCZwJ9FZKOqXqqqW0TkJWAr4ANujeeRNcFeu/XejTEdTVSjpwyem5ur69ati3QzwiZYppk7NYccm57AGNNORGS9quY2ts6ueG0n+d4y5j67jl1e5y5OdpLVGBMJNgtlO1m8Kp9d3nKGZKZamcYYEzHWk28noXV4K9UYYyLFQj5MQodJvr3lIHOn5liJxhgTcRbyYRIcJrl29+dWhzfGRA0L+TAJlmdCe/LGGBNpNoTSGGM6uVMNobTRNcYYE8Ms5I0xJoZZyBtjTAyzkD8N+d4yFr76Mfneskg3xRhjTslG17RCcCz8sUpfYF54GyZpjIlu1pNvheBYeEWZM3mADZM0xkQ9C/lWmDs1hwuGZfJxQQkzRve16QqMMVHPyjUt8MH2Iu59/RMGZ6aS7y1n/5FKHly2lXeH945004wx5pQs5FvgwWVb2Xekkn1HKrlgWCaJCS7umzUq0s0yxphmWci3wH2zRtX35O+7fJSVaYwxnYaFfAucP7w3f7vrwkg3wxhjWs1OvBpjTAyzkDfGmBhmId8Eu6rVGBMLLOSbELzwafGq/Eg3xRhjTpudeA0RnLZg7tScE+7RaowxnZWFfIhg7x2cOWlsXhpjTGdnIR/Ceu/GmFjTppq8iDwqIp+KyMci8pqIZISsWygieSKyXUQubXtT219OZhoPzR5rFzsZY2JGW0+8rgDOUtWxwA5gIYCIjAKuA0YDM4AnRcTdxmMZY4xppTaFvKouV1Vf4OkaICvw+ErgBVWtVtXdQB4wuS3HMsYY03rhHEL5LeAvgcf9gf0h6woCy04iIvNEZJ2IrPN6vWFsTsvYeHhjTCxrNuRF5B0R+aSRrytDtrkb8AHPt7YBqrpIVXNVNTczM7O1L28zGw9vjIllzY6uUdWLT7VeRG4CZgEXqaoGFh8ABoRslhVYFnVsRI0xJpa1dXTNDOBO4ApVrQhZtRS4TkSSRGQwMBRY25ZjtRcbUWOMiWVtHSf/KyAJWCEiAGtU9RZV3SIiLwFbcco4t6pqXRuPZYwxppXaFPKqeuYp1v0E+Elb9m+MMaZtbIIyY4yJYXEX8jZk0hgTT+Iu5G3IpDEmnsTdBGU2ZNIYE0/ipicfLNMANmTSGBM34ibkrUxjjIlHcVOusTKNMSYexXxP3so0xph4FvMhb2UaY0w8i/lyjZVpjDHxLOZDPjgBmTHGxKOYL9cYY0w8s5A3xpgYZiFvjDExLOZC3iYgM8aY42Iu5B9bvoMla/fz2PIdkW6KMcZEXMyFvKInfDfGmHgWc0Mob58+nPQUj42LN8YYYjDkbVy8McYcFzPlGjvhaowxJ4uJnny+t4y5z65jl7ccwHryxhgTEBM9+cWr8tnlLWdIZqrV4o0xJkRM9ORDJyGzqYSNMea4mAh5O9lqjDGNi4lyjTHGmMa1KeRF5L9E5GMR2Sgiy0XkjMByEZFfikheYP3E8DTXGGNMa7S1J/+oqo5V1fHAMuC+wPIvAUMDX/OA37TxOMYYY05Dm0JeVUtDnqZC/VwCVwLPqmMNkCEi/dpyLGOMMa3X5hOvIvIT4BtACTAtsLg/sD9ks4LAssK2Hs8YY0zLNduTF5F3ROSTRr6uBFDVu1V1APA8cFtrGyAi80RknYis83q9rX8HxhhjmtRsT15VL27hvp4H3gLuBw4AA0LWZQWWNbb/RcAigNzcXJs60hhjwqhN5RoRGaqqOwNPrwQ+DTxeCtwmIi8AXwRKVLXZUs369euLRaQcKG5Lu6JYL+y9dUax+t5i9X1B/L23QU1t3Naa/MMiMhzwA3uBWwLL3wIuA/KACuCbLdmZqmaKyDpVzW1ju6KSvbfOKVbfW6y+L7D3FqpNIa+qX2liuQK3tmXfxhhj2s6ueDXGmBgWjSG/KNINaEf23jqnWH1vsfq+wN5bPXEqK8YYY2JRNPbkjTHGhImFvDHGxLCoDXkRmS8in4rIFhF5JNLtCTcRuV1EVER6Rbot4SIijwb+zT4WkddEJCPSbWoLEZkhItsDs6kuiHR7wkVEBojIShHZGvj9+k6k2xROIuIWkY9EZFmk2xJOIpIhIi8Hfse2icg5LXldVIa8iEzDubhqnKqOBn4a4SaFlYgMAKYD+yLdljBbAZylqmOBHcDCCLfntImIG/g1zoyqo4A5IjIqsq0KGx9wu6qOAs4Gbo2h9wbwHWBbpBvRDh4H3lbVEcA4WvgeozLkgf8AHlbVagBVLYpwe8Lt58CdHJ+1Myao6nJV9QWersGZzqKzmgzkqWq+qtYAL+B0PDo9VS1U1Q2Bx8dwwqJ/ZFsVHiKSBcwEfhfptoSTiHQD/g34PYCq1qjq0Za8NlpDfhgwVUQ+FJEPROQLkW5QuAQmdjugqpsi3ZZ29i3gL5FuRBs0NZNqTBGRbGAC8GFkWxI2v8DpQPkj3ZAwGwx4gT8ESlG/E5HUlrwwYvd4FZF3gL6NrLobp109cP6U/ALwkojkaCcZ79nMe/shTqmmUzrVe1PVNwLb3I1TEni+I9tmWkdE0oBXgO82uDdEpyQis4AiVV0vIhdEuj1hlgBMBOar6oci8jiwALi3JS+MiFPNbiki/wG8Ggj1tSLix5mUp1PMRdzUexORMTifyJtEBJxyxgYRmayqBzuwiaetuVlJReQmYBZwUWf5UG5Ci2dS7YxExIMT8M+r6quRbk+YTAGuEJHLgGQgXUT+qKrXR7hd4VAAFKhq8C+ul3FCvlnRWq55ncANSERkGJBIDMwop6qbVbW3qmarajbOP9zEzhLwzRGRGTh/Kl+hqhWRbk8b/QsYKiKDRSQRuA5ndtVOT5wexu+Bbar6WKTbEy6qulBVswK/W9cB78VIwBPIiP2BCSEBLgK2tuS1EevJN+Mp4CkR+QSoAW7s5L3CePErIAlYEfhLZY2q3nLql0QnVfWJyG3AXwE38JSqbolws8JlCnADsFlENgaW/VBV34pgm0zz5gPPBzod+bRwdl+b1sAYY2JYtJZrjDHGhIGFvDHGxDALeWOMiWEW8sYYE8Ms5I0xJoZZyBtjTAyzkDfGmBj2/wH/TCbQNR29nwAAAABJRU5ErkJggg==", - "text/plain": [ - "<Figure size 432x288 with 1 Axes>" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ + "# Get the data for the cubic function, injected with noise and missing-ness\n", + "# This is just a toy dataset that we can use to test some of the wrappers on\n", "def gen_data(x_min, x_max, n, train=True):\n", + " if train: \n", " x = np.random.triangular(x_min, 2, x_max, size=(n, 1))\n", + " else: \n", + " x = np.linspace(x_min, x_max, n).reshape(n, 1)\n", "\n", - " sigma = np.exp(-(x+1)**2/1) + 0.2 if train else np.zeros_like(x)\n", - " y = x**3/6 + np.random.normal(0, sigma).astype(np.float32)\n", + " sigma = 2*np.exp(-(x+1)**2/1) + 0.2 if train else np.zeros_like(x)\n", + " y = x**3/6 + np.random.normal(0, sigma).astype(np.float32)\n", "\n", - " return x, y\n", + " return x, y\n", "\n", - "x, y = gen_data(-4, 4, 2000)\n", - "x_val, y_val = gen_data(-6, 6, 500)\n", - "plt.scatter(x_val,y_val, s=1.5, label='test data')\n", - "plt.scatter(x,y, s=1.5, label='train data')\n", + "# Plot the dataset and visualize the train and test datapoints\n", + "x_train, y_train = gen_data(-4, 4, 2000, train=True) # train data\n", + "x_test, y_test = gen_data(-6, 6, 500, train=False) # test data\n", "\n", - "plt.legend()\n", - "plt.show()" + "plt.figure(figsize=(10, 6))\n", + "plt.plot(x_test, y_test, c='r', zorder=-1, label='ground truth')\n", + "plt.scatter(x_train, y_train, s=1.5, label='train data')\n", + "plt.legend()" ] }, + { + "cell_type": "markdown", + "source": [ + "In the plot above, the blue points are the training data, which will be used as inputs to train the neural network model. The red line is the ground truth data, which will be used to evaluate the performance of the model.\n", + "\n", + "#### **TODO: Inspecting the 2D regression dataset**\n", + "\n", + " Write short (~1 sentence) answers to the questions below to complete the `TODO`s:\n", + "\n", + "1. What are your observations about where the train data and test data lie relative to each other?\n", + "2. What, if any, areas do you expect to have high/low aleatoric (data) uncertainty?\n", + "3. What, if any, areas do you expect to have high/low epistemic (model) uncertainty?" + ], + "metadata": { + "id": "Fz3UxT8vuN95" + } + }, { "cell_type": "markdown", "metadata": { "id": "mXMOYRHnv8tF" }, "source": [ - "### 1.2 Vanilla regression\n", - "Let's define a small model that can predict `y` given `x`: this is a classical regression task!" + "## 1.2 Regression on cubic dataset\n", + "\n", + "Next we will define a small dense neural network model that can predict `y` given `x`: this is a classical regression task! We will build the model and use the [`model.fit()`](https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit) function to train the model -- normally, without any risk-awareness -- using the train dataset that we visualized above." ] }, { @@ -159,17 +191,30 @@ }, "outputs": [], "source": [ - "def create_standard_classifier():\n", + "### Define and train a dense NN model for the regression task###\n", + "\n", + "'''Function to define a small dense NN'''\n", + "def create_dense_NN():\n", " return tf.keras.Sequential(\n", " [\n", " tf.keras.Input(shape=(1,)),\n", - " tf.keras.layers.Dense(8, \"relu\"),\n", - " tf.keras.layers.Dense(8, \"relu\"),\n", + " tf.keras.layers.Dense(32, \"relu\"),\n", + " tf.keras.layers.Dense(32, \"relu\"),\n", + " tf.keras.layers.Dense(32, \"relu\"),\n", " tf.keras.layers.Dense(1),\n", " ]\n", " )\n", "\n", - "standard_classifier = create_standard_classifier()" + "dense_NN = create_dense_NN()\n", + "\n", + "# Build the model for regression, defining the loss function and optimizer\n", + "dense_NN.compile(\n", + " optimizer=tf.keras.optimizers.Adam(learning_rate=5e-3),\n", + " loss=tf.keras.losses.MeanSquaredError(), # MSE loss for the regression task\n", + ")\n", + "\n", + "# Train the model for 30 epochs using model.fit().\n", + "loss_history = dense_NN.fit(x_train, y_train, epochs=30)" ] }, { @@ -178,96 +223,44 @@ "id": "ovwYBUG3wTDv" }, "source": [ - "Let's first train this model normally, without any wrapping. Which areas would you expect the model to do well in? Which areas should it do worse in?" + "Now, we are ready to evaluate our neural network. We use the test data to assess performance on the regression task, and visualize the predicted values against the true values.\n", + "\n", + "Given your observation of the data in the previous plot, where do you expect the model to perform well? Let's test the model and see:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "oPNxsGBRwaNA", - "outputId": "0598cef9-350c-4785-a7a9-51ed3b54fd4b" + "id": "fb-EklZywR4D" }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 1/10\n", - "63/63 [==============================] - 1s 2ms/step - loss: 5.5708\n", - "Epoch 2/10\n", - "63/63 [==============================] - 0s 2ms/step - loss: 4.3687\n", - "Epoch 3/10\n", - "63/63 [==============================] - 0s 2ms/step - loss: 3.9064\n", - "Epoch 4/10\n", - "63/63 [==============================] - 0s 2ms/step - loss: 3.1653\n", - "Epoch 5/10\n", - "63/63 [==============================] - 0s 4ms/step - loss: 2.1027\n", - "Epoch 6/10\n", - "63/63 [==============================] - 0s 3ms/step - loss: 1.6488\n", - "Epoch 7/10\n", - "63/63 [==============================] - 0s 2ms/step - loss: 1.3093\n", - "Epoch 8/10\n", - "63/63 [==============================] - 0s 2ms/step - loss: 1.1078\n", - "Epoch 9/10\n", - "63/63 [==============================] - 0s 2ms/step - loss: 0.9919\n", - "Epoch 10/10\n", - "63/63 [==============================] - 0s 3ms/step - loss: 0.8937\n" - ] - } - ], + "outputs": [], "source": [ - "standard_classifier.compile(\n", - " optimizer=tf.keras.optimizers.Adam(learning_rate=2e-3),\n", - " loss=tf.keras.losses.MeanSquaredError(),\n", - ")\n", + "# Pass the test data through the network and predict the y values\n", + "y_predicted = dense_NN.predict(x_test)\n", "\n", - "history = standard_classifier.fit(x, y, epochs=10)\n" + "# Visualize the true (x, y) pairs for the test data vs. the predicted values\n", + "plt.figure(figsize=(10, 6))\n", + "plt.scatter(x_train, y_train, s=1.5, label='train data')\n", + "plt.plot(x_test, y_test, c='r', zorder=-1, label='ground truth')\n", + "plt.plot(x_test, y_predicted, c='b', zorder=0, label='predicted')\n", + "plt.legend()" ] }, { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 283 - }, - "id": "fb-EklZywR4D", - "outputId": "1f913c81-fbef-43dd-a391-b7dc209055fa" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "<matplotlib.legend.Legend at 0x7fe11cd8e3a0>" - ] - }, - "execution_count": 107, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAD4CAYAAAAJmJb0AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3deXxU9b3/8ddnEkLYAyGELZCFJQQCASKCLAFRQEFRFAGFuiOIWn+9rUvb2z7uvfWWVqu4sMitirW41N26ILKLigohbAFCEggESAiBKAFDMpnv748zk40EEjPJTGY+z8cDZ+acmXO+J8h7vvmc7/keMcaglFLKN9k83QCllFINR0NeKaV8mIa8Ukr5MA15pZTyYRrySinlwwI93YCKOnbsaCIjIz3dDKWUalK2bdt20hgTVt06rwr5yMhItm7d6ulmKKVUkyIiWTWt03KNUkr5MA15pZTyYRrySinlw7yqJl+dkpISsrOzKSoq8nRTfF5wcDDdu3enWbNmnm6KUspN6h3yIhIMbAKaO7f3jjHmjyISBbwJhALbgDnGmOK6bj87O5s2bdoQGRmJiNS3uaoGxhjy8/PJzs4mKirK081RSrmJO8o154ErjTGDgARgkogMB/4CPGOM6QWcBu7+ORsvKioiNDRUA76BiQihoaH6G5NSPqbeIW8shc6XzZx/DHAl8I5z+avADT93HxrwjUN/zkr5HreceBWRABFJAU4AXwAZQIExxu58SzbQrYbPzhWRrSKyNS8vzx3NUUqpJqXY7mBNai7Fdofbt+2WkDfGlBpjEoDuwDAgtg6fXW6MSTTGJIaFVXvBlkcVFBSwZMmSOn9uxYoVHDt2rOx1ZGQkJ0+edGfTlFI+YlNaHvP+uY1Nae7v6Lp1CKUxpgBYD4wAQkTEdWK3O3DUnftqLDWFvN1ur+bd5aqGvFJK1WRMnzCWzR7KmD7u7+i6Y3RNGFBijCkQkRbA1VgnXdcDN2ONsLkd+LC++/KExx57jIyMDBISEmjWrBnBwcG0b9+effv2sXr1aqZMmcLu3bsBeOqppygsLGTAgAFs3bqV2267jRYtWvDNN98A8Pzzz/Pvf/+bkpIS3n77bWJja/0Lj1LKhwUF2rgqLrxBtu2OnnwXYL2I7AS+B74wxnwMPAr8SkTSsYZRvuSGfTW6hQsXEhMTQ0pKCk8++STJyck8++yzpKWl1fiZm2++mcTERFauXElKSgotWrQAoGPHjiQnJzN//nyeeuqpxjoEpZQfq3dP3hizExhczfJMrPp8oyu2O9iUlseYPmEEBbr3ot5hw4b97HHk06ZNA2Do0KG899577myWUkpVyyenNWjIkxitWrUqex4YGIjDUX42/FJjzJs3bw5AQEDAJWv6SinlDj4Z8u48idGmTRvOnDlT7brw8HBOnDhBfn4+58+f5+OPP67V55RSqrF4/dw1P4c7T2KEhoYycuRIBgwYQIsWLQgPL99us2bN+MMf/sCwYcPo1q1bpROpd9xxB/Pmzat04lUppRqbGGM83YYyiYmJpupNQ/bu3Uu/fv081CL/oz9vpZoeEdlmjEmsbp1PlmuUUkpZNOSVUqqRNOT0BTXRkFdKqUbSkCP/aqIhr5RSjaQhpy+oiU+OrlFKKW/UkNMX1ER78kop1cA8UYt30ZBvZBs2bGDKlCkAfPTRRyxcuLDG91adAfPYsWPcfPPNDd5GpZT7FNsdPLf2QKPX4l005N2ktLS0zp+5/vrreeyxx2pcXzXku3btyjvvvFPj+5VS3mdTWh5LN6QzLymmUWvxLhrytXDo0CFiY2O57bbb6NevHzfffDPnzp0jMjKSRx99lCFDhvD222+zevVqRowYwZAhQ5g+fTqFhdZdEVetWkVsbCxDhgypNDHZihUreOCBBwDIzc3lxhtvZNCgQQwaNIivv/660jTHv/nNbzh06BADBgwArHly7rzzTuLj4xk8eDDr168v2+a0adOYNGkSvXv35pFHHgGsL6E77riDAQMGEB8fzzPPPNOYP0Kl/FJhkZ3krFM8OzOBh8b3dvuEibWhJ15raf/+/bz00kuMHDmSu+66q6yHHRoaSnJyMidPnmTatGmsWbOGVq1a8Ze//IWnn36aRx55hHvvvZd169bRq1cvZsyYUe32H3roIZKSknj//fcpLS2lsLCQhQsXsnv3blJSUgDry8Zl8eLFiAi7du1i3759TJgwoWz645SUFLZv307z5s3p27cvDz74ICdOnODo0aNlc98XFBQ04E9LKQWwbGMGSzZm8sC4XkwZ5Jk+tW/25O3FsP8z69FNIiIiGDlyJACzZ89m8+bNAGWhvWXLFlJTUxk5ciQJCQm8+uqrZGVlsW/fPqKioujduzciwuzZs6vd/rp165g/fz5gzVLZrl27i7Zn8+bNZduKjY2lZ8+eZSE/fvx42rVrR3BwMHFxcWRlZREdHU1mZiYPPvggq1atom3btvX/oSilLmpeUgwPjOvFvKQYj7XBN0M+Yy28Ncd6dBMRqfa1a+phYwxXX301KSkppKSkkJqayksveeY+Ka4pjaF8WuP27duzY8cOxo4dy7Jly7jnnns80jal/Enr4EB+PbEvrYM9VzTxzZCPGQ8zXrMe3eTw4cNls0m+/vrrjBo1qtL64cOH89VXX5Geng7A2bNnSUtLIzY2lkOHDpGRkQHAG2+8Ue32x48fz9KlSwGrfv7DDz9cdLri0aNHs3LlSgDS0tI4fPgwffv2rbH9J0+exOFwcNNNN/GnP/2J5OTkOhy9Uqqp8s2QDwyCvtdYj27St29fFi9eTL9+/Th9+nRZacUlLCyMFStWMGvWLAYOHMiIESPYt28fwcHBLF++nMmTJzNkyBA6depU7fafffZZ1q9fT3x8PEOHDiU1NbXSNMe/+c1vKr3//vvvx+FwEB8fz4wZM1ixYkWlHnxVR48eZezYsSQkJDB79mz+/Oc/1/+HopTyejrVcC0cOnSo0g27fZk3/LyVaqoa8tajF6NTDSulVCPwxARkl6IhXwuRkZF+0YtXStWPJyYgu5QmMU7eGHPB6Bblft5UulOqKfLEBGSX4vU9+eDgYPLz8zWAGpgxhvz8fIKDgz3dFKWUG3l9T7579+5kZ2eTl+c9NS5fFRwcTPfu3T3dDKWUG3l9yDdr1oyoqChPN0MppZokry/XKKWU+vnqHfIiEiEi60UkVUT2iMgvncs7iMgXInLA+di+/s1VSinv4MkbgdSFO3ryduA/jDFxwHBggYjEAY8Ba40xvYG1ztdKKeUT1u3NZe5rW1m3N9fTTbmoeoe8Mea4MSbZ+fwMsBfoBkwFXnW+7VXghvruSymlvEGx3cHOowVgAC8f3e3WmryIRAKDgW+BcGPMceeqHKDawaMiMldEtorIVh1Bo5RqCjal5bF800HuH9eLK2O9a1x8VW6bu0ZEWgMbgSeMMe+JSIExJqTC+tPGmIvW5aubu0YppbyNp+aoqUmDz10jIs2Ad4GVxhjX/e1yRaSLc30X4IQ79qWUUp7iOtkKcFVcuFcE/KW4Y3SNAC8Be40xT1dY9RFwu/P57cCH9d2XUkp5kjdOQHYp7rgYaiQwB9glIinOZb8FFgL/EpG7gSzgFjfsSymlPKLY7sBe6uCFWwd71QRkl1LvkDfGbKbm88vuuzWTUkp50Ka0PB54YzvLZg9tEmUal6bTUqWU8iBvnEa4NjTklVLqIpriydaKmlZrlVKqkTXFk60Vef0slEop5QmusfDDo0ObZJnGRXvySilVDVcPfktmfpMs07g0zVYrpVQDa6onWqvSkFdKqSq8bdqC+mjarVdKqQbQ1E+2VqQhr5RSVfhKqQZ0dI1SSl0gKNDGVXHePYVwbWlPXimlfJj25JVSfq3Y7rBu4SdwZWzTHSpZE986GqWUqqN1e3OZvzKZBSu3+8SJ1qo05JVSfst1r1YB7kuK8okTrVVpuUYp5bc2peXx4sZM7h/Xi4fG9/a5Ug1oyCul/NiYPmG8OCfRJy56qolvHpVSSlXDNW1wsd0BlA+V9NWABw15pZQf8aUrWWtLQ14p5Td86UrW2tKQV0r5BV+adKwu/OdIlVJ+p7DIzl8/28vHO46ybm+u35VqQEfXKKV82LKNGSzZmIkAz88a7HelGtCevFLKh909MoohPUIACLSJz4+kqY5/Ha1Sym8UFtn5/Ye72XGkgAXjenFlP9+YVbKuNOSVUj7DNQ6+sMjOo+/u5JNdx5k0oIvPXs1aG26pyYvIy8AU4IQxZoBzWQfgLSASOATcYow57Y79KaVUVcV2B4vW7GfZhkwm9Q9n9d4TTI7vwl9uGui3AQ/u68mvACZVWfYYsNYY0xtY63ytlFINwpqH5iAO4NM9ucxLiuGZGQm0Dvbv8SVuCXljzCbgVJXFU4FXnc9fBW5wx76UUqo6w6NDuSq2EwBzR0f7dYmmoob8igs3xhx3Ps8Bqj3rISJzgbkAPXr0aMDmKKV8UbHdwbp9uew48gOr9+ZiExjSM0QD3qlRfo8xxhgRMTWsWw4sB0hMTKz2PUopVZ1iu4Pn1h5g6YYMDIZ5STEMimjHlbH+OZKmOg0Z8rki0sUYc1xEugAnGnBfSik/U1hk55F3Uli1O5d7x0QzOCKEK/v53zj4S2nIn8ZHwO3O57cDHzbgvpRSfmbZxgw+3Z2LAxgcEcKk+C4a8NVw1xDKN4CxQEcRyQb+CCwE/iUidwNZwC3u2JdSyn+5brp99ryd/Tk/cucVkQztEeK3FzrVhltC3hgzq4ZV492xfaWUctXfl2xIxxgwwOT4LkxJ6Obppnk1/x5AqpRqMtbtzWXx+nTmjo6iT3gb1uw9wZ+mDvB0s9zDXgwZayFmPAQGuXXTWsBSSjUJdodBBOK7h3BTYgRL5wylfWv3BmKDsxfD/s+sx4oy1sJbc6xHN9OevFLKa7lq8AhgrBHWgTbxbKPq4nwhfPk3CIuDZsHWsnfuhBmvQd9ryt8XM95aFuP+CrcY4z1D0xMTE83WrVs93QyllAe57uA0tEd7fvv+Lj7bk4MNWDJ7CIE2m/fc2am6EkvVZev+BJueBATEBtNfhYBAt5dlRGSbMSaxunXak1dKeZXVe3J48I3tJPQIYfvhAgSYNzaaK2M9PAbeXgxpn1vPY8ZZPfTNi+CWf0DcddZyV9nF1VMf+TA4Sst78n0mur3mfika8kopr1Fsd/BxylEMsONIAdf078zkgV2Y0L9z4wZ81UA/9CWU2uFt56U/o38FXz174eeqll2at4ar/tg4ba6BhrxSyiu45oD/fK91cfy9o6P4jwmxDRvu9mLY+wkc3w5dEyB2itXTzlhbOdA3L4KbX7HKLWAFf+dB1vM+E8u3FxhUudbuBTTklVJeYfH6dD7ZdZyEiHbcMyqKCf3ddAWrq04eORoy1lvLXGWTjLXw7l2Aw6qZz3zdCumY8ZUDvdvQC+vorhKNl9OQV0p5xOnCYn73wU6u7BtG6xZB9A1vgwA7j/xAcLPAnxfwrjJLqR1MCWQnw9Hv4WgyjHq4vMQyc2V5mN/0cnlP3lVmCQyqHOJe1juvCx1do5RqVMV2B6v3HOfp1QfIzD8LQIAIi28d7BwqSe0nGnMNUezUD4zAnnch7TPnSufGXG5aYY1sAY+cAG1IOrpGKeU1NqXl8eAbKRXjl/uSomoX7K6eevEZSP0YzubC0a1YgQ5lod7tMrh8rtWT/+EwxN0I/Sb7VLDXloa8UqrBFRbZ+dvn+9mefZonbxpIVMdWZJ48iwDzx0bz8FV9Lwx410nRY9uh62CIHGGdDD3ynfPCKGegRwyDy+61evJHt8GP2XD9c9CyAwzUeRE15JVSDcJVlkk99iM/lZTyyjdZAEx/8RsKfrIzpEcId42MKh8eee4UfPgQtAyF3F3QfSh8t9y5NRv0GAaHt1ihPvROqyffIQrGPW4NVQQYNN0zB+vFNOSVUm7nmjHyhfXpAES0Dy5bV/CTnRv7t2Nh2/dovnEdZI2D6FGw5wPY/+/yjRzbZj3GXgf9b7J68qseg8lPW730hFsb85CaLA15pZTbuKYkKCopZemGDMLbBJF7pphzp3NY12YRHQN+4kj0bcTmvE1ARob1oVOZsP0VuGE5OBzlPfnEOyC4feWTpNNXeOrQmiwNeaVUvbkmEtt+5DTLNx1kSt82PGJ7jaifsohrcZq2QQG0OX8EKYH+qQvLP9g8xOqlR4+yLkTScovbacgrpX62wiI7y9fsYFjWi5jj6RSYBBYH7kAyDJOabbPOjRqQ80D7aOuE6eX3WdMEtK9ST1cNQkNeKVUnxQXHKPzHLA6FTyHnwBYWnF9DkA0IgEliBbtDwN7rWvJ//IGwkuMEjPoVDJpRXnYZMd+jx+BPNOSVUhd3vhD72ic4eXgvZ1t1oUv2GtoXnaBdfgqDoezWQwXSimZX/ReFqZ8Tetl0msXfSGc/HJfubTTklVKVnTsFHz0ErTvBmRxKHYaAA58S7rp6SeCYoz2fh/6CqPN7GVq4hkNEkDnhFW4YNZTWo+71aPNVZRrySvkrezHsetcai96yvVU/LyqgpEUYgRmryt5mQ1hlH0oAhvzAcE6XBPBC6TT+OnY45wNsDFqZzPykaB4e0ddzx6JqpCGvlL84kwOvz4CfTkFwCLTthimb56Xc8Zb92GW/jBOmHeGSzyrHFXzmuJz4Hh3ZfriASf0789eBnZnQvwsA//eLRO+5W5O6gIa8alKK7Q7W7cut2yRWlI/f9rswOneK0g8WcPJ0AR3ztxLgKHZOBnAYcnZyLGwsuTmHyTctsBFAOznHax3+yEenHGVzy9wzKorne7ZnVO8wtmTmX/AzvCou3AMHpmpLQ141KZvS8liwcjsGw/I5Vg+yNuG9KS2Pef/cxrLZQ2sVSk3uS8E1G2P7aGsCr74T4NRBOHkAW9qndDKAWBWZQlqSXhrKiXaJvGy/lW9LSipt6p7wHiwd3gF7qYNAm63Sl6kGetOjUw0rjygssrNsYwbzkmJoHVz7vkbVnnxN4V01pKv7DeBiQb5q93EWrNzO4tsGM2lAl2rb4dEvgYLD8M+brfuLprwOWV/D0e+p+q+5aMh8Th/bT1ZOHp1Kc1nvGMTTpbM4R/k0A9EdW3Hw5FnuHhVJs4AAFozrVae/E+V5Hp1qWEQmAc8CAcDfjTELL/ER5QeWbcwom9fk1xNrf8IuKNBWKXTH9AnjhVsHYy91UFhkLysnVAx/12sMPPDG9rJlz609wLKNGdX37g0Y6z9lKga7a/sv3DqYQJutTmFf7RfQ3lwQLrhZteu9wzuVUPL6LNqGdiVg0v9S+uIYbMWFOF6+BlvRKQBOdUjgy7bX0DxjDQfajsBWcIj9p67l40OjLwj/Hu2bc1W/LgzpEcLY2PBqyzDKNzRoyItIALAYuBrIBr4XkY+MMakNuV/lXarr9c5Liil7vFjI1fR5l6BAGxiYvzKZwT1C2HGkgBfnJDI8OpS5o6MoKrGzbl8uD7y+nQn9OrFoxiDG9Alj3d5cFq9PZ35SNGP6hJXtZ/Xu46Tm/MgdI6K4f2wvRvUOK9u/vdTBgteTmT+2F3ePjGJeUgx2u4MFbyVzX1I0A7uFVKpbg1UmGh4dyoZ9uaTm/Mj9Y3uzJTO/7AtoeHQoj767k092HUeACXHhXB8fyoTmqZjI0Xzy3gpy931Hv+bb6WrPhlM7sR/bQUBxIWdKg3mpx7ME73kLA7xwbBrnjgUD8Ug+XBt/I13aBWM4Q0L3tpQa2HX0RwAemdSPKYO6lf0ctQzjuxq6Jz8MSDfGZAKIyJvAVEBD3kdV10t9bu0Blm5I58U5iWVh0jo4kF9P7Fu2fsmGdARh8W2DrQ1VKKtUV5KpuB+7wzpJmHy4gMnxXcpCfMnGTAR4fmYCE/t35pNdx4nu1IYpg2wgYBMhrmvbSr3zB9+0bmaRlf8Tn+/JIbZzGz7ZeZzPU3N4bmYC88f2YtnGDA6ePMuq3ce5c2QkgyJCWLohE4BJ/Tvzxd5c5iXFEBXakl+/Y93ebu3+PABKHTC4RwgvzBrM0B7tueOV7ziSdYC1QX/mm9K+dEz7EVuaA1tgCunRd3J9xgpsAQaHHdIdXcg0nXm1ZC6z7Cv4vf1uok51YGfprRhABCbGhbMm9QT3JUXx8FXWz/f4D0Ws2n2cZ2cmsOfYGfp3bcuE/p0b938M5TENWpMXkZuBScaYe5yv5wCXG2MeqPCeucBcgB49egzNyspqsPaohrcmNbdSILtq2/eNieLhqy+8McSqXceZtzKZq/t1IjqsFfHd2vHLN3dUOrFasZcPVu/Y7nCwYKXVq47rKGS991/EhAZzxajxZJwu5XhIPPLvh+glOZxJmEvfHuFkff0O3S+bSt6hnXRtHcipYwc4bDrxWnZnbpw5l8u7N+fwirv5IbgrMTc/wT+S8ykqKeXvmw8iwNLZQ7g8MpR7X9vK1qzT9AprRXqedfu6qNBWHMo/iwEGR4SwM7uA5oHCuZLyf18CzB/RmbbfP01CRAjfdppF8NbFzApYSxs5f8HNjVa3v4VOvYfTumAPJ1vF8srpARhbEONjQ3ns/VQEePKmgbQODsTucBAYYGNUrwtHwHj8/IFqcF59+z9jzHJgOVgnXj3cHFVPY/qEldW8Aex2Bw5jiOvStvqTnWJdFb923wnW7oMltw6xevOGst71/3vje9648ixBva9j19o3sH/zFl0Sb2RRXBGBm59hcNeWXNNsI/wI5tN36W+E7qEJdAjcYdXWd/0PskuIxeBYvYFI50iTzkBnA5c1E0oDhpP/7j/pV7AJBLJWCbPSPuffsf/LO82eojh8CIk9L+ezD5Zwz9GPmBncgZiCdH4Kbkl4aR5ZP4QhgUKEnODLYwMZ2jyGyaWrCQgWUtpMon/BKmytw4mVeIIDP4Zj0Pt8Kh2aJQNwjmCOdL+Gw1lZtAmP4oeSAC7/xV9o36EDAFmpuayrcI4hpGVwjeWtqqWXoECblmP8WEOH/FEgosLr7s5lyk8EBtoQIDXnRybYu1xQerkyNpwltw2h5Owpem15nOisfti6DuTI/mSkZAhjc/awPiqN8K/WYecA/b/7GwMCgO3fMxCs0/m5zk5wt0Tsl80lNbeYfpePx/7RL8k/vIdnzk1k9si+xJ/ZTGmvq0neuoU92Se5osNZmofF0H3QWJr1uZoOEZeT+0YJHSP6ELH7XWwBp7jnwAME2OxwMh1ZVcL16R9AoPOEbADWow2iOVF23+iYgDXAGms9EH9mifX8p0xOHzxLkIEDHZKI/sUyDn3yV/bt203wtMWMjO9NdloeQ6vpcVf88gwKtDEp/sIRP0pVp6HLNYFAGjAeK9y/B241xuyp7v06hLLpKRv9ER3Klsx87A4HD7y+nRdutXrjdodhz7Ef+b8vMyuNdBkT2YKgLc/BZffAV89zLvlNWpw/WblkIYK4ahf9b2Rdnz/w4b/+zuORaXROvIkS4yB/2/u0GziFnIOpdLv+9wS1bFvWNlfpaF5SDA+N710WnIVFdh59dyerdh+vdJ6gomXvfsF1KXP5PG4hN55YTJveIwgc+ygcWEvpnvc4bG9F97MHOHjGRmdHLufa9KTUOPgpN5PAmHFEDhgOW1eA2GDgDNj5FrTpzJrev+XNDz5g+i23M3FQTy2lKLe4WLmmwcfJi8i1wCKsvszLxpgnanqvhrz3qxrqRcV2HnorhXtHRfF/mw/y3IwEgoMCsZc6uP/1ZAThhWmRRH7zn4ROfBzbu3fQrlsszcL7wpYXoMdwOLzFinIDha17Ejz+UY7sT6ZLrwRy0rYREdqGwHGPUhzQslIgVgzx6oZCVhegZSd616dz/7helcK/4vuL7Y46j+OvTWBrqKuG4NGQrwsNee9zQaiXlPLLN1O4LymK5ZsOMndMFC9uPMg9oyP5+5eHWDqjDxNO/AN7qZ3d9p5gc9Bn59O0KDrBT9KCFuYnEJDY66BTbFlPvuTEftLzf+LTqN9z36TLCAq0seiL/SzZmMn9SdHVnrR1XVB198goth0+XavgrKl3X3Fdba+KVcpbePWJV+W9KvZ6h/RsT3LWae5LisFhDJm5hfz1uijCt/6Naa13cvrMdG4ZmEvPnG8wW54jABiEOMvUBiPwTts7ueL0++QFdWfo5EUEtelo7WjSE2xMzWXuP7biyD2FaZFBQkQIL26yRrW8uOkgQ3p2uCB4t2Tms2yj9d7qQrm6XnPV2nZFVU8aK+ULtCfv5y5Woli16zjzVyaTEBHC/iM5/Hf7z5iUNIavv9nEwZNnGdjyFMOLtwDgAGwCMuKXfHson5TDp+jc73KmDuhEznfv8eesPlx5wz28tjWHlCMFF4x5X70nh5TDpwkIEB68sg9BgTbW7cvFbncQGGir80VSoD1z5T+0XKNq5BrHfufISLZlnWL7kR+4dkA4i2YOYcPuTA688990jY6j/8EV9LYdx0jlc6PGQE5AN/Li7mJA20ICk35NoQmu9GVRMYyBC4J5TWou9722rWxs/MUCuS41ba1/K3+hIa+qVWx38MwXaby0cT+TbN9yje1rbMCnZiTXz5xHUu4rBGz+W9n7TzbvyZ8Lr+GhuPNEhraiJHwgqSeK6Td2OkHNg2veUS3aUdvpg7V3rtSFtCavytmLKdn9PvnfvkVWaSi5RzvwYZf9xJ5eX/aWCaRgD7iCwNG/Agx0iIIDa2h37dNce6SUrn3CINBGM2BQHXdfXe+66qRjF6N1c6XqRnvyvu58Iaz/M6V5+/kxP4fgkfNo/skCxPnX7iq/nOgynpTsUwzqFkLnK26FuBugAW7CrD1xpdxPe/L+4nwhbPwrlBaDvQRydkC3RPhuKTYgxMDRjS/yl+IFXGvbQrYJY9jlYxjU8iTtr/h/BBz6iQ7OXnpD0Z64Uo1LQ76pct0JqHM8xE6xet1fLYKvnwXKp0EvNZAZczulufsx507T7pZXmXi6OSnZs4nv3o5+/btAoI0g4Kq4tjXuzl10HhWlGpeGfFNiL7Zu7VZqh9T3Ye+H1mXzM1+HvtfAyIet95QWk5lzitOZ29jQ/j95YVshMBGAyV8WsPi2IZXmEldK+S4NeW/k6qU7T3hy3SJo2QEy1sLbt1vjFgH6TYX+N0DMeGnx8bMAAAzHSURBVACKA1qyrst8ikpKWZFxiJSSqVx9riVQiAAJEe3409QBHjsspVTj05D3FvZiK8Rjxltll81Pl6+z2WD6Cmvd9FetnnxAIPSZWOnk6Lq9ucxbmVxps5MHdOamId1rnJZWKeXbNOQbk6vcAhBxGax6DCY/Xd5Lf2sOzHjNKrs4Sst78pOdgR8YBHHX1bx950gZEbg6Lpzr4rswYUAXDXal/JgOoWwoFQM9Zhwc+tLqgb99u7Us4jI4vAX632j10iv25H/m0MVL3StVKeWb9IrXxuAK6cjRFwb66F/B5kVw8yvl76/ak68jDXSllIuOk3cn10nR8IHQb3J5r9tVbhn1cHmgT3/VWhczDroNvbCXPn3Fz2qCazqCZRszsInw4hy9sEgpVT0N+dpwBXtYHOz7tzV0ERvMcg5dBCvAZ7xm9eSrC3TX++rBNcfLjiM/sGxjBgD3JUXphUVKqRppyFdUsS4O1Yx2EetPv6kQd2P5+8AKdFeQuyHQq7MpLY8FK7djMMxPimZgRIiWapRSF6UhXzHYK45wgQtHu4TFQbPgC4YuNpYxfcJYfNvgWs3WqJRS4A8nXqueEHU9usop+z8rD3NX0FftyXsg0F030thz7EcWjOtV6/uMKqX8z8VOvPpeV/B8Iaz7k/UI5b3zrxZVfsxYa6131dJdYd73Guux4vNG5rrt3kNvbmfpxgyWrD/Q6G1QSvkG3+geVhyTfiy5/GrRK39/4QnRiidGoXIt3UtsSstj6YZ0ro4LZ/WeXOK6NvzEYUop3+Qb5Zr9n8Gbt1nPp70EeXusOnrz1u5tYCMpLLKzbGMGd4+MYtvh03r7OqXURfn+OHnXnC5gnRSNv9Gz7amnLZn5LNuYQUJEiI5/V0rVi2+E/KXmdGkCKt4WT2+soZRyF60BeIlNaXnM++c2NqXlld1YQ0s0Sqn60hTxkMIiO099vp/CIjugt8VTSjWMeoW8iEwXkT0i4hCRxCrrHheRdBHZLyIT69dM37NsYwYvrE8vm55Ae+9KqYZQ35r8bmAa8GLFhSISB8wE+gNdgTUi0scYU1rP/fmMu0dGcfDkWe4eGeXppiilfFi9uo3GmL3GmP3VrJoKvGmMOW+MOQikA8Pqsy9fs+3waT7fk8O2w6c93RSllA9rqNE13YAtFV5nO5ddQETmAnMBevTo0UDN8Q6uWSQxMKq31uCVUg3vkiEvImuAztWs+p0x5sP6NsAYsxxYDtbFUPXdnjerOIvk8jmJOgZeKdXgLhnyxpirfsZ2jwIRFV53dy7zS64x8MOjQ8tmkdQevFKqMTTUUI6PgJki0lxEooDewHcNtC+v5xoDvyUzn0kDujApXm+urZRqHPUdQnmjiGQDI4BPRORzAGPMHuBfQCqwCljgzyNrdAy8UspTfGOCMi9VcaoC7bkrpRqKf80n7yVcc8K7pipQSilP0JBvIK454eclxWiZRinlMb4xC6UXGtMnjBfnJGqpRinlUZo+blJsd7AmNZfCIjtrUnMBdC4apZTHaQK5iWuY5LKNGVqHV0p5DS3XuIlrmOTw6FASIkK0Dq+U8goa8m7imioY0OkKlFJeQ8s1SinlwzTklVLKh2nIK6WUD9OQ/xlcwyWL7Q5PN0UppS5KQ74OXOG+bl+uDpNUSjUJGvJ14BoLj0FnlVRKNQka8nUwpk8Yi2YMYmd2AcOjQ/VqVqWU19Nx8rVQWGRnyYYDxHVuS+rxH1myMRObzcavJ/b1dNOUUuqiNORrYdnGDJZsyESA52cl8MC4XsxLivF0s5RS6pI05GthXlIMDuMgrnNbJvTvwpRBWqZRSjUNGvK10Do4kEcm9fN0M5RSqs60S6qUUj5MQ14ppXyYhnwN9KpWpZQv0JCvgevCJ72qVSnVlGnIV1Cx9+66CYhe1aqUaso05Cuo2Ht33QREr2pVSjVlmmAVaO9dKeVr6hXyIvKkiOwTkZ0i8r6IhFRY97iIpIvIfhGZWP+mNjztvSulfE190+wLYIAxZiCQBjwOICJxwEygPzAJWCIiAfXcl1JKqTqqV8gbY1YbY+zOl1uA7s7nU4E3jTHnjTEHgXRgWH32pZRSqu7cWZe4C/jM+bwbcKTCumznsguIyFwR2SoiW/PyGn+4oo6HV0r5skuGvIisEZHd1fyZWuE9vwPswMq6NsAYs9wYk2iMSQwLa/wTnjoeXinlyy45QZkx5qqLrReRO4ApwHhjjHEuPgpEVHhbd+cyr6MjapRSvqy+o2smAY8A1xtjzlVY9REwU0Sai0gU0Bv4rj77aig6okYp5cvqO9XwC0Bz4AsRAdhijJlnjNkjIv8CUrHKOAuMMaX13JdSSqk6qlfIG2N6XWTdE8AT9dm+Ukqp+tEahVJK+TC/C3kdMqmU8id+F/I6ZFIp5U/8LuR1yKRSyp/4Tci7yjSADplUSvkNv0k6LdMopfyR34S8lmmUUv7I50NeyzRKKX/m84mnZRqllD/z+ZDXMo1Syp/Vd+4ar+eagEwppfyRz/fklVLKn2nIK6WUD9OQV0opH+ZzIa8TkCmlVDmfC/l1+3K577VtrNuX6+mmKKWUx/lcyGPAWP9RSim/53NDKK/sF87yOYk6Ll4ppfDBkNdx8UopVc5nyjV6wlUppS7kEyFfbHfw3NoDOkeNUkpV4RMhvyktj6Ub0pmXFKO1eKWUqsAnavJj+oTxovNkq04lrJRS5Xwi5PVkq1JKVU+7vUop5cPqFfIi8j8islNEUkRktYh0dS4XEXlORNKd64e4p7lKKaXqor49+SeNMQONMQnAx8AfnMuvAXo7/8wFltZzP0oppX6GeoW8MebHCi9bUT6ZwFTgH8ayBQgRkS712ZdSSqm6q/eJVxF5AvgF8AMwzrm4G3CkwtuyncuO13d/Simlau+SPXkRWSMiu6v5MxXAGPM7Y0wEsBJ4oK4NEJG5IrJVRLbm5emFTEop5U6X7MkbY66q5bZWAp8CfwSOAhEV1nV3Lqtu+8uB5QCJiYk6d6RSSrlRvco1ItLbGHPA+XIqsM/5/CPgARF5E7gc+MEYc8lSzbZt206KyFngZH3a5cU6osfWFPnqsfnqcYH/HVvPmt5c35r8QhHpCziALGCec/mnwLVAOnAOuLM2GzPGhInIVmNMYj3b5ZX02JomXz02Xz0u0GOrqF4hb4y5qYblBlhQn20rpZSqP73iVSmlfJg3hvxyTzegAemxNU2+emy+elygx1ZGrMqKUkopX+SNPXmllFJuoiGvlFI+zGtDXkQeFJF9IrJHRP7q6fa4m4j8h4gYEeno6ba4i4g86fw72yki74tIiKfbVB8iMklE9jtnU33M0+1xFxGJEJH1IpLq/Pf1S0+3yZ1EJEBEtovIx55uizuJSIiIvOP8N7ZXREbU5nNeGfIiMg7r4qpBxpj+wFMebpJbiUgEMAE47Om2uNkXwABjzEAgDXjcw+352UQkAFiMNaNqHDBLROI82yq3sQP/YYyJA4YDC3zo2AB+Cez1dCMawLPAKmNMLDCIWh6jV4Y8MB9YaIw5D2CMOeHh9rjbM8AjlM/a6ROMMauNMXbnyy1Y01k0VcOAdGNMpjGmGHgTq+PR5Bljjhtjkp3Pz2CFRTfPtso9RKQ7MBn4u6fb4k4i0g4YA7wEYIwpNsYU1Oaz3hryfYDRIvKtiGwUkcs83SB3cU7sdtQYs8PTbWlgdwGfeboR9VDTTKo+RUQigcHAt55tidsswupAOTzdEDeLAvKAV5ylqL+LSKvafNBj93gVkTVA52pW/Q6rXR2wfpW8DPiXiESbJjLe8xLH9lusUk2TdLFjM8Z86HzP77BKAisbs22qbkSkNfAu8HCVe0M0SSIyBThhjNkmImM93R43CwSGAA8aY74VkWeBx4D/rM0HPeJis1uKyHzgPWeofyciDqxJeZrEXMQ1HZuIxGN9I+8QEbDKGckiMswYk9OITfzZLjUrqYjcAUwBxjeVL+Ua1Hom1aZIRJphBfxKY8x7nm6Pm4wErheRa4FgoK2I/NMYM9vD7XKHbCDbGOP6jesdrJC/JG8t13yA8wYkItIHCMIHZpQzxuwyxnQyxkQaYyKx/uKGNJWAvxQRmYT1q/L1xphznm5PPX0P9BaRKBEJAmZiza7a5InVw3gJ2GuMedrT7XEXY8zjxpjuzn9bM4F1PhLwODPiiHNCSIDxQGptPuuxnvwlvAy8LCK7gWLg9ibeK/QXLwDNgS+cv6lsMcbMu/hHvJMxxi4iDwCfAwHAy8aYPR5ulruMBOYAu0Qkxbnst8aYTz3YJnVpDwIrnZ2OTGo5u69Oa6CUUj7MW8s1Siml3EBDXimlfJiGvFJK+TANeaWU8mEa8kop5cM05JVSyodpyCullA/7/60hSkBWGud+AAAAAElFTkSuQmCC", - "text/plain": [ - "<Figure size 432x288 with 1 Axes>" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], + "cell_type": "markdown", "source": [ - "plt.scatter(x_val, y_val, s=0.5, label='truth')\n", - "plt.scatter(x_val, standard_classifier(x_val), s=0.5, label='predictions')\n", - "plt.legend()" - ] + "\n", + "#### **TODO: Analyzing the performance of standard regression model**\n", + "\n", + "Write short (~1 sentence) answers to the questions below to complete the `TODO`s:\n", + "\n", + "1. Where does the model perform well?\n", + "2. Where does the model perform poorly?" + ], + "metadata": { + "id": "7Vktjwfu0ReH" + } }, { "cell_type": "markdown", @@ -275,8 +268,13 @@ "id": "7MzvM48JyZMO" }, "source": [ - "### 1.3 Bias Identification\n", - "Now that we've seen what the predictions from this model look like, let's see what the uncertainty and bias look like! To do this, we'll wrap a model first with a `HistogramWrapper`. For low-dimensional data, the HistogramWrapper bins the input directly into discrete categories and measures the density. " + "## 1.3 Evaluating bias\n", + "\n", + "Now that we've seen what the predictions from this model look like, we will identify and quantify bias and uncertainty in this problem. We first consider bias.\n", + "\n", + "Recall that *representation bias* reflects how likely combinations of features are to appear in a given dataset. Capsa calculates how likely combinations of features are by using a histogram estimation approach: the `capsa.HistogramWrapper`. For low-dimensional data, the `capsa.HistogramWrapper` bins the input directly into discrete categories and measures the density. More details of the `HistogramWrapper` and how it can be used are [available here](https://themisai.io/capsa/api_documentation/HistogramWrapper.html).\n", + "\n", + "We start by taking our `dense_NN` and wrapping it with the `capsa.HistogramWrapper`:" ] }, { @@ -287,10 +285,15 @@ }, "outputs": [], "source": [ - "standard_classifier = create_standard_classifier()\n", - "bias_wrapped_classifier = HistogramWrapper(standard_classifier, \n", - " queue_size=2000, # how many samples to track\n", - " target_hidden_layer=False) # for low-dimensional data, we can estimate densities directly from data\n" + "### Wrap the dense network for bias estimation ###\n", + "\n", + "standard_dense_NN = create_dense_NN()\n", + "bias_wrapped_dense_NN = capsa.HistogramWrapper(\n", + " standard_dense_NN, # the original model\n", + " num_bins=20,\n", + " queue_size=2000, # how many samples to track\n", + " target_hidden_layer=False # for low-dimensional data (like this dataset), we can estimate biases directly from data\n", + ")" ] }, { @@ -299,108 +302,29 @@ "id": "UFHO7LKcz8uP" }, "source": [ - "Now that we've wrapped the classifier, let's re-train it to update the biases as we train. We can use the exact same training pipeline as above to accomplish this!" + "Now that we've wrapped the classifier, let's re-train it to update the bias estimates as we train. We can use the exact same training pipeline, using `compile` to build the model and `model.fit()` to train the model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "SkyD3rsqy2ff", - "outputId": "7cd6b5fa-c61a-4306-faed-02a5b9d6a3e3" + "id": "SkyD3rsqy2ff" }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 1/30\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:tensorflow:Gradients do not exist for variables ['dense_47/kernel:0', 'dense_47/bias:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss`argument?\n", - "WARNING:tensorflow:Gradients do not exist for variables ['dense_47/kernel:0', 'dense_47/bias:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss`argument?\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "63/63 [==============================] - 1s 2ms/step - histogram_compiled_loss: 3.9734 - histogram_wrapper_loss: 7.2317\n", - "Epoch 2/30\n", - "63/63 [==============================] - 0s 2ms/step - histogram_compiled_loss: 2.0948 - histogram_wrapper_loss: 4.1785\n", - "Epoch 3/30\n", - "63/63 [==============================] - 0s 6ms/step - histogram_compiled_loss: 1.6633 - histogram_wrapper_loss: 3.2580\n", - "Epoch 4/30\n", - "63/63 [==============================] - 0s 4ms/step - histogram_compiled_loss: 1.5208 - histogram_wrapper_loss: 2.8870\n", - "Epoch 5/30\n", - "63/63 [==============================] - 0s 4ms/step - histogram_compiled_loss: 1.2064 - histogram_wrapper_loss: 2.5474\n", - "Epoch 6/30\n", - "63/63 [==============================] - 0s 4ms/step - histogram_compiled_loss: 1.1682 - histogram_wrapper_loss: 2.3297\n", - "Epoch 7/30\n", - "63/63 [==============================] - 0s 3ms/step - histogram_compiled_loss: 1.0387 - histogram_wrapper_loss: 2.0440\n", - "Epoch 8/30\n", - "63/63 [==============================] - 0s 3ms/step - histogram_compiled_loss: 0.9051 - histogram_wrapper_loss: 1.8478\n", - "Epoch 9/30\n", - "63/63 [==============================] - 0s 3ms/step - histogram_compiled_loss: 0.8954 - histogram_wrapper_loss: 1.6332\n", - "Epoch 10/30\n", - "63/63 [==============================] - 0s 5ms/step - histogram_compiled_loss: 0.7636 - histogram_wrapper_loss: 1.5712\n", - "Epoch 11/30\n", - "63/63 [==============================] - 0s 4ms/step - histogram_compiled_loss: 0.6725 - histogram_wrapper_loss: 1.3582\n", - "Epoch 12/30\n", - "63/63 [==============================] - 0s 5ms/step - histogram_compiled_loss: 0.6783 - histogram_wrapper_loss: 1.2359\n", - "Epoch 13/30\n", - "63/63 [==============================] - 0s 4ms/step - histogram_compiled_loss: 0.6118 - histogram_wrapper_loss: 1.1157\n", - "Epoch 14/30\n", - "63/63 [==============================] - 0s 5ms/step - histogram_compiled_loss: 0.5462 - histogram_wrapper_loss: 1.0705\n", - "Epoch 15/30\n", - "63/63 [==============================] - 0s 4ms/step - histogram_compiled_loss: 0.4946 - histogram_wrapper_loss: 0.9810\n", - "Epoch 16/30\n", - "63/63 [==============================] - 0s 4ms/step - histogram_compiled_loss: 0.4712 - histogram_wrapper_loss: 0.9213\n", - "Epoch 17/30\n", - "63/63 [==============================] - 0s 4ms/step - histogram_compiled_loss: 0.4449 - histogram_wrapper_loss: 0.8751\n", - "Epoch 18/30\n", - "63/63 [==============================] - 0s 4ms/step - histogram_compiled_loss: 0.4146 - histogram_wrapper_loss: 0.8342\n", - "Epoch 19/30\n", - "63/63 [==============================] - 0s 5ms/step - histogram_compiled_loss: 0.4441 - histogram_wrapper_loss: 0.8335\n", - "Epoch 20/30\n", - "63/63 [==============================] - 0s 5ms/step - histogram_compiled_loss: 0.4050 - histogram_wrapper_loss: 0.7910\n", - "Epoch 21/30\n", - "63/63 [==============================] - 0s 4ms/step - histogram_compiled_loss: 0.4113 - histogram_wrapper_loss: 0.7864\n", - "Epoch 22/30\n", - "63/63 [==============================] - 0s 5ms/step - histogram_compiled_loss: 0.3650 - histogram_wrapper_loss: 0.7556\n", - "Epoch 23/30\n", - "63/63 [==============================] - 0s 4ms/step - histogram_compiled_loss: 0.3521 - histogram_wrapper_loss: 0.7350\n", - "Epoch 24/30\n", - "63/63 [==============================] - 0s 4ms/step - histogram_compiled_loss: 0.3672 - histogram_wrapper_loss: 0.7575\n", - "Epoch 25/30\n", - "63/63 [==============================] - 0s 3ms/step - histogram_compiled_loss: 0.3608 - histogram_wrapper_loss: 0.7124\n", - "Epoch 26/30\n", - "63/63 [==============================] - 0s 4ms/step - histogram_compiled_loss: 0.3740 - histogram_wrapper_loss: 0.7006\n", - "Epoch 27/30\n", - "63/63 [==============================] - 0s 4ms/step - histogram_compiled_loss: 0.3691 - histogram_wrapper_loss: 0.6984\n", - "Epoch 28/30\n", - "63/63 [==============================] - 0s 5ms/step - histogram_compiled_loss: 0.3733 - histogram_wrapper_loss: 0.6849\n", - "Epoch 29/30\n", - "63/63 [==============================] - 0s 3ms/step - histogram_compiled_loss: 0.3372 - histogram_wrapper_loss: 0.6665\n", - "Epoch 30/30\n", - "63/63 [==============================] - 0s 5ms/step - histogram_compiled_loss: 0.3604 - histogram_wrapper_loss: 0.6658\n" - ] - } - ], + "outputs": [], "source": [ - "bias_wrapped_classifier.compile(\n", + "### Compile and train the wrapped model! ###\n", + "\n", + "# Build the model for regression, defining the loss function and optimizer\n", + "bias_wrapped_dense_NN.compile(\n", " optimizer=tf.keras.optimizers.Adam(learning_rate=2e-3),\n", - " loss=tf.keras.losses.MeanSquaredError(),\n", + " loss=tf.keras.losses.MeanSquaredError(), # MSE loss for the regression task\n", ")\n", "\n", - "history = bias_wrapped_classifier.fit(x, y, epochs=30)" + "# Train the wrapped model for 30 epochs.\n", + "loss_history_bias_wrap = bias_wrapped_dense_NN.fit(x_train, y_train, epochs=30)\n", + "\n", + "print(\"Done training model with Bias Wrapper!\")" ] }, { @@ -409,213 +333,117 @@ "id": "_6iVeeqq0f_H" }, "source": [ - "To access the bias for a given testing input, we can simply call the method as we would normally. In addition to outputting the prediction, this risk-aware model now also outputs an additional bias score per output." + "We can now use our wrapped model to assess the bias for a given test input. With the wrapping capability, Capsa neatly allows us to output a *bias score* along with the predicted target value. This bias score reflects the density of data surrounding an input point -- the higher the score, the greater the data representation and density. The wrapped, risk-aware model outputs the predicted target and bias score after it is called!\n", + "\n", + "Let's see how it is done:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 287 - }, - "id": "tZ17eCbP0YM4", - "outputId": "4da00423-1115-4bf2-95e6-966b8697b5b6" + "id": "tZ17eCbP0YM4" }, - "outputs": [ - { - "data": { - "text/plain": [ - "<matplotlib.legend.Legend at 0x7fe11cd97190>" - ] - }, - "execution_count": 110, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYYAAAD8CAYAAABzTgP2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAeVElEQVR4nO3de3RX5Z3v8fcnhBAVRQwpplwEBMSAV35Fe6TeL3hpsWsxp9Rq7VEPxcGu9rSdM9A5Z6pOXVM7jnZmibVUrPdBlrVjxmodFS/HNXIJiMhVI4iEconIRVTAkO/547cTfzuG5AcJhITPa62s7P3s53l+z0NIPnn23vltRQRmZmb1Ctp7AGZmdnBxMJiZWYqDwczMUhwMZmaW4mAwM7MUB4OZmaXkFQySxkhaIalK0uQmjneT9HhyfI6kATnHpiTlKyRdkpQVS5or6U1JSyTdklP/AUmrJC1MPk5t/TTNzCxfhS1VkNQFmApcBFQD8yRVRMTSnGrXA5sjYrCk8cDtwLcklQPjgeHAl4EXJA0FdgLnR8R2SV2B1yQ9GxGzk/7+JiKeaKtJmplZ/vJZMYwCqiJiZUTsAmYAYxvVGQs8mGw/AVwgSUn5jIjYGRGrgCpgVGRtT+p3TT78l3ZmZgeBFlcMQB9gTc5+NXDGnupERK2krUBJUj67Uds+0LASmQ8MBqZGxJycerdJ+nvgRWByROxsboC9evWKAQMG5DEVMzOrN3/+/A8iorRxeT7BsF9ExG7gVElHA3+UNCIiFgNTgPVAETAN+Fvg1sbtJU0AJgD079+fysrKAzZ2M7POQNLqpsrzOZW0FuiXs983KWuyjqRCoAewKZ+2EbEFeAkYk+yvS0417QR+T/ZU1hdExLSIyEREprT0C4FnZmb7KJ9gmAcMkTRQUhHZi8kVjepUANcm2+OAWZF9d74KYHxy19JAYAgwV1JpslJA0mFkL2wvT/bLks8CrgQWt2aCZma2d1o8lZRcM7gJeA7oAtwfEUsk3QpURkQFMB14WFIV8CHZ8CCpNxNYCtQCkyJid/LD/8HkOkMBMDMink5e8lFJpYCAhcDEtpywmZk1T53hbbczmUz4GoOZ7avPPvuM6upqduzY0d5D2S+Ki4vp27cvXbt2TZVLmh8Rmcb12+3is5nZwaK6upojjzySAQMGkD2L3XlEBJs2baK6upqBAwfm1cZviWFmh7wdO3ZQUlLS6UIBQBIlJSV7tRpyMJiZQacMhXp7OzefSjLrIHbV1vHq2zWM7N+Tf531DrNXfsBHO2o5vrQ75w3rxVML1/OtTB+OPrwbZwws4d5Xq/h0127WbP6UwoIC/teFg5n02BscV3IEt105godmv8enO2vZsH0XU8YM41d/Xs7ZQ0t5Z+PH1NXVsaN2N69VfUBxYbBu22cM730Un+6u41tf6cfyddtZuOZDjj6iKwUUUNhFnHdCL155exNlRxVT1LWAU/v15OLhx1JU6N8/OxpffDbrIF5YuoGJj8znkuHH8qe31u2xXoHg0hFlX6hzRFEBH++qA2Bw6RFU1XzccKzsqG6s29bsGwzstQLBtGsyXFjeu0373R+WLVvGiSee2K5jeO+997jiiitYvDh9h/4NN9zAj3/8Y8rLy1vVf1Nz9MVnsw7u7KGl3Hv1SEb278mXjuzW4oqhb8/idl8xnD3Uf3zaWvfdd98Bf02vGMzskHewrBjGjBnDyJEjWbBgAcOHD+ehhx7isssu44477iCTyXDjjTcyb948Pv30U8aNG8ctt2SfWDB58mQqKiooLCzk4osv5o477vhC/14xmJl1QCtWrGD69OmcddZZXHfdddxzzz2p47fddhvHHHMMu3fv5oILLmDRokX06dOHP/7xjyxfvhxJbNmypdXj8FUhM7N9sKu2jheWbmBXbV2b9dmvXz/OOussAK6++mpee+211PGZM2dy+umnc9ppp7FkyRKWLl1Kjx49KC4u5vrrr+fJJ5/k8MMPb/U4HAxmZvvg1bdrmPjIfF59u6bN+mx8W2nu/qpVq7jjjjt48cUXWbRoEZdffjk7duygsLCQuXPnMm7cOJ5++mnGjBnT6nE4GMzM9kH9zQBteYH9/fff5/XXXwfgscceY/To0Q3Htm3bxhFHHEGPHj3YsGEDzz77LADbt29n69atXHbZZdx11128+eabrR6Hg8HMbB8UFRZwYXnvNv07jRNOOIGpU6dy4oknsnnzZm688caGY6eccgqnnXYaw4YN46qrrmo45fTRRx9xxRVXcPLJJzN69GjuvPPOVo/DF5/NzA4CAwYMYPny5V8of/nllxu2H3jggSbbzp07t03H4hWDmZmlOBjMzCzFwWBmRvbtqTurvZ2bg8HMDnnFxcVs2rSpU4ZD/fMYiouL827ji89mdsjr27cv1dXV1NS03d8kHEzqn+CWLweDmR3yunbtmvfTzQ4FPpVkZmYpDgYzM0txMJiZWYqDwczMUhwMZmaW4mAwM7MUB4OZmaXkFQySxkhaIalK0uQmjneT9HhyfI6kATnHpiTlKyRdkpQVS5or6U1JSyTdklN/YNJHVdJnUeunaWZm+WoxGCR1AaYClwLlwLcllTeqdj2wOSIGA3cBtydty4HxwHBgDHBP0t9O4PyIOAU4FRgj6cykr9uBu5K+Nid9m5nZAZLPimEUUBURKyNiFzADGNuozljgwWT7CeACZZ9JNxaYERE7I2IVUAWMiqztSf2uyUckbc5P+iDp88p9nJuZme2DfIKhD7AmZ786KWuyTkTUAluBkubaSuoiaSGwEXg+IuYkbbYkfezptczMbD9qt4vPEbE7Ik4F+gKjJI3Ym/aSJkiqlFTZWd/4ysysPeQTDGuBfjn7fZOyJutIKgR6AJvyaRsRW4CXyF6D2AQcnfSxp9eqbzctIjIRkSktbbuHcZuZHeryCYZ5wJDkbqEisheTKxrVqQCuTbbHAbMi+8bmFcD45K6lgcAQYK6kUklHA0g6DLgIWJ60eSnpg6TPp/Z9emZmtrdafNvtiKiVdBPwHNAFuD8ilki6FaiMiApgOvCwpCrgQ7LhQVJvJrAUqAUmRcRuSWXAg8kdSgXAzIh4OnnJvwVmSPoF8EbSt5mZHSDqDE8symQyUVlZ2d7DMDPrUCTNj4hM43L/5bOZmaU4GMzMLMXBYGZmKQ4GMzNLcTCYmVmKg8HMzFIcDGZmluJgMDOzFAeDmZmlOBjMzCzFwWBmZikOBjMzS3EwmJlZioPBzMxSHAxmZpbiYDAzsxQHg5mZpTgYzMwsxcFgZmYpDgYzM0txMJiZWYqDwczMUhwMZmaW4mAwM7MUB4OZmaU4GMzMLCWvYJA0RtIKSVWSJjdxvJukx5PjcyQNyDk2JSlfIemSpKyfpJckLZW0RNIPc+rfLGmtpIXJx2Wtn6aZmeWrsKUKkroAU4GLgGpgnqSKiFiaU+16YHNEDJY0Hrgd+JakcmA8MBz4MvCCpKFALfCTiFgg6UhgvqTnc/q8KyLuaKtJmplZ/loMBmAUUBURKwEkzQDGArnBMBa4Odl+ArhbkpLyGRGxE1glqQoYFRGvA+sAIuIjScuAPo36NDsgdtXWMWvZBnZ8Vsey9Vsp69GNh15fw1Vn9GXee1u4aNiXKCwQb6zZyoL3P+RLRx3GP37zJBat3cqZg0p4Yck6/rxkPecN683itdt4a+1WrjztWP6rahP9S7oz4WuDmP7aSj79bDerN31CzfYd9DisiLUffsyA0iMoPbKY1975gOtGD2Duyi3U1u2ma5cunHl8Tx6Z/T6jB5dyat+jeKfmE0768lFcPKKMokKfBbb9RxHRfAVpHDAmIm5I9q8BzoiIm3LqLE7qVCf77wJnkA2L2RHxSFI+HXg2Ip7IaTsAeBUYERHbJN0MfA/YBlSSXVlsbmJcE4AJAP379x+5evXqvZ+9GfDC0g1MeLiSCGj+u+FzmeN6snDNFiaeczx3v1TVYt3K1V/4L7xPBPzuuxkuLO/dJv3ZoU3S/IjINC7PZ8Ww30jqDvwB+FFEbEuKfwP8A9nv0X8A/hm4rnHbiJgGTAPIZDL5fj+bfcHZQ0u556rT93nFMLjX4Qd0xXD20NL2/iezTi6fFcNXgZsjov7C8RSAiPjHnDrPJXVel1QIrAdKgcm5dRvV6wo8DTwXEXfu4bUHAE9HxIjmxpjJZKKysrLl2ZqZWYM9rRjyOVE5DxgiaaCkIrIXkysa1akArk22xwGzIps4FcD45K6lgcAQYG5y/WE6sKxxKEgqy9n9JrA4jzGamVkbafFUUkTUSroJeA7oAtwfEUsk3QpURkQF2R/yDycXlz8kGx4k9WaSvahcC0yKiN2SRgPXAG9JWpi81M8i4hngV5JOJXsq6T3g+204XzMza0GLp5I6Ap9KMjPbe605lWRmZocQB4OZmaU4GMzMLMXBYGZmKQ4GMzNLcTCYmVmKg8HMzFIcDGZmluJgMDOzFAeDmZmlOBjMzCzFwWBmZikOBjMzS3EwmJlZioPBzMxSHAxmZpbiYDAzsxQHg5mZpTgYzMwsxcFgZmYpDgYzM0txMJiZWYqDwczMUhwMZmaW4mAwM7MUB4OZmaXkFQySxkhaIalK0uQmjneT9HhyfI6kATnHpiTlKyRdkpT1k/SSpKWSlkj6YU79YyQ9L+md5HPP1k/TzMzy1WIwSOoCTAUuBcqBb0sqb1TtemBzRAwG7gJuT9qWA+OB4cAY4J6kv1rgJxFRDpwJTMrpczLwYkQMAV5M9s3M7ADJZ8UwCqiKiJURsQuYAYxtVGcs8GCy/QRwgSQl5TMiYmdErAKqgFERsS4iFgBExEfAMqBPE309CFy5b1MzM7N9kU8w9AHW5OxX8/kP8S/UiYhaYCtQkk/b5LTTacCcpKh3RKxLttcDvZsalKQJkiolVdbU1OQxDTMzy0e7XnyW1B34A/CjiNjW+HhEBBBNtY2IaRGRiYhMaWnpfh6pmdmhI59gWAv0y9nvm5Q1WUdSIdAD2NRcW0ldyYbCoxHxZE6dDZLKkjplwMZ8J2NmZq2XTzDMA4ZIGiipiOzF5IpGdSqAa5PtccCs5Lf9CmB8ctfSQGAIMDe5/jAdWBYRdzbT17XAU3s7KTMz23eFLVWIiFpJNwHPAV2A+yNiiaRbgcqIqCD7Q/5hSVXAh2TDg6TeTGAp2TuRJkXEbkmjgWuAtyQtTF7qZxHxDPBLYKak64HVwH9vywmbmVnzlP3FvmPLZDJRWVnZ3sMwM+tQJM2PiEzjcv/ls5mZpTgYzMwsxcFgZmYpDgYzM0txMJiZWYqDwczMUhwMZmaW4mAwM7MUB4OZmaU4GMzMLMXBYGZmKQ4GMzNLcTCYmVmKg8HMzFJafB6D2b7YVVvHM2+u5U+L11PWo5gIsX7bDi4f0ZuvDCxh0mNvcMzhXbh4eG/e37yD74w6jtueWcr5J5RSWFDAorXbWLVpOwUSowf34qHXV/PfBpXwlYElFHaB2t11vLlmK+9+8DHv1XzEWUNKOaJbV0aUHQkqYPFftkDAqf2P5qzjS7n3lSp2RzCi7Chq64JnF6+jNoLCggK+fnIfLju5jKJC/55kBn4eg+0nLyzdwA0PNf01GVx6BFU1H7dY1hQlH3t8GHjOccguiS89qYw/vbWu2T5/990MF5b3bvH1zTqTPT2PwSsG2y/OHlrKr//q5INmxdD36OJmVwxnDy1t738ys4OGVwxmZocoP8HNzMzy4lNJZrZfbNy6g+8/UslJfY7i5D5Hs6h6CwurtzKs7EgWrt7Myk2fMKhXMd/96kDe2/Qpu2qDDdt2MOXSYfz8Pxaz7C8fcd3ogVS8+Rc+3vEZ3ztrIK++U8OGbTuQYGS/Yxg1qITT+h3NpMcWcORhhRQViEtHHAsU8OLyjdzyjeHMfe9D3nh/CxIMLj2Ml97exHlDe7Hqg084sawHUMeC1ZtZWL2Vkf17csPXBnHzfyyhLuDrJ5dxYfmxvPZODQjOH9abosKC7M0Vi9byn0s2cNs3T6Zn9yIge9PFrGUbQDB6cCmzV27izEElDe1HDy7l5RUbWfKXbUw6bzDdiwsb2r36dg1nDy2lqLCAzdt38X+eWszPryhn0dqtnDmohNkrNzUcr9e4XVvxqSQz2y/G/ea/qFy9ea/blR3VjXXbduZVt0AwqNeeb1zIHNeTBe9vpm4PP+aUfM49nHsjhIBJ5w3mnperEOK314zkwvLevLB0A//zoUoCuPykMqZ+53Qge9PFhIcrEeLGc4/n3lfeZeI5xze0v/Hc7HZdwE3nDeanl5zQ0G7iI/O59+ps/5MeXcCf3lpH5rieLFyzhYnnZPuqP16vcbu9tadTSQ4GM9svvGI4+FcMDgYzM0vxxWczM8uLg8HMzFLyCgZJYyStkFQlaXITx7tJejw5PkfSgJxjU5LyFZIuySm/X9JGSYsb9XWzpLWSFiYfl+379MzMbG+1GAySugBTgUuBcuDbksobVbse2BwRg4G7gNuTtuXAeGA4MAa4J+kP4IGkrCl3RcSpycczezclMzNrjXxWDKOAqohYGRG7gBnA2EZ1xgIPJttPABdIUlI+IyJ2RsQqoCrpj4h4FfiwDeZgZmZtKJ9g6AOsydmvTsqarBMRtcBWoCTPtk25SdKi5HRTzzzqm5lZGzkYLz7/BjgeOBVYB/xzU5UkTZBUKamypqbmQI7PzKxTyycY1gL9cvb7JmVN1pFUCPQANuXZNiUiNkTE7oioA35HcuqpiXrTIiITEZnSUr8zpplZW8knGOYBQyQNlFRE9mJyRaM6FcC1yfY4YFZk/3KuAhif3LU0EBgCzG3uxSSV5ex+E1i8p7pmZtb2WnwTvYiolXQT8BzQBbg/IpZIuhWojIgKYDrwsKQqsheUxydtl0iaCSwFaoFJEbEbQNK/AecCvSRVAz+PiOnArySdSvbtS94Dvt+WEzYzs+b5LTHMzA5RfksMMzPLi4PBzMxSHAxmZpbiYDAzsxQHg5mZpTgYzMwsxcFgZmYpDgYzM0txMJiZWYqDwczMUhwMZmaW4mAwM7MUB4OZmaU4GMzMLMXBYGZmKQ4GMzNLcTCYmVmKg8HMzFIcDGZmluJgMDOzFAeDmZmlOBjMzCzFwWBmZikOBjMzS3EwmJlZioPBzMxS8goGSWMkrZBUJWlyE8e7SXo8OT5H0oCcY1OS8hWSLskpv1/SRkmLG/V1jKTnJb2TfO6579MzM7O91WIwSOoCTAUuBcqBb0sqb1TtemBzRAwG7gJuT9qWA+OB4cAY4J6kP4AHkrLGJgMvRsQQ4MVk38zMDpB8VgyjgKqIWBkRu4AZwNhGdcYCDybbTwAXSFJSPiMidkbEKqAq6Y+IeBX4sInXy+3rQeDKvZiPmZm1Uj7B0AdYk7NfnZQ1WSciaoGtQEmebRvrHRHrku31QO88xmhmZm3koL74HBEBRFPHJE2QVCmpsqam5gCPzMys88onGNYC/XL2+yZlTdaRVAj0ADbl2baxDZLKkr7KgI1NVYqIaRGRiYhMaWlpHtMwM7N85BMM84AhkgZKKiJ7MbmiUZ0K4NpkexwwK/ltvwIYn9y1NBAYAsxt4fVy+7oWeCqPMZqZWRtpMRiSawY3Ac8By4CZEbFE0q2SvpFUmw6USKoCfkxyJ1FELAFmAkuBPwOTImI3gKR/A14HTpBULen6pK9fAhdJege4MNk3M7MDRNlf7Du2TCYTlZWV7T0MM7MORdL8iMg0Lj+oLz6bmdmB52AwM7MUB4OZmaU4GMzMLMXBYGZmKQ4GMzNLcTCYmVmKg8HMzFIcDGZmluJgMDOzFAeDmZmlOBjMzCzFwWBmZikOBjMzS3EwmJlZioPBzMxSHAxmZpbiYDAzsxQHg5mZpTgYzMwsxcFgZmYpDgYzM0txMJiZWYqDwczMUhwMZmaW4mAwM7OUvIJB0hhJKyRVSZrcxPFukh5Pjs+RNCDn2JSkfIWkS1rqU9IDklZJWph8nNq6KZqZ2d4obKmCpC7AVOAioBqYJ6kiIpbmVLse2BwRgyWNB24HviWpHBgPDAe+DLwgaWjSprk+/yYinmiD+TVrV20ds5ZtoLYuIILCwgLOH9abosKChmMIzh/WG4BX367h7KGlDcfr93OP1W+fOaiE196pobYuqN29m2XrPuLEsh4Udy3g/BOzr1E/hob6VTUQMHpIacP2+SdmX3vW8g0QcHr/ntzy9FJ+MXYEPbsXpcaxq7aOf33xHT6rraOoUPzggqF0Ly5k+45a7n3lXb575nE88PoqhpZ2p7hbYcNct++o5dcvrGD1pk/42pAS/n3hOn4xdjhTnlzM0Yd3ZUjv7kw6dwivvL2BP721nrIeh9G1SwEn9elB7e46nl2yDoCvn9yHy04uY1dtHfe+8i4Tzzme7sUt/hczs4NMPt+1o4CqiFgJIGkGMBbIDYaxwM3J9hPA3ZKUlM+IiJ3AKklVSX/k0ed+9+rbNfz1YwsAiIACid9eM5ILy3s3HBPZMoCJj8zn3qs/P16/n3usfnviOcdzz8tVDX0HIECCaddkuLD887Cpr/+bl98lCP763MEN29OuyQAw6dE3CILT+/ekcvVmAKZ+5/TUOBau2cK0/7eyYX5dC7vw00tO4N5X3uXul6qYvXJTQ9suOXO995V3ue+19wB4ftlGAP7qt6/z8a46AF555wPWbd3JM2+tI3L+/QSp/VnLa+heXMjCNVu4+6Xs3H96yQmt+yKZ2QGXTzD0Adbk7FcDZ+ypTkTUStoKlCTlsxu17ZNsN9fnbZL+HngRmJwES5s7e2gp91x1emrFUP9bf/0xREPZvVePTB3P3W+8feagEkZ8+agmVwz19XL7OXNQCSP6HNWwYqjfrq879TunfWHF0HgcZw4qYVdtXcOKYeI5xwM0fG68Yqjve+I5x7Pjs9oWVwwXnVja4oqhfhy5r2tmHYsiovkK0jhgTETckOxfA5wRETfl1Fmc1KlO9t8l+4P+ZmB2RDySlE8Hnk2aNdmnpDJgPVAETAPejYhbmxjXBGACQP/+/UeuXr163/4FzMwOUZLmR0SmcXk+F5/XAv1y9vsmZU3WkVQI9AA2NdN2j31GxLrI2gn8ns9PPaVExLSIyEREprS0tKkqZma2D/IJhnnAEEkDJRWRvZhc0ahOBXBtsj0OmBXZpUgFMD65a2kgMASY21yfyYqB5BrFlcDi1kzQzMz2TovXGJJrBjcBzwFdgPsjYomkW4HKiKgApgMPJxeXPyT7g56k3kyyF5VrgUkRsRugqT6Tl3xUUinZa5sLgYltN10zM2tJi9cYOoJMJhOVlZXtPQwzsw6lNdcYzMzsEOJgMDOzFAeDmZmldIprDJJqgI+BD9p7LPtJLzy3jqazzgs8t46qqbkdFxFfuN+/UwQDgKTKpi6idAaeW8fTWecFnltHtTdz86kkMzNLcTCYmVlKZwqGae09gP3Ic+t4Ouu8wHPrqPKeW6e5xmBmZm2jM60YzMysDXS6YJD0A0nLJS2R9Kv2Hk9bk/QTSSGpV3uPpS1I+qfk67VI0h8lHd3eY2qtlh6F21FJ6ifpJUlLk++vH7b3mNqSpC6S3pD0dHuPpS1JOlrSE8n32TJJX22pTacKBknnkX0S3CkRMRy4o52H1KYk9QMuBt5v77G0oeeBERFxMvA2MKWdx9MqOY/CvRQoB76dPOK2M6gFfhIR5cCZwKRONDeAHwLL2nsQ+8G/AH+OiGHAKeQxx04VDMCNwC/rn/gWERvbeTxt7S7gf5N+omaHFhH/GRG1ye5sss/m6MgaHoUbEbuA+sfWdnjJs1IWJNsfkf0B06f5Vh2DpL7A5cB97T2WtiSpB3A22XfAJiJ2RcSWltp1tmAYCnxN0hxJr0j6SnsPqK1IGgusjYg323ss+9F1fP6Ev46qqUfhdoofnrkkDQBOA+a070jazK/J/tJV194DaWMDgRrg98lpsvskHdFSo3ye+XxQkfQCcGwTh/6O7HyOIbvM/QowU9Kg6CC3XrUwt5+RPY3U4TQ3r4h4Kqnzd2RPVTx6IMdme09Sd+APwI8iYlt7j6e1JF0BbIyI+ZLObe/xtLFC4HTgBxExR9K/AJOB/9tSow4lIi7c0zFJNwJPJkEwV1Id2fcHqTlQ42uNPc1N0klkk//N7IPt6AsskDQqItYfwCHuk+a+ZgCSvgdcAVzQUUK8Gfk8CrfDktSVbCg8GhFPtvd42shZwDckXQYUA0dJeiQirm7ncbWFaqA6IupXdk+QDYZmdbZTSf8OnAcgaShQRCd4Q6yIeCsivhQRAyJiANkv9ukdIRRaImkM2SX8NyLik/YeTxvI51G4HVLyuN3pwLKIuLO9x9NWImJKRPRNvrfGk300cWcIBZKfEWsknZAUXUD2iZrN6nArhhbcD9wvaTGwC7i2E/wG2tndDXQDnk9WQ7MjosM+znVPj8Jt52G1lbOAa4C3JC1Myn4WEc+045isZT8g+8jkImAl8D9aauC/fDYzs5TOdirJzMxaycFgZmYpDgYzM0txMJiZWYqDwczMUhwMZmaW4mAwM7MUB4OZmaX8fwHMF12qq64UAAAAAElFTkSuQmCC", - "text/plain": [ - "<Figure size 432x288 with 1 Axes>" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ - "predictions, bias = bias_wrapped_classifier(np.sort(x_val))\n", - "plt.scatter(np.sort(x_val), bias, label='bias', s=0.5)\n", - "plt.legend()" + "### Generate and visualize bias scores for data in test set ###\n", + "\n", + "# Call the risk-aware model to generate scores\n", + "predictions, bias = bias_wrapped_dense_NN(x_test)\n", + "\n", + "# Visualize the relationship between the input data x and the bias\n", + "fig, ax = plt.subplots(2, 1, figsize=(8,6))\n", + "ax[0].plot(x_test, bias, label='bias')\n", + "ax[0].set_ylabel('Estimated Bias')\n", + "ax[0].legend()\n", + "\n", + "# Let's compare against the ground truth density distribution\n", + "# should roughly align with our estimated bias in this toy example\n", + "ax[1].hist(x_train, 50, label='ground truth')\n", + "ax[1].set_xlim(-6, 6)\n", + "ax[1].set_ylabel('True Density')\n", + "ax[1].legend();" ] }, { "cell_type": "markdown", - "metadata": { - "id": "PvS8xR_q27Ec" - }, "source": [ - "## 1.3 Aleatoric Estimation\n", - "Now, let's do the same thing but for aleatoric estimation! The method we use here is Mean and Variance Estimation (MVE) since we're trying to estimate both mean and variance for every input. As presented in lecture 5, we measure the accuracy of these predictions negative likelihood loss in addition to mean squared error. However, capsa *automatically* does this for us, so we only have to specify the loss function that we want to use for evaluating the predictions, not the uncertainty." - ] + "#### **TODO: Evaluating bias with wrapped regression model**\n", + "\n", + "Write short (~1 sentence) answers to the questions below to complete the `TODO`s:\n", + "\n", + "1. How does the bias score relate to the train/test data density from the first plot?\n", + "2. What is one limitation of the Histogram approach that simply bins the data based on frequency?" + ], + "metadata": { + "id": "HpDMT_1FERQE" + } }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": { - "id": "sxmm-2sd3G9u" + "id": "PvS8xR_q27Ec" }, - "outputs": [], "source": [ - "standard_classifier = create_standard_classifier()\n", - "mve_wrapped_classifier = MVEWrapper(standard_classifier)\n" + "# 1.4 Estimating data uncertainty\n", + "\n", + "Next we turn our attention to uncertainty, first focusing on the uncertainty in the data -- the aleatoric uncertainty.\n", + "\n", + "As introduced in Lecture 5 on Robust & Trustworthy Deep Learning, in regression we can estimate aleatoric uncertainty by training the model to predict both a target value and a variance for every input. Because we estimate both a mean and variance for every input, this method is called Mean Variance Estimation (MVE). MVE involves modifying the output layer to predict both the mean and variance, and changing the loss to reflect the prediction likelihood.\n", + "\n", + "Capsa automatically implements these changes for us: we can wrap a given model using `capsa.MVEWrapper` to use MVE to estimate aleatoric uncertainty. All we have to do is define the model and the loss function to evaluate its predictions! More details of the `MVEWrapper` and how it can be used are [available here](https://themisai.io/capsa/api_documentation/MVEWrapper.html).\n", + "\n", + "Let's take our standard network, wrap it with `capsa.MVEWrapper`, build the wrapped model, and then train it for the regression task. Finally, we evaluate performance of the resulting model by quantifying the aleatoric uncertainty across the data space: " ] }, { "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Yr0yIJEc26yM", - "outputId": "5dc23258-613b-4a83-a672-70da11ebbe81" + "id": "sxmm-2sd3G9u" }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 1/30\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:tensorflow:Gradients do not exist for variables ['dense_62/kernel:0', 'dense_62/bias:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss`argument?\n", - "WARNING:tensorflow:Gradients do not exist for variables ['dense_62/kernel:0', 'dense_62/bias:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss`argument?\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "63/63 [==============================] - 1s 2ms/step - mve_compiled_loss: 5.2917 - mve_wrapper_loss: 8.2990\n", - "Epoch 2/30\n", - "63/63 [==============================] - 0s 2ms/step - mve_compiled_loss: 1.5322 - mve_wrapper_loss: 2.2633\n", - "Epoch 3/30\n", - "63/63 [==============================] - 0s 2ms/step - mve_compiled_loss: 0.6495 - mve_wrapper_loss: 0.4517\n", - "Epoch 4/30\n", - "63/63 [==============================] - 0s 2ms/step - mve_compiled_loss: 0.4826 - mve_wrapper_loss: -0.0290\n", - "Epoch 5/30\n", - "63/63 [==============================] - 0s 2ms/step - mve_compiled_loss: 0.4571 - mve_wrapper_loss: -0.2686\n", - "Epoch 6/30\n", - "63/63 [==============================] - 0s 2ms/step - mve_compiled_loss: 0.4070 - mve_wrapper_loss: -0.3623\n", - "Epoch 7/30\n", - "63/63 [==============================] - 0s 2ms/step - mve_compiled_loss: 0.3092 - mve_wrapper_loss: -0.4281\n", - "Epoch 8/30\n", - "63/63 [==============================] - 0s 3ms/step - mve_compiled_loss: 0.3229 - mve_wrapper_loss: -0.4038\n", - "Epoch 9/30\n", - "63/63 [==============================] - 0s 2ms/step - mve_compiled_loss: 0.3236 - mve_wrapper_loss: -0.5268\n", - "Epoch 10/30\n", - "63/63 [==============================] - 0s 2ms/step - mve_compiled_loss: 0.3022 - mve_wrapper_loss: -0.5458\n", - "Epoch 11/30\n", - "63/63 [==============================] - 0s 2ms/step - mve_compiled_loss: 0.3035 - mve_wrapper_loss: -0.6220\n", - "Epoch 12/30\n", - "63/63 [==============================] - 0s 2ms/step - mve_compiled_loss: 0.3110 - mve_wrapper_loss: -0.5680\n", - "Epoch 13/30\n", - "63/63 [==============================] - 0s 2ms/step - mve_compiled_loss: 0.2720 - mve_wrapper_loss: -0.4468\n", - "Epoch 14/30\n", - "63/63 [==============================] - 0s 2ms/step - mve_compiled_loss: 0.2848 - mve_wrapper_loss: -0.5656\n", - "Epoch 15/30\n", - "63/63 [==============================] - 0s 2ms/step - mve_compiled_loss: 0.3078 - mve_wrapper_loss: -0.6007\n", - "Epoch 16/30\n", - "63/63 [==============================] - 0s 2ms/step - mve_compiled_loss: 0.2827 - mve_wrapper_loss: -0.6292\n", - "Epoch 17/30\n", - "63/63 [==============================] - 0s 2ms/step - mve_compiled_loss: 0.3168 - mve_wrapper_loss: -0.6420\n", - "Epoch 18/30\n", - "63/63 [==============================] - 0s 2ms/step - mve_compiled_loss: 0.2910 - mve_wrapper_loss: -0.6672\n", - "Epoch 19/30\n", - "63/63 [==============================] - 0s 2ms/step - mve_compiled_loss: 0.3076 - mve_wrapper_loss: -0.5917\n", - "Epoch 20/30\n", - "63/63 [==============================] - 0s 2ms/step - mve_compiled_loss: 0.3097 - mve_wrapper_loss: -0.6985\n", - "Epoch 21/30\n", - "63/63 [==============================] - 0s 2ms/step - mve_compiled_loss: 0.2982 - mve_wrapper_loss: -0.5248\n", - "Epoch 22/30\n", - "63/63 [==============================] - 0s 2ms/step - mve_compiled_loss: 0.2912 - mve_wrapper_loss: -0.5999\n", - "Epoch 23/30\n", - "63/63 [==============================] - 0s 2ms/step - mve_compiled_loss: 0.3003 - mve_wrapper_loss: -0.5714\n", - "Epoch 24/30\n", - "63/63 [==============================] - 0s 2ms/step - mve_compiled_loss: 0.3314 - mve_wrapper_loss: -0.6600\n", - "Epoch 25/30\n", - "63/63 [==============================] - 0s 2ms/step - mve_compiled_loss: 0.2974 - mve_wrapper_loss: -0.5685\n", - "Epoch 26/30\n", - "63/63 [==============================] - 0s 2ms/step - mve_compiled_loss: 0.3157 - mve_wrapper_loss: -0.6695\n", - "Epoch 27/30\n", - "63/63 [==============================] - 0s 2ms/step - mve_compiled_loss: 0.2832 - mve_wrapper_loss: -0.6686\n", - "Epoch 28/30\n", - "63/63 [==============================] - 0s 2ms/step - mve_compiled_loss: 0.3066 - mve_wrapper_loss: -0.6361\n", - "Epoch 29/30\n", - "63/63 [==============================] - 0s 2ms/step - mve_compiled_loss: 0.2880 - mve_wrapper_loss: -0.6316\n", - "Epoch 30/30\n", - "63/63 [==============================] - 0s 2ms/step - mve_compiled_loss: 0.3322 - mve_wrapper_loss: -0.5759\n" - ] - } - ], + "outputs": [], "source": [ - "mve_wrapped_classifier.compile(\n", + "### Estimating data uncertainty with Capsa wrapping ###\n", + "\n", + "standard_dense_NN = create_dense_NN()\n", + "# Wrap the dense network for aleatoric uncertainty estimation\n", + "mve_wrapped_NN = capsa.MVEWrapper(standard_dense_NN)\n", + "\n", + "# Build the model for regression, defining the loss function and optimizer\n", + "mve_wrapped_NN.compile(\n", " optimizer=tf.keras.optimizers.Adam(learning_rate=1e-2),\n", - " loss=tf.keras.losses.MeanSquaredError(),\n", + " loss=tf.keras.losses.MeanSquaredError(), # MSE loss for the regression task\n", ")\n", "\n", - "history = mve_wrapped_classifier.fit(x, y, epochs=30)" + "# Train the wrapped model for 30 epochs.\n", + "loss_history_mve_wrap = mve_wrapped_NN.fit(x_train, y_train, epochs=30)\n", + "\n", + "# Call the uncertainty-aware model to generate outputs for the test data\n", + "x_test_clipped = np.clip(x_test, x_train.min(), x_train.max())\n", + "prediction = mve_wrapped_NN(x_test_clipped)" ] }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 283 - }, - "id": "k_m_7H4P1ADv", - "outputId": "c215f212-1bb6-45aa-beab-0565ed20362b" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "<matplotlib.legend.Legend at 0x7fe12a1e10a0>" - ] - }, - "execution_count": 126, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3de3xU5bXw8d+aDDGvooAQMQJyDxAuAZJyeD+0hIACXl6p9ZyqeO2p5SLoaaut9rRve04vn948p4oCkWptxXvbY8t7RLGSkNQqlQAKJsGQBJVrCMjFlBPHYdb7x2TSEHKZJJPZM3uv7+cDk5nZM7N2sveaZ6/97OcRVcUYY0zy8zkdgDHGmNiwhG6MMS5hCd0YY1zCEroxxriEJXRjjHEJv1MfPGDAAB02bJhTH2+MMUlp69atR1Q1vbXnHEvow4YNo7S01KmPN8aYpCQiH7T1nJVcjDHGJSyhG2OMS1hCN8YYl3Cshm6MF3366afs27ePhoYGp0MxCS4tLY3BgwfTq1evqF9jCd2YONq3bx/nn38+w4YNQ0ScDsckKFXl6NGj7Nu3j+HDh0f9Oiu5GBNHDQ0N9O/f35K5aZeI0L9//04fyVlCNybOLJmbaHRlO7GEbpJCfUOQBza8R31D0OlQjElYltBNUlhVtJtHiqpYVbTb6VBca9iwYRw5cqRLr33wwQc5depUp193xx13UF5e3qXPjIdo1yua9fjDH/7Q4+tqCd0kvEAwxGlVfMBpVQLBkNMhmRa6ktBPnz7NY489RlZWVg9F1T2nT5+Oer2iWQ9L6MYAJZV1PPbnPVwxMYPH/ryHkso6p0NKap///OfJyclh/PjxrFmzptVlnnrqKaZNm8bkyZNZvHgxp0+fBmDp0qXk5uYyfvx4vve97wGwYsUKDhw4QH5+Pvn5+QA8++yzTJw4kQkTJnDfffc1vW/v3r255557yM7O5s0332TWrFlNQ4C88sorTJ06lezsbObMmXNWTL/+9a9Zvnx50/2rr76aTZs2Nb3vt7/9bbKzs5k+fTq1tbUA1NbWcu2115KdnU12djZvvPFGu+vXPL4f/ehHZ61Xa+sPnLEercXyxhtvsG7dOr7xjW8wefJkqqurmTp1atPrd+/efcb9LlNVR/7l5OSoMe355NPT+vLOA/r/tu/Tl3ce0I//51P9U9kh/eTT006H1mXl5eVOh6BHjx5VVdVTp07p+PHj9ciRI6qqOnToUK2rq9Py8nK9+uqrNRAIqKrq0qVL9Te/+c0Zrw0Gg5qXl6fvvPPOGa9VVd2/f78OGTJEDx8+rJ9++qnm5+friy++qKqqgD7//PNNseTl5emWLVv08OHDOnjwYK2pqTnjc5p74okndNmyZU33r7rqKi0qKmp633Xr1qmq6je+8Q39wQ9+oKqqX/ziF/UXv/hFU8zHjx9vd/1axtd8vdpb/8h6tBfLbbfdpr/97W+b3mvWrFm6fft2VVX91re+pStWrDhrnVvbXoBSbSOvWgvdJKzCXbXc+fQ27n7ubfw+H73T/FyWNZBUv7c220AwxGvltTErNa1YsaKp9bh371527z7zvMTGjRvZunUrn/nMZ5g8eTIbN26kpqYGgBdeeIGpU6cyZcoUysrKWi0hbNmyhVmzZpGeno7f7+emm26ipKQEgJSUFK677rqzXrN582ZmzpzZ1Of6wgsv7NQ6paamcvXVVwOQk5PD+++/D0BhYSFLly5t+uw+ffq0u35txRcRzfq3FUtLd9xxB0888QSnT5/m+eefZ+HChZ1a59bYhUUmIQWCId7ZewIUFueNZGZmOoFgiJLKOmZmpnsqqZdU1rHkqa0U3JzDZVkDu/VemzZt4rXXXuPNN9/k3HPPZdasWWf1dVZVbrvtNn784x+f8fiePXt44IEH2LJlC/369eP222/vdD/ptLQ0UlJSuhS73+8nFPr7l1rzz+7Vq1dTN7+UlBSCwbZ7Q7W1fh3FF+36RxvLddddx7//+78ze/ZscnJy6N+/f5sxR8s7e4VJKoW7allTUsOSWSP42uWZpPp9FFbUsmhtKYUVtU6HF1czM9MpuDmHmZmtDoHdKSdOnKBfv36ce+657Nq1i82bN5+1zJw5c/jd737H4cOHAfjoo4/44IMPOHnyJOeddx59+vShtraWl19+uek1559/Ph9//DEA06ZNo7i4mCNHjnD69GmeffZZ8vLy2o1r+vTplJSUsGfPnqbPbGnYsGG8/fbbhEIh9u7dy1tvvdXh+s6ZM4fVq1cD4ZOcJ06caHP9WtN8vdpb/2g0fy8If3nMmzePpUuX8qUvfalT79UWS+gmMSkoyqRBff/eGheQ8H+ekur3xazUNH/+fILBIOPGjeP+++9n+vTpZy2TlZXFD3/4Q+bOncukSZO4/PLLOXjwINnZ2UyZMoWxY8eycOFCZsyY0fSaRYsWMX/+fPLz88nIyOAnP/kJ+fn5ZGdnk5OTw4IFC9qNKz09nTVr1vCFL3yB7Oxsrr/++rOWmTFjBsOHDycrK4u77747qpOIDz30EEVFRUycOJGcnBzKy8vbXL/WNF+v9tY/GjfccAM///nPmTJlCtXV1QDcdNNN+Hw+5s6d26n3aouEa+zxl5ubqzbBhWkpUlaZPqI/m2uOnlFeCQRDvPruQcoPneTOWaPpnZZ8FcOKigrGjRvndBgmQTzwwAOcOHGCH/zgB60+39r2IiJbVTW3teWTb48wrlXfEOS+3+/glXcP8ugtuWfVi1P9PnbV1rNqUw0+8XHvvDEORWpM91177bVUV1dTWFgYs/fsMKGLyK+Aq4HDqjqhlecFeAi4EjgF3K6q22IWofGEQDDEfb/fwUs7D3LVxIw268VL8kaecWtMsnrxxRdj/p7RFOV+Dcxv5/krgNGN/xYBq7sflvGawopa1u88yBXjL+an101qs17cO83PvfPGJGW5JcKpMqdJLl3ZTjpM6KpaApx9yvnvFgBPNvZ53wz0FZGMTkdivE3AJ8KCKZe0m6xj3Sc73tLS0jh69KglddMubRwPPS0trVOvi0UzZxCwt9n9fY2PnXXaWEQWEW7Fc+mll8bgo41bzB47kEdv6bhrXiz7ZDth8ODB7Nu3j7o6G77AtC8yY1FnxPW4VVXXAGsg3Mslnp9tEldnLhiKZZ9sJ/Tq1atTM9AY0xmx6Ie+HxjS7P7gxseMiUqk1R3NoFupfh8zM9MpqaxL2rKLMT0lFgl9HXCrhE0HTqhq6730jWlFZ1vdnfkCMMZLoum2+CwwCxggIvuA7wG9AFS1AFhPuMtiFeFui7G5htV4QlfGZ0n2sosxPaXDhK6qN3bwvALLYhaR8ZSunOSMXApvjDmTjeViHBMIhgiGQjxy45ROt7aTvfuiMT3BErpxTEllHcuf2Y4/xdfpgaesjm7M2SyhG8dMH9GfJXkjmT6i8+NAWx3dmLNZQjeO2VxzlILiajbXHO30a2M5pKwxbmF7g3GMtbKNiS1L6MYR9Q1BVmzczfQR/a2VbUyM2J5kHFFQXM0jRVUUFFd3+T2sp4sxZ7KEbhyxJG8ky/NHdWtcc+vpYsyZLKGbuDt8ooHbn3iLW6cP7da45laDN+ZMltBN3C16qpTSD46x6KnuzSlrPV2MOZPtCSbupg7pd8atMSY2knceL5O0vj53DOee47d5QY2JMWuhm7g6Vh/gvt/v4Mszhif1vKDGJCJL6CauvvPHd3lp50G+88d3nQ7FGNexJpKJqx8umHDGbXcFgiEKd9WCwuxxdoLUeJtt/SZu6huCPP6XPfz0ukn0650ak/csqaxj2dPbufOZbdYf3XietdBN3ESuDgW4d96YmLznzMx0HrphMmUHTnZp1EZj3MQSuombSK+WWPZuSfX7SOuVwi//XEPO0H42k5HxNAnPIBd/ubm5WlravQtLjIGuzUtqTLISka2qmtvac9ZCN0nP5hg1JsyaM8YY4xKW0E2Pq28I8sCG96hvCPbYZ9hQusZYQjdxEIuxzztSWFHLorWlFFbU9thnGJPorIZuelxP9G5pKRgKgTbeGuNRltBNj+ud5o9Zv/O2BE+DNt4a41VWcjE9Jh6184j3aj9GG2+N8SproZse0xNXhrZlWf4oUnxiQ/IaT7MWuukxX54xnKsmZvDlGcN7/LN6p/m5e85oNtcctZ4uxrOiSugiMl9E3hORKhG5v5XnLxWRIhHZLiI7ROTK2Idqkkl9Q5Dv/PFdNpQdYuuHx+LymTZptPG6DhO6iKQAK4ErgCzgRhHJarHYd4AXVHUKcAOwKtaBmuTycGElL+08yJxxF8VtEmebNNp4XTQt9GlAlarWqGoAeA5Y0GIZBS5o/LkPcCB2IZpkEwiGeP/IKQCGXXhu3MZXsUmjjddFs+UPAvY2u7+v8bHm/g24WUT2AeuBu1p7IxFZJCKlIlJaV2eHxW5VUlnHxl2HuWpiBnfNyXQ6HGM8I1ZNmRuBX6vqYOBKYK2InPXeqrpGVXNVNTc93Q6L3SpS+vjF9ZNt3lBj4iiahL4fGNLs/uDGx5r7MvACgKq+CaQBA2IRoEkuNpStMc6JZo/bAowWkeEikkr4pOe6Fst8CMwBEJFxhBO61VQ86NWygyx6spRXyw46HYoxntNhQlfVILAc2ABUEO7NUiYi3xeRaxoXuwf4ioi8AzwL3K5OzZxhHBMIhnhpxyFCQPmBk06HY4znRFXgVNX1hE92Nn/su81+LgdmxDY0k2xeLTvIK2WHmJc1kDvzRzsdjjGeY0VOEzPlB06iwMj08+xkqDEOsL3OxEQgGCLrkj4szRvJnfmjnA7HGE+yFrqJicKKWu5+bjvZg/tY69wYh1hCN90WCIbYse84goA4H4tNRWe8yhK66baSyjoeLalh6ayRzB470PFYbIAu41XiVO/C3NxcLS0tdeSzTWwl0sVEiRSLMT1BRLaqam5rz1mx03RbZFCsRJBIsRgTb9aEMd2SaDXrRIvHmHiyhG66LBAMsWLjbhavLU2YmnWkhr5i425L6sZzLKGbLivcVcvqTdUsnjkiYSaVmJmZzpK8kazeVJUwXzLGxIvV0E2XBYMhQqpkZVyQMCcgU/0+7p4zmslD+ibMl4wx8ZIYe6FJSn6/D58I/gRJ5sZ4ne2Jpstmjx3Io7fkON73vCXri268yhK66bRITxIgIefwtMmijVcl1p5okkJhRS2L1pZSWFHrdCitssmijVfZFm86T0iIcVuMMWeyhG46JRAMgcLKm6YkXO3cGK+zhG46paSyjuXPbsfv81lJw5gEY3ukiVogGCIYCvHIjVPshKMxCcgSuona+h37WfrUNho+DVrr3JgEZHuliUp9Q5CHC6tR4NWyxOzdYozXWUI3UVm1aTfVR/7GiAHn8aNrJzkdjjGmFZbQTVSyLr4AAb5+2Wj69U51OhxjTCssoZsOBYIh/Ck+Vt88lbkTMpwOJyo2LrrxIkvopkPJ2FXRxnMxXmRzipoOJeM8nckYszHRaG9OUdvSTZvqG4I8sOE9AsFQ0o2NYuO5GC+yrd20aWVRFY8UVbGyqMrpUIwxUYgqoYvIfBF5T0SqROT+Npb5ooiUi0iZiDwT2zBNvNU3BKmpq8cnMP6SC5wOp0vsxKjxmg4TuoikACuBK4As4EYRyWqxzGjgW8AMVR0PfLUHYjVxVFBczYbyWq6YkMHc8Rc7HU6XJPowv8bEWjRzik4DqlS1BkBEngMWAOXNlvkKsFJVjwGo6uFYB2ria0neyKbbpK1D2zC/xmOi2VMHAXub3d/X+FhzmUCmiPxFRDaLyPzW3khEFolIqYiU1tVZd7JEVd8QpKC4miV5I+mdlrzziCfqFHnG9JRYNb38wGhgFnAj8EsR6dtyIVVdo6q5qpqbnm6j9SWqguJqHimqoqC42ulQusV6uhiviWZL3w8MaXZ/cONjze0D1qnqp6q6B6gknOBNkjlWH6Cq9iR3zBjWVHYxxiSHaBL6FmC0iAwXkVTgBmBdi2X+QLh1jogMIFyCqYlhnCZOvvPHd3ml/DAHT36S1OUWY7yowz1WVYMishzYAKQAv1LVMhH5PlCqqusan5srIuXAaeAbqnq0JwM3sRcIhrh8bDqqIX64YILT4RhjOimqJpiqrgfWt3jsu81+VuDrjf9MkiqprOPe3++k4OYcG1HRmCRkZ4sMEO7ZsvWDYzx4/WRXTS8XGb6gviHodCjG9DhL6AYI92xZXVzNrkMfu6pXiFt67BgTDTvrZYAzLyRyE7eulzGtsYRuAOid5ufeeWOcDiPmUv0+Jg/p66qjDmPaYlu5cbXIRBcrNu62QbqM61lC9zAvjEY4MzOdJXkjWb2pymYvMq5nCd3DvDBNW6rfx91zRvPoLbmu6r1jTGushu5Rbu2m2JrImC7GuJ210D1q1abdrC6upvzgCdefMPRCackYsITuSYFgiODp8DDhWRcn52xEneGF0pIxYAndk0oq63j89RqW5Y9i7oQMp8PpcdNH9GdJ3kimj+jvdCjG9ChL6B40fUR/ls4aldyzEXXC5pqjrN4UvlrUyi7Gzdy/N5uzvL67jlWbqnh9tzdKEDMz01k6axQFxdVWdjGuZr1cvMhjc21Gui5OHtLX9T16jLdZQvegyFybXkpu1nXReIEldA+y5GaMO1kN3WOsT7Yx7mUJ3WOsT7Yx7mUJ3WNmZqZTcLO36ufGeIUldI+J1M+90P+8JSs3Gbfz3l5tPMvKTcbtLKF7iNdbqFZuMm5nCd1DCnfVsnjtVgp31TodiiO8XG4y3mBbtpcoaPg/Y4wLWUL3iEAwBAKrFk5l9ji7qMgYN7KE7hEllXUsf2Y7/hSflRyMcSnbsz3CTgga4342lotH2PgtxrhfVC10EZkvIu+JSJWI3N/OcteJiIpIbuxCNN3l9e6KxnhFhwldRFKAlcAVQBZwo4hktbLc+cC/AH+NdZCme7zeXdEYr4imhT4NqFLVGlUNAM8BC1pZ7gfAT4GGGMZnYsG6KxrjCdEk9EHA3mb39zU+1kREpgJDVPWlGMZmYmT2uIGsuSXXuisa43LdPikqIj7gP4Hbo1h2EbAI4NJLL+3uR5so2QlRY7whmhb6fmBIs/uDGx+LOB+YAGwSkfeB6cC61k6MquoaVc1V1dz0dOs+Fw92QtQY74gmoW8BRovIcBFJBW4A1kWeVNUTqjpAVYep6jBgM3CNqpb2SMSmU14tO8SitaW8WnbI6VCMMT2sw4SuqkFgObABqABeUNUyEfm+iFzT0wGa7ik7cJKQhm/N39mRi3GjqGroqroeWN/ise+2seys7odlYmVZ/ihSfMKSvJFOh5JQCitqufOZbaxaOJX5EzOcDseYmLArRV2ud5qfe+eNcTqMxCMg4f+McQ0by8XFrKzQts+OSmfprJF8dpSdnDfuYQndxWzKtba9vruOVZuqeH23/W6Me1hCdzEbYbEdVnIxLmQJ3cVsyrW2zR47kJULpwBYScq4hu3pxpNS/T78KT6WP7PdSlLGNSyhu5CdDI2OlaSM21hCd6HCiloWrS2lsMKGy22PlaSM29iW7EZ2ws8YT7KE7jL1DUF27D3OQzdkM3usjbBojJdYQneZlUVVrCquoezAx1ZKMMZjbI93mfGXXIBPwremY/UNQR7Y8B71DUGnQzGm2yyhu8yMkQO4YkIGM0YOcDqUpFBQXM0jRVXc/sRbltRN0rOE7jKP/2UPL+08yON/2eN0KElhSd5Ipl7al9IPjrGqaLfT4Zh2WHfcjllCd5H6hiCBYIjFM0fYcLlR6p3m59bpQxEgc+D5Todj2lG4q5ZFT5by4GvvWVJvgyV0FykormbNn2voleKjd5qNjByttF4piIRvTQJTUODR4j2s2LjbknorbK93kUir3FrnnTN73EDW3JJrV4wmsEAwBAIP3zCZ8oMnWVVURfXheq6alMHc8Rdbj65G9ltwkchkFtY67xy7YjTxlVTWsfyZ7aSl+vnq5WO4YmIGL5cdYvmz21m/46DT4SUM24KNMQktEAwRPB3ikYVTmJmZTqrfx0+vm8TIAecC8HBhpfVQamQJ3QWsL3X3WQ+KxBWZ/xWl6Siqd5qf3y2Zwaj086g+copVRbvt74cldFd4eGMljxRV8fDGSqdDSVo2u1MCa2Nson69U/nDss+yPH8UmQPP544nS7lxzZscqw84E2cCsITuBtLi1nTazMx0HvxiNts+/MiOdBJIIBgiGAyxOG94q/O/Rs4bRXoobf3wOHc8ucWzLXVL6C5w1+xMlueP4q7ZmU6HkrRS/T7KD51k1aYaVm2yC4wSRUllHXc//zaPFu9hc83RNpebPW4gd8wYBsD2D4979kjLukMkufqGIAXF1SzJG2m9W7op6+ILkMZbkxhmZqazauFUENrtVprq9/HNK8aRO+xCEJg+oj+vldc2nUT1CssASay+IcjtT7xF6QfHALh33hiHI0pucydksPomH0j4UN9LiSARBYIhSirrmD0uui6lqX4f8ydmAPBaeS2L15aydNYo7p4z2jN/S2+spUsVFFdT+sExcof2s4uJYiDVH07my57eTuEum+3JaYW7alm8dmuX/hYzM9NZOmsUBcXVniq/WAs9iTW/MtTKLTGioJFrzI2zuvG3SPX7uHvOaCYP6eupK4BF1ZktNzc3V0tLSx35bDeIHI56rUbY0+z3mjhi/bcIBEPh1r4SdRknEYnIVlXNbe255FwjYxNB9xAbBiBxxPpvUVJZx51PbWPJ09t48E/uHLExqt+UiMwXkfdEpEpE7m/l+a+LSLmI7BCRjSIyNPahmjPYRNDGxXri6ueZmeksmTUCH/BoyR5X1tY7LLyKSAqwErgc2AdsEZF1qlrebLHtQK6qnhKRpcDPgOt7ImATNnvsQB69JcdT9cF4stKLcwLBEPf9fgcv7QwPuhWr3lupfh9fvWwMkwb17bAbZLKKZkudBlSpao2qBoDngAXNF1DVIlU91Xh3MzA4tmGalqw00LOspOWckso6NpQd4qqJGTHvvRXp2jh/QkbTvuOmcXyiyQaDgL3N7u9rfKwtXwZebu0JEVkkIqUiUlpX577DHeMiVtJyzMzMdApuzuEX10+OS++tkso6Fq8tdcWkGTH9bYnIzUAukNfa86q6BlgD4V4usfxsY2Jp9tiBrLxpCqhdZBRvkaPPeGneZ33CoAvw+3xJW2qLJuL9wJBm9wc3PnYGEbkM+DZwjap+EpvwjHFGqt8HCnc+s83KLnHkRPkj0me94OYcUFjyVPhipmQsw0ST0LcAo0VkuIikAjcA65ovICJTgEcJJ/PDsQ/TuKnOlzSs7BJXgWCIFRt3OzKMceSoYPa4gU2JPRnLMB0mdFUNAsuBDUAF8IKqlonI90XkmsbFfg70Bn4rIm+LyLo23s50kY3XHX+zxw7koRsm887eEzakbhwUVtSyqqiKr3xuhGM9UJon9kgZprCillfePcgrOw8mfHKPqoauquuB9S0e+26zny+LcVymmZZTcJn4SPX7KDtwktXF1fx1z1Ge/Od/sCEWekggGGLH/uOICNlD+jhev24+dEAwFGLZ09sJqbL6pqlNA4AlouSr+ntM5DB02TPb8Pt8jm/oXjP+kvCQuts+PG7jpPegwl21FGyqYXHecGaPjd8J0fY0tdbHDmRx3vBw5S3By2+WHRJc4a5aVm+qZvFM5w5DvWzu+Iv58meHAeEv10Q/5E5aCghMGtQ34RotkQuS1tyae8aXTSKe10qs35w5W+OIc5MGJ96G7gWpfh+5Qy9EgF/95X0bVrcHBIIhEFi1cCqzxyVG67yl1i7kS8TzWpYhElTk2/+zo9NZc0tuwm7oXjB73ECW5I1EFXbsPZ5QLbJkV98Q5GvPv83yZ7bjT0mukmLkAqjIkXN9Q5CfvVzBf7+z37FtxM7wJKjCilrufGYbqxYm9kkYL0j1+/ja5Zk0BD5ldXENI9N7c13ukI5faNoVCIb45m/fZn1ZLVeOH5h0JcWWF0AVFFezqrgGAZbl1zsyU5Il9ERlfaATSqrfx/Z9J1Dgp6/sYt6EDOvx0k2Fu2p5uawWAa6clJFUrfPWLMkbSSgU4rRq+KrTSy7AnxLfq06T+zfoYpHRFBPljL+Bm6eFR4U+XB+wHi/ddKw+wJqSGpRwIpw7PvmPQnun+fnmFeO4Z+7Y8MVJQtxr7JbQE5SNpph4rpkyiEWfGwHA6dNqtfRu+NcXd7Ltw+NMvbQvX7s801XbefPujs1r7BE92TvGPb9FY3pYqt/HvfPGcGfeCB573Xq8dMe88QMR4NbpQ12VzJtrq1HWk71j3PmbTEI9MUOLib1Uv49Jg/tyWpV/++O77P/oVMcvMk0irdPLsi7ml7fmcuWkS5wOKe5a9o6JJUvoCaC+IcjtT7zFI0VVFBRXOx2O6cDscQO5+PxUDn0c4LrVf7HSSydEWqeba456tqTYk+VU7/02E0wkmZd+cIzcof1iPkOLib1wN8bwtGiHPg4k1IUliay+IcjWD47x4PWTk66LYrKwflcOKyiubkrmv/7SNOsKlySunTqYPuf2ouGTINs+/IjpI/rb364dh080sGDl6xw8+QnL80dxdbb3Si3xYFugwyIt8iV5Iy0hJJFUv4/5EzL42SsVrNpUw/t1f+PBG6d6soTQkcMnGsj/jyL+FgiRccE5dhTag2zrc0Dzbku90/zcO2+MJfMklXXxBQCsL6vliwVv8Iete62m3syx+gDzHizmb4EQ56X6+OOyz9q23oMsoTsgEQf1MV0zd0IGV44PX/z19r4TfPW3O7jiwWJ+X2qJPRAM8ZW1pRz7nyD9/pefonvyuahPmtNhuZp9VTqgJ7stmfhK9fv42T9NJqNvJR9+dIqq2hNUHznFPb/bAcCeo6c8WU47fKKBG3+5mT1H/0bOpX157NbP0K93qtNhuZ63tjKH1DcEKSiubtqx4z2ruelZvdP8/N//kwWEE9k1j/yZQx8HWPvXD3h77wneqD7iqdmO9n90itn/uYlPggrAVz43wpJ5nHhjC3PYyqIqVhdXczqk3HfFWKfDMT3ooj5pvHZPPgXF1Qzvfy7v7N3Btg+P8/DGSgKnla0fHCNnaD/umeu+8yaRhstLOw7wSVA5xy/8aMF4G/o5jty1RSWQQOUtyGwAAAirSURBVDAUvjRcYczA8/FJeDoz436RE92BYIhzeqVQduAkIQ3xxBvvA7Bj/wkKd9Uye8xF5Ay7kLnjL0763jH1DUEWrnmDHQc+5h+nXkKKT/jNl6Yx6MJznQ7NU0RVHfng3NxcLS0tdeSz4+G18loWr92KoqxaODXuw2iaxFLfEOQ/Xn2Pl3Yc4HB9oOlxAeZmDeQnX5jE1g+PJd02Emm4vLh1PxsqwmPbTBnSlxeXzXA4MvcSka2qmtvqc5bQYyMQDFFSWcf0Ef3ZXHOU6SP683pVHWj4UvFk2klNz6lvCPLga5XsOVLP4Y8/Yef+kwBcNTGDl3ceZO64i8jom0aqP4W75mQmdFnmWH2AO57cwva9x1ENTws6YsB5PPeV6dabpQe1l9ATd2tJEpFEHgyFWP7MdpbkjaSguJqCm3OYPyH5x3g2sdU7zc93rg6fQK1vCPJw4W5UYWnjxTYv7TzYtOxrFYd5/LZcfrS+AoCfXZedECcX6xuCrCyqYvOeo2z/8DgAiz43gqlD+zJ7rDVenGQt9C5o3hovKK5m9aYqVjaWVSIt9GQ7dDbOCyf4SqoO17NxV/gahYwLzuHgyU8AmDP2IkZddB6BT0OkpAij0s+lqPIo+ZnpfPDR/7Asf1S3W/SBYIj1O/bz8ru15I8ZQPmBk+zYf5LPT8lgQ1ktW2qOkdZLqP9U8Um4vHLLP1zKldmDbHuPEyu5xNhr5bUseWprU2t8Sd5IR+YPNO4UCIZY/85+Xq2o5ZvzxvKlJ7aw56NTTB7Uh7f3n2jzdfOyBjKgdyp/rjzMOakpHKn/hAG90/jk09Pc8bnh7K49xdt7P6Lveb3wEZ6Q+dtXjuPHL+8i/fxzqD3RQEbfNJ7c/GGHMfZNS+GH105yxQndZGMJvYvqG4KsKtpN1iUXMHf83+c8bFkvt9a46UmR7oC3Th/K43+pabWFXlx5hA3lhwh1cndufgTQ3Ij+57Fo5rBWW+gD+6TywuIZ1oPFIZbQO6G+IciqTbvJTO/Nq+W1rC+rxQesuTXXLgYyCSsQDPFq2SHerD7SrRb6ZVkDef/oqZiUb0zPsITeikgrO3L5feTnFRt380hRFRAe6Gbe+Iu5atLFZ7TQjTHGKd3u5SIi84GHgBTgMVX9SYvnzwGeBHKAo8D1qvp+d4LuikAwRGFFLQjMHhtuTUcSdarfd0YSjwyQVXBzDkDTz0vyRhLSEJnpvUk7x29n7Y0xSaPDFrqIpACVwOXAPmALcKOqljdb5k5gkqouEZEbgGtV9fr23rerLfRI0g6GFFTx+31NSfe18loWrS1FEB695cxEfVnWwKaTmZGBsVproVvyNsYksm6VXETkfwP/pqrzGu9/C0BVf9xsmQ2Ny7wpIn7gEJCu7bx5VxN6JGmHYwCfhJP3ZVkDO9VCt8RtjElG3S25DAL2Nru/D/iHtpZR1aCInAD6A0daBLIIWARw6aWXRhV8SzMz01m1cOoZLfRIKzvV72P+xDMv5ml+ItNGOTTGuFlcT2Or6hpgDYRb6F15j9aStjHGmOhmLNoPDGl2f3DjY60u01hy6UP45Kgxxpg4iSahbwFGi8hwEUkFbgDWtVhmHXBb48//CBS2Vz83xhgTex2WXBpr4suBDYS7Lf5KVctE5PtAqaquAx4H1opIFfAR4aRvjDEmjqKqoavqemB9i8e+2+znBuCfYhuaMcaYzrC+e8YY4xKW0I0xxiUsoRtjjEs4NjiXiNQBf6PFxUcuMgBbt2Tj1vUCW7dk1dq6DVXV9NYWdiyhA4hIaVuXsCY7W7fk49b1Alu3ZNXZdbOSizHGuIQldGOMcQmnE/oahz+/J9m6JR+3rhfYuiWrTq2bozV0Y4wxseN0C90YY0yMWEI3xhiXSIiELiJ3icguESkTkZ85HU+sicg9IqIiMsDpWGJBRH7e+PfaISIvikhfp2PqLhGZLyLviUiViNzvdDyxIiJDRKRIRMob969/cTqmWBKRFBHZLiL/7XQssSQifUXkd437WUXjzHEdcjyhi0g+sADIVtXxwAMOhxRTIjIEmAt86HQsMfQnYIKqTiI83+y3HI6nWxrnzV0JXAFkATeKSJazUcVMELhHVbOA6cAyF60bwL8AFU4H0QMeAl5R1bFANlGuo+MJHVgK/ERVPwFQ1cMOxxNrvwC+Cbjm7LOqvqqqwca7mwlPepLMpgFVqlqjqgHgOcKNjKSnqgdVdVvjzx8TTgyDnI0qNkRkMHAV8JjTscSSiPQBZhIelhxVDajq8WhemwgJPRP4nIj8VUSKReQzTgcUKyKyANivqu84HUsP+mfgZaeD6KbW5s11RdJrTkSGAVOAvzobScw8SLixFHI6kBgbDtQBTzSWkx4TkfOieWFc5hQVkdeAi1t56tuNMVxI+HDwM8ALIjIiWWY86mDd/pVwuSXptLdeqvrHxmW+TfiQ/ul4xmY6T0R6A78HvqqqJ52Op7tE5GrgsKpuFZFZTscTY35gKnCXqv5VRB4C7gf+bzQv7HGqellbz4nIUuC/GhP4WyISIjwgTV08YuuuttZNRCYS/qZ9R0QgXJbYJiLTVPVQHEPskvb+ZgAicjtwNTAnWb582xHNvLlJS0R6EU7mT6vqfzkdT4zMAK4RkSuBNOACEXlKVW92OK5Y2AfsU9XIkdTvCCf0DiVCyeUPQD6AiGQCqbhg5DRV3amqF6nqMFUdRviPNDUZknlHRGQ+4UPda1T1lNPxxEA08+YmJQm3Jh4HKlT1P52OJ1ZU9VuqOrhx37qB8DzGbkjmNOaIvSIypvGhOUB5NK+NSwu9A78CfiUi7wIB4DYXtPjc7hHgHOBPjUcfm1V1ibMhdV1b8+Y6HFaszABuAXaKyNuNj/1r47SSJnHdBTzd2MCoAb4UzYvs0n9jjHGJRCi5GGOMiQFL6MYY4xKW0I0xxiUsoRtjjEtYQjfGGJewhG6MMS5hCd0YY1zi/wPCRpV2fve6CgAAAABJRU5ErkJggg==", - "text/plain": [ - "<Figure size 432x288 with 1 Axes>" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], "source": [ - "outputs = mve_wrapped_classifier(x_val)\n", - "plt.scatter(x_val, outputs.aleatoric, label='aleatoric uncertainty', s=0.5)\n", + "# Capsa makes the aleatoric uncertainty an attribute of the prediction!\n", + "pred = np.array(prediction.y_hat).flatten()\n", + "unc = np.sqrt(prediction.aleatoric).flatten() # out.aleatoric is the predicted variance\n", + "\n", + "# Visualize the aleatoric uncertainty across the data space\n", + "plt.figure(figsize=(10, 6))\n", + "plt.scatter(x_train, y_train, s=1.5, label='train data')\n", + "plt.plot(x_test, y_test, c='r', zorder=-1, label='ground truth')\n", + "plt.fill_between(x_test_clipped.flatten(), pred-2*unc, pred+2*unc, \n", + " color='b', alpha=0.2, label='aleatoric')\n", "plt.legend()" - ] + ], + "metadata": { + "id": "dT2Rx8JCg3NR" + }, + "execution_count": null, + "outputs": [] }, { "cell_type": "markdown", @@ -623,7 +451,12 @@ "id": "ZFeArgRX9U9s" }, "source": [ - "We can see that in the areas of high label noise-- where small changes in the input lead to large changes in the output-- aleatoric uncertainty spikes!" + "#### **TODO: Estimating aleatoric uncertainty**\n", + "\n", + "Write short (~1 sentence) answers to the questions below to complete the `TODO`s:\n", + "\n", + "1. For what values of $x$ is the aleatoric uncertainty high or increasing suddenly?\n", + "2. How does your answer in (1) relate to how the $x$ values are distributed?" ] }, { @@ -632,172 +465,100 @@ "id": "6FC5WPRT5lAb" }, "source": [ - "## 1.4 Epistemic Estimation\n", - "Finally, let's do the same thing but for epistemic estimation! In this example, we'll use ensembles, which essentially copy the model `N` times and average predictions across all runs for a more robust prediction, and also calculate the variance of the `N` runs. Feel free to play around with any of the epistemic methods shown in the github repository! Which methods perform the best? Why do you think this is?" + "# 1.5 Estimating model uncertainty\n", + "\n", + "Finally, we use Capsa for estimating the uncertainty underlying the model predictions -- the epistemic uncertainty. In this example, we'll use ensembles, which essentially copy the model `N` times and average predictions across all runs for a more robust prediction, and also calculate the variance of the `N` runs to estimate the uncertainty.\n", + "\n", + "Capsa provides a neat wrapper, `capsa.EnsembleWrapper`, to make an ensemble from an input model. Just like with aleatoric estimation, we can take our standard dense network model, wrap it with `capsa.EnsembleWrapper`, build the wrapped model, and then train it for the regression task. More details of the `EnsembleWrapper` and how it can be used are [available here](https://themisai.io/capsa/api_documentation/EnsembleWrapper.html).\n", + "\n", + "Finally, we evaluate the resulting model by quantifying the epistemic uncertainty on the test data:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "SuRlhq2c5Fob", - "outputId": "b1f81f5a-69da-4e40-af2a-2a908a1da639" + "id": "SuRlhq2c5Fob" }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 1/30\n", - "63/63 [==============================] - 2s 3ms/step - usermodel_0_compiled_loss: 6.5601 - usermodel_1_compiled_loss: 4.9589 - usermodel_2_compiled_loss: 4.8135 - usermodel_3_compiled_loss: 4.3547 - usermodel_4_compiled_loss: 6.3809\n", - "Epoch 2/30\n", - "63/63 [==============================] - 0s 3ms/step - usermodel_0_compiled_loss: 3.7732 - usermodel_1_compiled_loss: 4.3968 - usermodel_2_compiled_loss: 2.8938 - usermodel_3_compiled_loss: 3.2580 - usermodel_4_compiled_loss: 4.0894\n", - "Epoch 3/30\n", - "63/63 [==============================] - 0s 4ms/step - usermodel_0_compiled_loss: 2.8896 - usermodel_1_compiled_loss: 4.1322 - usermodel_2_compiled_loss: 2.3051 - usermodel_3_compiled_loss: 2.6397 - usermodel_4_compiled_loss: 3.0872\n", - "Epoch 4/30\n", - "63/63 [==============================] - 0s 3ms/step - usermodel_0_compiled_loss: 2.4637 - usermodel_1_compiled_loss: 4.0084 - usermodel_2_compiled_loss: 2.0035 - usermodel_3_compiled_loss: 2.2897 - usermodel_4_compiled_loss: 2.5871\n", - "Epoch 5/30\n", - "63/63 [==============================] - 0s 3ms/step - usermodel_0_compiled_loss: 2.1638 - usermodel_1_compiled_loss: 3.8279 - usermodel_2_compiled_loss: 1.7763 - usermodel_3_compiled_loss: 2.0321 - usermodel_4_compiled_loss: 2.2477\n", - "Epoch 6/30\n", - "63/63 [==============================] - 0s 3ms/step - usermodel_0_compiled_loss: 1.9428 - usermodel_1_compiled_loss: 3.6973 - usermodel_2_compiled_loss: 1.6039 - usermodel_3_compiled_loss: 1.8292 - usermodel_4_compiled_loss: 2.0054\n", - "Epoch 7/30\n", - "63/63 [==============================] - 0s 3ms/step - usermodel_0_compiled_loss: 1.7664 - usermodel_1_compiled_loss: 3.5628 - usermodel_2_compiled_loss: 1.4630 - usermodel_3_compiled_loss: 1.6634 - usermodel_4_compiled_loss: 1.8182\n", - "Epoch 8/30\n", - "63/63 [==============================] - 0s 3ms/step - usermodel_0_compiled_loss: 1.6222 - usermodel_1_compiled_loss: 3.4514 - usermodel_2_compiled_loss: 1.3476 - usermodel_3_compiled_loss: 1.5290 - usermodel_4_compiled_loss: 1.6662\n", - "Epoch 9/30\n", - "63/63 [==============================] - 0s 3ms/step - usermodel_0_compiled_loss: 1.4997 - usermodel_1_compiled_loss: 3.3267 - usermodel_2_compiled_loss: 1.2495 - usermodel_3_compiled_loss: 1.4147 - usermodel_4_compiled_loss: 1.5371\n", - "Epoch 10/30\n", - "63/63 [==============================] - 0s 3ms/step - usermodel_0_compiled_loss: 1.3961 - usermodel_1_compiled_loss: 3.2211 - usermodel_2_compiled_loss: 1.1664 - usermodel_3_compiled_loss: 1.3165 - usermodel_4_compiled_loss: 1.4279\n", - "Epoch 11/30\n", - "63/63 [==============================] - 0s 3ms/step - usermodel_0_compiled_loss: 1.3066 - usermodel_1_compiled_loss: 3.1169 - usermodel_2_compiled_loss: 1.0947 - usermodel_3_compiled_loss: 1.2315 - usermodel_4_compiled_loss: 1.3337\n", - "Epoch 12/30\n", - "63/63 [==============================] - 0s 3ms/step - usermodel_0_compiled_loss: 1.2304 - usermodel_1_compiled_loss: 3.0270 - usermodel_2_compiled_loss: 1.0334 - usermodel_3_compiled_loss: 1.1599 - usermodel_4_compiled_loss: 1.2545\n", - "Epoch 13/30\n", - "63/63 [==============================] - 0s 3ms/step - usermodel_0_compiled_loss: 1.1635 - usermodel_1_compiled_loss: 2.9340 - usermodel_2_compiled_loss: 0.9794 - usermodel_3_compiled_loss: 1.0972 - usermodel_4_compiled_loss: 1.1850\n", - "Epoch 14/30\n", - "63/63 [==============================] - 0s 3ms/step - usermodel_0_compiled_loss: 1.1056 - usermodel_1_compiled_loss: 2.8478 - usermodel_2_compiled_loss: 0.9327 - usermodel_3_compiled_loss: 1.0430 - usermodel_4_compiled_loss: 1.1251\n", - "Epoch 15/30\n", - "63/63 [==============================] - 0s 3ms/step - usermodel_0_compiled_loss: 1.0553 - usermodel_1_compiled_loss: 2.7690 - usermodel_2_compiled_loss: 0.8924 - usermodel_3_compiled_loss: 0.9959 - usermodel_4_compiled_loss: 1.0732\n", - "Epoch 16/30\n", - "63/63 [==============================] - 0s 3ms/step - usermodel_0_compiled_loss: 1.0106 - usermodel_1_compiled_loss: 2.6925 - usermodel_2_compiled_loss: 0.8563 - usermodel_3_compiled_loss: 0.9544 - usermodel_4_compiled_loss: 1.0276\n", - "Epoch 17/30\n", - "63/63 [==============================] - 0s 3ms/step - usermodel_0_compiled_loss: 0.9704 - usermodel_1_compiled_loss: 2.6262 - usermodel_2_compiled_loss: 0.8240 - usermodel_3_compiled_loss: 0.9167 - usermodel_4_compiled_loss: 0.9862\n", - "Epoch 18/30\n", - "63/63 [==============================] - 0s 3ms/step - usermodel_0_compiled_loss: 0.9343 - usermodel_1_compiled_loss: 2.5600 - usermodel_2_compiled_loss: 0.7951 - usermodel_3_compiled_loss: 0.8834 - usermodel_4_compiled_loss: 0.9494\n", - "Epoch 19/30\n", - "63/63 [==============================] - 0s 3ms/step - usermodel_0_compiled_loss: 0.9023 - usermodel_1_compiled_loss: 2.4966 - usermodel_2_compiled_loss: 0.7695 - usermodel_3_compiled_loss: 0.8541 - usermodel_4_compiled_loss: 0.9171\n", - "Epoch 20/30\n", - "63/63 [==============================] - 0s 3ms/step - usermodel_0_compiled_loss: 0.8727 - usermodel_1_compiled_loss: 2.4440 - usermodel_2_compiled_loss: 0.7459 - usermodel_3_compiled_loss: 0.8270 - usermodel_4_compiled_loss: 0.8870\n", - "Epoch 21/30\n", - "63/63 [==============================] - 0s 3ms/step - usermodel_0_compiled_loss: 0.8455 - usermodel_1_compiled_loss: 2.3854 - usermodel_2_compiled_loss: 0.7243 - usermodel_3_compiled_loss: 0.8023 - usermodel_4_compiled_loss: 0.8594\n", - "Epoch 22/30\n", - "63/63 [==============================] - 0s 3ms/step - usermodel_0_compiled_loss: 0.8213 - usermodel_1_compiled_loss: 2.3344 - usermodel_2_compiled_loss: 0.7052 - usermodel_3_compiled_loss: 0.7805 - usermodel_4_compiled_loss: 0.8349\n", - "Epoch 23/30\n", - "63/63 [==============================] - 0s 3ms/step - usermodel_0_compiled_loss: 0.7981 - usermodel_1_compiled_loss: 2.2851 - usermodel_2_compiled_loss: 0.6867 - usermodel_3_compiled_loss: 0.7593 - usermodel_4_compiled_loss: 0.8112\n", - "Epoch 24/30\n", - "63/63 [==============================] - 0s 3ms/step - usermodel_0_compiled_loss: 0.7778 - usermodel_1_compiled_loss: 2.2358 - usermodel_2_compiled_loss: 0.6707 - usermodel_3_compiled_loss: 0.7409 - usermodel_4_compiled_loss: 0.7906\n", - "Epoch 25/30\n", - "63/63 [==============================] - 0s 3ms/step - usermodel_0_compiled_loss: 0.7586 - usermodel_1_compiled_loss: 2.1888 - usermodel_2_compiled_loss: 0.6555 - usermodel_3_compiled_loss: 0.7235 - usermodel_4_compiled_loss: 0.7711\n", - "Epoch 26/30\n", - "63/63 [==============================] - 0s 3ms/step - usermodel_0_compiled_loss: 0.7409 - usermodel_1_compiled_loss: 2.1461 - usermodel_2_compiled_loss: 0.6415 - usermodel_3_compiled_loss: 0.7074 - usermodel_4_compiled_loss: 0.7532\n", - "Epoch 27/30\n", - "63/63 [==============================] - 0s 3ms/step - usermodel_0_compiled_loss: 0.7247 - usermodel_1_compiled_loss: 2.1043 - usermodel_2_compiled_loss: 0.6288 - usermodel_3_compiled_loss: 0.6927 - usermodel_4_compiled_loss: 0.7367\n", - "Epoch 28/30\n", - "63/63 [==============================] - 0s 4ms/step - usermodel_0_compiled_loss: 0.7098 - usermodel_1_compiled_loss: 2.0641 - usermodel_2_compiled_loss: 0.6171 - usermodel_3_compiled_loss: 0.6792 - usermodel_4_compiled_loss: 0.7216\n", - "Epoch 29/30\n", - "63/63 [==============================] - 0s 3ms/step - usermodel_0_compiled_loss: 0.6950 - usermodel_1_compiled_loss: 2.0264 - usermodel_2_compiled_loss: 0.6053 - usermodel_3_compiled_loss: 0.6658 - usermodel_4_compiled_loss: 0.7067\n", - "Epoch 30/30\n", - "63/63 [==============================] - 0s 3ms/step - usermodel_0_compiled_loss: 0.6819 - usermodel_1_compiled_loss: 1.9897 - usermodel_2_compiled_loss: 0.5950 - usermodel_3_compiled_loss: 0.6539 - usermodel_4_compiled_loss: 0.6934\n" - ] - } - ], + "outputs": [], "source": [ - "standard_classifier = create_standard_classifier()\n", - "ensemble_wrapper = EnsembleWrapper(standard_classifier, num_members=5)\n", + "### Estimating model uncertainty with Capsa wrapping ###\n", + "\n", + "standard_dense_NN = create_dense_NN()\n", + "# Wrap the dense network for epistemic uncertainty estimation with an Ensemble\n", + "ensemble_NN = capsa.EnsembleWrapper(standard_dense_NN)\n", "\n", - "ensemble_wrapper.compile(\n", + "# Build the model for regression, defining the loss function and optimizer\n", + "ensemble_NN.compile(\n", " optimizer=tf.keras.optimizers.Adam(learning_rate=3e-3),\n", - " loss=tf.keras.losses.MeanSquaredError(),\n", + " loss=tf.keras.losses.MeanSquaredError(), # MSE loss for the regression task\n", ")\n", "\n", - "history = ensemble_wrapper.fit(x, y, epochs=30)" + "# Train the wrapped model for 30 epochs.\n", + "loss_history_ensemble = ensemble_NN.fit(x_train, y_train, epochs=30)\n", + "\n", + "# Call the uncertainty-aware model to generate outputs for the test data\n", + "prediction = ensemble_NN(x_test)" ] }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 283 - }, - "id": "HfnPqf8T6TVw", - "outputId": "4a9fa19d-ae27-477c-bc47-f84d62a3444f" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "<matplotlib.legend.Legend at 0x7fe127f9f7f0>" - ] - }, - "execution_count": 130, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD4CAYAAADFAawfAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3de3yU5Z338c8VAkbOp4icAygIgRBIpLQqh6BAEbu2lccqUrGsHDy0trW7Wp9H23112+6q7VbRIlutFbTblSr1VbWiBMiugppwkkMMCQQJhxCQU4A4TOb3/JFMGiDnzGTuzHzfrxfkMPfc+d2TzHeuue7rvi5nZoiIiHfFRboAERGpm4JaRMTjFNQiIh6noBYR8TgFtYiIx8WHY6c9e/a0pKSkcOxaRCQq5eTkHDGzxJpuC0tQJyUlkZ2dHY5di4hEJefc3tpuU9eHiIjHKahFRDxOQS0i4nFh6aOuyblz5ygqKqKsrKylfqREkYSEBPr160fbtm0jXYpIi2uxoC4qKqJTp04kJSXhnGupHytRwMw4evQoRUVFDBo0KNLliLS4Fuv6KCsro0ePHgppaTTnHD169NC7MYlZLdpHrZCWptLfjsQynUwUEQkBnz/AezuK8fkDId+3groOjz76KO+9916tt69cuZIdO3aE7OfNmDGD48ePh2x/ofbzn/+8Qds15DhefPFFDhw4EIqyRDwhK6+EhctzyMorCfm+XUMWDnDOFQKngHLAb2bpdW2fnp5uF16ZuHPnToYPH970Sj1o7ty5zJw5k1tuuSXSpYSVmWFmdO7cmdLS0pDsc9KkSTzxxBOkp9f5p3SeaPwbkujh8wfIyithwtBE2sU3vg3snMupLVsbs7fJZpZaX0h72fLlyxk3bhypqaksWLCA8vJyADp27Mj3v/99kpOTmTJlCiUlFa+Ic+fOZcWKFQA89NBDjBgxgpSUFB588EE++OAD3njjDX70ox+RmppKQUEBBQUFTJ8+nbS0NK677jpyc3Or9rNo0SLGjx/P4MGDWbt2Ld/5zncYPnw4c+fOraovKSmJI0eOAPDSSy+RkpLC6NGjmTNnzkXH8pOf/IQnnnii6uuRI0dSWFhIYWEhw4cP5+677yY5OZmpU6dy9uxZAPLz87n++usZPXo0Y8eOpaCgAIDHH3+cq6++mpSUFB577DEACgsLGTZsGN/+9rcZOXIk8+bN4+zZs6SmpjJ79mwAbr75ZtLS0khOTmbp0qUXHUdttaxYsYLs7Gxmz55Namoqb775JjfffHPV/d99912+/vWvN+M3LdLy2sXHcf2IXk0K6XoFW0t1/QMKgZ4N2dbMSEtLswvt2LHjou+1pB07dtjMmTPN5/OZmdmiRYvsD3/4g5mZAbZ8+XIzM/vpT39q9957r5mZ3Xnnnfbqq6/akSNHbOjQoRYIBMzM7NixY+fdHpSRkWF5eXlmZrZhwwabPHly1Xa33nqrBQIBW7lypXXq1Mm2bt1q5eXlNnbsWNu0aZOZmQ0cONBKSkps27ZtduWVV1pJSYmZmR09evSi43nsscfs8ccfr/o6OTnZ9uzZY3v27LE2bdpU7XPWrFm2bNkyMzMbN26cvfbaa2ZmdvbsWTt9+rS98847dvfdd1sgELDy8nK78cYbbd26dbZnzx5zztn69eurfkaHDh3OqyFY15kzZyw5OdmOHDly3nHUVcvEiRPt448/NjOzQCBgw4YNs8OHD5uZ2W233WZvvPFGjb9DkWgFZFstmdrQcdQGrHLOGfCcmS29cAPn3HxgPsCAAQNC8RrS7LcS1a1evZqcnByuvvpqAM6ePctll10GQFxcHLfeeisAd9xxB9/4xjfOu2+XLl1ISEhg3rx5zJw5k5kzZ160/9LSUj744ANmzZpV9b0vvvii6vObbroJ5xyjRo2iV69ejBo1CoDk5GQKCwtJTU2t2jYzM5NZs2bRs2dPALp3796oYx00aFDV/tLS0igsLOTUqVPs37+/qqWakJAAwKpVq1i1ahVjxoypOo5du3YxYMAABg4cyPjx42v9OU899RSvv/46APv27WPXrl306NGj3lou5Jxjzpw5LF++nLvuuov169fz0ksvNeqYRaJZQ4P6WjPb75y7DHjXOZdrZlnVN6gM76VQ0UcdiuKCnfNL7kjj+hG9mrUvM+POO+/kF7/4Rb3bXjgULD4+no8++ojVq1ezYsUKFi9eTGZm5nnbBAIBunbtyubNm2vc5yWXXAJUvCgEPw9+7ff7G3s4xMfHEwj8/exy9THG1fffpk2bqq6PmpgZDz/8MAsWLDjv+4WFhXTo0KHW+61du5b33nuP9evX0759eyZNmlTjOOeG1nLXXXdx0003kZCQwKxZs4iPb7FrsUQ8r0HNVDPbX/nxMPA6MC6cRQVNGJrIkjvSmDC0xilaG2XKlCmsWLGCw4cPA/D555+zd2/FrIKBQKCqL/qVV17h2muvPe++paWlnDhxghkzZvDrX/+aLVu2ANCpUydOnToFQOfOnRk0aBCvvvoqUBGAwe0aKyMjg1dffZWjR49W1XqhpKQkNm7cCMDGjRvZs2dPnfvs1KkT/fr1Y+XKlUBFa//MmTNMmzaNF154oeok4f79+6seowu1bduWc+fOAXDixAm6detG+/btyc3NZcOGDY06xuqPHUCfPn3o06cPP/vZz7jrrrsatS+RaFdvUDvnOjjnOgU/B6YC28JdGIS2c37EiBH87Gc/Y+rUqaSkpHDDDTdw8OBBADp06MBHH33EyJEjyczM5NFHHz3vvqdOnWLmzJmkpKRw7bXX8qtf/QqAb33rWzz++OOMGTOGgoICXn75ZZ5//nlGjx5NcnIyf/nLX5pUa3JyMo888ggTJ05k9OjR/OAHP7hom29+85t8/vnnJCcns3jxYoYOHVrvfpctW8ZTTz1FSkoKX/nKVzh06BBTp07l9ttv58tf/jKjRo3illtuOS9Aq5s/fz4pKSnMnj2b6dOn4/f7GT58OA899FCdXSQ1mTt3LgsXLiQ1NbWqlT179mz69++vkR0iF6h3eJ5zbjAVrWio6Cp5xcz+ta77tLbheR07dgzZsDNpuvvuu48xY8Ywb968Gm/38t+QSHPVNTyv3o5AM9sNjA55VSLVpKWl0aFDB5588slIlyLiOTpjA2pNe0BOTk6kSxDxrBa9hLy+bhaR2uhvR2JZiwV1QkICR48e1RNOGs0q56MOjv0WibRwTsBUkxbr+ujXrx9FRUVVl2eLNEZwhRcRLwjlNR4N0WJB3bZtW63OISJRIZTXeDSETiaKiDRS8BqPlqL5qEVEPE5BLSLicQpqEZEGaOmRHtUpqEVEGiCcS23VR0EtItIALT3SozoFtYhIPUK5iElTKKhFROrg8wf49bt5zF+WTebO4ojUoKAWEalDVl4Jz62rWAgaV/e24aILXkRE6jBhaCLPzh4LDjKuarmLXKpTUIuI1KFdfBzTR/WOaA3q+hAR8TgFtYiIxymoRUQ8TkEtIuJxCmoREY9TUIuIeJyCWkSkmkjOklcbBbWISDWRnCWvNgpqEZFqIjlLXm10ZaKISDUtvR5iQ6hFLSLicQpqERGPU1CLiHicglpExOMaHNTOuTbOuU3Oub+GsyARETlfY1rU3wN2hqsQERGpWYOC2jnXD7gR+F14yxERkQs1tEX9H8A/AbVeU+mcm++cy3bOZZeUeOeKHhGRupSW+XninU8pLfNHupRa1RvUzrmZwGEzy6lrOzNbambpZpaemOidK3pEROqyZF0Bi9fksyS4gK0HNaRFfQ3wNedcIfBfQIZzbnlYqxIRaQE+f4CrLu/EoolDWDhxSKTLqVW9QW1mD5tZPzNLAr4FZJrZHWGvTEQkzLLySnjgT5tJG9iNjgnenVFD46hFJGZ5cQKmmjTqJcTM1gJrw1KJiEgL8+IETDVRi1pExOMU1CISU7y4gkt9FNQiElO8uIJLfRTUIhJTxg/uwcKJQxg/uEekS2kwBbWIxJQNu4+yZF0BG3YfjXQpDaagFpGY0lqG5FXn3RHeIiJh0FqG5FWnFrWIxITWONojSEEtIjGhNY72CFJQi0hMaI1900HqoxaRmNAa+6aD1KIWEfE4BbWIiMcpqEUkarXmkR7VKahFJGq15pEe1SmoRSRqteaRHtUpqEUkKvn8AbLySpgwNJF28a076lp39SIitYiWbg9QUItIlIqWbg/QBS8iEqVa8wUuF1KLWkTE4xTUIiIep6AWEfE4BbWIiMcpqEVEPE5BLSJRIVrm9aiJglpEosKq7YeYvyybVdsPRbqUkFNQi0hU2H7gJAGr+BhtdMGLiESFeydfQZs4x8KJQyJdSsipRS0irVqwb7pdfBwPThtGx4Toa3/WG9TOuQTn3EfOuS3Oue3OuZ+2RGEiIg0RTZMv1aYhLz1fABlmVuqcawv8r3PubTPbEObaRETqFU2TL9Wm3qA2MwNKK79sW/nPwlmUiEhDRdPkS7VpUB+1c66Nc24zcBh418w+rGGb+c65bOdcdklJ9L4FERFpaQ0KajMrN7NUoB8wzjk3soZtlppZupmlJyZG71sQEZGW1qhRH2Z2HFgDTA9POSIicqGGjPpIdM51rfz8UuAGIDfchYmISIWGtKh7A2ucc1uBj6noo/5reMsSEaldNM/rUZN6g9rMtprZGDNLMbORZvYvLVGYiEhtYmHsdHW6MlFEWhWfP4C/PMDi28dE9djp6hTUItKqZOWVcN8fNxEfF0e7+NiIsNg4ShGJGrFwJeKFom/2EhGJarFwJeKF1KIWEfE4BbWItAqxNiSvOgW1iHiezx/gqdW7YmpIXnUKahHxvKy8En67Np+FE4fE1EnEIJ1MFBFPO1bq47WcfTw5K4UZKX1jZkhedbF3xCLSqjyycitvbS9mVeVyW7EoNo9aRFoFnz9A7y7tAZg6PLaG5FWnoBYRz8rKK+EP6wu5b/IVzBjdN9LlRIz6qEXEs6pfhRir3R6gFrWIeJTPHyArryTmQxoU1CLiQbE+bvpCCmoR8ZxYHzd9IQW1iHhKaZmfjwuP8o/XDmLhxCEx3+0BOpkoIh7zzJp8nsvagwOuHtQj5mbKq4leqkTEM3z+AOf85Thg3rVJ6vaopKAWEc/Iyivh+fcLcUD6wO7q9qikrg8R8YwJQxP57eyx4CDjKnV5BCmoRcQz2sXHMX1U70iX4Tl6XyEinhDLCwPUR0EtIp6QubOY+cuyydxZHOlSPEdBLSLe4MBV/CcXUB+1iHhCxlW9eG5Omobk1UAtahGJqGDfNMD1I3ppSF4N9IiISERl5hazYFkOmbnqm66NglpEIsrvDxAww6/RHrWqN6idc/2dc2ucczucc9udc99ricJEJDbEx8cR5xzx6vKoVUNOJvqBH5rZRudcJyDHOfeume0Ic20iEuV8/gAYPDN7jK5ErEO9L2FmdtDMNlZ+fgrYCcTu4mUiEjKZO4u555WNYOgkYh0a9cg455KAMcCH4ShGRGKHzx9g6/7jGjvdAA0OaudcR+DPwANmdrKG2+c757Kdc9klJVo6R0TqlplbzJK1u1kwcZC6PerRoKB2zrWlIqRfNrPXatrGzJaaWbqZpScmasC6iNTDAAcpfbuq26Me9Z5MdM454Hlgp5n9KvwliUgsyBjei6Vz0nUlYgM05GXsGmAOkOGc21z5b0aY6xKRKFZa5uep1bsYP7iHWtMNUG+L2sz+F3X1i0gILVlXwOI1+QA8OG1YhKvxPk3KJCItbuHEIed9lLopqEWkxXVMiFdLuhHUOSQi4nEKahFpMaVlfp5451NKy/yRLqVVUVCLSIt5Zk0+i9fk80zliURpGAW1iLSY5D6diXMVH6XhdDJRRFrM1OTLdZFLEyioRaTFtIuP4/oRmtejsdT1ISLicQpqEQm74AK2Pi231SQKahEJOy1g2zwKahEJK58/wNZ9xwmYVUxtKo2moBaRsMrMLea5rD0smjiYjOE6kdgUCmoRCS8Dw0jppwUCmkrD80QkrLRAQPPp5U1EwsbnD5CVV8KEoYlqTTeDpx45DeERiS5ZeSUsXJ5DVp4WvG4OTwV15s5i5i/LJnOnhvCItHY+fwB/IMDi28ao26OZPBXUOHA4ys6VaypEkVYuM7eYe1/eBA51ezSTpx69jKt68dycNPKKT7F4TT5L1hVEuiQRaQKfP8CWfScwjZ0OCU+N+ghO2DJ+cA/i4uK0nppIK5WZW8zSrN0s1NjpkPBUUAdpPTWR1q3M5ydgxtBeHdXtEQKt4hHUaBCR1iWvuBSr/CjN1yqCOjjEJ3NnsQJbxON8/gAj+nTmnomDuWfylZEuJyq0iqCeMDSRJXekgUNjMkU8btX2g3z3j5sZ0aczHRM82bva6rSKoA6eZMy4qhdL7khj/OAealmLeJDPH+CtrQcJADsOnIx0OVGjVQR1UDCwN+w+qpa1iAdl5ZXwzo5ibhzVW90eIeTMQj/IMT093bKzs0O+3yDNHyDiTXpuNp1zLsfM0mu6rVV2IGmBTBFv0nMzPKLmJU9D+EQkWtUb1M65F5xzh51z21qioKbSLF0ikaXGUvg0pEX9IjA9zHU0W3AIn0aEiESGGkvhU29Qm1kW8HkL1NIsGhEiElnBxpKmNA29kPVRO+fmO+eynXPZJSWRC0n9sYhERrCxpNEeoReyR9TMlppZupmlJyZGLiQv/GNRv5lIeJWW+TV/fJhF/Uuf+s1Ewuvp1XksXpPP06vzIl1K1GqV46gbo/pJxr99chBcxQIFensm0nw+f4DdR04DUB6Gi+ekQkOG5/0RWA8Mc84VOefmhb+s0Kl+kvGeVzZy78ubeGr1LnWFiIRAVl4Jq3cexgGp/btGupyoVW+L2sxua4lCwm3C0ESevX0smz47xjNr8ikPGN+/Yaha1iLNMH5wDxZOGsyIyzszNbl3pMuJWjGTUu3i45g+qjdjBnYDYMm6Aq12LtIMpWV+/vnPW1matYeEdvFq9IRRzD2yGVf1YtGkwcQ5x9b9x9UFItJET2fu4s1PDjJl+GUaDhtmUX8y8ULt4uN44PphxLk4frs2n7EDumsSGZEm+OJcRSOnd+dL1ZoOs5gLaqgI6+9OuZLU/l3VEhBpokvaxp33UcInJoMaNB2jSHP4/AFG9e3CoolDuHfyFZEuJ+rppVBEGi0rr4QH/rSZtIHdtC5iC1BQi0ijaU6dlqWgvoDmBhGpW2mZn6dW72L84B46idhC9ChfIDg3iK5eFLmYzx/gR69uYfGafJ5Zkx/pcmKGgvoCE4YmsnDiEH67Nl8TOYlcIDO3mL9tP4QDkvt0jnQ5MUNnAS6goXsitfOXV7zLnD9hEFOTL49wNbFDLeoaVB+6p/5qkQo+f4CtRScBGNWni/qnW5Ae6TpoLmuRv1u17SBL/2c3BsQrpFuUHu06aMFckb/bcaiiNT1jZC8yrtLFYi1JfdR1CHaBvLejmIXLc1g4cQjfnXKl3vJJTLpn0pXEuTgWThyi50AL06PdABoJIrHO5w+wYfdRvjvlSl2JGAEK6gYIjgR5bk66RoJITHpry37+8aVs3tqyP9KlxCQFdQNpJIjEqtIyP0+vKQBglRbbiAgFdSPpykWJNc+u2UXBkdMM6dmBf705JdLlxCR1NjVS9f7qkX06E98mjglDE3VyRaJSaZmf3YdLccD3b7iSbh3bRbqkmKR0aaTq/dVl5wLMX5bNqu2HIl2WSFgsWVfA33YeZsao3lq8NoIU1E0Q7K/+tPgUAYPtB05GuiSRkCst83OuvJwF1w3i376ZoneNEaSuj2a4d/IVtIlzLJw4JNKliISUzx/gn/+8lTc/Och9k6/QkLwI00tkM3RMiOfBacPO+yPWfNYSDbLySnj7k4PcOKq3GiIeoKAOseCokMzc4hYNbJ8/wF+3HODf3s6ltMzfIj9TolNpmZ+PC49yd2WXh1rTkaffQIgF5wfxlwdYsCyb+ROGMLp/FzKu6hWyPj6fP8Cq7Qf5pOgEQ3t1YtuBk2zed5zN+45jwKrtB7k/YygzUnqrX1EapXqXRxxw9aAeWgTaAxTUIRY80ejzB1g06QqeXZOPAfMnDOaHU4c1OThLy/w8+U4um/YdZ8yArrz4wV4McIBVbjOmf1dOnvVRcOQMD/z3ZvJLSjU3SYQcK/XxyOtbmZrcixkpfVvN7yAzt5i/bTvEtORe3JTSW1fieoSCOkyCw/gKDpfy9vZDLM3azdVJ3RvdOikt8/Psml3sOlzKuzsPA7C56AQLJgwCo6pFfeD4WX75jYqLEX78+if07XYpS9YVkNq/q1pELay0zM+s5z4gv+Q0b28v5tPi0yF/VxVqPn+At7Ye4MX1hQTM+HpqX6aP0nA8r1BQh1G7+DgenzWapJ4dSO7TuVGtk4rujUO88P4eNn52vOr7ST3ac9+kwXxtTP+qJ/03L7jvb+ek4fMHGD+4R1haRMHath84yb0aEVDF5w/w1pb9PL2mgIIjpxncswMZwxJ5bl3F5dfTki/nxpTeTE2+vNbA9vkDZO4spuxcgG37j4ODKxLbszr3CF0vbcuHhZ/To0M8uQdL6dc9gQHdO/HQ9GH8aMUndLq0DXE44uJgSM+OPHDDsAb9bvZ/foavPb2Oo2crzqekD+xGxnC9uHuJM7P6t2qk9PR0y87ODvl+o0VpmZ8l6wpYOHFIrU+klRuLeOC/twCQNqArc8YPIKFdfKNbZcH+7B0HTnL3dUN4v6CEnMJjfHbsLAO6tydtQFemjmxYX7bPH+CNTft4dt0edh85DVQ8qZ+9fSwvbdhb5/FEs9IyP79a9SmZOw9SeOwLAK5I7MDKe6+lXXwcmTuLeWPLft7aVowDMob15Ipenbg/Y+h5j1dpmZ8H/7SJv1W+c2qobpfGc+zsxSeQpw3vxZBeHbhnUu0z3pWW+bn231ZzvPL+8XHw8Y9v0BWIEeCcyzGz9Jpua9Czyjk3HfgN0Ab4nZn9MoT1xZwl6wpYvCafc+UB2raJOy/gfP4AmbnF/OGDvQAM7tmBP3znS00OwKy8Er77x80EgFU7iskvOX3e7b//ACYPK6Kw5DTzrkuia/sEMKPsXIDVuYf5vzcO56X1hZz2+Xk//yh7jp6puu+Qnu3J3nuMe17ZSPbeY7y5dT99ul7K/uNneXneePp2b39RPT5/gFXbDrK56ASYkTqgG9cM6cnz7+8JSdCXlvn5j3c/Zd/xs3wv4wru++Nm7vzKAApLzlB0/CzTRlxOfBzsLD5F7y4JvPh+Ib27JnD8rI/TZ8sZN7g7Ow6eJD4ujl9+YxQPvfYJ/kCAkX278OmhU9ya3p/dR85gBiN6dyK/5DRl58r5/QeF1R6XDry64CtVxzJ9VG+uvTKRQyc/ZONnx1n96RFWf3qE9flHwTnGDujGPZOGVD2OADcMv4yB3ds3uUUdHxfHOzuLYScUHD7N07ePPe/FuLTMz9Or8/io8FhVSF8aH8cb912jkPagelvUzrk2QB5wA1AEfAzcZmY7aruPWtR1C7aoff4AS/9nN+kDu/HiXePomBDPezuKWbAsh3Iz0gd24z/npDfriRNsUf/+/UJyPjvO9BGX0bvLpVUt6oPHzpzXgotzYPb3E5RXJHY4L9yTuiXQrWMCs8f1Y8rw3jz//h6+PX4gt/9uw3nb9ep0CQ9/dRh5xaXM/tJA/uXNHfTveimpA7py/x83V+0/zsENI3rxzvZienVqx/0ZQzhw4ovzWoE+f4CsvBKGX96J7/1pM0/cksKfsovo1y2BP328jx4dLuEfUnszI6UvT63exeI1+QB0aBfHaV/Th0fWd//gidwx/bqyqeg4ie3bcFNqf34wteYuh2A4+vzlbCk6wcZ9J6puSx/Yjey9xxjTvwvjBnW/qLXdWD5/gO++nFP1u00b0JXbru5HfJs2bDtwgpy9x9m0r6JLbeyArnznmiSmJmuUUCTV1aJuSFB/GfiJmU2r/PphADP7RW33UVA3TGmZn7m//4jsvce4b/IVPDhtWFWLGoOM4aE7+VRbd0tpmZ9/fzuX/9l1uN4W9bYDJ1l6RzqXdUm4aP/HSn3805+3cMbnZ9fhUkpO+YCKIKse9vdMHMyI3p3Pa1Fv+uw4z7+/57z93TiqN4/NHMGjKz/hbHk56/KOktTjUgqPnqV350s4ePKL87Z3wH9+O53xg3u0eIv6rq8kNbrrJziKJ+ez41Ut6lB3H5WW+Xl27S7W53/OpqLjF92e2r8LXwrBi4KERnOD+hZgupn9Y+XXc4Avmdl9F2w3H5gPMGDAgLS9e/eGovao15D+6tYm+GJT5vNf1KKu6QRXaZmfJ1flsumz49yS1ocPCj7nne3FjBnQraorAGDy0EROfeGvs0WtFuHFjpX6eGTlVr7wB1idW7FC0bQRvXjy/6RGzd9cNGiRoK5OLWppjmBXR0rfLjy68hN6dGzHkTPn+OXXU9R/2gylZX6eWZNPcp/OdY48kcho7snE/UD/al/3q/yeSFhUX01nyZ1XR7ia6NExIZ5//upVkS5DmqAhL6kfA1c65wY559oB3wLeCG9ZIiISVG+L2sz8zrn7gHeoGJ73gpltD3tlIiICNHActZm9BbwV5lpERKQGOpsgIuJxCmoREY9TUIuIeJyCWkTE48Iye55zrgQ4DRwJ+c4jryfReVygY2utovXYovW4oOZjG2hmNc5LHJagBnDOZdd2lU1rFq3HBTq21ipajy1ajwsaf2zq+hAR8TgFtYiIx4UzqJeGcd+RFK3HBTq21ipajy1ajwsaeWxh66MWEZHQUNeHiIjHKahFRDwurEHtnLvfOZfrnNvunPv3cP6sSHDO/dA5Z865npGuJVScc49X/s62Ouded851jXRNzeGcm+6c+9Q5l++ceyjS9YSKc66/c26Nc25H5fPre5GuKdScc22cc5ucc3+NdC2h5Jzr6pxbUfk821m53GGdwhbUzrnJwD8Ao80sGXgiXD8rEpxz/YGpwGeRriXE3gVGmlkKFYsaPxzhepqscmHmZ4CvAiOA25xzIyJbVcj4gR+a2QhgPHBvFB1b0PeAnZEuIgx+A/zNzK4CRtOAYwxni3oR8Esz+wLAzA7Xs31r82vgn/j7gt1RwcxWmZm/8ssNVKzo01qNA/LNbIcVo1AAAAJCSURBVLeZ+YD/oqLx0OqZ2UEz21j5+Skqnux9I1tV6Djn+gE3Ar+LdC2h5JzrAkwAngcwM5+ZXbzy8AXCGdRDgeuccx8659Y556JmTSXn3D8A+81sS6RrCbPvAG9Huohm6Avsq/Z1EVEUZkHOuSRgDPBhZCsJqf+goiEUiHQhITYIKAF+X9mt8zvnXIf67tSsJYidc+8Bl9dw0yOV++5Oxduyq4H/ds4NtlYyHrCeY/sxFd0erVJdx2Zmf6nc5hEq3l6/3JK1SeM45zoCfwYeMLOTka4nFJxzM4HDZpbjnJsU6XpCLB4YC9xvZh86534DPAT8v/ru1GRmdn1ttznnFgGvVQbzR865ABUTkZQ052e2lNqOzTk3iopXxS3OOajoGtjonBtnZodasMQmq+v3BuCcmwvMBKa0lhfWWkT1wszOubZUhPTLZvZapOsJoWuArznnZgAJQGfn3HIzuyPCdYVCEVBkZsF3PyuoCOo6hbPrYyUwGcA5NxRoRxTMhGVmn5jZZWaWZGZJVDzwY1tLSNfHOTedirecXzOzM5Gup5midmFmV9FKeB7YaWa/inQ9oWRmD5tZv8rn17eAzCgJaSpzYp9zbljlt6YAO+q7X7Na1PV4AXjBObcN8AF3tvLWWaxYDFwCvFv5jmGDmS2MbElNE+ULM18DzAE+cc5trvzejyvXNxVvux94ubLxsBu4q7476BJyERGP05WJIiIep6AWEfE4BbWIiMcpqEVEPE5BLSLicQpqERGPU1CLiHjc/wdWOPFXO0u5dgAAAABJRU5ErkJggg==", - "text/plain": [ - "<Figure size 432x288 with 1 Axes>" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], "source": [ - "outputs = ensemble_wrapper(x_val)\n", - "plt.scatter(x_val, outputs.epistemic, label='epistemic uncertainty', s=0.5)\n", + "# Capsa makes the epistemic uncertainty an attribute of the prediction!\n", + "pred = np.array(prediction.y_hat).flatten()\n", + "unc = np.array(prediction.epistemic).flatten()\n", + "\n", + "# Visualize the aleatoric uncertainty across the data space\n", + "plt.figure(figsize=(10, 6))\n", + "plt.scatter(x_train, y_train, s=1.5, label='train data')\n", + "plt.plot(x_test, y_test, c='r', zorder=-1, label='ground truth')\n", + "plt.fill_between(x_test.flatten(), pred-20*unc, pred+20*unc, color='b', alpha=0.2, label='epistemic')\n", "plt.legend()" - ] + ], + "metadata": { + "id": "eauNoKDOj_ZT" + }, + "execution_count": null, + "outputs": [] }, { "cell_type": "markdown", - "metadata": { - "id": "VU6eMpYX9m9N" - }, "source": [ - "## Conclusion\n", - "As expected, areas where there is no training data have very high epistemic uncertainty, since all of the testing data is OOD. If our training data contained more samples from this region, would you expect the epistemic uncertainty to decrease?" - ] + "#### **TODO: Estimating epistemic uncertainty**\n", + "\n", + "Write short (~1 sentence) answers to the questions below to complete the `TODO`s:\n", + "\n", + "1. For what values of $x$ is the epistemic uncertainty high or increasing suddenly?\n", + "2. How does your answer in (1) relate to how the $x$ values are distributed (refer back to original plot)? Think about both the train and test data.\n", + "3. How could you reduce the epistemic uncertainty in regions where it is high?" + ], + "metadata": { + "id": "N4LMn2tLPBdg" + } }, { - "attachments": {}, "cell_type": "markdown", "metadata": { "id": "CkpvkOL06jRd" }, "source": [ + "# 1.6 Conclusion\n", "\n", - "You've just analyzed the bias, aleatoric uncertainty, and epistemic uncertainty for your first risk-aware model! This is a task that data scientists do constantly to determine methods of improving their models and datasets. In the next part, you'll continue to build off of these concepts to *mitigate* these risks, in addition to diagnosing them!\n", + "You've just analyzed the bias, aleatoric uncertainty, and epistemic uncertainty for your first risk-aware model! This is a task that data scientists do constantly to determine methods of improving their models and datasets.\n", + "\n", + "In the next part of the lab, you'll continue to build off of these concepts to study them in the context of facial detection systems: not only diagnosing issues of bias and uncertainty, but also developing solutions to *mitigate* these risks.\n", "\n", "" ] }, { "cell_type": "code", - "execution_count": null, + "source": [], "metadata": { - "id": "bs4mAQ5c6cMY" + "id": "nIpfPcpjlsKK" }, - "outputs": [], - "source": [] + "execution_count": null, + "outputs": [] } ], "metadata": { @@ -814,4 +575,4 @@ }, "nbformat": 4, "nbformat_minor": 0 -} +} \ No newline at end of file diff --git a/mitdeeplearning/lab3.py b/mitdeeplearning/lab3.py index a33f4886..ee7b9213 100644 --- a/mitdeeplearning/lab3.py +++ b/mitdeeplearning/lab3.py @@ -76,6 +76,7 @@ def __init__(self, data_path, batch_size, training=True): self.train_inds = np.concatenate((self.pos_train_inds, self.neg_train_inds)) self.batch_size = batch_size self.p_pos = np.ones(self.pos_train_inds.shape) / len(self.pos_train_inds) + self.p_neg = np.ones(self.neg_train_inds.shape) / len(self.neg_train_inds) def get_train_size(self): return self.pos_train_inds.shape[0] + self.neg_train_inds.shape[0] @@ -88,7 +89,7 @@ def __getitem__(self, index): self.pos_train_inds, size=self.batch_size // 2, replace=False, p=self.p_pos ) selected_neg_inds = np.random.choice( - self.neg_train_inds, size=self.batch_size // 2, replace=False + self.neg_train_inds, size=self.batch_size // 2, replace=False, p = self.p_neg ) selected_inds = np.concatenate((selected_pos_inds, selected_neg_inds)) diff --git a/setup.py b/setup.py index 90d0df82..9f2ffd0b 100644 --- a/setup.py +++ b/setup.py @@ -22,13 +22,13 @@ def get_dist(pkgname): setup( name = 'mitdeeplearning', # How you named your package folder (MyLib) packages = ['mitdeeplearning'], # Chose the same as "name" - version = '0.3.0', # Start with a small number and increase it with every change you make + version = '0.4.0', # Start with a small number and increase it with every change you make license='MIT', # Chose a license from here: https://help.github.com/articles/licensing-a-repository description = 'Official software labs for MIT Introduction to Deep Learning (http://introtodeeplearning.com)', # Give a short description about your library author = 'Alexander Amini', # Type in your name author_email = 'introtodeeplearning-staff@mit.edu', # Type in your E-Mail url = 'http://introtodeeplearning.com', # Provide either the link to your github or to your website - download_url = 'https://github.com/aamini/introtodeeplearning/archive/v0.3.0.tar.gz', # I explain this later on + download_url = 'https://github.com/aamini/introtodeeplearning/archive/v0.4.0.tar.gz', # I explain this later on keywords = ['deep learning', 'neural networks', 'tensorflow', 'introduction'], # Keywords that define your package best install_requires=install_deps, classifiers=[