0


A2C算法原理及代码实现

本文主要参考王树森老师的强化学习课程

1.A2C算法原理

A2C算法是策略学习中比较经典的一个算法,是在 Barto 等人1983年提出的。我们知道策略梯度方法用策略梯度更新策略网络参数 θ,从而增大目标函数,即下面的随机梯度:

Actor-Critic 方法中用一个神经网络近似动作价值函数 Q π (s,a),这个神经网络叫做“价值网络”,记为 q(s,a;w),其中的 w 表示神经网络中可训练的参数。价值网络的输入是状态 s,输出是每个动作的价值。动作空间 A 中有多少种动作,那么价值网络的输出就是多少维的向量,向量每个元素对应一个动作。举个例子,动作空间是 A = {左,右,上},价值网络的输出是 :

神经网络可以采用以下结构:

虽然价值网络 q(s,a;w) 与DQN有相同的结构,但是两者的意义不同,训练算法也不同。、

  • 价值网络是对动作价值函数 Q π (s,a) 的近似。而 DQN 则是对最优动作价值函数Q ⋆ (s,a) 的近似。
  • 对价值网络的训练使用的是SARSA算法,它属于同策略,不能用经验回放。对DQN的训练使用的是 Q 学习算法,它属于异策略,可以用经验回放。

Actor-Critic 翻译成“演员—评论家”方法。策略网络 π(a|s;θ) 相当于演员,它基于状态 s 做出动作 a。价值网络 q(s,a;w) 相当于评论家,它给演员的表现打分,量化在状态 s的情况下做出动作 a 的好坏程度。策略网络(演员)和价值网络(评委)的关系如下图所示。

2. A2C算法训练流程

设当前策略网络参数是θnow ,价值网络参数是Wnow 。执行下面的步骤,将参数更新成 θnew 和 Wnew :

3.A2C代码实现

基于pytorch在gym基础环境中选择经典环境cartpole-v0倒立摆进行验证。

3.1 算法代码:

  1. import torch.optim as optim
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torch.distributions import Categorical
  5. class ActorCritic(nn.Module):
  6. ''' A2C网络模型,包含一个Actor和Critic
  7. '''
  8. def __init__(self, input_dim, output_dim, hidden_dim):
  9. super(ActorCritic, self).__init__()
  10. self.critic = nn.Sequential(
  11. nn.Linear(input_dim, hidden_dim),
  12. nn.ReLU(),
  13. nn.Linear(hidden_dim, 1)
  14. )
  15. self.actor = nn.Sequential(
  16. nn.Linear(input_dim, hidden_dim),
  17. nn.ReLU(),
  18. nn.Linear(hidden_dim, output_dim),
  19. nn.Softmax(dim=1),
  20. )
  21. def forward(self, x):
  22. value = self.critic(x)
  23. probs = self.actor(x)
  24. dist = Categorical(probs)
  25. return dist, value
  26. class A2C:
  27. ''' A2C算法
  28. '''
  29. def __init__(self,state_dim,action_dim,cfg) -> None:
  30. self.gamma = cfg.gamma
  31. self.device = cfg.device
  32. self.model = ActorCritic(state_dim, action_dim, cfg.hidden_size).to(self.device)
  33. self.optimizer = optim.Adam(self.model.parameters())
  34. def compute_returns(self,next_value, rewards, masks):
  35. R = next_value
  36. returns = []
  37. for step in reversed(range(len(rewards))):
  38. R = rewards[step] + self.gamma * R * masks[step]
  39. returns.insert(0, R)
  40. return returns

3.2 实验代码:

  1. import sys
  2. import os
  3. curr_path = os.path.dirname(os.path.abspath(__file__)) # 当前文件所在绝对路径
  4. parent_path = os.path.dirname(curr_path) # 父路径
  5. sys.path.append(parent_path) # 添加路径到系统路径
  6. import gym
  7. import numpy as np
  8. import torch
  9. import torch.optim as optim
  10. import datetime
  11. from common.multiprocessing_env import SubprocVecEnv
  12. from a2c import ActorCritic
  13. from common.utils import save_results, make_dir
  14. from common.utils import plot_rewards
  15. curr_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") # 获取当前时间
  16. algo_name = 'A2C' # 算法名称
  17. env_name = 'CartPole-v0' # 环境名称
  18. class A2CConfig:
  19. def __init__(self) -> None:
  20. self.algo_name = algo_name# 算法名称
  21. self.env_name = env_name # 环境名称
  22. self.n_envs = 8 # 异步的环境数目
  23. self.gamma = 0.99 # 强化学习中的折扣因子
  24. self.hidden_dim = 256
  25. self.lr = 1e-3 # learning rate
  26. self.max_frames = 30000
  27. self.n_steps = 5
  28. self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  29. class PlotConfig:
  30. def __init__(self) -> None:
  31. self.algo_name = algo_name # 算法名称
  32. self.env_name = env_name # 环境名称
  33. self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 检测GPU
  34. self.result_path = curr_path+"/outputs/" + self.env_name + \
  35. '/'+curr_time+'/results/' # 保存结果的路径
  36. self.model_path = curr_path+"/outputs/" + self.env_name + \
  37. '/'+curr_time+'/models/' # 保存模型的路径
  38. self.save = True # 是否保存图片
  39. def make_envs(env_name):
  40. def _thunk():
  41. env = gym.make(env_name)
  42. env.seed(2)
  43. return env
  44. return _thunk
  45. def ceshi_env(env,model,vis=False):
  46. state = env.reset()
  47. if vis: env.render()
  48. done = False
  49. total_reward = 0
  50. while not done:
  51. state = torch.FloatTensor(state).unsqueeze(0).to(cfg.device)
  52. dist, _ = model(state)
  53. next_state, reward, done, _ = env.step(dist.sample().cpu().numpy()[0])
  54. state = next_state
  55. if vis: env.render()
  56. total_reward += reward
  57. return total_reward
  58. def compute_returns(next_value, rewards, masks, gamma=0.99):
  59. R = next_value
  60. returns = []
  61. for step in reversed(range(len(rewards))):
  62. R = rewards[step] + gamma * R * masks[step]
  63. returns.insert(0, R)
  64. return returns
  65. def train(cfg,envs):
  66. print('开始训练!')
  67. print(f'环境:{cfg.env_name}, 算法:{cfg.algo_name}, 设备:{cfg.device}')
  68. env = gym.make(cfg.env_name) # a single env
  69. env.seed(10)
  70. state_dim = envs.observation_space.shape[0]
  71. action_dim = envs.action_space.n
  72. model = ActorCritic(state_dim, action_dim, cfg.hidden_dim).to(cfg.device)
  73. optimizer = optim.Adam(model.parameters())
  74. frame_idx = 0
  75. test_rewards = []
  76. test_ma_rewards = []
  77. state = envs.reset()
  78. while frame_idx < cfg.max_frames:
  79. log_probs = []
  80. values = []
  81. rewards = []
  82. masks = []
  83. entropy = 0
  84. # rollout trajectory
  85. for _ in range(cfg.n_steps):
  86. state = torch.FloatTensor(state).to(cfg.device)
  87. dist, value = model(state)
  88. action = dist.sample()
  89. next_state, reward, done, _ = envs.step(action.cpu().numpy())
  90. log_prob = dist.log_prob(action)
  91. entropy += dist.entropy().mean()
  92. log_probs.append(log_prob)
  93. values.append(value)
  94. rewards.append(torch.FloatTensor(reward).unsqueeze(1).to(cfg.device))
  95. masks.append(torch.FloatTensor(1 - done).unsqueeze(1).to(cfg.device))
  96. state = next_state
  97. frame_idx += 1
  98. if frame_idx % 100 == 0:
  99. test_reward = np.mean([ceshi_env(env,model) for _ in range(10)])
  100. print(f"frame_idx:{frame_idx}, test_reward:{test_reward}")
  101. test_rewards.append(test_reward)
  102. if test_ma_rewards:
  103. test_ma_rewards.append(0.9*test_ma_rewards[-1]+0.1*test_reward)
  104. else:
  105. test_ma_rewards.append(test_reward)
  106. # plot(frame_idx, test_rewards)
  107. next_state = torch.FloatTensor(next_state).to(cfg.device)
  108. _, next_value = model(next_state)
  109. returns = compute_returns(next_value, rewards, masks)
  110. log_probs = torch.cat(log_probs)
  111. returns = torch.cat(returns).detach()
  112. values = torch.cat(values)
  113. advantage = returns - values
  114. actor_loss = -(log_probs * advantage.detach()).mean()
  115. critic_loss = advantage.pow(2).mean()
  116. loss = actor_loss + 0.5 * critic_loss - 0.001 * entropy
  117. optimizer.zero_grad()
  118. loss.backward()
  119. optimizer.step()
  120. print('完成训练!')
  121. return test_rewards, test_ma_rewards
  122. if __name__ == "__main__":
  123. cfg = A2CConfig()
  124. plot_cfg = PlotConfig()
  125. envs = [make_envs(cfg.env_name) for i in range(cfg.n_envs)]
  126. envs = SubprocVecEnv(envs)
  127. # 训练
  128. rewards,ma_rewards = train(cfg,envs)
  129. make_dir(plot_cfg.result_path,plot_cfg.model_path)
  130. save_results(rewards, ma_rewards, tag='train', path=plot_cfg.result_path) # 保存结果
  131. plot_rewards(rewards, ma_rewards, plot_cfg, tag="train") # 画出结果

3.2 一些依赖的文件(common文件夹)

3.2.1 multiprocessing_env.py(来自 openai baseline,用于多线程环境)

  1. # 该代码来自 openai baseline,用于多线程环境
  2. # https://github.com/openai/baselines/tree/master/baselines/common/vec_env
  3. import numpy as np
  4. from multiprocessing import Process, Pipe
  5. def worker(remote, parent_remote, env_fn_wrapper):
  6. parent_remote.close()
  7. env = env_fn_wrapper.x()
  8. while True:
  9. cmd, data = remote.recv()
  10. if cmd == 'step':
  11. ob, reward, done, info = env.step(data)
  12. if done:
  13. ob = env.reset()
  14. remote.send((ob, reward, done, info))
  15. elif cmd == 'reset':
  16. ob = env.reset()
  17. remote.send(ob)
  18. elif cmd == 'reset_task':
  19. ob = env.reset_task()
  20. remote.send(ob)
  21. elif cmd == 'close':
  22. remote.close()
  23. break
  24. elif cmd == 'get_spaces':
  25. remote.send((env.observation_space, env.action_space))
  26. else:
  27. raise NotImplementedError
  28. class VecEnv(object):
  29. """
  30. An abstract asynchronous, vectorized environment.
  31. """
  32. def __init__(self, num_envs, observation_space, action_space):
  33. self.num_envs = num_envs
  34. self.observation_space = observation_space
  35. self.action_space = action_space
  36. def reset(self):
  37. """
  38. Reset all the environments and return an array of
  39. observations, or a tuple of observation arrays.
  40. If step_async is still doing work, that work will
  41. be cancelled and step_wait() should not be called
  42. until step_async() is invoked again.
  43. """
  44. pass
  45. def step_async(self, actions):
  46. """
  47. Tell all the environments to start taking a step
  48. with the given actions.
  49. Call step_wait() to get the results of the step.
  50. You should not call this if a step_async run is
  51. already pending.
  52. """
  53. pass
  54. def step_wait(self):
  55. """
  56. Wait for the step taken with step_async().
  57. Returns (obs, rews, dones, infos):
  58. - obs: an array of observations, or a tuple of
  59. arrays of observations.
  60. - rews: an array of rewards
  61. - dones: an array of "episode done" booleans
  62. - infos: a sequence of info objects
  63. """
  64. pass
  65. def close(self):
  66. """
  67. Clean up the environments' resources.
  68. """
  69. pass
  70. def step(self, actions):
  71. self.step_async(actions)
  72. return self.step_wait()
  73. class CloudpickleWrapper(object):
  74. """
  75. Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
  76. """
  77. def __init__(self, x):
  78. self.x = x
  79. def __getstate__(self):
  80. import cloudpickle
  81. return cloudpickle.dumps(self.x)
  82. def __setstate__(self, ob):
  83. import pickle
  84. self.x = pickle.loads(ob)
  85. class SubprocVecEnv(VecEnv):
  86. def __init__(self, env_fns, spaces=None):
  87. """
  88. envs: list of gym environments to run in subprocesses
  89. """
  90. self.waiting = False
  91. self.closed = False
  92. nenvs = len(env_fns)
  93. self.nenvs = nenvs
  94. self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])
  95. self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
  96. for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]
  97. for p in self.ps:
  98. p.daemon = True # if the main process crashes, we should not cause things to hang
  99. p.start()
  100. for remote in self.work_remotes:
  101. remote.close()
  102. self.remotes[0].send(('get_spaces', None))
  103. observation_space, action_space = self.remotes[0].recv()
  104. VecEnv.__init__(self, len(env_fns), observation_space, action_space)
  105. def step_async(self, actions):
  106. for remote, action in zip(self.remotes, actions):
  107. remote.send(('step', action))
  108. self.waiting = True
  109. def step_wait(self):
  110. results = [remote.recv() for remote in self.remotes]
  111. self.waiting = False
  112. obs, rews, dones, infos = zip(*results)
  113. return np.stack(obs), np.stack(rews), np.stack(dones), infos
  114. def reset(self):
  115. for remote in self.remotes:
  116. remote.send(('reset', None))
  117. return np.stack([remote.recv() for remote in self.remotes])
  118. def reset_task(self):
  119. for remote in self.remotes:
  120. remote.send(('reset_task', None))
  121. return np.stack([remote.recv() for remote in self.remotes])
  122. def close(self):
  123. if self.closed:
  124. return
  125. if self.waiting:
  126. for remote in self.remotes:
  127. remote.recv()
  128. for remote in self.remotes:
  129. remote.send(('close', None))
  130. for p in self.ps:
  131. p.join()
  132. self.closed = True
  133. def __len__(self):
  134. return self.nenvs

3.2.2 utils.py(主要是文件创建与绘图函数)

  1. import os
  2. import numpy as np
  3. from pathlib import Path
  4. import matplotlib.pyplot as plt
  5. # import seaborn as sns
  6. from matplotlib.font_manager import FontProperties # 导入字体模块
  7. def chinese_font():
  8. ''' 设置中文字体,注意需要根据自己电脑情况更改字体路径,否则还是默认的字体
  9. '''
  10. try:
  11. font = FontProperties(
  12. fname='/System/Library/Fonts/STHeiti Light.ttc', size=15) # fname系统字体路径,此处是mac的
  13. except:
  14. font = None
  15. return font
  16. def plot_rewards_cn(rewards, ma_rewards, plot_cfg, tag='train'):
  17. ''' 中文画图
  18. '''
  19. # sns.set()
  20. plt.figure()
  21. plt.title(u"{}环境下{}算法的学习曲线".format(plot_cfg.env_name,
  22. plot_cfg.algo_name), fontproperties=chinese_font())
  23. plt.xlabel(u'回合数', fontproperties=chinese_font())
  24. plt.plot(rewards)
  25. plt.plot(ma_rewards)
  26. plt.legend((u'奖励', u'滑动平均奖励',), loc="best", prop=chinese_font())
  27. if plot_cfg.save:
  28. plt.savefig(plot_cfg.result_path+f"{tag}_rewards_curve_cn")
  29. # plt.show()
  30. def plot_rewards(rewards, ma_rewards, plot_cfg, tag='train'):
  31. # sns.set()
  32. plt.figure() # 创建一个图形实例,方便同时多画几个图
  33. plt.title("learning curve on {} of {} for {}".format(
  34. plot_cfg.device, plot_cfg.algo_name, plot_cfg.env_name))
  35. plt.xlabel('epsiodes')
  36. plt.plot(rewards, label='rewards')
  37. plt.plot(ma_rewards, label='ma rewards')
  38. plt.legend()
  39. if plot_cfg.save:
  40. plt.savefig(plot_cfg.result_path+"{}_rewards_curve".format(tag))
  41. plt.show()
  42. def plot_losses(losses, algo="DQN", save=True, path='./'):
  43. # sns.set()
  44. plt.figure()
  45. plt.title("loss curve of {}".format(algo))
  46. plt.xlabel('epsiodes')
  47. plt.plot(losses, label='rewards')
  48. plt.legend()
  49. if save:
  50. plt.savefig(path+"losses_curve")
  51. plt.show()
  52. def save_results(rewards, ma_rewards, tag='train', path='./results'):
  53. ''' 保存奖励
  54. '''
  55. np.save(path+'{}_rewards.npy'.format(tag), rewards)
  56. np.save(path+'{}_ma_rewards.npy'.format(tag), ma_rewards)
  57. print('结果保存完毕!')
  58. def make_dir(*paths):
  59. ''' 创建文件夹
  60. '''
  61. for path in paths:
  62. Path(path).mkdir(parents=True, exist_ok=True)
  63. def del_empty_dir(*paths):
  64. ''' 删除目录下所有空文件夹
  65. '''
  66. for path in paths:
  67. dirs = os.listdir(path)
  68. for dir in dirs:
  69. if not os.listdir(os.path.join(path, dir)):
  70. os.removedirs(os.path.join(path, dir))

4 实验结果


本文转载自: https://blog.csdn.net/weixin_45985148/article/details/127143122
版权归原作者 Cary. 所有, 如有侵权,请联系我们删除。

“A2C算法原理及代码实现”的评论:

还没有评论