sc2_model.py 3.6 KB
Newer Older
R
rical730 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import parl


class ComaModel(parl.Model):
    def __init__(self, config):
        super(ComaModel, self).__init__()
        self.n_actions = config['n_actions']
        self.n_agents = config['n_agents']
        self.state_shape = config['state_shape']
        self.obs_shape = config['obs_shape']

        actor_input_dim = self._get_actor_input_dim()
        critic_input_dim = self._get_critic_input_dim()

        self.actor_model = ActorModel(actor_input_dim, self.n_actions)
        self.critic_model = CriticModel(critic_input_dim, self.n_actions)

    def policy(self, obs, hidden_state):
        return self.actor_model.policy(obs, hidden_state)

    def value(self, inputs):
        return self.critic_model.value(inputs)

    def get_actor_params(self):
        return self.actor_model.parameters()

    def get_critic_params(self):
        return self.critic_model.parameters()

    def _get_actor_input_dim(self):
        input_shape = self.obs_shape  # obs: 30 in 3m map
        input_shape += self.n_actions  # agent's last action (one_hot): 9 in 3m map
        input_shape += self.n_agents  # agent's one_hot id: 3 in 3m map
        return input_shape  # 30 + 9 + 3 = 42

    def _get_critic_input_dim(self):
        input_shape = self.state_shape  # state: 48 in 3m map
        input_shape += self.obs_shape  # obs: 30 in 3m map
        input_shape += self.n_agents  # agent_id: 3 in 3m map
        input_shape += self.n_actions * self.n_agents * 2  # all agents' action and last_action (one-hot): 54 in 3m map
        return input_shape  # 48 + 30+ 3 = 135


# all agents share one actor network
class ActorModel(parl.Model):
    def __init__(self, input_shape, act_dim):
        """ input : obs, include the agent's id and last action, shape: (batch, obs_shape + n_action + n_agents)
            output: one agent's q(obs, act)
        """
        super(ActorModel, self).__init__()
        self.hid_size = 64

        self.fc1 = nn.Linear(input_shape, self.hid_size)
        self.rnn = nn.GRUCell(self.hid_size, self.hid_size)
        self.fc2 = nn.Linear(self.hid_size, act_dim)

    def init_hidden(self):
        # new hidden states
        return self.fc1.weight.new(1, self.hid_size).zero_()

    def policy(self, obs, h0):
        x = F.relu(self.fc1(obs))
        h1 = h0.reshape(-1, self.hid_size)
        h2 = self.rnn(x, h1)
        policy = self.fc2(h2)
        return policy, h2


class CriticModel(parl.Model):
    def __init__(self, input_shape, act_dim):
        """ inputs: [ s(t), o(t)_a, u(t)_a, agent_a, u(t-1) ], shape: (Batch, input_shape)
            output: Q,   shape: (Batch, n_actions)
            Batch = ep_num * n_agents
        """
        super(CriticModel, self).__init__()
        hid_size = 128
        self.fc1 = nn.Linear(input_shape, hid_size)
        self.fc2 = nn.Linear(hid_size, hid_size)
        self.fc3 = nn.Linear(hid_size, act_dim)

    def value(self, inputs):
        hid1 = F.relu(self.fc1(inputs))
        hid2 = F.relu(self.fc2(hid1))
        Q = self.fc3(hid2)
        return Q