未验证 提交 ebce4044 编写于 作者: V Varuna Jayasiri 提交者: GitHub

Masked Language Model (#56)

上级 4a9716b0
...@@ -92,6 +92,7 @@ implementations.</p> ...@@ -92,6 +92,7 @@ implementations.</p>
<li><a href="transformers/fast_weights/index.html">Fast Weights Transformer</a></li> <li><a href="transformers/fast_weights/index.html">Fast Weights Transformer</a></li>
<li><a href="transformers/fnet/index.html">FNet</a></li> <li><a href="transformers/fnet/index.html">FNet</a></li>
<li><a href="transformers/aft/index.html">Attention Free Transformer</a></li> <li><a href="transformers/aft/index.html">Attention Free Transformer</a></li>
<li><a href="transformers/mlm/index.html">Masked Language Model</a></li>
</ul> </ul>
<h4><a href="recurrent_highway_networks/index.html">Recurrent Highway Networks</a></h4> <h4><a href="recurrent_highway_networks/index.html">Recurrent Highway Networks</a></h4>
<h4><a href="lstm/index.html">LSTM</a></h4> <h4><a href="lstm/index.html">LSTM</a></h4>
......
...@@ -461,6 +461,20 @@ ...@@ -461,6 +461,20 @@
</url> </url>
<url>
<loc>https://nn.labml.ai/utils/index.html</loc>
<lastmod>2021-06-06T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/utils/tokenizer.html</loc>
<lastmod>2021-06-06T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url> <url>
<loc>https://nn.labml.ai/optimizers/adam_warmup.html</loc> <loc>https://nn.labml.ai/optimizers/adam_warmup.html</loc>
<lastmod>2021-01-13T16:30:00+00:00</lastmod> <lastmod>2021-01-13T16:30:00+00:00</lastmod>
...@@ -811,9 +825,23 @@ ...@@ -811,9 +825,23 @@
</url> </url>
<url>
<loc>https://nn.labml.ai/transformers/mlm/index.html</loc>
<lastmod>2021-06-06T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/transformers/mlm/experiment.html</loc>
<lastmod>2021-06-06T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url> <url>
<loc>https://nn.labml.ai/transformers/aft/index.html</loc> <loc>https://nn.labml.ai/transformers/aft/index.html</loc>
<lastmod>2021-06-04T16:30:00+00:00</lastmod> <lastmod>2021-06-06T16:30:00+00:00</lastmod>
<priority>1.00</priority> <priority>1.00</priority>
</url> </url>
...@@ -930,13 +958,6 @@ ...@@ -930,13 +958,6 @@
</url> </url>
<url>
<loc>https://nn.labml.ai/utils.html</loc>
<lastmod>2021-05-26T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url> <url>
<loc>https://nn.labml.ai/capsule_networks/mnist.html</loc> <loc>https://nn.labml.ai/capsule_networks/mnist.html</loc>
<lastmod>2021-02-27T16:30:00+00:00</lastmod> <lastmod>2021-02-27T16:30:00+00:00</lastmod>
......
...@@ -3,12 +3,12 @@ ...@@ -3,12 +3,12 @@
<head> <head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/> <meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/> <meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="This is an annotated implementation/tutorial the AFT (Attention Free Transformer) in PyTorch."/> <meta name="description" content="This is an annotated implementation/tutorial of the AFT (Attention Free Transformer) in PyTorch."/>
<meta name="twitter:card" content="summary"/> <meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/> <meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="An Attention Free Transformer"/> <meta name="twitter:title" content="An Attention Free Transformer"/>
<meta name="twitter:description" content="This is an annotated implementation/tutorial the AFT (Attention Free Transformer) in PyTorch."/> <meta name="twitter:description" content="This is an annotated implementation/tutorial of the AFT (Attention Free Transformer) in PyTorch."/>
<meta name="twitter:site" content="@labmlai"/> <meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/> <meta name="twitter:creator" content="@labmlai"/>
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
<meta property="og:site_name" content="LabML Neural Networks"/> <meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/> <meta property="og:type" content="object"/>
<meta property="og:title" content="An Attention Free Transformer"/> <meta property="og:title" content="An Attention Free Transformer"/>
<meta property="og:description" content="This is an annotated implementation/tutorial the AFT (Attention Free Transformer) in PyTorch."/> <meta property="og:description" content="This is an annotated implementation/tutorial of the AFT (Attention Free Transformer) in PyTorch."/>
<title>An Attention Free Transformer</title> <title>An Attention Free Transformer</title>
<link rel="shortcut icon" href="/icon.png"/> <link rel="shortcut icon" href="/icon.png"/>
......
...@@ -108,12 +108,15 @@ It does single GPU training but we implement the concept of switching as describ ...@@ -108,12 +108,15 @@ It does single GPU training but we implement the concept of switching as describ
<h2><a href="aft/index.html">Attention Free Transformer</a></h2> <h2><a href="aft/index.html">Attention Free Transformer</a></h2>
<p>This is an implementation of the paper <p>This is an implementation of the paper
<a href="https://papers.labml.ai/paper/2105.14103">An Attention Free Transformer</a>.</p> <a href="https://papers.labml.ai/paper/2105.14103">An Attention Free Transformer</a>.</p>
<h2><a href="mlm/index.html">Masked Language Model</a></h2>
<p>This is an implementation of Masked Language Model used for pre-training in paper
<a href="https://arxiv.org/abs/1810.04805">BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding</a>.</p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">72</span><span></span><span class="kn">from</span> <span class="nn">.configs</span> <span class="kn">import</span> <span class="n">TransformerConfigs</span> <div class="highlight"><pre><span class="lineno">77</span><span></span><span class="kn">from</span> <span class="nn">.configs</span> <span class="kn">import</span> <span class="n">TransformerConfigs</span>
<span class="lineno">73</span><span class="kn">from</span> <span class="nn">.models</span> <span class="kn">import</span> <span class="n">TransformerLayer</span><span class="p">,</span> <span class="n">Encoder</span><span class="p">,</span> <span class="n">Decoder</span><span class="p">,</span> <span class="n">Generator</span><span class="p">,</span> <span class="n">EncoderDecoder</span> <span class="lineno">78</span><span class="kn">from</span> <span class="nn">.models</span> <span class="kn">import</span> <span class="n">TransformerLayer</span><span class="p">,</span> <span class="n">Encoder</span><span class="p">,</span> <span class="n">Decoder</span><span class="p">,</span> <span class="n">Generator</span><span class="p">,</span> <span class="n">EncoderDecoder</span>
<span class="lineno">74</span><span class="kn">from</span> <span class="nn">.mha</span> <span class="kn">import</span> <span class="n">MultiHeadAttention</span> <span class="lineno">79</span><span class="kn">from</span> <span class="nn">.mha</span> <span class="kn">import</span> <span class="n">MultiHeadAttention</span>
<span class="lineno">75</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.xl.relative_mha</span> <span class="kn">import</span> <span class="n">RelativeMultiHeadAttention</span></pre></div> <span class="lineno">80</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.xl.relative_mha</span> <span class="kn">import</span> <span class="n">RelativeMultiHeadAttention</span></pre></div>
</div> </div>
</div> </div>
</div> </div>
......
此差异已折叠。
此差异已折叠。
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content=""/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Masked Language Model (MLM)"/>
<meta name="twitter:description" content=""/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/transformers/mlm/readme.html"/>
<meta property="og:title" content="Masked Language Model (MLM)"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Masked Language Model (MLM)"/>
<meta property="og:description" content=""/>
<title>Masked Language Model (MLM)</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css">
<link rel="canonical" href="https://nn.labml.ai/transformers/mlm/readme.html"/>
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">transformers</a>
<a class="parent" href="index.html">mlm</a>
</p>
<p>
<a href="https://github.com/lab-ml/labml_nn/tree/master/labml_nn/transformers/mlm/readme.md">
<img alt="Github"
src="https://img.shields.io/github/stars/lab-ml/nn?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai"
rel="nofollow">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1><a href="https://nn.labml.ai/transformers/mlm/index.html">Masked Language Model (MLM)</a></h1>
<p>This is a <a href="https://pytorch.org">PyTorch</a> implementation of Masked Language Model (MLM)
used to pre-train the BERT model introduced in the paper
<a href="https://arxiv.org/abs/1810.04805">BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding</a>.</p>
<h2>BERT Pretraining</h2>
<p>BERT model is a transformer model.
The paper pre-trains the model using MLM and with next sentence prediction.
We have only implemented MLM here.</p>
<h3>Next sentence prediction</h3>
<p>In next sentence prediction, the model is given two sentences <code>A</code> and <code>B</code> and the model
makes a binary prediction whether <code>B</code> is the sentence that follows <code>A</code> in actual text.
The model is fed with actual sentence pairs 50% of the time and random pairs 50% of the time.
This classification is done while applying MLM. <em>We haven&rsquo;t implemented this here.</em></p>
<h2>Masked LM</h2>
<p>This masks a percentage of tokens at random and train the model to predict
the masked tokens.
They <strong>mask 15% of the tokens</strong> by replacing them with a special <code>[MASK]</code> token.</p>
<p>The loss is computed on predicting the masked tokens only.
This causes a problem during fine-tuning and actual usage since there are no <code>[MASK]</code> tokens
at that time.
Therefore we might not get any meaningful representations.</p>
<p>To over come this <strong>10% of the masked tokens are replaced with the original token</strong>,
and another <strong>10% of the masked tokens are replaced with a random token</strong>.
This trains the model to give representations about the actual token whether or not the
input token at that position is a <code>[MASK]</code>.
And replacing with a random token causes it to
give a representation that has information from the context as well;
because it has to use the context to fix randomly replaced tokens.</p>
<h2>Training</h2>
<p>MLMs are harder to train that autoregressive models because they have a smaller training signal.
i.e. only a small percentage of predictions are trained per sample.</p>
<p>Another problem is since the model is bidirectional, any token can see any other token.
This makes the &ldquo;credit assignment&rdquo; harder.
Let&rsquo;s say you have the character level model trying to predict <code>home *s where i want to be</code>.
At least during the early stages of the training it&rsquo;ll be super hard to figure out why the
replacement for <code>*</code> should be <code>i</code>, it could be anything from the whole sentence.
Whilst, in an autoregressive setting the model will only have to use <code>h</code> to predict <code>o</code> and
<code>hom</code> to predict <code>e</code> and so on. So the model will initially start predicting with a shorter context first
and then learn to use longer contexts later.
Since MLMs have this problem it&rsquo;s a lot faster to train if you start with a smaller sequence length
initially and then use a longer sequence length later.</p>
<p>Here is <a href="https://nn.labml.ai/transformers/mlm/experiment.html">the training code</a> for a simple MLM model.</p>
<p><a href="https://app.labml.ai/run/3a6d22b6c67111ebb03d6764d13a38d1"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a></p>
</div>
<div class='code'>
</div>
</div>
</div>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
</script>
<!-- MathJax configuration -->
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ['$','$'] ],
displayMath: [ ['$$','$$'] ],
processEscapes: true,
processEnvironments: true
},
// Center justify equations in code and markdown cells. Elsewhere
// we use CSS to left justify single line equations in code cells.
displayAlign: 'center',
"HTML-CSS": { fonts: ["TeX"] }
});
</script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
console.log(images);
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>
\ No newline at end of file
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
<meta name="twitter:site" content="@labmlai"/> <meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/> <meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/utils.html"/> <meta property="og:url" content="https://nn.labml.ai/utils/index.html"/>
<meta property="og:title" content="Utilities"/> <meta property="og:title" content="Utilities"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/> <meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/> <meta property="og:site_name" content="LabML Neural Networks"/>
...@@ -22,8 +22,8 @@ ...@@ -22,8 +22,8 @@
<title>Utilities</title> <title>Utilities</title>
<link rel="shortcut icon" href="/icon.png"/> <link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="./pylit.css"> <link rel="stylesheet" href="../pylit.css">
<link rel="canonical" href="https://nn.labml.ai/utils.html"/> <link rel="canonical" href="https://nn.labml.ai/utils/index.html"/>
<!-- Global site tag (gtag.js) - Google Analytics --> <!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script> <script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script> <script>
...@@ -45,10 +45,11 @@ ...@@ -45,10 +45,11 @@
<div class='docs'> <div class='docs'>
<p> <p>
<a class="parent" href="/">home</a> <a class="parent" href="/">home</a>
<a class="parent" href="index.html">utils</a>
</p> </p>
<p> <p>
<a href="https://github.com/lab-ml/labml_nn/tree/master/labml_nn/utils.py"> <a href="https://github.com/lab-ml/labml_nn/tree/master/labml_nn/utils/__init__.py">
<img alt="Github" <img alt="Github"
src="https://img.shields.io/github/stars/lab-ml/nn?style=social" src="https://img.shields.io/github/stars/lab-ml/nn?style=social"
style="max-width:100%;"/></a> style="max-width:100%;"/></a>
......
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content=""/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="tokenizer.py"/>
<meta name="twitter:description" content=""/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/utils/tokenizer.html"/>
<meta property="og:title" content="tokenizer.py"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="tokenizer.py"/>
<meta property="og:description" content=""/>
<title>tokenizer.py</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../pylit.css">
<link rel="canonical" href="https://nn.labml.ai/utils/tokenizer.html"/>
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="index.html">utils</a>
</p>
<p>
<a href="https://github.com/lab-ml/labml_nn/tree/master/labml_nn/utils/tokenizer.py">
<img alt="Github"
src="https://img.shields.io/github/stars/lab-ml/nn?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai"
rel="nofollow">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">1</span><span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Callable</span>
<span class="lineno">2</span>
<span class="lineno">3</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">BaseConfigs</span><span class="p">,</span> <span class="n">option</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<p><a id="OptimizerConfigs"></p>
<h2>Optimizer Configurations</h2>
<p></a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">6</span><span class="k">class</span> <span class="nc">TokenizerConfigs</span><span class="p">(</span><span class="n">BaseConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">13</span> <span class="n">tokenizer</span><span class="p">:</span> <span class="n">Callable</span> <span class="o">=</span> <span class="s1">&#39;character&#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">15</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="lineno">16</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">_primary</span><span class="o">=</span><span class="s1">&#39;tokenizer&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<h3>Basic english tokenizer</h3>
<p>We use character level tokenizer in this experiment.
You can switch by setting,</p>
<pre><code> 'tokenizer': 'basic_english',
</code></pre>
<p>as the configurations dictionary when starting the experiment.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">19</span><span class="nd">@option</span><span class="p">(</span><span class="n">TokenizerConfigs</span><span class="o">.</span><span class="n">tokenizer</span><span class="p">)</span>
<span class="lineno">20</span><span class="k">def</span> <span class="nf">basic_english</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">34</span> <span class="kn">from</span> <span class="nn">torchtext.data</span> <span class="kn">import</span> <span class="n">get_tokenizer</span>
<span class="lineno">35</span> <span class="k">return</span> <span class="n">get_tokenizer</span><span class="p">(</span><span class="s1">&#39;basic_english&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<h3>Character level tokenizer</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">38</span><span class="k">def</span> <span class="nf">character_tokenizer</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">42</span> <span class="k">return</span> <span class="nb">list</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p>Character level tokenizer configuration</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">45</span><span class="nd">@option</span><span class="p">(</span><span class="n">TokenizerConfigs</span><span class="o">.</span><span class="n">tokenizer</span><span class="p">)</span>
<span class="lineno">46</span><span class="k">def</span> <span class="nf">character</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">50</span> <span class="k">return</span> <span class="n">character_tokenizer</span></pre></div>
</div>
</div>
</div>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
</script>
<!-- MathJax configuration -->
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ['$','$'] ],
displayMath: [ ['$$','$$'] ],
processEscapes: true,
processEnvironments: true
},
// Center justify equations in code and markdown cells. Elsewhere
// we use CSS to left justify single line equations in code cells.
displayAlign: 'center',
"HTML-CSS": { fonts: ["TeX"] }
});
</script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
console.log(images);
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>
\ No newline at end of file
...@@ -28,6 +28,7 @@ implementations. ...@@ -28,6 +28,7 @@ implementations.
* [Fast Weights Transformer](transformers/fast_weights/index.html) * [Fast Weights Transformer](transformers/fast_weights/index.html)
* [FNet](transformers/fnet/index.html) * [FNet](transformers/fnet/index.html)
* [Attention Free Transformer](transformers/aft/index.html) * [Attention Free Transformer](transformers/aft/index.html)
* [Masked Language Model](transformers/mlm/index.html)
#### ✨ [Recurrent Highway Networks](recurrent_highway_networks/index.html) #### ✨ [Recurrent Highway Networks](recurrent_highway_networks/index.html)
......
...@@ -67,6 +67,11 @@ This is an implementation of the paper ...@@ -67,6 +67,11 @@ This is an implementation of the paper
This is an implementation of the paper This is an implementation of the paper
[An Attention Free Transformer](https://papers.labml.ai/paper/2105.14103). [An Attention Free Transformer](https://papers.labml.ai/paper/2105.14103).
## [Masked Language Model](mlm/index.html)
This is an implementation of Masked Language Model used for pre-training in paper
[BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805).
""" """
from .configs import TransformerConfigs from .configs import TransformerConfigs
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
--- ---
title: An Attention Free Transformer title: An Attention Free Transformer
summary: > summary: >
This is an annotated implementation/tutorial the AFT (Attention Free Transformer) in PyTorch. This is an annotated implementation/tutorial of the AFT (Attention Free Transformer) in PyTorch.
--- ---
# An Attention Free Transformer # An Attention Free Transformer
......
"""
---
title: Masked Language Model
summary: >
This is an annotated implementation/tutorial of Masked Language Model in PyTorch.
---
# Masked Language Model (MLM)
This is a [PyTorch](https://pytorch.org) implementation of Masked Language Model (MLM)
used to pre-train the BERT model introduced in the paper
[BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805).
## BERT Pretraining
BERT model is a transformer model.
The paper pre-trains the model using MLM and with next sentence prediction.
We have only implemented MLM here.
### Next sentence prediction
In next sentence prediction, the model is given two sentences `A` and `B` and the model
makes a binary prediction whether `B` is the sentence that follows `A` in actual text.
The model is fed with actual sentence pairs 50% of the time and random pairs 50% of the time.
This classification is done while applying MLM. *We haven't implemented this here.*
## Masked LM
This masks a percentage of tokens at random and train the model to predict
the masked tokens.
They **mask 15% of the tokens** by replacing them with a special `[MASK]` token.
The loss is computed on predicting the masked tokens only.
This causes a problem during fine-tuning and actual usage since there are no `[MASK]` tokens
at that time.
Therefore we might not get any meaningful representations.
To over come this **10% of the masked tokens are replaced with the original token**,
and another **10% of the masked tokens are replaced with a random token**.
This trains the model to give representations about the actual token whether or not the
input token at that position is a `[MASK]`.
And replacing with a random token causes it to
give a representation that has information from the context as well;
because it has to use the context to fix randomly replaced tokens.
## Training
MLMs are harder to train that autoregressive models because they have a smaller training signal.
i.e. only a small percentage of predictions are trained per sample.
Another problem is since the model is bidirectional, any token can see any other token.
This makes the "credit assignment" harder.
Let's say you have the character level model trying to predict `home *s where i want to be`.
At least during the early stages of the training it'll be super hard to figure out why the
replacement for `*` should be `i`, it could be anything from the whole sentence.
Whilst, in an autoregressive setting the model will only have to use `h` to predict `o` and
`hom` to predict `e` and so on. So the model will initially start predicting with a shorter context first
and then learn to use longer contexts later.
Since MLMs have this problem it's a lot faster to train if you start with a smaller sequence length
initially and then use a longer sequence length later.
Here is [the training code](experiment.html) for a simple MLM model.
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/3a6d22b6c67111ebb03d6764d13a38d1)
"""
from typing import List
import torch
class MLM:
"""
## Masked LM (MLM)
This class implements the masking procedure for a given batch of token sequences.
"""
def __init__(self, *,
padding_token: int, mask_token: int, no_mask_tokens: List[int], n_tokens: int,
masking_prob: float = 0.15, randomize_prob: float = 0.1, no_change_prob: float = 0.1,
):
"""
* `padding_token` is the padding token `[PAD].
We will use this to mark the labels that shouldn't be used for loss calculation.
* `mask_token` is the masking token `[MASK]`.
* `no_mask_tokens` is a list of tokens that should not be masked.
This is useful if we are training the MLM with another task like classification at the same time,
and we have tokens such as `[CLS]` that shouldn't be masked.
* `n_tokens` total number of tokens (used for generating random tokens)
* `masking_prob` is the masking probability
* `randomize_prob` is the probability of replacing with a random token
* `no_change_prob` is the probability of replacing with original token
"""
self.n_tokens = n_tokens
self.no_change_prob = no_change_prob
self.randomize_prob = randomize_prob
self.masking_prob = masking_prob
self.no_mask_tokens = no_mask_tokens + [padding_token, mask_token]
self.padding_token = padding_token
self.mask_token = mask_token
def __call__(self, x: torch.Tensor):
"""
* `x` is the batch of input token sequences.
It's a tensor of type `long` with shape `[seq_len, batch_size]`.
"""
# Mask `masking_prob` of tokens
full_mask = torch.rand(x.shape, device=x.device) < self.masking_prob
# Unmask `no_mask_tokens`
for t in self.no_mask_tokens:
full_mask &= x != t
# A mask for tokens to be replaced with original tokens
unchanged = full_mask & (torch.rand(x.shape, device=x.device) < self.no_change_prob)
# A mask for tokens to be replaced with a random token
random_token_mask = full_mask & (torch.rand(x.shape, device=x.device) < self.randomize_prob)
# Indexes of tokens to be replaced with random tokens
random_token_idx = torch.nonzero(random_token_mask, as_tuple=True)
# Random tokens for each of the locations
random_tokens = torch.randint(0, self.n_tokens, (len(random_token_idx[0]),), device=x.device)
# The final set of tokens that are going to be replaced by `[MASK]`
mask = full_mask & ~random_token_mask & ~unchanged
# Make a clone of the input for the labels
y = x.clone()
# Replace with `[MASK]` tokens;
# note that this doesn't include the tokens that will have the original token unchanged and
# those that get replace with a random token.
x.masked_fill_(mask, self.mask_token)
# Assign random tokens
x[random_token_idx] = random_tokens
# Assign token `[PAD]` to all the other locations in the labels.
# The labels equal to `[PAD]` will not be used in the loss.
y.masked_fill_(~full_mask, self.padding_token)
# Return the masked input and the labels
return x, y
"""
---
title: Masked Language Model Experiment
summary: This experiment trains Masked Language Model (MLM) on Tiny Shakespeare dataset.
---
# [Masked Language Model (MLM)](index.html) Experiment
This is an annotated PyTorch experiment to train a [Masked Language Model](index.html).
"""
from typing import List
import torch
from torch import nn
from labml import experiment, tracker, logger
from labml.configs import option
from labml.logger import Text
from labml_helpers.metrics.accuracy import Accuracy
from labml_helpers.module import Module
from labml_helpers.train_valid import BatchIndex
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
from labml_nn.transformers import Encoder, Generator
from labml_nn.transformers import TransformerConfigs
from labml_nn.transformers.mlm import MLM
class TransformerMLM(nn.Module):
"""
# Transformer based model for MLM
"""
def __init__(self, *, encoder: Encoder, src_embed: Module, generator: Generator):
"""
* `encoder` is the transformer [Encoder](../models.html#Encoder)
* `src_embed` is the token
[embedding module (with positional encodings)](../models.html#EmbeddingsWithLearnedPositionalEncoding)
* `generator` is the [final fully connected layer](../models.html#Generator) that gives the logits.
"""
super().__init__()
self.generator = generator
self.src_embed = src_embed
self.encoder = encoder
def forward(self, x: torch.Tensor):
# Get the token embeddings with positional encodings
x = self.src_embed(x)
# Transformer encoder
x = self.encoder(x, None)
# Logits for the output
y = self.generator(x)
# Return results
# (second value is for state, since our trainer is used with RNNs also)
return y, None
class Configs(NLPAutoRegressionConfigs):
"""
## Configurations
This inherits from
[`NLPAutoRegressionConfigs`](../../experiments/nlp_autoregression.html)
because it has the data pipeline implementations that we reuse here.
We have implemented a custom training step form MLM.
"""
# MLM model
model: TransformerMLM
# Transformer
transformer: TransformerConfigs
# Number of tokens
n_tokens: int = 'n_tokens_mlm'
# Tokens that shouldn't be masked
no_mask_tokens: List[int] = []
# Probability of masking a token
masking_prob: float = 0.15
# Probability of replacing the mask with a random token
randomize_prob: float = 0.1
# Probability of replacing the mask with original token
no_change_prob: float = 0.1
# [Masked Language Model (MLM) class](index.html) to generate the mask
mlm: MLM
# `[MASK]` token
mask_token: int
# `[PADDING]` token
padding_token: int
# Prompt to sample
prompt: str = [
"We are accounted poor citizens, the patricians good.",
"What authority surfeits on would relieve us: if they",
"would yield us but the superfluity, while it were",
"wholesome, we might guess they relieved us humanely;",
"but they think we are too dear: the leanness that",
"afflicts us, the object of our misery, is as an",
"inventory to particularise their abundance; our",
"sufferance is a gain to them Let us revenge this with",
"our pikes, ere we become rakes: for the gods know I",
"speak this in hunger for bread, not in thirst for revenge.",
]
def init(self):
"""
### Initialization
"""
# `[MASK]` token
self.mask_token = self.n_tokens - 1
# `[PAD]` token
self.padding_token = self.n_tokens - 2
# [Masked Language Model (MLM) class](index.html) to generate the mask
self.mlm = MLM(padding_token=self.padding_token,
mask_token=self.mask_token,
no_mask_tokens=self.no_mask_tokens,
n_tokens=self.n_tokens,
masking_prob=self.masking_prob,
randomize_prob=self.randomize_prob,
no_change_prob=self.no_change_prob)
# Accuracy metric (ignore the labels equal to `[PAD]`)
self.accuracy = Accuracy(ignore_index=self.padding_token)
# Cross entropy loss (ignore the labels equal to `[PAD]`)
self.loss_func = nn.CrossEntropyLoss(ignore_index=self.padding_token)
#
super().init()
def step(self, batch: any, batch_idx: BatchIndex):
"""
### Training or validation step
"""
# Move the input to the device
data = batch[0].to(self.device)
# Update global step (number of tokens processed) when in training mode
if self.mode.is_train:
tracker.add_global_step(data.shape[0] * data.shape[1])
# Get the masked input and labels
with torch.no_grad():
data, labels = self.mlm(data)
# Whether to capture model outputs
with self.mode.update(is_log_activations=batch_idx.is_last):
# Get model outputs.
# It's returning a tuple for states when using RNNs.
# This is not implemented yet.
output, *_ = self.model(data)
# Calculate and log the loss
loss = self.loss_func(output.view(-1, output.shape[-1]), labels.view(-1))
tracker.add("loss.", loss)
# Calculate and log accuracy
self.accuracy(output, labels)
self.accuracy.track()
# Train the model
if self.mode.is_train:
# Calculate gradients
loss.backward()
# Clip gradients
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
# Take optimizer step
self.optimizer.step()
# Log the model parameters and gradients on last batch of every epoch
if batch_idx.is_last:
tracker.add('model', self.model)
# Clear the gradients
self.optimizer.zero_grad()
# Save the tracked metrics
tracker.save()
@torch.no_grad()
def sample(self):
"""
### Sampling function to generate samples periodically while training
"""
# Empty tensor for data filled with `[PAD]`.
data = torch.full((self.seq_len, len(self.prompt)), self.padding_token, dtype=torch.long)
# Add the prompts one by one
for i, p in enumerate(self.prompt):
# Get token indexes
d = self.text.text_to_i(p)
# Add to the tensor
s = min(self.seq_len, len(d))
data[:s, i] = d[:s]
# Move the tensor to current device
data = data.to(self.device)
# Get masked input and labels
data, labels = self.mlm(data)
# Get model outputs
output, *_ = self.model(data)
# Print the samples generated
for j in range(data.shape[1]):
# Collect output from printing
log = []
# For each token
for i in range(len(data)):
# If the label is not `[PAD]`
if labels[i, j] != self.padding_token:
# Get the prediction
t = output[i, j].argmax().item()
# If it's a printable character
if t < len(self.text.itos):
# Correct prediction
if t == labels[i, j]:
log.append((self.text.itos[t], Text.value))
# Incorrect prediction
else:
log.append((self.text.itos[t], Text.danger))
# If it's not a printable character
else:
log.append(('*', Text.danger))
# If the label is `[PAD]` (unmasked) print the original.
elif data[i, j] < len(self.text.itos):
log.append((self.text.itos[data[i, j]], Text.subtle))
# Print
logger.log(log)
@option(Configs.n_tokens)
def n_tokens_mlm(c: Configs):
"""
Number of tokens including `[PAD]` and `[MASK]`
"""
return c.text.n_tokens + 2
@option(Configs.transformer)
def _transformer_configs(c: Configs):
"""
### Transformer configurations
"""
# We use our
# [configurable transformer implementation](../configs.html#TransformerConfigs)
conf = TransformerConfigs()
# Set the vocabulary sizes for embeddings and generating logits
conf.n_src_vocab = c.n_tokens
conf.n_tgt_vocab = c.n_tokens
# Embedding size
conf.d_model = c.d_model
#
return conf
@option(Configs.model)
def _model(c: Configs):
"""
Create classification model
"""
m = TransformerMLM(encoder=c.transformer.encoder,
src_embed=c.transformer.src_embed,
generator=c.transformer.generator).to(c.device)
return m
def main():
# Create experiment
experiment.create(name="mlm")
# Create configs
conf = Configs()
# Override configurations
experiment.configs(conf, {
# Batch size
'batch_size': 64,
# Sequence length of $32$. We use a short sequence length to train faster.
# Otherwise it takes forever to train.
'seq_len': 32,
# Train for 1024 epochs.
'epochs': 1024,
# Switch between training and validation for $1$ times
# per epoch
'inner_iterations': 1,
# Transformer configurations (same as defaults)
'd_model': 128,
'transformer.ffn.d_ff': 256,
'transformer.n_heads': 8,
'transformer.n_layers': 6,
# Use [Noam optimizer](../../optimizers/noam.html)
'optimizer.optimizer': 'Noam',
'optimizer.learning_rate': 1.,
})
# Set models for saving and loading
experiment.add_pytorch_models({'model': conf.model})
# Start the experiment
with experiment.start():
# Run training
conf.run()
#
if __name__ == '__main__':
main()
# [Masked Language Model (MLM)](https://nn.labml.ai/transformers/mlm/index.html)
This is a [PyTorch](https://pytorch.org) implementation of Masked Language Model (MLM)
used to pre-train the BERT model introduced in the paper
[BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805).
## BERT Pretraining
BERT model is a transformer model.
The paper pre-trains the model using MLM and with next sentence prediction.
We have only implemented MLM here.
### Next sentence prediction
In next sentence prediction, the model is given two sentences `A` and `B` and the model
makes a binary prediction whether `B` is the sentence that follows `A` in actual text.
The model is fed with actual sentence pairs 50% of the time and random pairs 50% of the time.
This classification is done while applying MLM. *We haven't implemented this here.*
## Masked LM
This masks a percentage of tokens at random and train the model to predict
the masked tokens.
They **mask 15% of the tokens** by replacing them with a special `[MASK]` token.
The loss is computed on predicting the masked tokens only.
This causes a problem during fine-tuning and actual usage since there are no `[MASK]` tokens
at that time.
Therefore we might not get any meaningful representations.
To over come this **10% of the masked tokens are replaced with the original token**,
and another **10% of the masked tokens are replaced with a random token**.
This trains the model to give representations about the actual token whether or not the
input token at that position is a `[MASK]`.
And replacing with a random token causes it to
give a representation that has information from the context as well;
because it has to use the context to fix randomly replaced tokens.
## Training
MLMs are harder to train that autoregressive models because they have a smaller training signal.
i.e. only a small percentage of predictions are trained per sample.
Another problem is since the model is bidirectional, any token can see any other token.
This makes the "credit assignment" harder.
Let's say you have the character level model trying to predict `home *s where i want to be`.
At least during the early stages of the training it'll be super hard to figure out why the
replacement for `*` should be `i`, it could be anything from the whole sentence.
Whilst, in an autoregressive setting the model will only have to use `h` to predict `o` and
`hom` to predict `e` and so on. So the model will initially start predicting with a shorter context first
and then learn to use longer contexts later.
Since MLMs have this problem it's a lot faster to train if you start with a smaller sequence length
initially and then use a longer sequence length later.
Here is [the training code](https://nn.labml.ai/transformers/mlm/experiment.html) for a simple MLM model.
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/3a6d22b6c67111ebb03d6764d13a38d1)
from typing import Callable
from labml.configs import BaseConfigs, option
class TokenizerConfigs(BaseConfigs):
"""
<a id="OptimizerConfigs">
## Optimizer Configurations
</a>
"""
tokenizer: Callable = 'character'
def __init__(self):
super().__init__(_primary='tokenizer')
@option(TokenizerConfigs.tokenizer)
def basic_english():
"""
### Basic english tokenizer
We use character level tokenizer in this experiment.
You can switch by setting,
```
'tokenizer': 'basic_english',
```
as the configurations dictionary when starting the experiment.
"""
from torchtext.data import get_tokenizer
return get_tokenizer('basic_english')
def character_tokenizer(x: str):
"""
### Character level tokenizer
"""
return list(x)
@option(TokenizerConfigs.tokenizer)
def character():
"""
Character level tokenizer configuration
"""
return character_tokenizer
...@@ -34,6 +34,7 @@ implementations almost weekly. ...@@ -34,6 +34,7 @@ implementations almost weekly.
* [Fast Weights Transformer](https://nn.labml.ai/transformers/fast_weights/index.html) * [Fast Weights Transformer](https://nn.labml.ai/transformers/fast_weights/index.html)
* [FNet](https://nn.labml.ai/transformers/fnet/index.html) * [FNet](https://nn.labml.ai/transformers/fnet/index.html)
* [Attention Free Transformer](https://nn.labml.ai/transformers/aft/index.html) * [Attention Free Transformer](https://nn.labml.ai/transformers/aft/index.html)
* [Masked Language Model](https://nn.labml.ai/transformers/mlm/index.html)
#### ✨ [Recurrent Highway Networks](https://nn.labml.ai/recurrent_highway_networks/index.html) #### ✨ [Recurrent Highway Networks](https://nn.labml.ai/recurrent_highway_networks/index.html)
......
...@@ -21,7 +21,7 @@ setuptools.setup( ...@@ -21,7 +21,7 @@ setuptools.setup(
'test', 'test',
'test.*')), 'test.*')),
install_requires=['labml>=0.4.110', install_requires=['labml>=0.4.110',
'labml-helpers>=0.4.76', 'labml-helpers>=0.4.77',
'torch', 'torch',
'einops', 'einops',
'numpy'], 'numpy'],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册