diff --git a/labml_nn/resnet/__init__.py b/labml_nn/resnet/__init__.py index 31373800b95c8caa766b55c66b41f08d844a5a44..a4aacbce5f5949fca1680c4a714220181545443f 100644 --- a/labml_nn/resnet/__init__.py +++ b/labml_nn/resnet/__init__.py @@ -128,7 +128,7 @@ class ResidualBlock(Module): self.bn2 = nn.BatchNorm2d(out_channels) # Shortcut connection should be a projection if the stride length is not $1$ - # of if the number of channels change + # or if the number of channels change if stride != 1 or in_channels != out_channels: # Projection $W_s x$ self.shortcut = ShortcutProjection(in_channels, out_channels, stride) @@ -210,7 +210,7 @@ class BottleneckResidualBlock(Module): self.bn3 = nn.BatchNorm2d(out_channels) # Shortcut connection should be a projection if the stride length is not $1$ - # of if the number of channels change + # or if the number of channels change if stride != 1 or in_channels != out_channels: # Projection $W_s x$ self.shortcut = ShortcutProjection(in_channels, out_channels, stride)