A3C算法原理及代码详解

释放双眼,带上耳机,听听看~!
本文介绍了A3C算法的原理及其代码详解,讨论了A3C算法相比于Actor-Critic的优化,以及其在强化学习中的应用,旨在提高训练速度和效率。

前言

       强化学习有一个问题,就是它很慢,怎么提高训练的速度呢?在动漫《火影忍者》中,有一次鸣人想要在一周之内打败晓,所以要加快修行的速度,鸣人的老师就教他一个方法:用影分身进行同样的修行。两个一起修行,经验值累积的速度就会变成两倍,所以鸣人就使用了 1000 个影分身来进行修行。这就是异步优势演员-评论员算法的体现。

A3C算法原理及代码详解

       在提出 DDPG 后,DeepMind 在这个基础上提出了效果更好的 Asynchronous Advantage Actor-Critic(A3C),详见论文:Asynchronous Methods for Deep Reinforcement Learning (arxiv.org)

A3C算法原理

       A3C是A2C的异步版本。在A3C的设计中,协调器被移除。每个Worker节点直接和全局行动者和全局批评者进行对话。Master节点则不再需要等待各个Worker节点提供的梯度信息,而是在每次有Worker节点结束梯度计算的时候直接更新全局Actor-Critic。由于不再需要等待,A3C比A2C有更高的计算效率。但是同样也由于没有协调器协调各个Worker节点,Worker节点提供梯度信息和全局Actor-Critic的一致性不再成立,即每次Master节点从Worker节点得到的梯度信息很可能不再是当前全局Actor-Critic的梯度信息。

       A3C算法同时使用很多个进程(worker),每一个进程就像一个影分身,最后这些影分身会把所有的经验值集合在一起。如果我们没有很多CPU,不好实现异步优势演员-评论员算法,但可以实现优势演员-评论员算法。具体来说,利用多线程的方法,同时在多个线程里面分别和环境进行交互学习,每个线程都把学习的成果汇总起来,整理保存在一个公共的地方。并且,定期从公共的地方把大家的齐心学习的成果拿回来,指导自己和环境后面的学习交互。通过这种方法,A3C避免了经验回放相关性过强的问题,同时做到了异步并发的学习模型。

相比Actor-Critic,A3C的优化主要有3点:异步训练框架网络结构优化Critic评估点的优化

1. 异步训练框架:

A3C算法原理及代码详解

  • 图中Global Network就是共享的公共部分,主要是一个神经网络模型,包括Actor网络和Critic网络两部分

  • 下面有n个worker线程,每个线程里有和公共的神经网络一样的网络结构,每个线程会独立的和环境进行交互得到经验数据,这些线程之间互不干扰,独立运行

  • 每个线程和环境交互到一定量的数据后,就计算在自己线程里的神经网络损失函数的梯度,这些梯度用来更新公共的神经网络。

  • 每隔一段时间,线程的参数会更新为公共的参数,进而指导后面的环境交互

2. 网络结构优化:

在Actor-Critic中,我们使用了两个不同的网络Actor和Critic。在A3C这里,我们把两个网络放到了一起

3. Critic评估点的优化

在A3C中,使用了N步采样,可以加速收敛,因此使用优势函数表达式为:

A3C算法原理及代码详解

在参数更新当中加入了熵项:

A3C算法原理及代码详解

A3C算法的代码详解

Worker 类的结构如下所示:

class Worker(object):
    def __init__(self, name, globalAC):
        ...

    def work(self):
        ...

每个 Worker 节点都有自己的行动者网络批判者网络。所以在初始化函数中,我们通过实例化 ACNet 类来创建模型。

class Worker(object):
    def __init__(self, name, globalAC):
        self.env = gym.make(GAME).unwrapped
        self.name = name
        self.AC = ACNet(name, globalAC)

work() 函数是 Worker 类的主要函数。这里循环的主要内容是从智能体取得动作,并与环境交互。

    def work(self):
        global GLOBAL_RUNNING_R, GLOBAL_EP
        total_step = 1
        buffer_s, buffer_a, buffer_r = [], [], []
        while not COORD.should_stop() and GLOBAL_EP < MAX_GLOBAL_EP:
            s = self.env.reset()
            ep_r = 0
            for ep_t in range(MAX_EP_STEP):
                # if self.name == 'W_0':
                #     self.env.render()
                a = self.AC.choose_action(s)
                s_, r, done, info = self.env.step(a)
                done = True if ep_t == MAX_EP_STEP - 1 else False

                ep_r += r
                buffer_s.append(s)
                buffer_a.append(a)
                buffer_r.append((r+8)/8) 

当智能体采集足够的数据时,将开始更新全局网络。在那之后,本地网络的参数将被替换为更新后的最新全局网络参数。

                if total_step % UPDATE_GLOBAL_ITER == 0 or done:   
                    if done:
                        v_s_ = 0   # terminal
                    else:
                        v_s_ = SESS.run(self.AC.v, {self.AC.s: s_[np.newaxis, :]})[0, 0]
                    buffer_v_target = []
                    for r in buffer_r[::-1]:    # reverse buffer r
                        v_s_ = r + GAMMA * v_s_
                        buffer_v_target.append(v_s_)
                    buffer_v_target.reverse()

                    buffer_s, buffer_a, buffer_v_target = np.vstack(buffer_s), np.vstack(buffer_a), np.vstack(buffer_v_target)
                    feed_dict = {
                        self.AC.s: buffer_s,
                        self.AC.a_his: buffer_a,
                        self.AC.v_target: buffer_v_target,
                    }
                    self.AC.update_global(feed_dict)
                    buffer_s, buffer_a, buffer_r = [], [], []
                    self.AC.pull_global()

                s = s_
                total_step += 1
                if done:
                    if len(GLOBAL_RUNNING_R) == 0:  # record running episode reward
                        GLOBAL_RUNNING_R.append(ep_r)
                    else:
                        GLOBAL_RUNNING_R.append(0.9 * GLOBAL_RUNNING_R[-1] + 0.1 * ep_r)
                    print(
                        self.name,
                        "Ep:", GLOBAL_EP,
                        "| Ep_r: %i" % GLOBAL_RUNNING_R[-1],
                          )
                    GLOBAL_EP += 1
                    break

在上述代码中用到的 ACNet 类包含行动者和批判者。它的结构如下所示:

class ACNet(object):
    def __init__(self, scope, globalAC=None):
        ...
    def _build_net(self, scope):
        ...
    def update_global(self, feed_dict):  # run by a local
        ...
    def pull_global(self):  # 本地运行,从全局网络同步数据
        ...
    def choose_action(self, s):  # run by a local
        ...

update_global()函数是其中最重要的函数之一,它使用了采样数据来计算梯度,但是将梯度应用到全局网络,在那之后,再从全局网络更新数据,并继续循环。在这个模式下,可以异步更新多个Worker节点。

最后,准备工作都完成后,在主函数中逐一启动各个线程即可。

if __name__ == "__main__":
    SESS = tf.Session()

    with tf.device("/cpu:0"):
        OPT_A = tf.train.RMSPropOptimizer(LR_A, name='RMSPropA')
        OPT_C = tf.train.RMSPropOptimizer(LR_C, name='RMSPropC')
        GLOBAL_AC = ACNet(GLOBAL_NET_SCOPE)  # we only need its params
        workers = []
        # Create worker
        for i in range(N_WORKERS):
            i_name = 'W_%i' % i   # worker name
            workers.append(Worker(i_name, GLOBAL_AC))

    COORD = tf.train.Coordinator()
    SESS.run(tf.global_variables_initializer())

    if OUTPUT_GRAPH:
        if os.path.exists(LOG_DIR):
            shutil.rmtree(LOG_DIR)
        tf.summary.FileWriter(LOG_DIR, SESS.graph)

    worker_threads = []
    for worker in workers:
        job = lambda: worker.work()
        t = threading.Thread(target=job)
        t.start()
        worker_threads.append(t)
    COORD.join(worker_threads)

    plt.plot(np.arange(len(GLOBAL_RUNNING_R)), GLOBAL_RUNNING_R)
    plt.xlabel('step')
    plt.ylabel('Total moving reward')
    plt.show()

A3C算法Tensorflow实现结果:

A3C算法原理及代码详解

A3C算法原理及代码详解

本网站的内容主要来自互联网上的各种资源,仅供参考和信息分享之用,不代表本网站拥有相关版权或知识产权。如您认为内容侵犯您的权益,请联系我们,我们将尽快采取行动,包括删除或更正。
AI教程

Visual ChatGPT:跨界AI的新尝试

2023-12-17 15:24:14

AI教程

2022最新YOLOv7论文解读:图像检测算法的最前沿

2023-12-17 15:39:14

个人中心
购物车
优惠劵
今日签到
有新私信 私信列表
搜索