25.md 2.7 KB
Newer Older
W
wizardforcel 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
# torch.utils.checkpoint

### torch.utils.checkpoint.checkpoint(function, *args)

`checkpoint`模型或模型的一部分

`checkpoint`通过交换计算内存来工作。而不是存储整个计算图的所有中间激活用于向后计算,`checkpoint`不会不保存中间激活部分,而是在反向传递中重新计算它们。它可以应用于模型的任何部分。

具体来说,在正向传递中,`function`将以`torch.no_grad()`方式运行 ,即不存储中间激活。相反,正向传递保存输入元组和 `function`参数。在向后计算中,保存的输入变量以及 `function`会被回收,并且正向计算被`function`再次计算 ,现在跟踪中间激活,然后使用这些激活值来计算梯度。

> **警告** `checkpoint`在`torch.autograd.grad()`中不起作用,但仅适用于`torch.autograd.backward()`。
> 
> **警告** 如果`function`在向后执行和前向执行都不同,例如由于某个全局变量,`checkpoint`版本将不等同,并且不幸的是无法检测到。

参数:

*   function - 描述在模型的正向传递或模型的一部分中运行的内容。它也应该知道如何处理作为元组传递的输入。例如,在 LSTM 中,如果用户通过 ,应正确使用第一个输入作为第二个输入(activation, hidden)functionactivationhidden
*   args - 包含输入的元组 function

返回: attr`function``*args`

返回类型: 运行输出

### torch.utils.checkpoint.checkpoint_sequential(functions, segments, *inputs)

用于`checkpoint` `sequential`模型的辅助函数。

`sequential 模型按顺序执行一系列模块/函数(按顺序)。因此,我们可以将这种模型分为不同的部分和`checkpoint`。除最后一个段以外的所有段都将以某种`torch.no_grad()`方式运行 ,即不存储中间活动。将保存每个`checkpoint`段的输入,以便在向后传递中重新运行段。

关于`checkpoint`如何工作可以参考[checkpoint()](https://pytorch.org/docs/master/checkpoint.html#torch.utils.checkpoint.checkpoint)。

> **警告** `checkpoint`在`torch.autograd.grad()`中不起作用,但仅适用于`torch.autograd.backward()`。

参数:

*   functions - A torch.nn.Sequential 或按顺序运行的模块或函数(包括模型)的列表。
*   segments - 要在模型中创建的块数
*   segments - 输入的张量元组 functions

返回:

`functions`按顺序运行的输出`*inputs`

例:

```
>>> model = nn.Sequential(...)
>>> input_var = checkpoint_sequential(model, chunks, input_var) 
```

部分地方存在翻译错误,即将修复

### 译者署名

| 用户名 | 头像 | 职能 | 签名 |
| --- | --- | --- | --- |
| [Song](https://ptorch.com) | ![](img/2018033000352689884.jpeg) | 翻译 | 人生总要追求点什么 |