cartpole_solver_parallel.cc 3.3 KB
Newer Older
Z
zenghsh3 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
//   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <algorithm>
#include <glog/logging.h>
#include <omp.h>
#include "cartpole.h"
#include "es_agent.h"
#include "paddle_api.h"

using namespace DeepES;
using namespace paddle::lite_api;

const int ITER = 10;

// Use PaddlePredictor of CartPole model to predict the action.
std::vector<float> forward(std::shared_ptr<PaddlePredictor> predictor, const float* obs) {
  std::unique_ptr<Tensor> input_tensor(std::move(predictor->GetInput(0)));
  input_tensor->Resize({1, 4});
  input_tensor->CopyFromCpu(obs);
  
  predictor->Run();
  
  std::vector<float> probs(2, 0.0);
  std::unique_ptr<const Tensor> output_tensor(
      std::move(predictor->GetOutput(0)));
  output_tensor->CopyToCpu(probs.data());
  return probs;
}

int arg_max(const std::vector<float>& vec) {
  return static_cast<int>(std::distance(vec.begin(), std::max_element(vec.begin(), vec.end())));
}


47
float evaluate(CartPole& env, std::shared_ptr<ESAgent> agent) {
Z
zenghsh3 已提交
48 49 50 51 52
  float total_reward = 0.0;
  env.reset();
  const float* obs = env.getState();

  std::shared_ptr<PaddlePredictor> paddle_predictor;
53
  paddle_predictor = agent->get_predictor();
Z
zenghsh3 已提交
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74

  while (true) {
    std::vector<float> probs = forward(paddle_predictor, obs); 
    int act = arg_max(probs);
    env.step(act);
    float reward = env.getReward(); 
    bool done = env.isDone();
    total_reward += reward;
    if (done) break;
    obs = env.getState();
  }
  return total_reward;
}


int main(int argc, char* argv[]) {
  std::vector<CartPole> envs;
  for (int i = 0; i < ITER; ++i) {
    envs.push_back(CartPole());
  }

Z
zhoubo01 已提交
75 76
  //std::shared_ptr<PaddlePredictor> paddle_predictor = create_paddle_predictor("../demo/paddle/cartpole_init_model");
  std::shared_ptr<ESAgent> agent = std::make_shared<ESAgent>("../demo/paddle/cartpole_init_model", "../benchmark/cartpole_config.prototxt");
Z
zenghsh3 已提交
77

78 79 80
  // Clone agents to sample (explore).
  std::vector< std::shared_ptr<ESAgent> > sampling_agents;
  for (int i = 0; i < ITER; ++i) {
Z
zenghsh3 已提交
81 82 83
    sampling_agents.push_back(agent->clone());
  }

Z
zhoubo01 已提交
84
  std::vector<SamplingInfo> noisy_keys;
Z
zenghsh3 已提交
85
  std::vector<float> noisy_rewards(ITER, 0.0f);
Z
zhoubo01 已提交
86
  noisy_keys.resize(ITER);
Z
zenghsh3 已提交
87 88

  omp_set_num_threads(10);
Z
zhoubo01 已提交
89
  for (int epoch = 0; epoch < 100; ++epoch) {
Z
zenghsh3 已提交
90 91 92
#pragma omp parallel for schedule(dynamic, 1)
    for (int i = 0; i < ITER; ++i) {
      std::shared_ptr<ESAgent> sampling_agent = sampling_agents[i];
Z
zhoubo01 已提交
93 94
      SamplingInfo key;
      bool success = sampling_agent->add_noise(key);
Z
zenghsh3 已提交
95 96
      float reward = evaluate(envs[i], sampling_agent);

Z
zhoubo01 已提交
97
      noisy_keys[i] = key;
Z
zenghsh3 已提交
98 99 100 101
      noisy_rewards[i] = reward;
    }

    // NOTE: all parameters of sampling_agents will be updated
Z
zhoubo01 已提交
102
    bool success = agent->update(noisy_keys, noisy_rewards);
Z
zenghsh3 已提交
103
  
104
    int reward = evaluate(envs[0], agent);
Z
zenghsh3 已提交
105 106 107
    LOG(INFO) << "Epoch:" << epoch << " Reward: " << reward;
  }
}