提交 4cf1d74e 编写于 作者: V Varuna Jayasiri

sampling links

上级 f3189e23
...@@ -149,6 +149,11 @@ ...@@ -149,6 +149,11 @@
<ul><li><a href="uncertainty/evidence/index.html">Evidential Deep Learning to Quantify Classification Uncertainty</a></li></ul> <ul><li><a href="uncertainty/evidence/index.html">Evidential Deep Learning to Quantify Classification Uncertainty</a></li></ul>
<h4><a href="activations/index.html">Activations</a></h4> <h4><a href="activations/index.html">Activations</a></h4>
<ul><li><a href="activations/fta/index.html">Fuzzy Tiling Activations</a></li></ul> <ul><li><a href="activations/fta/index.html">Fuzzy Tiling Activations</a></li></ul>
<h4><a href="sampling/index.html">Sampling Techniques</a></h4>
<ul><li><a href="sampling/greedy.html">Greedy Sampling</a> </li>
<li><a href="sampling/temperature.html">Temperature Sampling</a> </li>
<li><a href="sampling/top_k.html">Top-k Sampling</a> </li>
<li><a href="sampling/nucleus.html">Nucleus Sampling</a></li></ul>
<h2>Highlighted Research Paper PDFs</h2> <h2>Highlighted Research Paper PDFs</h2>
<ul><li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2204.10628.pdf">Autoregressive Search Engines: Generating Substrings as Document Identifiers</a> </li> <ul><li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2204.10628.pdf">Autoregressive Search Engines: Generating Substrings as Document Identifiers</a> </li>
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2203.15556.pdf">Training Compute-Optimal Large Language Models</a> </li> <li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2203.15556.pdf">Training Compute-Optimal Large Language Models</a> </li>
......
...@@ -76,12 +76,13 @@ ...@@ -76,12 +76,13 @@
</div> </div>
<h1>Greedy Sampling</h1> <h1>Greedy Sampling</h1>
<p>Here we sample the most likely token from the distribution of logits.</p> <p>Here we sample the most likely token from the distribution of logits.</p>
<p>Here&#x27;s an <a href="experiment.html">experiment</a> that uses these sampling techniques.</p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">12</span><span></span><span class="kn">import</span> <span class="nn">torch</span> <div class="highlight"><pre><span class="lineno">14</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">13</span> <span class="lineno">15</span>
<span class="lineno">14</span><span class="kn">from</span> <span class="nn">labml_nn.sampling</span> <span class="kn">import</span> <span class="n">Sampler</span></pre></div> <span class="lineno">16</span><span class="kn">from</span> <span class="nn">labml_nn.sampling</span> <span class="kn">import</span> <span class="n">Sampler</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-1'> <div class='section' id='section-1'>
...@@ -92,7 +93,7 @@ ...@@ -92,7 +93,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">17</span><span class="k">class</span> <span class="nc">GreedySampler</span><span class="p">(</span><span class="n">Sampler</span><span class="p">):</span></pre></div> <div class="highlight"><pre><span class="lineno">19</span><span class="k">class</span> <span class="nc">GreedySampler</span><span class="p">(</span><span class="n">Sampler</span><span class="p">):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-2'> <div class='section' id='section-2'>
...@@ -104,7 +105,7 @@ ...@@ -104,7 +105,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">18</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logits</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div> <div class="highlight"><pre><span class="lineno">20</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logits</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-3'> <div class='section' id='section-3'>
...@@ -115,7 +116,7 @@ ...@@ -115,7 +116,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">22</span> <span class="k">return</span> <span class="n">logits</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">24</span> <span class="k">return</span> <span class="n">logits</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='footer'> <div class='footer'>
......
...@@ -81,13 +81,14 @@ ...@@ -81,13 +81,14 @@
<p><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:2.541535em;vertical-align:-1.49153em;"></span><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.050005em;"><span style="top:-1.75857em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight"><span class="mord mathnormal mtight">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span><span class="mrel mtight"></span><span class="mord mtight coloredeq eqd" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.22222em">V</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8220357142857143em;"><span style="top:-2.8220357142857138em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5357142857142856em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mopen mtight" style="">(</span><span class="mord mtight coloredeq eqe" style=""><span class="mord mathnormal mtight" style="">p</span></span><span class="mclose mtight" style="">)</span></span></span></span></span></span></span></span></span></span></span></span></span><span style="top:-3.0500049999999996em;"><span class="pstrut" style="height:3.05em;"></span><span><span class="mop op-symbol large-op"></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:1.49153em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqb" style=""><span class="mord mathnormal" style="margin-right:0.13889em">P</span><span class="mopen" style="">(</span><span class="mord" style=""><span class="mord mathnormal" style="">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord" style=""></span><span class="mord" style=""><span class="mord mathnormal" style="">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style="">1</span><span class="mrel mtight" style="">:</span><span class="mord mathnormal mtight" style="">i</span><span class="mbin mtight" style=""></span><span class="mord mtight" style="">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.208331em;"><span></span></span></span></span></span></span><span class="mclose" style="">)</span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqe" style=""><span class="mord mathnormal" style="">p</span></span></span></span></span></span></p> <p><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:2.541535em;vertical-align:-1.49153em;"></span><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.050005em;"><span style="top:-1.75857em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight"><span class="mord mathnormal mtight">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span><span class="mrel mtight"></span><span class="mord mtight coloredeq eqd" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.22222em">V</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8220357142857143em;"><span style="top:-2.8220357142857138em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5357142857142856em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mopen mtight" style="">(</span><span class="mord mtight coloredeq eqe" style=""><span class="mord mathnormal mtight" style="">p</span></span><span class="mclose mtight" style="">)</span></span></span></span></span></span></span></span></span></span></span></span></span><span style="top:-3.0500049999999996em;"><span class="pstrut" style="height:3.05em;"></span><span><span class="mop op-symbol large-op"></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:1.49153em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqb" style=""><span class="mord mathnormal" style="margin-right:0.13889em">P</span><span class="mopen" style="">(</span><span class="mord" style=""><span class="mord mathnormal" style="">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord" style=""></span><span class="mord" style=""><span class="mord mathnormal" style="">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style="">1</span><span class="mrel mtight" style="">:</span><span class="mord mathnormal mtight" style="">i</span><span class="mbin mtight" style=""></span><span class="mord mtight" style="">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.208331em;"><span></span></span></span></span></span></span><span class="mclose" style="">)</span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqe" style=""><span class="mord mathnormal" style="">p</span></span></span></span></span></span></p>
<p>That is, we pick the highest probable tokens until the sum of their probabilities is less that <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqe" style=""><span class="mord mathnormal" style="">p</span></span></span></span></span>.</p> <p>That is, we pick the highest probable tokens until the sum of their probabilities is less that <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqe" style=""><span class="mord mathnormal" style="">p</span></span></span></span></span>.</p>
<p>Then we sample from the selected tokens.</p> <p>Then we sample from the selected tokens.</p>
<p>Here&#x27;s an <a href="experiment.html">experiment</a> that uses these sampling techniques.</p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">27</span><span></span><span class="kn">import</span> <span class="nn">torch</span> <div class="highlight"><pre><span class="lineno">29</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">28</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span> <span class="lineno">30</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">29</span> <span class="lineno">31</span>
<span class="lineno">30</span><span class="kn">from</span> <span class="nn">labml_nn.sampling</span> <span class="kn">import</span> <span class="n">Sampler</span></pre></div> <span class="lineno">32</span><span class="kn">from</span> <span class="nn">labml_nn.sampling</span> <span class="kn">import</span> <span class="n">Sampler</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-1'> <div class='section' id='section-1'>
...@@ -99,7 +100,7 @@ ...@@ -99,7 +100,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">33</span><span class="k">class</span> <span class="nc">NucleusSampler</span><span class="p">(</span><span class="n">Sampler</span><span class="p">):</span></pre></div> <div class="highlight"><pre><span class="lineno">35</span><span class="k">class</span> <span class="nc">NucleusSampler</span><span class="p">(</span><span class="n">Sampler</span><span class="p">):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-2'> <div class='section' id='section-2'>
...@@ -114,7 +115,7 @@ ...@@ -114,7 +115,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">37</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">p</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">sampler</span><span class="p">:</span> <span class="n">Sampler</span><span class="p">):</span></pre></div> <div class="highlight"><pre><span class="lineno">39</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">p</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">sampler</span><span class="p">:</span> <span class="n">Sampler</span><span class="p">):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-3'> <div class='section' id='section-3'>
...@@ -125,8 +126,8 @@ ...@@ -125,8 +126,8 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">42</span> <span class="bp">self</span><span class="o">.</span><span class="n">p</span> <span class="o">=</span> <span class="n">p</span> <div class="highlight"><pre><span class="lineno">44</span> <span class="bp">self</span><span class="o">.</span><span class="n">p</span> <span class="o">=</span> <span class="n">p</span>
<span class="lineno">43</span> <span class="bp">self</span><span class="o">.</span><span class="n">sampler</span> <span class="o">=</span> <span class="n">sampler</span></pre></div> <span class="lineno">45</span> <span class="bp">self</span><span class="o">.</span><span class="n">sampler</span> <span class="o">=</span> <span class="n">sampler</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-4'> <div class='section' id='section-4'>
...@@ -138,7 +139,7 @@ ...@@ -138,7 +139,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">45</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">47</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-5'> <div class='section' id='section-5'>
...@@ -150,7 +151,7 @@ ...@@ -150,7 +151,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">47</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logits</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div> <div class="highlight"><pre><span class="lineno">49</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logits</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-6'> <div class='section' id='section-6'>
...@@ -162,7 +163,7 @@ ...@@ -162,7 +163,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">53</span> <span class="n">probs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">logits</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">55</span> <span class="n">probs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">logits</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-7'> <div class='section' id='section-7'>
...@@ -174,7 +175,7 @@ ...@@ -174,7 +175,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">56</span> <span class="n">sorted_probs</span><span class="p">,</span> <span class="n">indices</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sort</span><span class="p">(</span><span class="n">probs</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="n">descending</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">58</span> <span class="n">sorted_probs</span><span class="p">,</span> <span class="n">indices</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sort</span><span class="p">(</span><span class="n">probs</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="n">descending</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-8'> <div class='section' id='section-8'>
...@@ -186,7 +187,7 @@ ...@@ -186,7 +187,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">58</span> <span class="n">cum_sum_probs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">sorted_probs</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">60</span> <span class="n">cum_sum_probs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">sorted_probs</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-9'> <div class='section' id='section-9'>
...@@ -198,7 +199,7 @@ ...@@ -198,7 +199,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">60</span> <span class="n">nucleus</span> <span class="o">=</span> <span class="n">cum_sum_probs</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">p</span></pre></div> <div class="highlight"><pre><span class="lineno">62</span> <span class="n">nucleus</span> <span class="o">=</span> <span class="n">cum_sum_probs</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">p</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-10'> <div class='section' id='section-10'>
...@@ -210,7 +211,7 @@ ...@@ -210,7 +211,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">63</span> <span class="n">nucleus</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">nucleus</span><span class="o">.</span><span class="n">new_ones</span><span class="p">(</span><span class="n">nucleus</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="p">(</span><span class="mi">1</span><span class="p">,)),</span> <span class="n">nucleus</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="p">:</span><span class="o">-</span><span class="mi">1</span><span class="p">]],</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">65</span> <span class="n">nucleus</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">nucleus</span><span class="o">.</span><span class="n">new_ones</span><span class="p">(</span><span class="n">nucleus</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="p">(</span><span class="mi">1</span><span class="p">,)),</span> <span class="n">nucleus</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="p">:</span><span class="o">-</span><span class="mi">1</span><span class="p">]],</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-11'> <div class='section' id='section-11'>
...@@ -222,8 +223,8 @@ ...@@ -222,8 +223,8 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">66</span> <span class="n">sorted_log_probs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">sorted_probs</span><span class="p">)</span> <div class="highlight"><pre><span class="lineno">68</span> <span class="n">sorted_log_probs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">sorted_probs</span><span class="p">)</span>
<span class="lineno">67</span> <span class="n">sorted_log_probs</span><span class="p">[</span><span class="o">~</span><span class="n">nucleus</span><span class="p">]</span> <span class="o">=</span> <span class="nb">float</span><span class="p">(</span><span class="s1">&#39;-inf&#39;</span><span class="p">)</span></pre></div> <span class="lineno">69</span> <span class="n">sorted_log_probs</span><span class="p">[</span><span class="o">~</span><span class="n">nucleus</span><span class="p">]</span> <span class="o">=</span> <span class="nb">float</span><span class="p">(</span><span class="s1">&#39;-inf&#39;</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-12'> <div class='section' id='section-12'>
...@@ -235,7 +236,7 @@ ...@@ -235,7 +236,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">70</span> <span class="n">sampled_sorted_indexes</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sampler</span><span class="p">(</span><span class="n">sorted_log_probs</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">72</span> <span class="n">sampled_sorted_indexes</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sampler</span><span class="p">(</span><span class="n">sorted_log_probs</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-13'> <div class='section' id='section-13'>
...@@ -247,7 +248,7 @@ ...@@ -247,7 +248,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">73</span> <span class="n">res</span> <span class="o">=</span> <span class="n">indices</span><span class="o">.</span><span class="n">gather</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">sampled_sorted_indexes</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">))</span></pre></div> <div class="highlight"><pre><span class="lineno">75</span> <span class="n">res</span> <span class="o">=</span> <span class="n">indices</span><span class="o">.</span><span class="n">gather</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">sampled_sorted_indexes</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">))</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-14'> <div class='section' id='section-14'>
...@@ -259,7 +260,7 @@ ...@@ -259,7 +260,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">76</span> <span class="k">return</span> <span class="n">res</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">78</span> <span class="k">return</span> <span class="n">res</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='footer'> <div class='footer'>
......
...@@ -78,13 +78,14 @@ ...@@ -78,13 +78,14 @@
<p>Here we sample from the following probability distribution where <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqd" style=""><span class="mord mathnormal" style="margin-right:0.22222em">V</span></span></span></span></span> is the vocabulary, <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.7857599999999999em;vertical-align:-0.3551999999999999em;"></span><span class="mord"><span class="mord mathnormal">u</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.34480000000000005em;"><span style="top:-2.5198em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">1</span><span class="mrel mtight">:</span><span class="mord mtight"></span><span class="mord mtight coloredeq eqd" style=""><span class="mord mathnormal mtight" style="margin-right:0.22222em">V</span></span><span class="mord mtight"></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.3551999999999999em;"><span></span></span></span></span></span></span></span></span></span> are the logits of the distribution and T is the temperature:</p> <p>Here we sample from the following probability distribution where <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqd" style=""><span class="mord mathnormal" style="margin-right:0.22222em">V</span></span></span></span></span> is the vocabulary, <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.7857599999999999em;vertical-align:-0.3551999999999999em;"></span><span class="mord"><span class="mord mathnormal">u</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.34480000000000005em;"><span style="top:-2.5198em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">1</span><span class="mrel mtight">:</span><span class="mord mtight"></span><span class="mord mtight coloredeq eqd" style=""><span class="mord mathnormal mtight" style="margin-right:0.22222em">V</span></span><span class="mord mtight"></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.3551999999999999em;"><span></span></span></span></span></span></span></span></span></span> are the logits of the distribution and T is the temperature:</p>
<p><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.13889em;">P</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord coloredeq eqd" style=""><span class="mord mathnormal" style="margin-right:0.22222em">V</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.01968em;">l</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord"></span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">1</span><span class="mrel mtight">:</span><span class="mord mathnormal mtight">i</span><span class="mbin mtight"></span><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.208331em;"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:2.61953em;vertical-align:-1.13453em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.4849999999999999em;"><span style="top:-2.301288em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mop"><span class="mop op-symbol small-op" style="position:relative;top:-0.0000050000000000050004em;"></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.16195399999999993em;"><span style="top:-2.40029em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.05724em;">j</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.43581800000000004em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop">exp</span><span class="mopen">(</span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.808712em;"><span style="top:-2.6550000000000002em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.13889em;">T</span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.50732em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight"><span class="mord mathnormal mtight">u</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mathnormal mtight" style="margin-right:0.05724em;">j</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.2818857142857143em;"><span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.345em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mclose">)</span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.7350000000000003em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mop">exp</span><span class="mopen">(</span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.717252em;"><span style="top:-2.6550000000000002em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.13889em;">T</span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.41586em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight"><span class="mord mathnormal mtight">u</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.3487714285714287em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mathnormal mtight" style="margin-right:0.01968em;">l</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15122857142857138em;"><span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.345em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mclose">)</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:1.13453em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span></p> <p><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.13889em;">P</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord coloredeq eqd" style=""><span class="mord mathnormal" style="margin-right:0.22222em">V</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.01968em;">l</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord"></span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">1</span><span class="mrel mtight">:</span><span class="mord mathnormal mtight">i</span><span class="mbin mtight"></span><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.208331em;"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:2.61953em;vertical-align:-1.13453em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.4849999999999999em;"><span style="top:-2.301288em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mop"><span class="mop op-symbol small-op" style="position:relative;top:-0.0000050000000000050004em;"></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.16195399999999993em;"><span style="top:-2.40029em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.05724em;">j</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.43581800000000004em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop">exp</span><span class="mopen">(</span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.808712em;"><span style="top:-2.6550000000000002em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.13889em;">T</span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.50732em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight"><span class="mord mathnormal mtight">u</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mathnormal mtight" style="margin-right:0.05724em;">j</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.2818857142857143em;"><span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.345em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mclose">)</span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.7350000000000003em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mop">exp</span><span class="mopen">(</span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.717252em;"><span style="top:-2.6550000000000002em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.13889em;">T</span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.41586em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight"><span class="mord mathnormal mtight">u</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.3487714285714287em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mathnormal mtight" style="margin-right:0.01968em;">l</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15122857142857138em;"><span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.345em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mclose">)</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:1.13453em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span></p>
<p><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord mathnormal" style="margin-right:0.13889em;">T</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">1</span></span></span></span> is normal random sampling.</p> <p><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord mathnormal" style="margin-right:0.13889em;">T</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">1</span></span></span></span> is normal random sampling.</p>
<p>Here&#x27;s an <a href="experiment.html">experiment</a> that uses these sampling techniques.</p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">17</span><span></span><span class="kn">import</span> <span class="nn">torch</span> <div class="highlight"><pre><span class="lineno">19</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">18</span><span class="kn">from</span> <span class="nn">torch.distributions</span> <span class="kn">import</span> <span class="n">Categorical</span> <span class="lineno">20</span><span class="kn">from</span> <span class="nn">torch.distributions</span> <span class="kn">import</span> <span class="n">Categorical</span>
<span class="lineno">19</span> <span class="lineno">21</span>
<span class="lineno">20</span><span class="kn">from</span> <span class="nn">labml_nn.sampling</span> <span class="kn">import</span> <span class="n">Sampler</span></pre></div> <span class="lineno">22</span><span class="kn">from</span> <span class="nn">labml_nn.sampling</span> <span class="kn">import</span> <span class="n">Sampler</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-1'> <div class='section' id='section-1'>
...@@ -96,7 +97,7 @@ ...@@ -96,7 +97,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">23</span><span class="k">class</span> <span class="nc">TemperatureSampler</span><span class="p">(</span><span class="n">Sampler</span><span class="p">):</span></pre></div> <div class="highlight"><pre><span class="lineno">25</span><span class="k">class</span> <span class="nc">TemperatureSampler</span><span class="p">(</span><span class="n">Sampler</span><span class="p">):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-2'> <div class='section' id='section-2'>
...@@ -109,7 +110,7 @@ ...@@ -109,7 +110,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">27</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">temperature</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">):</span></pre></div> <div class="highlight"><pre><span class="lineno">29</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">temperature</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-3'> <div class='section' id='section-3'>
...@@ -120,7 +121,7 @@ ...@@ -120,7 +121,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">31</span> <span class="bp">self</span><span class="o">.</span><span class="n">temperature</span> <span class="o">=</span> <span class="n">temperature</span></pre></div> <div class="highlight"><pre><span class="lineno">33</span> <span class="bp">self</span><span class="o">.</span><span class="n">temperature</span> <span class="o">=</span> <span class="n">temperature</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-4'> <div class='section' id='section-4'>
...@@ -132,7 +133,7 @@ ...@@ -132,7 +133,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">33</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logits</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div> <div class="highlight"><pre><span class="lineno">35</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logits</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-5'> <div class='section' id='section-5'>
...@@ -144,7 +145,7 @@ ...@@ -144,7 +145,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">39</span> <span class="n">dist</span> <span class="o">=</span> <span class="n">Categorical</span><span class="p">(</span><span class="n">logits</span><span class="o">=</span><span class="n">logits</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">temperature</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">41</span> <span class="n">dist</span> <span class="o">=</span> <span class="n">Categorical</span><span class="p">(</span><span class="n">logits</span><span class="o">=</span><span class="n">logits</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">temperature</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-6'> <div class='section' id='section-6'>
...@@ -156,7 +157,7 @@ ...@@ -156,7 +157,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">42</span> <span class="k">return</span> <span class="n">dist</span><span class="o">.</span><span class="n">sample</span><span class="p">()</span></pre></div> <div class="highlight"><pre><span class="lineno">44</span> <span class="k">return</span> <span class="n">dist</span><span class="o">.</span><span class="n">sample</span><span class="p">()</span></pre></div>
</div> </div>
</div> </div>
<div class='footer'> <div class='footer'>
......
...@@ -76,12 +76,13 @@ ...@@ -76,12 +76,13 @@
</div> </div>
<h1>Top-k Sampling</h1> <h1>Top-k Sampling</h1>
<p>Here we first pick the top-k tokens from the distribution of logits, and then sample from them.</p> <p>Here we first pick the top-k tokens from the distribution of logits, and then sample from them.</p>
<p>Here&#x27;s an <a href="experiment.html">experiment</a> that uses these sampling techniques.</p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">13</span><span></span><span class="kn">import</span> <span class="nn">torch</span> <div class="highlight"><pre><span class="lineno">15</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">14</span> <span class="lineno">16</span>
<span class="lineno">15</span><span class="kn">from</span> <span class="nn">labml_nn.sampling</span> <span class="kn">import</span> <span class="n">Sampler</span></pre></div> <span class="lineno">17</span><span class="kn">from</span> <span class="nn">labml_nn.sampling</span> <span class="kn">import</span> <span class="n">Sampler</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-1'> <div class='section' id='section-1'>
...@@ -93,7 +94,7 @@ ...@@ -93,7 +94,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">18</span><span class="k">class</span> <span class="nc">TopKSampler</span><span class="p">(</span><span class="n">Sampler</span><span class="p">):</span></pre></div> <div class="highlight"><pre><span class="lineno">20</span><span class="k">class</span> <span class="nc">TopKSampler</span><span class="p">(</span><span class="n">Sampler</span><span class="p">):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-2'> <div class='section' id='section-2'>
...@@ -110,7 +111,7 @@ ...@@ -110,7 +111,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">22</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">k</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">sampler</span><span class="p">:</span> <span class="n">Sampler</span><span class="p">):</span></pre></div> <div class="highlight"><pre><span class="lineno">24</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">k</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">sampler</span><span class="p">:</span> <span class="n">Sampler</span><span class="p">):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-3'> <div class='section' id='section-3'>
...@@ -121,8 +122,8 @@ ...@@ -121,8 +122,8 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">30</span> <span class="bp">self</span><span class="o">.</span><span class="n">k</span> <span class="o">=</span> <span class="n">k</span> <div class="highlight"><pre><span class="lineno">32</span> <span class="bp">self</span><span class="o">.</span><span class="n">k</span> <span class="o">=</span> <span class="n">k</span>
<span class="lineno">31</span> <span class="bp">self</span><span class="o">.</span><span class="n">sampler</span> <span class="o">=</span> <span class="n">sampler</span></pre></div> <span class="lineno">33</span> <span class="bp">self</span><span class="o">.</span><span class="n">sampler</span> <span class="o">=</span> <span class="n">sampler</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-4'> <div class='section' id='section-4'>
...@@ -134,7 +135,7 @@ ...@@ -134,7 +135,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">33</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logits</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div> <div class="highlight"><pre><span class="lineno">35</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logits</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-5'> <div class='section' id='section-5'>
...@@ -146,7 +147,7 @@ ...@@ -146,7 +147,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">38</span> <span class="n">zeros</span> <span class="o">=</span> <span class="n">logits</span><span class="o">.</span><span class="n">new_ones</span><span class="p">(</span><span class="n">logits</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">*</span> <span class="nb">float</span><span class="p">(</span><span class="s1">&#39;-inf&#39;</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">40</span> <span class="n">zeros</span> <span class="o">=</span> <span class="n">logits</span><span class="o">.</span><span class="n">new_ones</span><span class="p">(</span><span class="n">logits</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">*</span> <span class="nb">float</span><span class="p">(</span><span class="s1">&#39;-inf&#39;</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-6'> <div class='section' id='section-6'>
...@@ -158,7 +159,7 @@ ...@@ -158,7 +159,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">40</span> <span class="n">values</span><span class="p">,</span> <span class="n">indices</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">topk</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">k</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">42</span> <span class="n">values</span><span class="p">,</span> <span class="n">indices</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">topk</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">k</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-7'> <div class='section' id='section-7'>
...@@ -170,7 +171,7 @@ ...@@ -170,7 +171,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">43</span> <span class="n">zeros</span><span class="o">.</span><span class="n">scatter_</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">indices</span><span class="p">,</span> <span class="n">values</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">45</span> <span class="n">zeros</span><span class="o">.</span><span class="n">scatter_</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">indices</span><span class="p">,</span> <span class="n">values</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-8'> <div class='section' id='section-8'>
...@@ -182,7 +183,7 @@ ...@@ -182,7 +183,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">46</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">sampler</span><span class="p">(</span><span class="n">zeros</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">48</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">sampler</span><span class="p">(</span><span class="n">zeros</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='footer'> <div class='footer'>
......
...@@ -316,49 +316,49 @@ ...@@ -316,49 +316,49 @@
<url> <url>
<loc>https://nn.labml.ai/sampling/experiment_tiny.html</loc> <loc>https://nn.labml.ai/sampling/experiment_tiny.html</loc>
<lastmod>2022-07-29T16:30:00+00:00</lastmod> <lastmod>2022-08-08T16:30:00+00:00</lastmod>
<priority>1.00</priority> <priority>1.00</priority>
</url> </url>
<url> <url>
<loc>https://nn.labml.ai/sampling/greedy.html</loc> <loc>https://nn.labml.ai/sampling/greedy.html</loc>
<lastmod>2022-07-29T16:30:00+00:00</lastmod> <lastmod>2022-08-08T16:30:00+00:00</lastmod>
<priority>1.00</priority> <priority>1.00</priority>
</url> </url>
<url> <url>
<loc>https://nn.labml.ai/sampling/index.html</loc> <loc>https://nn.labml.ai/sampling/index.html</loc>
<lastmod>2022-07-29T16:30:00+00:00</lastmod> <lastmod>2022-08-08T16:30:00+00:00</lastmod>
<priority>1.00</priority> <priority>1.00</priority>
</url> </url>
<url> <url>
<loc>https://nn.labml.ai/sampling/top_k.html</loc> <loc>https://nn.labml.ai/sampling/top_k.html</loc>
<lastmod>2022-07-29T16:30:00+00:00</lastmod> <lastmod>2022-08-08T16:30:00+00:00</lastmod>
<priority>1.00</priority> <priority>1.00</priority>
</url> </url>
<url> <url>
<loc>https://nn.labml.ai/sampling/temperature.html</loc> <loc>https://nn.labml.ai/sampling/temperature.html</loc>
<lastmod>2022-07-29T16:30:00+00:00</lastmod> <lastmod>2022-08-08T16:30:00+00:00</lastmod>
<priority>1.00</priority> <priority>1.00</priority>
</url> </url>
<url> <url>
<loc>https://nn.labml.ai/sampling/experiment.html</loc> <loc>https://nn.labml.ai/sampling/experiment.html</loc>
<lastmod>2022-05-07T16:30:00+00:00</lastmod> <lastmod>2022-08-08T16:30:00+00:00</lastmod>
<priority>1.00</priority> <priority>1.00</priority>
</url> </url>
<url> <url>
<loc>https://nn.labml.ai/sampling/nucleus.html</loc> <loc>https://nn.labml.ai/sampling/nucleus.html</loc>
<lastmod>2022-07-29T16:30:00+00:00</lastmod> <lastmod>2022-08-08T16:30:00+00:00</lastmod>
<priority>1.00</priority> <priority>1.00</priority>
</url> </url>
......
...@@ -119,6 +119,12 @@ Solving games with incomplete information such as poker with CFR. ...@@ -119,6 +119,12 @@ Solving games with incomplete information such as poker with CFR.
* [Fuzzy Tiling Activations](activations/fta/index.html) * [Fuzzy Tiling Activations](activations/fta/index.html)
#### ✨ [Sampling Techniques](sampling/index.html)
* [Greedy Sampling](sampling/greedy.html)
* [Temperature Sampling](sampling/temperature.html)
* [Top-k Sampling](sampling/top_k.html)
* [Nucleus Sampling](sampling/nucleus.html)
## Highlighted Research Paper PDFs ## Highlighted Research Paper PDFs
* [Autoregressive Search Engines: Generating Substrings as Document Identifiers](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2204.10628.pdf) * [Autoregressive Search Engines: Generating Substrings as Document Identifiers](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2204.10628.pdf)
......
...@@ -7,6 +7,8 @@ summary: A PyTorch implementation of greedy sampling from language models. ...@@ -7,6 +7,8 @@ summary: A PyTorch implementation of greedy sampling from language models.
# Greedy Sampling # Greedy Sampling
Here we sample the most likely token from the distribution of logits. Here we sample the most likely token from the distribution of logits.
Here's an [experiment](experiment.html) that uses these sampling techniques.
""" """
import torch import torch
......
...@@ -22,6 +22,8 @@ $$\sum_{x_i \in V^{(p)}} P(x_i | x_{1:i-1}) \ge p$$ ...@@ -22,6 +22,8 @@ $$\sum_{x_i \in V^{(p)}} P(x_i | x_{1:i-1}) \ge p$$
That is, we pick the highest probable tokens until the sum of their probabilities is less that $p$. That is, we pick the highest probable tokens until the sum of their probabilities is less that $p$.
Then we sample from the selected tokens. Then we sample from the selected tokens.
Here's an [experiment](experiment.html) that uses these sampling techniques.
""" """
import torch import torch
......
...@@ -12,6 +12,8 @@ $u_{1:|V|}$ are the logits of the distribution and T is the temperature: ...@@ -12,6 +12,8 @@ $u_{1:|V|}$ are the logits of the distribution and T is the temperature:
$$P(x_i=V_l | x_{1:i-1}) = \frac{\exp(\frac{u_l}{T})}{\sum_j \exp(\frac{u_j}{T})}$$ $$P(x_i=V_l | x_{1:i-1}) = \frac{\exp(\frac{u_l}{T})}{\sum_j \exp(\frac{u_j}{T})}$$
$T = 1$ is normal random sampling. $T = 1$ is normal random sampling.
Here's an [experiment](experiment.html) that uses these sampling techniques.
""" """
import torch import torch
......
...@@ -8,6 +8,8 @@ summary: A PyTorch implementation of top-k sampling from language models. ...@@ -8,6 +8,8 @@ summary: A PyTorch implementation of top-k sampling from language models.
Here we first pick the top-k tokens from the distribution of logits, and then Here we first pick the top-k tokens from the distribution of logits, and then
sample from them. sample from them.
Here's an [experiment](experiment.html) that uses these sampling techniques.
""" """
import torch import torch
......
...@@ -123,6 +123,12 @@ Solving games with incomplete information such as poker with CFR. ...@@ -123,6 +123,12 @@ Solving games with incomplete information such as poker with CFR.
* [Fuzzy Tiling Activations](https://nn.labml.ai/activations/fta/index.html) * [Fuzzy Tiling Activations](https://nn.labml.ai/activations/fta/index.html)
#### ✨ [Sampling Techniques](https://nn.labml.ai/sampling/index.html)
* [Greedy Sampling](https://nn.labml.ai/sampling/greedy.html)
* [Temperature Sampling](https://nn.labml.ai/sampling/temperature.html)
* [Top-k Sampling](https://nn.labml.ai/sampling/top_k.html)
* [Nucleus Sampling](https://nn.labml.ai/sampling/nucleus.html)
## Highlighted Research Paper PDFs ## Highlighted Research Paper PDFs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册