未验证 提交 f0014586 编写于 作者: X Xu Jingxin 提交者: GitHub

Merge pull request #122 from opendilab/dev-torch1.1.0

feature(nyz): extend torch1.1.0 support
import torch
def torch_gt_131():
return int("".join(list(filter(str.isdigit, torch.__version__)))) >= 131
...@@ -2,7 +2,7 @@ from typing import Optional ...@@ -2,7 +2,7 @@ from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from ding.torch_utils import ResFCBlock, ResBlock from ding.torch_utils import ResFCBlock, ResBlock, Flatten
from ding.utils import SequenceType from ding.utils import SequenceType
...@@ -49,7 +49,7 @@ class ConvEncoder(nn.Module): ...@@ -49,7 +49,7 @@ class ConvEncoder(nn.Module):
assert len(set(hidden_size_list[3:-1])) <= 1, "Please indicate the same hidden size for res block parts" assert len(set(hidden_size_list[3:-1])) <= 1, "Please indicate the same hidden size for res block parts"
for i in range(3, len(self.hidden_size_list) - 1): for i in range(3, len(self.hidden_size_list) - 1):
layers.append(ResBlock(self.hidden_size_list[i], activation=self.act, norm_type=norm_type)) layers.append(ResBlock(self.hidden_size_list[i], activation=self.act, norm_type=norm_type))
layers.append(nn.Flatten()) layers.append(Flatten())
self.main = nn.Sequential(*layers) self.main = nn.Sequential(*layers)
flatten_size = self._get_flatten_size() flatten_size = self._get_flatten_size()
......
...@@ -4,6 +4,7 @@ import sys ...@@ -4,6 +4,7 @@ import sys
import traceback import traceback
from typing import Callable from typing import Callable
import torch import torch
import torch.utils.data # torch1.1.0 compatibility
from ding.utils import read_file, save_file from ding.utils import read_file, save_file
logger = logging.getLogger('default_logger') logger = logging.getLogger('default_logger')
......
from .activation import build_activation, Swish from .activation import build_activation, Swish
from .res_block import ResBlock, ResFCBlock from .res_block import ResBlock, ResFCBlock
from .nn_module import fc_block, conv2d_block, one_hot, deconv2d_block, BilinearUpsample, NearestUpsample, \ from .nn_module import fc_block, conv2d_block, one_hot, deconv2d_block, BilinearUpsample, NearestUpsample, \
binary_encode, NoiseLinearLayer, noise_block, MLP binary_encode, NoiseLinearLayer, noise_block, MLP, Flatten
from .normalization import build_normalization from .normalization import build_normalization
from .rnn import get_lstm, sequence_mask from .rnn import get_lstm, sequence_mask
from .soft_argmax import SoftArgmax from .soft_argmax import SoftArgmax
......
...@@ -4,6 +4,7 @@ import torch.nn as nn ...@@ -4,6 +4,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.init import xavier_normal_, kaiming_normal_, orthogonal_ from torch.nn.init import xavier_normal_, kaiming_normal_, orthogonal_
from typing import Union, Tuple, List, Callable from typing import Union, Tuple, List, Callable
from ding.compatibility import torch_gt_131
from .normalization import build_normalization from .normalization import build_normalization
...@@ -577,3 +578,23 @@ def noise_block( ...@@ -577,3 +578,23 @@ def noise_block(
if use_dropout: if use_dropout:
block.append(nn.Dropout(dropout_probability)) block.append(nn.Dropout(dropout_probability))
return sequential_pack(block) return sequential_pack(block)
class NaiveFlatten(nn.Module):
def __init__(self, start_dim: int = 1, end_dim: int = -1) -> None:
super(NaiveFlatten, self).__init__()
self.start_dim = start_dim
self.end_dim = end_dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.end_dim != -1:
return x.view(*x.shape[:self.start_dim], -1, *x.shape[self.end_dim + 1:])
else:
return x.view(*x.shape[:self.start_dim], -1)
if torch_gt_131():
Flatten = nn.Flatten
else:
Flatten = NaiveFlatten
...@@ -6,6 +6,7 @@ import math ...@@ -6,6 +6,7 @@ import math
import numpy as np import numpy as np
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .nn_module import Flatten
def to_2tuple(item): def to_2tuple(item):
...@@ -94,7 +95,7 @@ class ClassifierHead(nn.Module): ...@@ -94,7 +95,7 @@ class ClassifierHead(nn.Module):
self.drop_rate = drop_rate self.drop_rate = drop_rate
self.global_pool, num_pooled_features = _create_pool(in_chs, num_classes, pool_type, use_conv=use_conv) self.global_pool, num_pooled_features = _create_pool(in_chs, num_classes, pool_type, use_conv=use_conv)
self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity() self.flatten = Flatten(1) if use_conv and pool_type else nn.Identity()
def forward(self, x): def forward(self, x):
x = self.global_pool(x) x = self.global_pool(x)
......
...@@ -2,7 +2,7 @@ import torch ...@@ -2,7 +2,7 @@ import torch
import pytest import pytest
from ding.torch_utils import build_activation, build_normalization from ding.torch_utils import build_activation, build_normalization
from ding.torch_utils.network.nn_module import conv1d_block, conv2d_block, fc_block, deconv2d_block, ChannelShuffle, \ from ding.torch_utils.network.nn_module import conv1d_block, conv2d_block, fc_block, deconv2d_block, ChannelShuffle, \
one_hot, NearestUpsample, BilinearUpsample, binary_encode, weight_init_ one_hot, NearestUpsample, BilinearUpsample, binary_encode, weight_init_, NaiveFlatten
batch_size = 2 batch_size = 2
in_channels = 2 in_channels = 2
...@@ -148,3 +148,16 @@ class TestNnModule: ...@@ -148,3 +148,16 @@ class TestNnModule:
max_val = torch.tensor(8) max_val = torch.tensor(8)
output = binary_encode(input, max_val) output = binary_encode(input, max_val)
assert torch.equal(output, torch.tensor([[0, 1, 0, 0]])) assert torch.equal(output, torch.tensor([[0, 1, 0, 0]]))
@pytest.mark.tmp
def test_flatten(self):
inputs = torch.randn(4, 3, 8, 8)
model1 = NaiveFlatten()
output1 = model1(inputs)
assert output1.shape == (4, 3 * 8 * 8)
model2 = NaiveFlatten(1, 2)
output2 = model2(inputs)
assert output2.shape == (4, 3 * 8, 8)
model3 = NaiveFlatten(1, 3)
output3 = model2(inputs)
assert output1.shape == (4, 3 * 8 * 8)
...@@ -5,6 +5,7 @@ import torch ...@@ -5,6 +5,7 @@ import torch
import re import re
from torch._six import string_classes from torch._six import string_classes
import collections.abc as container_abcs import collections.abc as container_abcs
from ding.compatibility import torch_gt_131
int_classes = int int_classes = int
np_str_obj_array_pattern = re.compile(r'[SaUO]') np_str_obj_array_pattern = re.compile(r'[SaUO]')
...@@ -50,7 +51,7 @@ def default_collate(batch: Sequence, ...@@ -50,7 +51,7 @@ def default_collate(batch: Sequence,
elem_type = type(elem) elem_type = type(elem)
if isinstance(elem, torch.Tensor): if isinstance(elem, torch.Tensor):
out = None out = None
if torch.utils.data.get_worker_info() is not None: if torch_gt_131() and torch.utils.data.get_worker_info() is not None:
# If we're in a background process, directly concatenate into a # If we're in a background process, directly concatenate into a
# shared memory tensor to avoid an extra copy # shared memory tensor to avoid an extra copy
numel = sum([x.numel() for x in batch]) numel = sum([x.numel() for x in batch])
......
...@@ -74,13 +74,6 @@ def main(cfg, seed=0): ...@@ -74,13 +74,6 @@ def main(cfg, seed=0):
if train_data is None: if train_data is None:
break break
learner.train(train_data, collector.envstep) learner.train(train_data, collector.envstep)
# evaluate
evaluator_env = BaseEnvManager(env_fn=[wrapped_cartpole_env for _ in range(evaluator_env_num)], cfg=cfg.env.manager)
evaluator_env.enable_save_replay(cfg.env.replay_path) # switch save replay interface
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)
evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -50,7 +50,7 @@ setup( ...@@ -50,7 +50,7 @@ setup(
'requests>=2.25.1', 'requests>=2.25.1',
'six', 'six',
'gym==0.20.0', # pypy incompatible 'gym==0.20.0', # pypy incompatible
'torch>=1.3.1,<=1.9.0', # PyTorch 1.9.0 is available, if some errors, you need to do something like https://github.com/opendilab/DI-engine/discussions/81 'torch>=1.1.0,<=1.9.0', # PyTorch 1.9.0 is available, if some errors, you need to do something like https://github.com/opendilab/DI-engine/discussions/81
'pyyaml<6.0', 'pyyaml<6.0',
'easydict==1.9', 'easydict==1.9',
'tensorboardX>=2.1,<=2.2', 'tensorboardX>=2.1,<=2.2',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册