未验证 提交 0a068653 编写于 作者: L LI Yunxiang 提交者: GitHub

remove version 1.3 warnings (#252)

* remove version 1.3 warnings

* update

* yapf

* add algorithms test

* Update algs_test.py

* Update algs_test.py

add SAC DDPG TD3 tests

* yapf
上级 91e9814a
......@@ -59,7 +59,6 @@ Within class ``DQN(Algorithm)``, we define the following methods:
Args:
model (parl.Model): model defining forward network of Q function
hyperparas (dict): (deprecated) dict of hyper parameters.
act_dim (int): dimension of the action space
gamma (float): discounted factor for reward computation.
lr (float): learning rate.
......
......@@ -19,23 +19,16 @@ import copy
import paddle.fluid as fluid
from parl.core.fluid.algorithm import Algorithm
from parl.core.fluid import layers
from parl.utils.deprecation import deprecated
__all__ = ['DQN']
class DQN(Algorithm):
def __init__(self,
model,
hyperparas=None,
act_dim=None,
gamma=None,
lr=None):
def __init__(self, model, act_dim=None, gamma=None, lr=None):
""" DQN algorithm
Args:
model (parl.Model): model defining forward network of Q function
hyperparas (dict): (deprecated) dict of hyper parameters.
act_dim (int): dimension of the action space
gamma (float): discounted factor for reward computation.
lr (float): learning rate.
......@@ -43,20 +36,12 @@ class DQN(Algorithm):
self.model = model
self.target_model = copy.deepcopy(model)
if hyperparas is not None:
warnings.warn(
"the `hyperparas` argument of `__init__` function in `parl.Algorithms.DQN` is deprecated since version 1.2 and will be removed in version 1.3.",
DeprecationWarning,
stacklevel=2)
self.act_dim = hyperparas['action_dim']
self.gamma = hyperparas['gamma']
else:
assert isinstance(act_dim, int)
assert isinstance(gamma, float)
assert isinstance(lr, float)
self.act_dim = act_dim
self.gamma = gamma
self.lr = lr
assert isinstance(act_dim, int)
assert isinstance(gamma, float)
assert isinstance(lr, float)
self.act_dim = act_dim
self.gamma = gamma
self.lr = lr
def predict(self, obs):
""" use value model self.model to predict the action value
......@@ -100,12 +85,7 @@ class DQN(Algorithm):
cost = layers.reduce_mean(cost)
return cost
def sync_target(self, gpu_id=None):
def sync_target(self):
""" sync weights of self.model to self.target_model
"""
if gpu_id is not None:
warnings.warn(
"the `gpu_id` argument of `sync_target` function in `parl.Algorithms.DQN` is deprecated since version 1.2 and will be removed in version 1.3.",
DeprecationWarning,
stacklevel=2)
self.model.sync_weights_to(self.target_model)
......@@ -24,25 +24,17 @@ __all__ = ['A3C']
class A3C(Algorithm):
def __init__(self, model, hyperparas=None, vf_loss_coeff=None):
def __init__(self, model, vf_loss_coeff=None):
""" A3C/A2C algorithm
Args:
model (parl.Model): forward network of policy and value
hyperparas (dict): (deprecated) dict of hyper parameters.
vf_loss_coeff (float): coefficient of the value function loss
"""
self.model = model
if hyperparas is not None:
warnings.warn(
"the `hyperparas` argument of `__init__` function in `parl.Algorithms.A3C` is deprecated since version 1.2 and will be removed in version 1.3.",
DeprecationWarning,
stacklevel=2)
self.vf_loss_coeff = hyperparas['vf_loss_coeff']
else:
assert isinstance(vf_loss_coeff, (int, float))
self.vf_loss_coeff = vf_loss_coeff
assert isinstance(vf_loss_coeff, (int, float))
self.vf_loss_coeff = vf_loss_coeff
def learn(self, obs, actions, advantages, target_values, learning_rate,
entropy_coeff):
......
......@@ -19,7 +19,6 @@ from parl.core.fluid import layers
from copy import deepcopy
from paddle import fluid
from parl.core.fluid.algorithm import Algorithm
from parl.utils.deprecation import deprecated
__all__ = ['DDPG']
......@@ -27,7 +26,6 @@ __all__ = ['DDPG']
class DDPG(Algorithm):
def __init__(self,
model,
hyperparas=None,
gamma=None,
tau=None,
actor_lr=None,
......@@ -37,53 +35,28 @@ class DDPG(Algorithm):
Args:
model (parl.Model): forward network of actor and critic.
The function get_actor_params() of model should be implemented.
hyperparas (dict): (deprecated) dict of hyper parameters.
gamma (float): discounted factor for reward computation.
tau (float): decay coefficient when updating the weights of self.target_model with self.model
actor_lr (float): learning rate of the actor model
critic_lr (float): learning rate of the critic model
"""
if hyperparas is not None:
warnings.warn(
"the `hyperparas` argument of `__init__` function in `parl.Algorithms.DDPG` is deprecated since version 1.2 and will be removed in version 1.3.",
DeprecationWarning,
stacklevel=2)
self.gamma = hyperparas['gamma']
self.tau = hyperparas['tau']
self.actor_lr = hyperparas['actor_lr']
self.critic_lr = hyperparas['critic_lr']
else:
assert isinstance(gamma, float)
assert isinstance(tau, float)
assert isinstance(actor_lr, float)
assert isinstance(critic_lr, float)
self.gamma = gamma
self.tau = tau
self.actor_lr = actor_lr
self.critic_lr = critic_lr
assert isinstance(gamma, float)
assert isinstance(tau, float)
assert isinstance(actor_lr, float)
assert isinstance(critic_lr, float)
self.gamma = gamma
self.tau = tau
self.actor_lr = actor_lr
self.critic_lr = critic_lr
self.model = model
self.target_model = deepcopy(model)
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='predict')
def define_predict(self, obs):
""" use actor model of self.model to predict the action
"""
return self.predict(obs)
def predict(self, obs):
""" use actor model of self.model to predict the action
"""
return self.model.policy(obs)
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='learn')
def define_learn(self, obs, action, reward, next_obs, terminal):
""" update actor and critic model with DDPG algorithm
"""
return self.learn(obs, action, reward, next_obs, terminal)
def learn(self, obs, action, reward, next_obs, terminal):
""" update actor and critic model with DDPG algorithm
"""
......@@ -115,15 +88,7 @@ class DDPG(Algorithm):
optimizer.minimize(cost)
return cost
def sync_target(self,
gpu_id=None,
decay=None,
share_vars_parallel_executor=None):
if gpu_id is not None:
warnings.warn(
"the `gpu_id` argument of `sync_target` function in `parl.Algorithms.DDPG` is deprecated since version 1.2 and will be removed in version 1.3.",
DeprecationWarning,
stacklevel=2)
def sync_target(self, decay=None, share_vars_parallel_executor=None):
if decay is None:
decay = 1.0 - self.tau
self.model.sync_weights_to(
......
......@@ -85,12 +85,7 @@ class DDQN(Algorithm):
optimizer.minimize(cost)
return cost
def sync_target(self, gpu_id=None):
def sync_target(self):
""" sync weights of self.model to self.target_model
"""
if gpu_id is not None:
warnings.warn(
"the `gpu_id` argument of `sync_target` function in `parl.Algorithms.DQN` is deprecated since version 1.2 and will be removed in version 1.3.",
DeprecationWarning,
stacklevel=2)
self.model.sync_weights_to(self.target_model)
......@@ -19,18 +19,16 @@ import copy
import paddle.fluid as fluid
from parl.core.fluid.algorithm import Algorithm
from parl.core.fluid import layers
from parl.utils.deprecation import deprecated
__all__ = ['DQN']
class DQN(Algorithm):
def __init__(self, model, hyperparas=None, act_dim=None, gamma=None):
def __init__(self, model, act_dim=None, gamma=None):
""" DQN algorithm
Args:
model (parl.Model): model defining forward network of Q function
hyperparas (dict): (deprecated) dict of hyper parameters.
act_dim (int): dimension of the action space
gamma (float): discounted factor for reward computation.
lr (float): learning rate.
......@@ -38,38 +36,16 @@ class DQN(Algorithm):
self.model = model
self.target_model = copy.deepcopy(model)
if hyperparas is not None:
warnings.warn(
"the `hyperparas` argument of `__init__` function in `parl.Algorithms.DQN` is deprecated since version 1.2 and will be removed in version 1.3.",
DeprecationWarning,
stacklevel=2)
self.act_dim = hyperparas['action_dim']
self.gamma = hyperparas['gamma']
else:
assert isinstance(act_dim, int)
assert isinstance(gamma, float)
self.act_dim = act_dim
self.gamma = gamma
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='predict')
def define_predict(self, obs):
""" use value model self.model to predict the action value
"""
return self.predict(obs)
assert isinstance(act_dim, int)
assert isinstance(gamma, float)
self.act_dim = act_dim
self.gamma = gamma
def predict(self, obs):
""" use value model self.model to predict the action value
"""
return self.model.value(obs)
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='learn')
def define_learn(self, obs, action, reward, next_obs, terminal,
learning_rate):
return self.learn(obs, action, reward, next_obs, terminal,
learning_rate)
def learn(self, obs, action, reward, next_obs, terminal, learning_rate):
""" update value model self.model with DQN algorithm
"""
......@@ -92,12 +68,7 @@ class DQN(Algorithm):
optimizer.minimize(cost)
return cost
def sync_target(self, gpu_id=None):
def sync_target(self):
""" sync weights of self.model to self.target_model
"""
if gpu_id is not None:
warnings.warn(
"the `gpu_id` argument of `sync_target` function in `parl.Algorithms.DQN` is deprecated since version 1.2 and will be removed in version 1.3.",
DeprecationWarning,
stacklevel=2)
self.model.sync_weights_to(self.target_model)
......@@ -85,7 +85,6 @@ class VTraceLoss(object):
class IMPALA(Algorithm):
def __init__(self,
model,
hyperparas=None,
sample_batch_steps=None,
gamma=None,
vf_loss_coeff=None,
......@@ -95,34 +94,22 @@ class IMPALA(Algorithm):
Args:
model (parl.Model): forward network of policy and value
hyperparas (dict): (deprecated) dict of hyper parameters.
sample_batch_steps (int): steps of each environment sampling.
gamma (float): discounted factor for reward computation.
vf_loss_coeff (float): coefficient of the value function loss.
clip_rho_threshold (float): clipping threshold for importance weights (rho).
clip_pg_rho_threshold (float): clipping threshold on rho_s in \rho_s \delta log \pi(a|x) (r + \gamma v_{s+1} - V(x_s)).
"""
if hyperparas is not None:
warnings.warn(
"the `hyperparas` argument of `__init__` function in `parl.Algorithms.IMPALA` is deprecated since version 1.2 and will be removed in version 1.3.",
DeprecationWarning,
stacklevel=2)
self.sample_batch_steps = hyperparas['sample_batch_steps']
self.gamma = hyperparas['gamma']
self.vf_loss_coeff = hyperparas['vf_loss_coeff']
self.clip_rho_threshold = hyperparas['clip_rho_threshold']
self.clip_pg_rho_threshold = hyperparas['clip_pg_rho_threshold']
else:
assert isinstance(sample_batch_steps, int)
assert isinstance(gamma, float)
assert isinstance(vf_loss_coeff, float)
assert isinstance(clip_rho_threshold, float)
assert isinstance(clip_pg_rho_threshold, float)
self.sample_batch_steps = sample_batch_steps
self.gamma = gamma
self.vf_loss_coeff = vf_loss_coeff
self.clip_rho_threshold = clip_rho_threshold
self.clip_pg_rho_threshold = clip_pg_rho_threshold
assert isinstance(sample_batch_steps, int)
assert isinstance(gamma, float)
assert isinstance(vf_loss_coeff, float)
assert isinstance(clip_rho_threshold, float)
assert isinstance(clip_pg_rho_threshold, float)
self.sample_batch_steps = sample_batch_steps
self.gamma = gamma
self.vf_loss_coeff = vf_loss_coeff
self.clip_rho_threshold = clip_rho_threshold
self.clip_pg_rho_threshold = clip_pg_rho_threshold
self.model = model
......
......@@ -18,51 +18,28 @@ warnings.simplefilter('default')
import paddle.fluid as fluid
from parl.core.fluid.algorithm import Algorithm
from parl.core.fluid import layers
from parl.utils.deprecation import deprecated
__all__ = ['PolicyGradient']
class PolicyGradient(Algorithm):
def __init__(self, model, hyperparas=None, lr=None):
def __init__(self, model, lr=None):
""" Policy Gradient algorithm
Args:
model (parl.Model): forward network of the policy.
hyperparas (dict): (deprecated) dict of hyper parameters.
lr (float): learning rate of the policy model.
"""
self.model = model
if hyperparas is not None:
warnings.warn(
"the `hyperparas` argument of `__init__` function in `parl.Algorithms.PolicyGradient` is deprecated since version 1.2 and will be removed in version 1.3.",
DeprecationWarning,
stacklevel=2)
self.lr = hyperparas['lr']
else:
assert isinstance(lr, float)
self.lr = lr
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='predict')
def define_predict(self, obs):
""" use policy model self.model to predict the action probability
"""
return self.predict(obs)
assert isinstance(lr, float)
self.lr = lr
def predict(self, obs):
""" use policy model self.model to predict the action probability
"""
return self.model(obs)
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='learn')
def define_learn(self, obs, action, reward):
""" update policy model self.model with policy gradient algorithm
"""
return self.learn(obs, action, reward)
def learn(self, obs, action, reward):
""" update policy model self.model with policy gradient algorithm
"""
......
......@@ -20,7 +20,6 @@ from copy import deepcopy
from paddle import fluid
from parl.core.fluid import layers
from parl.core.fluid.algorithm import Algorithm
from parl.utils.deprecation import deprecated
__all__ = ['PPO']
......@@ -28,7 +27,6 @@ __all__ = ['PPO']
class PPO(Algorithm):
def __init__(self,
model,
hyperparas=None,
act_dim=None,
policy_lr=None,
value_lr=None,
......@@ -37,7 +35,6 @@ class PPO(Algorithm):
Args:
model (parl.Model): model defining forward network of policy and value.
hyperparas (dict): (deprecated) dict of hyper parameters.
act_dim (float): dimension of the action space.
policy_lr (float): learning rate of the policy model.
value_lr (float): learning rate of the value model.
......@@ -47,27 +44,14 @@ class PPO(Algorithm):
# Used to calculate probability of action in old policy
self.old_policy_model = deepcopy(model.policy_model)
if hyperparas is not None:
warnings.warn(
"the `hyperparas` argument of `__init__` function in `parl.Algorithms.PPO` is deprecated since version 1.2 and will be removed in version 1.3.",
DeprecationWarning,
stacklevel=2)
self.act_dim = hyperparas['act_dim']
self.policy_lr = hyperparas['policy_lr']
self.value_lr = hyperparas['value_lr']
if 'epsilon' in hyperparas:
self.epsilon = hyperparas['epsilon']
else:
self.epsilon = 0.2 # default
else:
assert isinstance(act_dim, int)
assert isinstance(policy_lr, float)
assert isinstance(value_lr, float)
assert isinstance(epsilon, float)
self.act_dim = act_dim
self.policy_lr = policy_lr
self.value_lr = value_lr
self.epsilon = epsilon
assert isinstance(act_dim, int)
assert isinstance(policy_lr, float)
assert isinstance(value_lr, float)
assert isinstance(epsilon, float)
self.act_dim = act_dim
self.policy_lr = policy_lr
self.value_lr = value_lr
self.epsilon = epsilon
def _calc_logprob(self, actions, means, logvars):
""" Calculate log probabilities of actions, when given means and logvars
......@@ -111,49 +95,18 @@ class PPO(Algorithm):
log_det_cov_new - log_det_cov_old) + tr_old_new - self.act_dim)
return kl
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='predict')
def define_predict(self, obs):
""" Use policy model of self.model to predict means and logvars of actions
"""
return self.predict(obs)
def predict(self, obs):
""" Use the policy model of self.model to predict means and logvars of actions
"""
means, logvars = self.model.policy(obs)
return means
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='sample')
def define_sample(self, obs):
""" Use the policy model of self.model to sample actions
"""
return self.sample(obs)
def sample(self, obs):
""" Use the policy model of self.model to sample actions
"""
sampled_act = self.model.policy_sample(obs)
return sampled_act
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='policy_learn')
def define_policy_learn(self, obs, actions, advantages, beta=None):
""" Learn policy model with:
1. CLIP loss: Clipped Surrogate Objective
2. KLPEN loss: Adaptive KL Penalty Objective
See: https://arxiv.org/pdf/1707.02286.pdf
Args:
obs: Tensor, (batch_size, obs_dim)
actions: Tensor, (batch_size, act_dim)
advantages: Tensor (batch_size, )
beta: Tensor (1) or None
if None, use CLIP Loss; else, use KLPEN loss.
"""
return self.policy_learn(obs, actions, advantages, beta)
def policy_learn(self, obs, actions, advantages, beta=None):
""" Learn policy model with:
1. CLIP loss: Clipped Surrogate Objective
......@@ -196,27 +149,11 @@ class PPO(Algorithm):
optimizer.minimize(loss)
return loss, kl
@deprecated(
deprecated_in='1.2',
removed_in='1.3',
replace_function='value_predict')
def define_value_predict(self, obs):
""" Use value model of self.model to predict value of obs
"""
return self.value_predict(obs)
def value_predict(self, obs):
""" Use value model of self.model to predict value of obs
"""
return self.model.value(obs)
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='value_learn')
def define_value_learn(self, obs, val):
""" Learn value model with square error cost
"""
return self.value_learn(obs, val)
def value_learn(self, obs, val):
""" Learn the value model with square error cost
"""
......@@ -227,12 +164,7 @@ class PPO(Algorithm):
optimizer.minimize(loss)
return loss
def sync_old_policy(self, gpu_id=None):
def sync_old_policy(self):
""" Synchronize weights of self.model.policy_model to self.old_policy_model
"""
if gpu_id is not None:
warnings.warn(
"the `gpu_id` argument of `sync_old_policy` function in `parl.Algorithms.PPO` is deprecated since version 1.2 and will be removed in version 1.3.",
DeprecationWarning,
stacklevel=2)
self.model.policy_model.sync_weights_to(self.old_policy_model)
此差异已折叠。
......@@ -27,7 +27,7 @@ __all__ = ['A2C']
class A2C(parl.Algorithm):
def __init__(self, model, config, hyperparas=None):
def __init__(self, model, config):
assert isinstance(config['vf_loss_coeff'], (int, float))
self.model = model
self.vf_loss_coeff = config['vf_loss_coeff']
......
......@@ -17,7 +17,6 @@ warnings.simplefilter('default')
import paddle.fluid as fluid
from parl.core.fluid import layers
from parl.utils.deprecation import deprecated
from parl.core.agent_base import AgentBase
from parl.core.fluid.algorithm import Algorithm
from parl.utils import machine_info
......@@ -46,7 +45,6 @@ class Agent(AgentBase):
This class will initialize the neural network parameters automatically, and provides an executor for users to run the programs (self.fluid_executor).
Attributes:
gpu_id (int): deprecated. specify which GPU to be used. -1 if to use the CPU.
fluid_executor (fluid.Executor): executor for running programs of the agent.
alg (parl.algorithm): algorithm of this agent.
......@@ -65,18 +63,12 @@ class Agent(AgentBase):
"""
def __init__(self, algorithm, gpu_id=None):
def __init__(self, algorithm):
"""Build programs by calling the method ``self.build_program()`` and run initialization function of ``fluid.default_startup_program()``.
Args:
algorithm (parl.Algorithm): an instance of `parl.Algorithm`. This algorithm is then passed to `self.alg`.
gpu_id (int): deprecated. specify which GPU to be used. -1 if to use the CPU.
"""
if gpu_id is not None:
warnings.warn(
"the `gpu_id` argument of `__init__` function in `parl.Agent` is deprecated since version 1.2 and will be removed in version 1.3.",
DeprecationWarning,
stacklevel=2)
assert isinstance(algorithm, Algorithm)
super(Agent, self).__init__(algorithm)
......@@ -119,26 +111,6 @@ class Agent(AgentBase):
"""
raise NotImplementedError
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='get_weights')
def get_params(self):
""" Returns a Python dictionary containing the whole parameters of self.alg.
Returns:
a Python List containing the parameters of self.alg.
"""
return self.algorithm.get_params()
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='set_weights')
def set_params(self, params):
"""Copy parameters from ``get_params()`` into this agent.
Args:
params(dict): a Python List containing the parameters of self.alg.
"""
self.algorithm.set_params(params)
def learn(self, *args, **kwargs):
"""The training interface for ``Agent``.
This function feeds the training data into the learn_program defined in ``build_program()``.
......
......@@ -17,7 +17,6 @@ warnings.simplefilter('default')
from parl.core.algorithm_base import AlgorithmBase
from parl.core.fluid.model import Model
from parl.utils.deprecation import deprecated
__all__ = ['Algorithm']
......@@ -57,47 +56,13 @@ class Algorithm(AlgorithmBase):
"""
def __init__(self, model=None, hyperparas=None):
def __init__(self, model=None):
"""
Args:
model(``parl.Model``): a neural network that represents a policy or a Q-value function.
hyperparas(dict): a dict storing the hyper-parameters relative to training.
"""
if model is not None:
warnings.warn(
"the `model` argument of `__init__` function in `parl.Algorithm` is deprecated since version 1.2 and will be removed in version 1.3.",
DeprecationWarning,
stacklevel=2)
assert isinstance(model, Model)
self.model = model
if hyperparas is not None:
warnings.warn(
"the `hyperparas` argument of `__init__` function in `parl.Algorithm` is deprecated since version 1.2 and will be removed in version 1.3.",
DeprecationWarning,
stacklevel=2)
self.hp = hyperparas
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='get_weights')
def get_params(self):
""" Get parameters of self.model.
Returns:
params(dict): a Python List containing the parameters of self.model.
"""
return self.model.get_params()
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='set_weights')
def set_params(self, params):
""" Set parameters from ``get_params`` to the model.
Args:
params(dict ): a Python List containing the parameters of self.model.
"""
self.model.set_params(params)
assert isinstance(model, Model)
self.model = model
def learn(self, *args, **kwargs):
""" Define the loss function and create an optimizer to minize the loss.
......
......@@ -17,7 +17,6 @@ import paddle.fluid as fluid
from parl.core.fluid.layers.layer_wrappers import LayerFunc
from parl.core.fluid.plutils import *
from parl.core.model_base import ModelBase
from parl.utils.deprecation import deprecated
from parl.utils import machine_info
__all__ = ['Model']
......@@ -67,30 +66,6 @@ class Model(ModelBase):
"""
@deprecated(
deprecated_in='1.2',
removed_in='1.3',
replace_function='sync_weights_to')
def sync_params_to(self,
target_net,
gpu_id=None,
decay=0.0,
share_vars_parallel_executor=None):
"""Synchronize parameters in the model to another model (target_net).
target_net_weights = decay * target_net_weights + (1 - decay) * source_net_weights
Args:
target_model (`parl.Model`): an instance of ``Model`` that has the same neural network architecture as the current model.
decay (float): the rate of decline in copying parameters. 0 if no parameters decay when synchronizing the parameters.
share_vars_parallel_executor (fluid.ParallelExecutor): Optional. If not None, will use fluid.ParallelExecutor
to run program instead of fluid.Executor
"""
self.sync_weights_to(
target_model=target_net,
decay=decay,
share_vars_parallel_executor=share_vars_parallel_executor)
def sync_weights_to(self,
target_model,
decay=0.0,
......@@ -181,21 +156,6 @@ class Model(ModelBase):
else:
self._cached_fluid_executor.run(fetch_list=[])
@property
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='parameters')
def parameter_names(self):
"""Get names of all parameters in this ``Model``.
Only parameters created by ``parl.layers`` are included.
The order of parameter names is consistent among
different instances of the same `Model`.
Returns:
param_names(list): list of string containing parameter names of all parameters.
"""
return self.parameters()
def parameters(self):
"""Get names of all parameters in this ``Model``.
......@@ -223,26 +183,6 @@ class Model(ModelBase):
self._parameter_names = self._get_parameter_names(self)
return self._parameter_names
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='get_weights')
def get_params(self):
""" Return a Python list containing parameters of current model.
Returns:
parameters: a Python list containing parameters of the current model.
"""
return self.get_weights()
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='set_weights')
def set_params(self, params, gpu_id=None):
"""Set parameters in the model with params.
Args:
params (List): List of numpy array .
"""
self.set_weights(weights=params)
def get_weights(self):
"""Returns a Python list containing parameters of current model.
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
warnings.simplefilter('default')
warnings.warn(
"import way `import parl.framework` is deprecated since version 1.2 and will be removed in version 1.3.",
DeprecationWarning,
stacklevel=2)
from parl.core.fluid.model import *
from parl.core.fluid.algorithm import *
from parl.core.fluid.agent import *
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
warnings.simplefilter('default')
warnings.warn(
"module `parl.framework.agent_base.Agent` is deprecated since version 1.2 and will be removed in version 1.3, please use `parl.Agent` instead.",
DeprecationWarning,
stacklevel=2)
from parl.core.fluid.agent import *
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
warnings.simplefilter('default')
warnings.warn(
"module `parl.framework.algorithm_base.Algorithm` is deprecated since version 1.2 and will be removed in version 1.3, please use `parl.Algorithm` instead.",
DeprecationWarning,
stacklevel=2)
from parl.core.fluid.algorithm import *
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
warnings.simplefilter('default')
warnings.warn(
"module `parl.framework.model_base.Model` is deprecated since version 1.2 and will be removed in version 1.3, please use `parl.Model` instead.",
DeprecationWarning,
stacklevel=2)
from parl.core.fluid.model import *
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
warnings.simplefilter('default')
warnings.warn(
"module `parl.framework.policy_distribution` is deprecated since version 1.2 and will be removed in version 1.3, please use `parl.policy_distribution` instead.",
DeprecationWarning,
stacklevel=2)
from parl.core.fluid.policy_distribution import *
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
warnings.simplefilter('default')
warnings.warn(
"import way `import parl.layers` is deprecated since version 1.2 and will be removed in version 1.3, please use `from parl import layers` or `import parl; parl.layers` instead.",
DeprecationWarning,
stacklevel=2)
from parl.core.fluid.layers import *
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
print(
"import way `import parl.plutils` is deprecated since version 1.2 and will be removed in version 1.3, please use `from parl import plutils` or `import parl; parl.plutils` instead."
)
from parl.core.fluid.plutils.common import *
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
print(
"import way `import parl.plutils` is deprecated since version 1.2 and will be removed in version 1.3, please use `from parl import plutils` or `import parl; parl.plutils` instead."
)
from parl.core.fluid.plutils.common import *
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册