From 9f8cd3573c30aef2460026c3bbd1e6c3038e7a3e Mon Sep 17 00:00:00 2001 From: ShusenTang Date: Sun, 10 Nov 2019 00:35:49 +0800 Subject: [PATCH] add more info about ModuleList --- .../4.1_model-construction.ipynb | 119 +++++++++++++++--- .../4.1_model-construction.md | 53 ++++++++ 2 files changed, 152 insertions(+), 20 deletions(-) diff --git a/code/chapter04_DL_computation/4.1_model-construction.ipynb b/code/chapter04_DL_computation/4.1_model-construction.ipynb index f9b83dc..84c7593 100644 --- a/code/chapter04_DL_computation/4.1_model-construction.ipynb +++ b/code/chapter04_DL_computation/4.1_model-construction.ipynb @@ -16,7 +16,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "0.4.1\n" + "1.2.0\n" ] } ], @@ -78,10 +78,10 @@ { "data": { "text/plain": [ - "tensor([[ 0.1351, -0.0034, 0.0948, -0.1652, 0.1512, 0.0887, -0.0032, 0.0692,\n", - " 0.0942, 0.0956],\n", - " [ 0.1624, -0.0383, 0.1557, -0.0735, 0.1931, 0.1699, -0.0067, 0.0353,\n", - " 0.1712, 0.1568]], grad_fn=)" + "tensor([[ 0.0234, -0.2646, -0.1168, -0.2127, 0.0884, -0.0456, 0.0811, 0.0297,\n", + " 0.2032, 0.1364],\n", + " [ 0.1479, -0.1545, -0.0265, -0.2119, -0.0543, -0.0086, 0.0902, -0.1017,\n", + " 0.1504, 0.1144]], grad_fn=)" ] }, "execution_count": 3, @@ -107,7 +107,9 @@ { "cell_type": "code", "execution_count": 4, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "class MySequential(nn.Module):\n", @@ -146,10 +148,10 @@ { "data": { "text/plain": [ - "tensor([[ 0.1883, -0.1269, -0.1886, 0.0638, -0.1004, -0.0600, 0.0760, -0.1788,\n", - " -0.1844, -0.2131],\n", - " [ 0.1319, -0.0490, -0.1365, 0.0133, -0.0483, -0.0861, 0.0369, -0.0830,\n", - " -0.0462, -0.2066]], grad_fn=)" + "tensor([[ 0.1273, 0.1642, -0.1060, 0.1401, 0.0609, -0.0199, -0.0140, -0.0588,\n", + " 0.1765, -0.1296],\n", + " [ 0.0267, 0.1670, -0.0626, 0.0744, 0.0574, 0.0413, 0.1313, -0.1479,\n", + " 0.0932, -0.0615]], grad_fn=)" ] }, "execution_count": 5, @@ -199,6 +201,74 @@ "print(net)" ] }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "# net(torch.zeros(1, 784)) # 会报NotImplementedError" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "class MyModule(nn.Module):\n", + " def __init__(self):\n", + " super(MyModule, self).__init__()\n", + " self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])\n", + "\n", + " def forward(self, x):\n", + " # ModuleList can act as an iterable, or be indexed using ints\n", + " for i, l in enumerate(self.linears):\n", + " x = self.linears[i // 2](x) + l(x)\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "net1:\n", + "torch.Size([10, 10])\n", + "torch.Size([10])\n", + "net2:\n" + ] + } + ], + "source": [ + "class Module_ModuleList(nn.Module):\n", + " def __init__(self):\n", + " super(Module_ModuleList, self).__init__()\n", + " self.linears = nn.ModuleList([nn.Linear(10, 10)])\n", + " \n", + "class Module_List(nn.Module):\n", + " def __init__(self):\n", + " super(Module_List, self).__init__()\n", + " self.linears = [nn.Linear(10, 10)]\n", + "\n", + "net1 = Module_ModuleList()\n", + "net2 = Module_List()\n", + "\n", + "print(\"net1:\")\n", + "for p in net1.parameters():\n", + " print(p.size())\n", + "\n", + "print(\"net2:\")\n", + "for p in net2.parameters():\n", + " print(p)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -208,7 +278,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -236,6 +306,15 @@ "print(net)" ] }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "# net(torch.zeros(1, 784)) # 会报NotImplementedError" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -245,7 +324,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 12, "metadata": { "collapsed": true }, @@ -275,7 +354,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -290,10 +369,10 @@ { "data": { "text/plain": [ - "tensor(12.1594, grad_fn=)" + "tensor(0.8907, grad_fn=)" ] }, - "execution_count": 9, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -307,7 +386,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -331,10 +410,10 @@ { "data": { "text/plain": [ - "tensor(0.1509, grad_fn=)" + "tensor(-0.4605, grad_fn=)" ] }, - "execution_count": 10, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -367,7 +446,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python [default]", + "display_name": "Python 3", "language": "python", "name": "python3" }, @@ -381,7 +460,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.3" + "version": "3.6.2" } }, "nbformat": 4, diff --git a/docs/chapter04_DL_computation/4.1_model-construction.md b/docs/chapter04_DL_computation/4.1_model-construction.md index 909f1a2..45924e5 100644 --- a/docs/chapter04_DL_computation/4.1_model-construction.md +++ b/docs/chapter04_DL_computation/4.1_model-construction.md @@ -114,6 +114,7 @@ net = nn.ModuleList([nn.Linear(784, 256), nn.ReLU()]) net.append(nn.Linear(256, 10)) # # 类似List的append操作 print(net[-1]) # 类似List的索引访问 print(net) +# net(torch.zeros(1, 784)) # 会报NotImplementedError ``` 输出: ``` @@ -125,6 +126,55 @@ ModuleList( ) ``` +既然`Sequential`和`ModuleList`都可以进行列表化构造网络,那二者区别是什么呢。`ModuleList`仅仅是一个储存各种模块的列表,这些模块之间没有联系也没有顺序(所以不用保证相邻层的输入输出维度匹配),而且没有实现`forward`功能需要自己实现,所以上面执行`net(torch.zeros(1, 784))`会报`NotImplementedError`;而`Sequential`内的模块需要按照顺序排列,要保证相邻层的输入输出大小相匹配,内部`forward`功能已经实现。 + +`ModuleList`的出现只是让网络定义前向传播时更加灵活,见下面官网的例子。 +``` python +class MyModule(nn.Module): + def __init__(self): + super(MyModule, self).__init__() + self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)]) + + def forward(self, x): + # ModuleList can act as an iterable, or be indexed using ints + for i, l in enumerate(self.linears): + x = self.linears[i // 2](x) + l(x) + return x +``` + +另外,`ModuleList`不同于一般的Python的`list`,加入到`ModuleList`里面的所有模块的参数会被自动添加到整个网络中,下面看一个例子对比一下。 + +``` python +class Module_ModuleList(nn.Module): + def __init__(self): + super(Module_ModuleList, self).__init__() + self.linears = nn.ModuleList([nn.Linear(10, 10)]) + +class Module_List(nn.Module): + def __init__(self): + super(Module_List, self).__init__() + self.linears = [nn.Linear(10, 10)] + +net1 = Module_ModuleList() +net2 = Module_List() + +print("net1:") +for p in net1.parameters(): + print(p.size()) + +print("net2:") +for p in net2.parameters(): + print(p) +``` +输出: +``` +net1: +torch.Size([10, 10]) +torch.Size([10]) +net2: +``` + + ### 4.1.2.3 `ModuleDict`类 `ModuleDict`接收一个子模块的字典作为输入, 然后也可以类似字典那样进行添加访问操作: ``` python @@ -136,6 +186,7 @@ net['output'] = nn.Linear(256, 10) # 添加 print(net['linear']) # 访问 print(net.output) print(net) +# net(torch.zeros(1, 784)) # 会报NotImplementedError ``` 输出: ``` @@ -148,6 +199,7 @@ ModuleDict( ) ``` +和`ModuleList`一样,`ModuleDict`实例仅仅是存放了一些模块的字典,并没有定义`forward`函数需要自己定义。同样,`ModuleDict`也与Python的`Dict`有所不同,`ModuleDict`里的所有模块的参数会被自动添加到整个网络中。 ## 4.1.3 构造复杂的模型 @@ -230,6 +282,7 @@ tensor(14.4908, grad_fn=) * 可以通过继承`Module`类来构造模型。 * `Sequential`、`ModuleList`、`ModuleDict`类都继承自`Module`类。 +* 与`Sequential`不同,`ModuleList`和`ModuleDict`并没有定义一个完整的网络,它们只是将不同的模块存放在一起,需要自己定义`forward`函数。 * 虽然`Sequential`等类可以使模型构造更加简单,但直接继承`Module`类可以极大地拓展模型构造的灵活性。 -- GitLab