提交 ad71feba 编写于 作者: N niuyazhe

style(nyz): update issue template doc link and polish comment doc

上级 16833c62
......@@ -15,7 +15,7 @@ assignees: ''
+ [ ] code design/refactor
+ [ ] documentation request
+ [ ] new feature request
- [ ] I have visited the [readme](https://github.com/opendilab/DI-engine/blob/github-dev/README.md) and [doc]()
- [ ] I have visited the [readme](https://github.com/opendilab/DI-engine/blob/github-dev/README.md) and [doc](https://opendilab.github.io/DI-engine/)
- [ ] I have searched through the [issue tracker](https://github.com/opendilab/DI-engine/issues) and [pr tracker](https://github.com/opendilab/DI-engine/pulls)
- [ ] I have mentioned version numbers, operating system and environment, where applicable:
```python
......
......@@ -49,7 +49,6 @@ class Task:
Task object of the connections.
Linking call is fully supported.
Example:
- A simple and common usage
>>> with master.new_connection('cnn1,', '127.0.0.1', 2333) as connection:
>>> task = connection.new_task({'data': 233})
>>> # task is not sent yet
......
......@@ -404,6 +404,7 @@ class Slave(ControllableService):
Overview:
Start current slave client
Here are the steps executed inside in order:
1. Start the task-processing thread
2. Start the heartbeat thread
3. Start the http server thread
......@@ -431,6 +432,7 @@ class Slave(ControllableService):
Overview:
Wait until current slave client is down completely.
Here are the steps executed inside in order:
1. Wait until the http server thread down
2. Wait until the heartbeat thread down
3. Wait until the task-processing thread down
......@@ -498,16 +500,19 @@ class Slave(ControllableService):
master will received the failure signal.
Example:
- A success task with return value (the return value will be received in master end)
>>> def _process_task(self, task):
>>> print('this is task data :', task)
>>> return str(task)
- A failed task with data (the data will be received in master end)
>>> def _process_task(self, task):
>>> print('this is task data :', task)
>>> raise TaskFail(task) # this is a failed task
- A failed task with data and message (both will be received in master end)
>>> def _process_task(self, task):
>>> print('this is task data :', task)
>>> raise TaskFail(task, 'this is message') # this is a failed task with message
......
......@@ -32,26 +32,25 @@ class QAC(nn.Module):
activation: Optional[nn.Module] = nn.ReLU(),
norm_type: Optional[str] = None,
) -> None:
r"""
"""
Overview:
Init the QAC Model according to arguments.
Arguments:
- obs_shape (:obj:`Union[int, SequenceType]`): Observation's space.
- action_shape (:obj:`Union[int, SequenceType, EasyDict]`): Action's space, such as 4, (3, ),
- action_shape (:obj:`Union[int, SequenceType, EasyDict]`): Action's space, such as 4, (3, ), \
EasyDict({'action_type_shape': 3, 'action_args_shape': 4}).
- actor_head_type (:obj:`str`): Whether choose ``regression`` or ``reparameterization`` or ``hybrid`` .
- twin_critic (:obj:`bool`): Whether include twin critic.
- actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor-nn's ``Head``.
- actor_head_layer_num (:obj:`int`):
The num of layers used in the network to compute Q value output for actor's nn.
- actor_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output \
for actor's nn.
- critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to critic-nn's ``Head``.
- critic_head_layer_num (:obj:`int`):
The num of layers used in the network to compute Q value output for critic's nn.
- activation (:obj:`Optional[nn.Module]`):
The type of activation function to use in ``MLP`` the after ``layer_fn``,
if ``None`` then default set to ``nn.ReLU()``
- norm_type (:obj:`Optional[str]`):
The type of normalization to use, see ``ding.torch_utils.fc_block`` for more details.
- critic_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output \
for critic's nn.
- activation (:obj:`Optional[nn.Module]`): The type of activation function to use in ``MLP`` \
after ``layer_fn``, if ``None`` then default set to ``nn.ReLU()``
- norm_type (:obj:`Optional[str]`): The type of normalization to use, \
see ``ding.torch_utils.netwrok`` for more details.
"""
super(QAC, self).__init__()
obs_shape: int = squeeze(obs_shape)
......@@ -145,29 +144,28 @@ class QAC(nn.Module):
)
def forward(self, inputs: Union[torch.Tensor, Dict], mode: str) -> Dict:
r"""
"""
Overview:
Use observation and action tensor to predict output.
Parameter updates with QAC's MLPs forward setup.
Arguments:
Forward with ``'compute_actor'``:
- inputs (:obj:`torch.Tensor`):
The encoded embedding tensor, determined with given ``hidden_size``, i.e. ``(B, N=hidden_size)``.
Whether ``actor_head_hidden_size`` or ``critic_head_hidden_size`` depend on ``mode``.
Forward with ``compute_actor``:
- inputs (:obj:`torch.Tensor`): The encoded embedding tensor, determined with given ``hidden_size``, \
i.e. ``(B, N=hidden_size)``.
Forward with ``'compute_critic'``, inputs (`Dict`) Necessary Keys:
- ``obs``, ``action`` encoded tensors.
Forward with ``compute_critic``:
- inputs (:obj:`Dict`)
- mode (:obj:`str`): Name of the forward mode.
Returns:
- outputs (:obj:`Dict`): Outputs of network forward.
Forward with ``'compute_actor'``, Necessary Keys (either):
Forward with ``compute_actor``
- action (:obj:`torch.Tensor`): Action tensor with same size as input ``x``.
- logit (:obj:`torch.Tensor`):
Logit tensor encoding ``mu`` and ``sigma``, both with same size as input ``x``.
- logit (:obj:`torch.Tensor`): Logit tensor encoding ``mu`` and ``sigma``, both with same size \
as input ``x``.
Forward with ``'compute_critic'``, Necessary Keys:
Forward with ``compute_critic``
- q_value (:obj:`torch.Tensor`): Q value tensor with same size as batch size.
Actor Shapes:
- inputs (:obj:`torch.Tensor`): :math:`(B, N0)`, B is batch size and N0 corresponds to ``hidden_size``
......@@ -176,7 +174,7 @@ class QAC(nn.Module):
Critic Shapes:
- obs (:obj:`torch.Tensor`): :math:`(B, N1)`, where B is batch size and N1 is ``obs_shape``
- action (:obj:`torch.Tensor`): :math:`(B, N2)`, where B is batch size and N2 is``action_shape``
- action (:obj:`torch.Tensor`): :math:`(B, N2)`, where B is batch size and N2 is ``action_shape``
- logit (:obj:`torch.FloatTensor`): :math:`(B, N2)`, where B is batch size and N3 is ``action_shape``
Actor Examples:
......@@ -205,14 +203,14 @@ class QAC(nn.Module):
return getattr(self, mode)(inputs)
def compute_actor(self, inputs: torch.Tensor) -> Dict:
r"""
"""
Overview:
Use encoded embedding tensor to predict output.
Execute parameter updates with ``'compute_actor'`` mode
Execute parameter updates with ``compute_actor`` mode
Use encoded embedding tensor to predict output.
Arguments:
- inputs (:obj:`torch.Tensor`):
The encoded embedding tensor, determined with given ``hidden_size``, i.e. ``(B, N=hidden_size)``.
The encoded embedding tensor, determined with given ``hidden_size``, i.e. ``(B, N=hidden_size)``. \
``hidden_size = actor_head_hidden_size``
- mode (:obj:`str`): Name of the forward mode.
Returns:
......@@ -220,17 +218,17 @@ class QAC(nn.Module):
ReturnsKeys (either):
- action (:obj:`torch.Tensor`): Continuous action tensor with same size as ``action_shape``.
- logit (:obj:`torch.Tensor`):
Logit tensor encoding ``mu`` and ``sigma``, both with same size as input ``x``.
- logit + action_args
- logit (:obj:`torch.Tensor`): Logit tensor encoding ``mu`` and ``sigma``, both with same size \
as input ``x``.
Shapes:
- inputs (:obj:`torch.Tensor`): :math:`(B, N0)`, B is batch size and N0 corresponds to ``hidden_size``
- action (:obj:`torch.Tensor`): :math:`(B, N0)`
- logit (:obj:`Union[list, torch.Tensor]`):
- case1(continuous space, list): 2 elements, mu and sigma, each is the shape of :math:`(B, N0)`.
- case2(hybrid space, torch.Tensor): :math:`(B, N1)`, where N1 is action_type_shape
- q_value (:obj:`torch.FloatTensor`): :math:`(B, )`, B is batch size.
- action_args (:obj:`torch.FloatTensor`): :math:`(B, N2)`, where N2 is action_args_shape
- action_args (:obj:`torch.FloatTensor`): :math:`(B, N2)`, where N2 is action_args_shape \
(action_args are continuous real value)
Examples:
>>> # Regression mode
......@@ -261,19 +259,21 @@ class QAC(nn.Module):
def compute_critic(self, inputs: Dict) -> Dict:
r"""
Overview:
Execute parameter updates with ``'compute_critic'`` mode
Execute parameter updates with ``compute_critic`` mode
Use encoded embedding tensor to predict output.
Arguments:
- inputs (:obj: `Dict`): ``obs``, ``action`` and ``logit` tensors.
- inputs (:obj:`Dict`): ``obs``, ``action`` and ``logit`` tensors.
- mode (:obj:`str`): Name of the forward mode.
Returns:
- outputs (:obj:`Dict`): Q-value output.
ArgumentsKeys:
- necessary:
- obs: (:obj:`torch.Tensor`): 2-dim vector observation
- action (:obj:`Union[torch.Tensor, Dict]`): action from actor
- optional:
- logit (:obj:`torch.Tensor`): discrete action logit
ReturnKeys:
- q_value (:obj:`torch.Tensor`): Q value tensor with same size as batch size.
......
......@@ -269,6 +269,7 @@ class CollaQPolicy(Policy):
Load the state_dict variable into policy learn mode.
Arguments:
- state_dict (:obj:`Dict[str, Any]`): the dict of policy learn state saved before.
.. tip::
If you want to only load some parts of model, you can simply set the ``strict`` argument in \
load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \
......@@ -424,6 +425,7 @@ class CollaQPolicy(Policy):
Return this algorithm default model setting for demonstration.
Returns:
- model_info (:obj:`Tuple[str, List[str]]`): model name and mode import_names
.. note::
The user can define and use customized network model but must obey the same inferface definition indicated \
by import_names path. For collaq, ``ding.model.qmix.qmix``
......
......@@ -364,6 +364,7 @@ class COMAPolicy(Policy):
Return this algorithm default model setting for demonstration.
Returns:
- model_info (:obj:`Tuple[str, List[str]]`): model name and mode import_names
.. note::
The user can define and use customized network model but must obey the same inferface definition indicated \
by import_names path. For coma, ``ding.model.coma.coma``
......
......@@ -319,6 +319,7 @@ class PPGPolicy(Policy):
When the value is distilled into the policy network, we need to make sure the policy \
network does not change the action predictions, we need two optimizers, \
_optimizer_ac is used in policy net, and _optimizer_aux_critic is used in value net.
.. tip::
If you want to only load some parts of model, you can simply set the ``strict`` argument in \
load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \
......@@ -459,6 +460,7 @@ class PPGPolicy(Policy):
Return this algorithm default model setting for demonstration.
Returns:
- model_info (:obj:`Tuple[str, List[str]]`): model name and mode import_names
.. note::
The user can define and use customized network model but must obey the same inferface definition indicated \
by import_names path.
......
......@@ -13,13 +13,13 @@ from .base_policy import Policy
@POLICY_REGISTRY.register('qmix')
class QMIXPolicy(Policy):
r"""
"""
Overview:
Policy class of QMIX algorithm. QMIX is a multi model reinforcement learning algorithm, \
you can view the paper in the following link https://arxiv.org/abs/1803.11485
Interface:
_init_learn, _data_preprocess_learn, _forward_learn, _reset_learn, _state_dict_learn, _load_state_dict_learn\
_init_collect, _forward_collect, _reset_collect, _process_transition, _init_eval, _forward_eval\
_init_learn, _data_preprocess_learn, _forward_learn, _reset_learn, _state_dict_learn, _load_state_dict_learn \
_init_collect, _forward_collect, _reset_collect, _process_transition, _init_eval, _forward_eval \
_reset_eval, _get_train_sample, default_model
Config:
== ==================== ======== ============== ======================================== =======================
......@@ -257,11 +257,12 @@ class QMIXPolicy(Policy):
}
def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
r"""
"""
Overview:
Load the state_dict variable into policy learn mode.
Arguments:
- state_dict (:obj:`Dict[str, Any]`): the dict of policy learn state saved before.
.. tip::
If you want to only load some parts of model, you can simply set the ``strict`` argument in \
load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \
......@@ -416,6 +417,7 @@ class QMIXPolicy(Policy):
Return this algorithm default model setting for demonstration.
Returns:
- model_info (:obj:`Tuple[str, List[str]]`): model name and mode import_names
.. note::
The user can define and use customized network model but must obey the same inferface definition indicated \
by import_names path. For QMIX, ``ding.model.qmix.qmix``
......
......@@ -32,6 +32,7 @@ class PdeilRewardModel(BaseRewardModel):
Overview:
Initialize ``self.`` See ``help(type(self))`` for accurate signature.
Some rules in naming the attributes of ``self.``:
- ``e_`` : expert values
- ``_sigma_`` : standard division values
- ``p_`` : current policy values
......
......@@ -188,7 +188,8 @@ class CheckpointHelper:
- logger_prefix (:obj:`str`): prefix of logger
- state_dict_mask (:obj:`list`): A list containing state_dict keys, \
which shouldn't be loaded into model(after prefix op)
..note:
.. note::
The checkpoint loaded from load_path is a dict, whose format is like '{'state_dict': OrderedDict(), ...}'
"""
......
......@@ -10,7 +10,7 @@ import torch
def to_device(item: Any, device: str, ignore_keys: list = []) -> Any:
r"""
"""
Overview:
Transfer data to certain device
Arguments:
......@@ -19,6 +19,7 @@ def to_device(item: Any, device: str, ignore_keys: list = []) -> Any:
- ignore_keys (:obj:`list`): the keys to be ignored in transfer, defalut set to empty
Returns:
- item (:obj:`Any`): the transferred item
.. note:
Now supports item type: :obj:`torch.nn.Module`, :obj:`torch.Tensor`, :obj:`Sequence`, \
......@@ -61,6 +62,7 @@ def to_dtype(item: Any, dtype: type) -> Any:
- dtype (:obj:`type`): the type wanted
Returns:
- item (:obj:`object`): the dtype changed item
.. note:
Now supports item type: :obj:`torch.Tensor`, :obj:`Sequence`, :obj:`dict`
......@@ -89,6 +91,7 @@ def to_tensor(
- dtype (:obj:`type`): the type of wanted tensor
Returns:
- item (:obj:`torch.Tensor`): the change tensor
.. note:
Now supports item type: :obj:`dict`, :obj:`list`, :obj:`tuple` and :obj:`None`
......@@ -158,6 +161,7 @@ def to_ndarray(item: Any, dtype: np.dtype = None) -> np.ndarray:
- dtype (:obj:`type`): the type of wanted ndarray
Returns:
- item (:obj:`object`): the changed ndarray
.. note:
Now supports item type: :obj:`torch.Tensor`, :obj:`dict`, :obj:`list`, :obj:`tuple` and :obj:`None`
......@@ -214,9 +218,10 @@ def to_list(item: Any) -> list:
- item (:obj:`Any`): the item to be transformed
Returns:
- item (:obj:`list`): the list after transformation
.. note::
Now supports item type: :obj:`torch.Tensor`,:obj:`numpy.ndarray`, :obj:`dict`, :obj:`list`, \
Now supports item type: :obj:`torch.Tensor`, :obj:`numpy.ndarray`, :obj:`dict`, :obj:`list`, \
:obj:`tuple` and :obj:`None`
"""
if item is None:
......@@ -243,6 +248,7 @@ def tensor_to_list(item):
- item (:obj:`Any`): the item to be transformed
Returns:
- item (:obj:`list`): the list after transformation
.. note::
Now supports item type: :obj:`torch.Tensor`, :obj:`dict`, :obj:`list`, :obj:`tuple` and :obj:`None`
......@@ -329,8 +335,8 @@ class CudaFetcher(object):
def run(self) -> None:
"""
Overview:
Start `producer` thread: Keep fetching data from source,
change the device, and put into `queue` for request.
Start ``producer`` thread: Keep fetching data from source,
change the device, and put into ``queue`` for request.
"""
self._end_flag = False
self._producer_thread.start()
......@@ -338,7 +344,7 @@ class CudaFetcher(object):
def close(self) -> None:
"""
Overview:
Stop `producer` thread by setting `end_flag` to `True`.
Stop ``producer`` thread by setting ``end_flag`` to ``True`` .
"""
self._end_flag = True
......
......@@ -14,6 +14,7 @@ class Pd(object):
Abstract class for parameterizable probability distributions and sampling functions.
Interface:
neglogp, entropy, noise_mode, mode, sample
.. tip::
In dereived classes, `logits` should be an attribute member stored in class.
......@@ -85,7 +86,7 @@ class CategoricalPd(Pd):
Overview:
Updata logits
Arguments:
- logits (:obj:torch.Tensor): logits to update
- logits (:obj:`torch.Tensor`): logits to update
"""
self.logits = logits
......@@ -122,12 +123,12 @@ class CategoricalPd(Pd):
return entropy.mean()
def noise_mode(self, viz: bool = False) -> Tuple[torch.Tensor, Dict[str, np.ndarray]]:
r"""
"""
Overview:
add noise to logits
Arguments:
- viz (:obj:`bool`): Whether to return numpy from of logits, noise and noise_logits; \
Short for "visualize". (Because tensor type cannot visualize in tb or text log)
Short for ``visualize`` . (Because tensor type cannot visualize in tb or text log)
Returns:
- result (:obj:`torch.Tensor`): noised logits
- viz_feature (:obj:`Dict[str, np.ndarray]`): ndarray type data for visualization.
......@@ -146,12 +147,12 @@ class CategoricalPd(Pd):
return result
def mode(self, viz: bool = False) -> Tuple[torch.Tensor, Dict[str, np.ndarray]]:
r"""
"""
Overview:
return logits argmax result
Argiments:
- viz (:obj:`bool`): Whether to return numpy from of logits, noise and noise_logits; \
Short for "visualize". (Because tensor type cannot visualize in tb or text log)
Arguments:
- viz (:obj:`bool`): Whether to return numpy from of logits, noise and noise_logits;
Short for ``visualize`` . (Because tensor type cannot visualize in tb or text log)
Returns:
- result (:obj:`torch.Tensor`): the logits argmax result
- viz_feature (:obj:`Dict[str, np.ndarray]`): ndarray type data for visualization.
......@@ -165,12 +166,12 @@ class CategoricalPd(Pd):
return result
def sample(self, viz: bool = False) -> Tuple[torch.Tensor, Dict[str, np.ndarray]]:
r"""
"""
Overview:
Sample from logits's distribution by using softmax
Arguments:
- viz (:obj:`bool`): Whether to return numpy from of logits, noise and noise_logits; \
Short for "visualize". (Because tensor type cannot visualize in tb or text log)
Short for ``visualize`` . (Because tensor type cannot visualize in tb or text log)
Returns:
- result (:obj:`torch.Tensor`): the logits sampled result
- viz_feature (:obj:`Dict[str, np.ndarray]`): ndarray type data for visualization.
......@@ -186,15 +187,12 @@ class CategoricalPd(Pd):
class CategoricalPdPytorch(torch.distributions.Categorical):
r"""
"""
Overview:
Wrapped ``torch.distributions.Categorical``
Notes:
Please refer to ``torch.distributions.Categorical`` doc: \
https://pytorch.org/docs/stable/distributions.html?highlight=torch%20distributions#module-torch.distributions\
Categorical
Interface:
update_logits, updata_probs, sample, neglogp, mode, entropy
update_logits, update_probs, sample, neglogp, mode, entropy
"""
def __init__(self, probs: torch.Tensor = None) -> None:
......@@ -206,7 +204,7 @@ class CategoricalPdPytorch(torch.distributions.Categorical):
Overview:
Updata logits
Arguments:
- logits (:obj:torch.Tensor): logits to update
- logits (:obj:`torch.Tensor`): logits to update
"""
super().__init__(logits=logits)
......@@ -215,7 +213,7 @@ class CategoricalPdPytorch(torch.distributions.Categorical):
Overview:
Updata probs
Arguments:
- probs (:obj:torch.Tensor): probs to update
- probs (:obj:`torch.Tensor`): probs to update
"""
super().__init__(probs=probs)
......@@ -250,7 +248,7 @@ class CategoricalPdPytorch(torch.distributions.Categorical):
Overview:
Return logits argmax result
Return:
- result(:obj: `torch.Tensor`): the logits argmax result
- result(:obj:`torch.Tensor`): the logits argmax result
"""
return self.probs.argmax(dim=-1)
......
......@@ -9,7 +9,7 @@ class GLU(nn.Module):
Gating Linear Unit.
This class does a thing like this:
.. code::python
.. code:: python
# Inputs: input, context, output_size
# The gate value is a learnt function of the input.
......@@ -20,6 +20,7 @@ class GLU(nn.Module):
return output
Interfaces:
forward
.. tip::
This module also supports 2D convolution, in which case, the input and context must have the same shape.
......
......@@ -91,6 +91,7 @@ def conv1d_block(
- norm_type (:obj:`str`): type of the normalization
Returns:
- block (:obj:`nn.Sequential`): a sequential list containing the torch layers of the 1 dim convlution layer
.. note::
Conv1d (https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv1d)
......@@ -132,6 +133,7 @@ def conv2d_block(
- norm_type (:obj:`str`): type of the normalization, default set to None, now support ['BN', 'IN', 'SyncBN']
Returns:
- block (:obj:`nn.Sequential`): a sequential list containing the torch layers of the 2 dim convlution layer
.. note::
Conv2d (https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html#torch.nn.Conv2d)
......@@ -182,6 +184,7 @@ def deconv2d_block(
Returns:
- block (:obj:`nn.Sequential`): a sequential list containing the torch layers of the 2-dim \
transpose convlution layer
.. note::
ConvTranspose2d (https://pytorch.org/docs/master/generated/torch.nn.ConvTranspose2d.html)
......@@ -227,6 +230,7 @@ def fc_block(
- dropout_probability (:obj:`float`) : probability of an element to be zeroed in the dropout. Default: 0.5
Returns:
- block (:obj:`nn.Sequential`): a sequential list containing the torch layers of the fully-connected block
.. note::
you can refer to nn.linear (https://pytorch.org/docs/master/generated/torch.nn.Linear.html)
......@@ -270,6 +274,7 @@ def MLP(
- dropout_probability (:obj:`float`): probability of an element to be zeroed in the dropout. Default: 0.5
Returns:
- block (:obj:`nn.Sequential`): a sequential list containing the torch layers of the fully-connected block
.. note::
you can refer to nn.linear (https://pytorch.org/docs/master/generated/torch.nn.Linear.html)
......@@ -299,6 +304,7 @@ class ChannelShuffle(nn.Module):
Apply channelShuffle to the input tensor
Interface:
forward
.. note::
You can see the original paper shuffle net in https://arxiv.org/abs/1707.01083
......@@ -565,6 +571,7 @@ def noise_block(
- simga0 (:obj:`float`): the sigma0 is the defalut noise volumn when init NoiseLinearLayer
Returns:
- block (:obj:`nn.Sequential`): a sequential list containing the torch layers of the fully-connected block
.. note::
you can refer to nn.linear (https://pytorch.org/docs/master/generated/torch.nn.Linear.html)
......
......@@ -122,13 +122,13 @@ class LSTMForwardWrapper(object):
class LSTM(nn.Module, LSTMForwardWrapper):
r"""
Overview:
Implimentation of LSTM cell
Implimentation of LSTM cell with LN
Interface:
forward
.. note::
s
For begainners, you can refer to <https://zhuanlan.zhihu.com/p/32085405> to learn the basics about lstm
For beginners, you can refer to <https://zhuanlan.zhihu.com/p/32085405> to learn the basics about lstm
"""
def __init__(
......@@ -141,13 +141,13 @@ s
) -> None:
r"""
Overview:
Initializate the LSTM cell
Initializate the LSTM cell arguments and parameters
Arguments:
- input_size (:obj:`int`): size of the input vector
- hidden_size (:obj:`int`): size of the hidden state vector
- num_layers (:obj:`int`): number of lstm layers
- norm_type (:obj:`Optional[str]`): type of the normaliztion, (default: None)
- dropout (:obj:float): dropout rate, default set to .0
- dropout (:obj:`float`): dropout rate, default to 0
"""
super(LSTM, self).__init__()
self.input_size = input_size
......@@ -180,7 +180,7 @@ s
inputs: torch.Tensor,
prev_state: torch.Tensor,
list_next_state: bool = True) -> Tuple[torch.Tensor, Union[torch.Tensor, list]]:
r"""
"""
Overview:
Take the previous state and the input and calculate the output and the nextstate
Arguments:
......
......@@ -93,8 +93,10 @@ def get_data_decompressor(name: str):
Get the data decompressor according to the input name
Arguments:
- name(:obj:`str`): Name of the decompressor, support ``['lz4', 'zlib', 'none']``
.. note::
For all the decompressors, the input of a bytes-like object is required.
.. note::
For all the decompressors, the input of a bytes-like object is required.
Returns:
- (:obj:`Callable`): Corresponding data_decompressor.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册