未验证 提交 b6607524 编写于 作者: V Varuna Jayasiri 提交者: GitHub

Evidential Deep Learning to Quantify Classification Uncertainty (#85)

上级 387b6dfd
......@@ -154,15 +154,19 @@ implementations.</p>
<ul>
<li><a href="adaptive_computation/ponder_net/index.html">PonderNet</a></li>
</ul>
<h4><a href="uncertainty/index.html">Uncertainty</a></h4>
<ul>
<li><a href="uncertainty/evidence/index.html">Evidential Deep Learning to Quantify Classification Uncertainty</a></li>
</ul>
<h3>Installation</h3>
<pre><code class="bash">pip install labml-nn
</code></pre>
<h3>Citing LabML</h3>
<p>If you use LabML for academic research, please cite the library using the following BibTeX entry.</p>
<p>If you use this for academic research, please cite it using the following BibTeX entry.</p>
<pre><code class="bibtex">@misc{labml,
author = {Varuna Jayasiri, Nipun Wijerathne},
title = {LabML: A library to organize machine learning experiments},
title = {labml.ai Annotated Paper Implementations},
year = {2020},
url = {https://nn.labml.ai/},
}
......
......@@ -268,7 +268,10 @@ and set a new function to calculate the model.</p>
<p>Load configurations</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">75</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span> <span class="p">{</span><span class="s1">&#39;optimizer.optimizer&#39;</span><span class="p">:</span> <span class="s1">&#39;Adam&#39;</span><span class="p">})</span></pre></div>
<div class="highlight"><pre><span class="lineno">75</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span> <span class="p">{</span>
<span class="lineno">76</span> <span class="s1">&#39;optimizer.optimizer&#39;</span><span class="p">:</span> <span class="s1">&#39;Adam&#39;</span><span class="p">,</span>
<span class="lineno">77</span> <span class="s1">&#39;optimizer.learning_rate&#39;</span><span class="p">:</span> <span class="mf">0.001</span><span class="p">,</span>
<span class="lineno">78</span> <span class="p">})</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
......@@ -279,8 +282,8 @@ and set a new function to calculate the model.</p>
<p>Start the experiment and run the training loop</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">77</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span>
<span class="lineno">78</span> <span class="n">conf</span><span class="o">.</span><span class="n">run</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">80</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span>
<span class="lineno">81</span> <span class="n">conf</span><span class="o">.</span><span class="n">run</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
......@@ -291,8 +294,8 @@ and set a new function to calculate the model.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">82</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">83</span> <span class="n">main</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">85</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">86</span> <span class="n">main</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='footer'>
......
......@@ -114,6 +114,9 @@
"1704.03477": [
"https://nn.labml.ai/sketch_rnn/index.html"
],
"1806.01768": [
"https://nn.labml.ai/uncertainty/evidence/index.html"
],
"1509.06461": [
"https://nn.labml.ai/rl/dqn/index.html"
],
......
......@@ -204,7 +204,7 @@
<url>
<loc>https://nn.labml.ai/normalization/batch_norm/mnist.html</loc>
<lastmod>2021-08-19T16:30:00+00:00</lastmod>
<lastmod>2021-08-20T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
......@@ -281,7 +281,7 @@
<url>
<loc>https://nn.labml.ai/index.html</loc>
<lastmod>2021-08-12T16:30:00+00:00</lastmod>
<lastmod>2021-08-21T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
......@@ -797,6 +797,27 @@
</url>
<url>
<loc>https://nn.labml.ai/uncertainty/evidence/index.html</loc>
<lastmod>2021-08-21T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/uncertainty/evidence/experiment.html</loc>
<lastmod>2021-08-21T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/uncertainty/index.html</loc>
<lastmod>2021-08-21T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/rl/game.html</loc>
<lastmod>2020-12-10T16:30:00+00:00</lastmod>
......
此差异已折叠。
此差异已折叠。
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content=""/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Evidential Deep Learning to Quantify Classification Uncertainty"/>
<meta name="twitter:description" content=""/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/uncertainty/evidence/readme.html"/>
<meta property="og:title" content="Evidential Deep Learning to Quantify Classification Uncertainty"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Evidential Deep Learning to Quantify Classification Uncertainty"/>
<meta property="og:description" content=""/>
<title>Evidential Deep Learning to Quantify Classification Uncertainty</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css">
<link rel="canonical" href="https://nn.labml.ai/uncertainty/evidence/readme.html"/>
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">uncertainty</a>
<a class="parent" href="index.html">evidence</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/uncertainty/evidence/readme.md">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai"
rel="nofollow">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1><a href="https://nn.labml.ai/uncertainty/evidence/index.html">Evidential Deep Learning to Quantify Classification Uncertainty</a></h1>
<p>This is a <a href="https://pytorch.org">PyTorch</a> implementation of the paper
<a href="https://papers.labml.ai/paper/1806.01768">Evidential Deep Learning to Quantify Classification Uncertainty</a>.</p>
<p>Here is the <a href="https://nn.labml.ai/uncertainty/evidence/experiment.html">training code <code>experiment.py</code></a> to train a model on MNIST dataset.</p>
<p><a href="https://app.labml.ai/run/f82b2bfc01ba11ecbb2aa16a33570106"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a></p>
</div>
<div class='code'>
</div>
</div>
<div class='footer'>
<a href="https://papers.labml.ai">Trending Research Papers</a>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
</script>
<!-- MathJax configuration -->
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ['$','$'] ],
displayMath: [ ['$$','$$'] ],
processEscapes: true,
processEnvironments: true
},
// Center justify equations in code and markdown cells. Elsewhere
// we use CSS to left justify single line equations in code cells.
displayAlign: 'center',
"HTML-CSS": { fonts: ["TeX"] }
});
</script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
console.log(images);
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>
\ No newline at end of file
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="A set of PyTorch implementations/tutorials related to uncertainty estimation"/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Neural Networks with Uncertainty Estimation"/>
<meta name="twitter:description" content="A set of PyTorch implementations/tutorials related to uncertainty estimation"/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/uncertainty/index.html"/>
<meta property="og:title" content="Neural Networks with Uncertainty Estimation"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Neural Networks with Uncertainty Estimation"/>
<meta property="og:description" content="A set of PyTorch implementations/tutorials related to uncertainty estimation"/>
<title>Neural Networks with Uncertainty Estimation</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../pylit.css">
<link rel="canonical" href="https://nn.labml.ai/uncertainty/index.html"/>
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="index.html">uncertainty</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/uncertainty/__init__.py">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai"
rel="nofollow">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Neural Networks with Uncertainty Estimation</h1>
<p>These are neural network architectures that estimate the uncertainty of the predictions.</p>
<ul>
<li><a href="evidence/index.html">Evidential Deep Learning to Quantify Classification Uncertainty</a></li>
</ul>
</div>
<div class='code'>
<div class="highlight"><pre></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://papers.labml.ai">Trending Research Papers</a>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
</script>
<!-- MathJax configuration -->
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ['$','$'] ],
displayMath: [ ['$$','$$'] ],
processEscapes: true,
processEnvironments: true
},
// Center justify equations in code and markdown cells. Elsewhere
// we use CSS to left justify single line equations in code cells.
displayAlign: 'center',
"HTML-CSS": { fonts: ["TeX"] }
});
</script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
console.log(images);
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>
\ No newline at end of file
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content=""/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Neural Networks with Uncertainty Estimation"/>
<meta name="twitter:description" content=""/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/uncertainty/readme.html"/>
<meta property="og:title" content="Neural Networks with Uncertainty Estimation"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Neural Networks with Uncertainty Estimation"/>
<meta property="og:description" content=""/>
<title>Neural Networks with Uncertainty Estimation</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../pylit.css">
<link rel="canonical" href="https://nn.labml.ai/uncertainty/readme.html"/>
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="index.html">uncertainty</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/uncertainty/readme.md">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai"
rel="nofollow">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1><a href="https://nn.labml.ai/uncertainty/index.html">Neural Networks with Uncertainty Estimation</a></h1>
<p>These are neural network architectures that estimate the uncertainty of the predictions.</p>
<ul>
<li><a href="https://nn.labml.ai/uncertainty/evidence/index.html">Evidential Deep Learning to Quantify Classification Uncertainty</a></li>
</ul>
</div>
<div class='code'>
</div>
</div>
<div class='footer'>
<a href="https://papers.labml.ai">Trending Research Papers</a>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
</script>
<!-- MathJax configuration -->
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ['$','$'] ],
displayMath: [ ['$$','$$'] ],
processEscapes: true,
processEnvironments: true
},
// Center justify equations in code and markdown cells. Elsewhere
// we use CSS to left justify single line equations in code cells.
displayAlign: 'center',
"HTML-CSS": { fonts: ["TeX"] }
});
</script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
console.log(images);
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>
\ No newline at end of file
......@@ -94,6 +94,10 @@ Solving games with incomplete information such as poker with CFR.
* [PonderNet](adaptive_computation/ponder_net/index.html)
#### ✨ [Uncertainty](uncertainty/index.html)
* [Evidential Deep Learning to Quantify Classification Uncertainty](uncertainty/evidence/index.html)
### Installation
```bash
......@@ -102,12 +106,12 @@ pip install labml-nn
### Citing LabML
If you use LabML for academic research, please cite the library using the following BibTeX entry.
If you use this for academic research, please cite it using the following BibTeX entry.
```bibtex
@misc{labml,
author = {Varuna Jayasiri, Nipun Wijerathne},
title = {LabML: A library to organize machine learning experiments},
title = {labml.ai Annotated Paper Implementations},
year = {2020},
url = {https://nn.labml.ai/},
}
......
......@@ -72,7 +72,10 @@ def main():
# Create configurations
conf = MNISTConfigs()
# Load configurations
experiment.configs(conf, {'optimizer.optimizer': 'Adam'})
experiment.configs(conf, {
'optimizer.optimizer': 'Adam',
'optimizer.learning_rate': 0.001,
})
# Start the experiment and run the training loop
with experiment.start():
conf.run()
......
"""
---
title: Neural Networks with Uncertainty Estimation
summary: >
A set of PyTorch implementations/tutorials related to uncertainty estimation
---
# Neural Networks with Uncertainty Estimation
These are neural network architectures that estimate the uncertainty of the predictions.
* [Evidential Deep Learning to Quantify Classification Uncertainty](evidence/index.html)
"""
"""
---
title: "Evidential Deep Learning to Quantify Classification Uncertainty"
summary: >
A PyTorch implementation/tutorial of the paper Evidential Deep Learning to Quantify Classification
Uncertainty.
---
# Evidential Deep Learning to Quantify Classification Uncertainty
This is a [PyTorch](https://pytorch.org) implementation of the paper
[Evidential Deep Learning to Quantify Classification Uncertainty](https://papers.labml.ai/paper/1806.01768).
[Dampster-Shafer Theory of Evidence](https://en.wikipedia.org/wiki/Dempster%E2%80%93Shafer_theory)
assigns belief masses a set of classes (unlike assigning a probability to a single class).
Sum of the masses of all subsets is $1$.
Individual class probabilities (plausibilities) can be derived from these masses.
Assigning a mass to the set of all classes means it can be any one of the classes; i.e. saying "I don't know".
If there are $K$ classes, we assign masses $b_k \ge 0$ to each of the classes and
an overall uncertainty mass $u \ge 0$ to all classes.
$$u + \sum_{k=1}^K b_k = 1$$
Belief masses $b_k$ and $u$ can be computed from evidence $e_k \ge 0$, as $b_k = \frac{e_k}{S}$
and $u = \frac{K}{S}$ where $S = \sum_{k=1}^K (e_k + 1)$.
Paper uses term evidence as a measure of the amount of support
collected from data in favor of a sample to be classified into a certain class.
This corresponds to a [Dirichlet distribution](https://en.wikipedia.org/wiki/Dirichlet_distribution)
with parameters $\color{cyan}{\alpha_k} = e_k + 1$, and
$\color{cyan}{\alpha_0} = S = \sum_{k=1}^K \color{cyan}{\alpha_k}$ is known as the Dirichlet strength.
Dirichlet distribution $D(\mathbf{p} \vert \color{cyan}{\mathbf{\alpha}})$
is a distribution over categorical distribution; i.e. you can sample class probabilities
from a Dirichlet distribution.
The expected probability for class $k$ is $\hat{p}_k = \frac{\color{cyan}{\alpha_k}}{S}$.
We get the model to output evidences
$$\mathbf{e} = \color{cyan}{\mathbf{\alpha}} - 1 = f(\mathbf{x} | \Theta)$$
for a given input $\mathbf{x}$.
We use a function such as
[ReLU](https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html) or a
[Softplus](https://pytorch.org/docs/stable/generated/torch.nn.Softplus.html)
at the final layer to get $f(\mathbf{x} | \Theta) \ge 0$.
The paper proposes a few loss functions to train the model, which we have implemented below.
Here is the [training code `experiment.py`](experiment.html) to train a model on MNIST dataset.
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/f82b2bfc01ba11ecbb2aa16a33570106)
"""
import torch
from labml import tracker
from labml_helpers.module import Module
class MaximumLikelihoodLoss(Module):
"""
<a id="MaximumLikelihoodLoss"></a>
## Type II Maximum Likelihood Loss
The distribution D(\mathbf{p} \vert \color{cyan}{\mathbf{\alpha}}) is a prior on the likelihood
$Multi(\mathbf{y} \vert p)$,
and the negative log marginal likelihood is calculated by integrating over class probabilities
$\mathbf{p}$.
If target probabilities (one-hot targets) are $y_k$ for a given sample the loss is,
\begin{align}
\mathcal{L}(\Theta)
&= -\log \Bigg(
\int
\prod_{k=1}^K p_k^{y_k}
\frac{1}{B(\color{cyan}{\mathbf{\alpha}})}
\prod_{k=1}^K p_k^{\color{cyan}{\alpha_k} - 1}
d\mathbf{p}
\Bigg ) \\
&= \sum_{k=1}^K y_k \bigg( \log S - \log \color{cyan}{\alpha_k} \bigg)
\end{align}
"""
def forward(self, evidence: torch.Tensor, target: torch.Tensor):
"""
* `evidence` is $\mathbf{e} \ge 0$ with shape `[batch_size, n_classes]`
* `target` is $\mathbf{y}$ with shape `[batch_size, n_classes]`
"""
# $\color{cyan}{\alpha_k} = e_k + 1$
alpha = evidence + 1.
# $S = \sum_{k=1}^K \color{cyan}{\alpha_k}$
strength = alpha.sum(dim=-1)
# Losses $\mathcal{L}(\Theta) = \sum_{k=1}^K y_k \bigg( \log S - \log \color{cyan}{\alpha_k} \bigg)$
loss = (target * (strength.log()[:, None] - alpha.log())).sum(dim=-1)
# Mean loss over the batch
return loss.mean()
class CrossEntropyBayesRisk(Module):
"""
<a id="CrossEntropyBayesRisk"></a>
## Bayes Risk with Cross Entropy Loss
Bayes risk is the overall maximum cost of making incorrect estimates.
It takes a cost function that gives the cost of making an incorrect estimate
and sums it over all possible outcomes based on probability distribution.
Here the cost function is cross-entropy loss, for one-hot coded $\mathbf{y}$
$$\sum_{k=1}^K -y_k \log p_k$$
We integrate this cost over all $\mathbf{p}$
\begin{align}
\mathcal{L}(\Theta)
&= -\log \Bigg(
\int
\Big[ \sum_{k=1}^K -y_k \log p_k \Big]
\frac{1}{B(\color{cyan}{\mathbf{\alpha}})}
\prod_{k=1}^K p_k^{\color{cyan}{\alpha_k} - 1}
d\mathbf{p}
\Bigg ) \\
&= \sum_{k=1}^K y_k \bigg( \psi(S) - \psi( \color{cyan}{\alpha_k} ) \bigg)
\end{align}
where $\psi(\cdot)$ is the $digamma$ function.
"""
def forward(self, evidence: torch.Tensor, target: torch.Tensor):
"""
* `evidence` is $\mathbf{e} \ge 0$ with shape `[batch_size, n_classes]`
* `target` is $\mathbf{y}$ with shape `[batch_size, n_classes]`
"""
# $\color{cyan}{\alpha_k} = e_k + 1$
alpha = evidence + 1.
# $S = \sum_{k=1}^K \color{cyan}{\alpha_k}$
strength = alpha.sum(dim=-1)
# Losses $\mathcal{L}(\Theta) = \sum_{k=1}^K y_k \bigg( \psi(S) - \psi( \color{cyan}{\alpha_k} ) \bigg)$
loss = (target * (torch.digamma(strength)[:, None] - torch.digamma(alpha))).sum(dim=-1)
# Mean loss over the batch
return loss.mean()
class SquaredErrorBayesRisk(Module):
"""
<a id="SquaredErrorBayesRisk"></a>
## Bayes Risk with Squared Error Loss
Here the cost function is squared error,
$$\sum_{k=1}^K (y_k - p_k)^2 = \Vert \mathbf{y} - \mathbf{p} \Vert_2^2$$
We integrate this cost over all $\mathbf{p}$
\begin{align}
\mathcal{L}(\Theta)
&= -\log \Bigg(
\int
\Big[ \sum_{k=1}^K (y_k - p_k)^2 \Big]
\frac{1}{B(\color{cyan}{\mathbf{\alpha}})}
\prod_{k=1}^K p_k^{\color{cyan}{\alpha_k} - 1}
d\mathbf{p}
\Bigg ) \\
&= \sum_{k=1}^K \mathbb{E} \Big[ y_k^2 -2 y_k p_k + p_k^2 \Big] \\
&= \sum_{k=1}^K \Big( y_k^2 -2 y_k \mathbb{E}[p_k] + \mathbb{E}[p_k^2] \Big)
\end{align}
Where $$\mathbb{E}[p_k] = \hat{p}_k = \frac{\color{cyan}{\alpha_k}}{S}$$
is the expected probability when sampled from the Dirichlet distribution
and $$\mathbb{E}[p_k^2] = \mathbb{E}[p_k]^2 + \text{Var}(p_k)$$
where
$$\text{Var}(p_k) = \frac{\color{cyan}{\alpha_k}(S - \color{cyan}{\alpha_k})}{S^2 (S + 1)}
= \frac{\hat{p}_k(1 - \hat{p}_k)}{S + 1}$$
is the variance.
This gives,
\begin{align}
\mathcal{L}(\Theta)
&= \sum_{k=1}^K \Big( y_k^2 -2 y_k \mathbb{E}[p_k] + \mathbb{E}[p_k^2] \Big) \\
&= \sum_{k=1}^K \Big( y_k^2 -2 y_k \mathbb{E}[p_k] + \mathbb{E}[p_k]^2 + \text{Var}(p_k) \Big) \\
&= \sum_{k=1}^K \Big( \big( y_k -\mathbb{E}[p_k] \big)^2 + \text{Var}(p_k) \Big) \\
&= \sum_{k=1}^K \Big( ( y_k -\hat{p}_k)^2 + \frac{\hat{p}_k(1 - \hat{p}_k)}{S + 1} \Big)
\end{align}
This first part of the equation $\big(y_k -\mathbb{E}[p_k]\big)^2$ is the error term and
the second part is the variance.
"""
def forward(self, evidence: torch.Tensor, target: torch.Tensor):
"""
* `evidence` is $\mathbf{e} \ge 0$ with shape `[batch_size, n_classes]`
* `target` is $\mathbf{y}$ with shape `[batch_size, n_classes]`
"""
# $\color{cyan}{\alpha_k} = e_k + 1$
alpha = evidence + 1.
# $S = \sum_{k=1}^K \color{cyan}{\alpha_k}$
strength = alpha.sum(dim=-1)
# $\hat{p}_k = \frac{\color{cyan}{\alpha_k}}{S}$
p = alpha / strength[:, None]
# Error $(y_k -\hat{p}_k)^2$
err = (target - p) ** 2
# Variance $\text{Var}(p_k) = \frac{\hat{p}_k(1 - \hat{p}_k)}{S + 1}$
var = p * (1 - p) / (strength[:, None] + 1)
# Sum of them
loss = (err + var).sum(dim=-1)
# Mean loss over the batch
return loss.mean()
class KLDivergenceLoss(Module):
"""
<a id="KLDivergenceLoss"></a>
## KL Divergence Regularization Loss
This tries to shrink the total evidence to zero if the sample cannot be correctly classified.
First we calculate $\tilde{\alpha}_k = y_k + (1 - y_k) \color{cyan}{\alpha_k}$ the
Dirichlet parameters after remove the correct evidence.
\begin{align}
&KL \Big[ D(\mathbf{p} \vert \mathbf{\tilde{\alpha}}) \Big \Vert
D(\mathbf{p} \vert <1, \dots, 1>\Big] \\
&= \log \Bigg( \frac{\Gamma \Big( \sum_{k=1}^K \tilde{\alpha}_k \Big)}
{\Gamma(K) \prod_{k=1}^K \Gamma(\tilde{\alpha}_k)} \Bigg)
+ \sum_{k=1}^K (\tilde{\alpha}_k - 1)
\Big[ \psi(\tilde{\alpha}_k) - \psi(\tilde{S}) \Big]
\end{align}
where $\Gamma(\cdot)$ is the gamma function,
$\psi(\cdot)$ is the $digamma$ function and
$\tilde{S} = \sum_{k=1}^K \tilde{\alpha}_k$
"""
def forward(self, evidence: torch.Tensor, target: torch.Tensor):
"""
* `evidence` is $\mathbf{e} \ge 0$ with shape `[batch_size, n_classes]`
* `target` is $\mathbf{y}$ with shape `[batch_size, n_classes]`
"""
# $\color{cyan}{\alpha_k} = e_k + 1$
alpha = evidence + 1.
# Number of classes
n_classes = evidence.shape[-1]
# Remove non-misleading evidence
# $$\tilde{\alpha}_k = y_k + (1 - y_k) \color{cyan}{\alpha_k}$$
alpha_tilde = target + (1 - target) * alpha
# $\tilde{S} = \sum_{k=1}^K \tilde{\alpha}_k$
strength_tilde = alpha_tilde.sum(dim=-1)
# The first term
# \begin{align}
# &\log \Bigg( \frac{\Gamma \Big( \sum_{k=1}^K \tilde{\alpha}_k \Big)}
# {\Gamma(K) \prod_{k=1}^K \Gamma(\tilde{\alpha}_k)} \Bigg) \\
# &= \log \Gamma \Big( \sum_{k=1}^K \tilde{\alpha}_k \Big)
# - \log \Gamma(K)
# - \sum_{k=1}^K \log \Gamma(\tilde{\alpha}_k)
# \end{align}
first = (torch.lgamma(alpha_tilde.sum(dim=-1))
- torch.lgamma(alpha_tilde.new_tensor(float(n_classes)))
- (torch.lgamma(alpha_tilde)).sum(dim=-1))
# The second term
# $$\sum_{k=1}^K (\tilde{\alpha}_k - 1)
# \Big[ \psi(\tilde{\alpha}_k) - \psi(\tilde{S}) \Big]$$
second = (
(alpha_tilde - 1) *
(torch.digamma(alpha_tilde) - torch.digamma(strength_tilde)[:, None])
).sum(dim=-1)
# Sum of the terms
loss = first + second
# Mean loss over the batch
return loss.mean()
class TrackStatistics(Module):
"""
<a id="TrackStatistics"></a>
### Track statistics
This module computes statistics and tracks them with [labml `tracker`](https://docs.labml.ai/api/tracker.html).
"""
def forward(self, evidence: torch.Tensor, target: torch.Tensor):
# Number of classes
n_classes = evidence.shape[-1]
# Predictions that correctly match with the target (greedy sampling based on highest probability)
match = evidence.argmax(dim=-1).eq(target.argmax(dim=-1))
# Track accuracy
tracker.add('accuracy.', match.sum() / match.shape[0])
# $\color{cyan}{\alpha_k} = e_k + 1$
alpha = evidence + 1.
# $S = \sum_{k=1}^K \color{cyan}{\alpha_k}$
strength = alpha.sum(dim=-1)
# $\hat{p}_k = \frac{\color{cyan}{\alpha_k}}{S}$
expected_probability = alpha / strength[:, None]
# Expected probability of the selected (greedy highset probability) class
expected_probability, _ = expected_probability.max(dim=-1)
# Uncertainty mass $u = \frac{K}{S}$
uncertainty_mass = n_classes / strength
# Track $u$ for correctly predictions
tracker.add('u.succ.', uncertainty_mass.masked_select(match))
# Track $u$ for incorrect predictions
tracker.add('u.fail.', uncertainty_mass.masked_select(~match))
# Track $\hat{p}_k$ for correctly predictions
tracker.add('prob.succ.', expected_probability.masked_select(match))
# Track $\hat{p}_k$ for incorrect predictions
tracker.add('prob.fail.', expected_probability.masked_select(~match))
"""
---
title: "Evidential Deep Learning to Quantify Classification Uncertainty Experiment"
summary: >
This trains is EDL model on MNIST
---
# [Evidential Deep Learning to Quantify Classification Uncertainty](index.html) Experiment
This trains a model based on [Evidential Deep Learning to Quantify Classification Uncertainty](index.html)
on MNIST dataset.
"""
from typing import Any
import torch.nn as nn
import torch.utils.data
from labml import tracker, experiment
from labml.configs import option, calculate
from labml_helpers.module import Module
from labml_helpers.schedule import Schedule, RelativePiecewise
from labml_helpers.train_valid import BatchIndex
from labml_nn.experiments.mnist import MNISTConfigs
from labml_nn.uncertainty.evidence import KLDivergenceLoss, TrackStatistics, MaximumLikelihoodLoss, \
CrossEntropyBayesRisk, SquaredErrorBayesRisk
class Model(Module):
"""
## LeNet based model fro MNIST classification
"""
def __init__(self, dropout: float):
super().__init__()
# First $5x5$ convolution layer
self.conv1 = nn.Conv2d(1, 20, kernel_size=5)
# ReLU activation
self.act1 = nn.ReLU()
# $2x2$ max-pooling
self.max_pool1 = nn.MaxPool2d(2, 2)
# Second $5x5$ convolution layer
self.conv2 = nn.Conv2d(20, 50, kernel_size=5)
# ReLU activation
self.act2 = nn.ReLU()
# $2x2$ max-pooling
self.max_pool2 = nn.MaxPool2d(2, 2)
# First fully-connected layer that maps to $500$ features
self.fc1 = nn.Linear(50 * 4 * 4, 500)
# ReLU activation
self.act3 = nn.ReLU()
# Final fully connected layer to output evidence for $10$ classes.
# The ReLU or Softplus activation is applied to this outside the model to get the
# non-negative evidence
self.fc2 = nn.Linear(500, 10)
# Dropout for the hidden layer
self.dropout = nn.Dropout(p=dropout)
def __call__(self, x: torch.Tensor):
"""
* `x` is the batch of MNIST images of shape `[batch_size, 1, 28, 28]`
"""
# Apply first convolution and max pooling.
# The result has shape `[batch_size, 20, 12, 12]`
x = self.max_pool1(self.act1(self.conv1(x)))
# Apply second convolution and max pooling.
# The result has shape `[batch_size, 50, 4, 4]`
x = self.max_pool2(self.act2(self.conv2(x)))
# Flatten the tensor to shape `[batch_size, 50 * 4 * 4]`
x = x.view(x.shape[0], -1)
# Apply hidden layer
x = self.act3(self.fc1(x))
# Apply dropout
x = self.dropout(x)
# Apply final layer and return
return self.fc2(x)
class Configs(MNISTConfigs):
"""
## Configurations
We use [`MNISTConfigs`](../../experiments/mnist.html#MNISTConfigs) configurations.
"""
# [KL Divergence regularization](index.html#KLDivergenceLoss)
kl_div_loss = KLDivergenceLoss()
# KL Divergence regularization coefficient schedule
kl_div_coef: Schedule
# KL Divergence regularization coefficient schedule
kl_div_coef_schedule = [(0, 0.), (0.2, 0.01), (1, 1.)]
# [Stats module](index.html#TrackStatistics) for tracking
stats = TrackStatistics()
# Dropout
dropout: float = 0.5
# Module to convert the model output to non-zero evidences
outputs_to_evidence: Module
def init(self):
"""
### Initialization
"""
# Set tracker configurations
tracker.set_scalar("loss.*", True)
tracker.set_scalar("accuracy.*", True)
tracker.set_histogram('u.*', True)
tracker.set_histogram('prob.*', False)
tracker.set_scalar('annealing_coef.*', False)
tracker.set_scalar('kl_div_loss.*', False)
#
self.state_modules = []
def step(self, batch: Any, batch_idx: BatchIndex):
"""
### Training or validation step
"""
# Training/Evaluation mode
self.model.train(self.mode.is_train)
# Move data to the device
data, target = batch[0].to(self.device), batch[1].to(self.device)
# One-hot coded targets
eye = torch.eye(10).to(torch.float).to(self.device)
target = eye[target]
# Update global step (number of samples processed) when in training mode
if self.mode.is_train:
tracker.add_global_step(len(data))
# Get model outputs
outputs = self.model(data)
# Get evidences $e_k \ge 0$
evidence = self.outputs_to_evidence(outputs)
# Calculate loss
loss = self.loss_func(evidence, target)
# Calculate KL Divergence regularization loss
kl_div_loss = self.kl_div_loss(evidence, target)
tracker.add("loss.", loss)
tracker.add("kl_div_loss.", kl_div_loss)
# KL Divergence loss coefficient $\lambda_t$
annealing_coef = min(1., self.kl_div_coef(tracker.get_global_step()))
tracker.add("annealing_coef.", annealing_coef)
# Total loss
loss = loss + annealing_coef * kl_div_loss
# Track statistics
self.stats(evidence, target)
# Train the model
if self.mode.is_train:
# Calculate gradients
loss.backward()
# Take optimizer step
self.optimizer.step()
# Clear the gradients
self.optimizer.zero_grad()
# Save the tracked metrics
tracker.save()
@option(Configs.model)
def mnist_model(c: Configs):
"""
### Create model
"""
return Model(c.dropout).to(c.device)
@option(Configs.kl_div_coef)
def kl_div_coef(c: Configs):
"""
### KL Divergence Loss Coefficient Schedule
"""
# Create a [relative piecewise schedule](https://docs.labml.ai/api/helpers.html#labml_helpers.schedule.Piecewise)
return RelativePiecewise(c.kl_div_coef_schedule, c.epochs * len(c.train_dataset))
# [Maximum Likelihood Loss](index.html#MaximumLikelihoodLoss)
calculate(Configs.loss_func, 'max_likelihood_loss', lambda: MaximumLikelihoodLoss())
# [Cross Entropy Bayes Risk](index.html#CrossEntropyBayesRisk)
calculate(Configs.loss_func, 'cross_entropy_bayes_risk', lambda: CrossEntropyBayesRisk())
# [Squared Error Bayes Risk](index.html#SquaredErrorBayesRisk)
calculate(Configs.loss_func, 'squared_error_bayes_risk', lambda: SquaredErrorBayesRisk())
# ReLU to calculate evidence
calculate(Configs.outputs_to_evidence, 'relu', lambda: nn.ReLU())
# Softplus to calculate evidence
calculate(Configs.outputs_to_evidence, 'softplus', lambda: nn.Softplus())
def main():
# Create experiment
experiment.create(name='evidence_mnist')
# Create configurations
conf = Configs()
# Load configurations
experiment.configs(conf, {
'optimizer.optimizer': 'Adam',
'optimizer.learning_rate': 0.001,
'optimizer.weight_decay': 0.005,
# 'loss_func': 'max_likelihood_loss',
# 'loss_func': 'cross_entropy_bayes_risk',
'loss_func': 'squared_error_bayes_risk',
'outputs_to_evidence': 'softplus',
'dropout': 0.5,
})
# Start the experiment and run the training loop
with experiment.start():
conf.run()
#
if __name__ == '__main__':
main()
# [Evidential Deep Learning to Quantify Classification Uncertainty](https://nn.labml.ai/uncertainty/evidence/index.html)
This is a [PyTorch](https://pytorch.org) implementation of the paper
[Evidential Deep Learning to Quantify Classification Uncertainty](https://papers.labml.ai/paper/1806.01768).
Here is the [training code `experiment.py`](https://nn.labml.ai/uncertainty/evidence/experiment.html) to train a model on MNIST dataset.
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/f82b2bfc01ba11ecbb2aa16a33570106)
# [Neural Networks with Uncertainty Estimation](https://nn.labml.ai/uncertainty/index.html)
These are neural network architectures that estimate the uncertainty of the predictions.
* [Evidential Deep Learning to Quantify Classification Uncertainty](https://nn.labml.ai/uncertainty/evidence/index.html)
......@@ -99,6 +99,10 @@ Solving games with incomplete information such as poker with CFR.
* [PonderNet](https://nn.labml.ai/adaptive_computation/ponder_net/index.html)
#### ✨ [Uncertainty](https://nn.labml.ai/uncertainty/index.html)
* [Evidential Deep Learning to Quantify Classification Uncertainty](https://nn.labml.ai/uncertainty/evidence/index.html)
### Installation
```bash
......
......@@ -5,10 +5,10 @@ with open("readme.md", "r") as f:
setuptools.setup(
name='labml-nn',
version='0.4.109',
version='0.4.110',
author="Varuna Jayasiri, Nipun Wijerathne",
author_email="vpjayasiri@gmail.com, hnipun@gmail.com",
description="🧠 Implementations/tutorials of deep learning papers with side-by-side notes; including transformers (original, xl, switch, feedback, vit), optimizers(adam, radam, adabelief), gans(dcgan, cyclegan, stylegan2), reinforcement learning (ppo, dqn), capsnet, distillation, etc.",
description="🧑‍🏫 Implementations/tutorials of deep learning papers with side-by-side notes 📝; including transformers (original, xl, switch, feedback, vit), optimizers (adam, radam, adabelief), gans(dcgan, cyclegan, stylegan2), 🎮 reinforcement learning (ppo, dqn), capsnet, distillation, etc. 🧠",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/labmlai/annotated_deep_learning_paper_implementations",
......@@ -20,7 +20,7 @@ setuptools.setup(
'labml_helpers', 'labml_helpers.*',
'test',
'test.*')),
install_requires=['labml>=0.4.129',
install_requires=['labml>=0.4.132',
'labml-helpers>=0.4.81',
'torch',
'einops',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册