未验证 提交 7e8b0d31 编写于 作者: L Luke Melas-Kyriazi 提交者: GitHub

Merge pull request #250 from rvandeghen/patch-1

Add new checkpoint
......@@ -238,18 +238,18 @@ class EfficientNet(nn.Module):
Returns:
Dictionary of last intermediate features
with reduction levels i in [1, 2, 3, 4, 5].
Example:
>>> import torch
>>> from efficientnet.model import EfficientNet
>>> inputs = torch.rand(1, 3, 224, 224)
>>> model = EfficientNet.from_pretrained('efficientnet-b0')
>>> endpoints = model.extract_endpoints(inputs)
>>> print(endpoints['reduction_1'].shape) # torch.Size([1, 16, 112, 112])
>>> print(endpoints['reduction_2'].shape) # torch.Size([1, 24, 56, 56])
>>> print(endpoints['reduction_3'].shape) # torch.Size([1, 40, 28, 28])
>>> print(endpoints['reduction_4'].shape) # torch.Size([1, 112, 14, 14])
>>> print(endpoints['reduction_5'].shape) # torch.Size([1, 1280, 7, 7])
Example:
>>> import torch
>>> from efficientnet.model import EfficientNet
>>> inputs = torch.rand(1, 3, 224, 224)
>>> model = EfficientNet.from_pretrained('efficientnet-b0')
>>> endpoints = model.extract_endpoints(inputs)
>>> print(endpoints['reduction_1'].shape) # torch.Size([1, 16, 112, 112])
>>> print(endpoints['reduction_2'].shape) # torch.Size([1, 24, 56, 56])
>>> print(endpoints['reduction_3'].shape) # torch.Size([1, 40, 28, 28])
>>> print(endpoints['reduction_4'].shape) # torch.Size([1, 112, 14, 14])
>>> print(endpoints['reduction_5'].shape) # torch.Size([1, 320, 7, 7])
>>> print(endpoints['reduction_6'].shape) # torch.Size([1, 1280, 7, 7])
"""
endpoints = dict()
......@@ -265,6 +265,8 @@ class EfficientNet(nn.Module):
x = block(x, drop_connect_rate=drop_connect_rate)
if prev_x.size(2) > x.size(2):
endpoints['reduction_{}'.format(len(endpoints) + 1)] = prev_x
elif idx == len(self._blocks) - 1:
endpoints['reduction_{}'.format(len(endpoints) + 1)] = x
prev_x = x
# Head
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册