machinelearning icon indicating copy to clipboard operation
machinelearning copied to clipboard

您好,可以帮我看看这个Nature-DQN的代码吗?运行特别慢

Open miss-fang opened this issue 6 years ago • 0 comments

`import tensorflow as tf import numpy as np import gym import random from collections import deque from keras.utils.np_utils import to_categorical

import tensorflow.keras.backend as K

class QNetwork(tf.keras.Model):

def __init__(self):
    super().__init__()
    self.dense1=tf.keras.layers.Dense(24,activation='relu')
    self.dense2=tf.keras.layers.Dense(2)
    self.dense3=tf.keras.layers.Dense(24,activation='relu')
    self.dense4=tf.keras.layers.Dense(2)

def call(self,inputs):
    x=self.dense1(inputs)
    x=self.dense2(x)
    return x

def tarNet_Q(self,inputs):
    x=self.dense3(inputs)
    x=self.dense4(x)
    return x

def get_action(self,inputs):
    q_values=self(inputs)
    return K.eval(tf.argmax(q_values,axis=-1))[0]
    

env=gym.make('CartPole-v0')

num_episodes=300 num_exploration=200 max_len=400 batch_size=32 lr=1e-3 gamma=0.9 initial_epsilon=0.5 final_epsilon=0.01 replay_buffer=deque(maxlen=10000) tarNet_update_frequence=10 optimizer=tf.train.AdamOptimizer(learning_rate=lr) qNet=QNetwork() for i in range(1,num_episodes+1): state=env.reset() epsilon=max(initial_epsilon*(num_exploration-i)/num_exploration,final_epsilon) for t in range(max_len):#设置最大得分1000 if random.random()<epsilon: action=env.action_space.sample() else: action=qNet.get_action(tf.constant(np.expand_dims(state,axis=0),dtype=tf.float32)) next_state,reward,done,info=env.step(action) reward=-1.if done else reward replay_buffer.append((state,action,reward,next_state,done)) state=next_state if done: print('episode %d,epsilon %f,score %d'%(i,epsilon,t)) break if len(replay_buffer)>=batch_size: batch_state,batch_action,batch_reward,batch_next_state,batch_done=
[np.array(a,dtype=np.float32) for a in zip(random.sample(replay_buffer,batch_size))] q_value=qNet.tarNet_Q(tf.constant(batch_next_state,dtype=tf.float32)) y=batch_reward+(gammatf.reduce_max(q_value,axis=1))*(1-batch_done) with tf.GradientTape() as tape: loss=tf.losses.mean_squared_error(y,tf.reduce_max( qNet(tf.constant(batch_state))*to_categorical(batch_action,num_classes=2),axis=1)) grads=tape.gradient(loss,qNet.variables[:4]) optimizer.apply_gradients(grads_and_vars=zip(grads,qNet.variables[:4])) if i%tarNet_update_frequence==0: for j in range(2): tf.assign(qNet.variables[4+j],qNet.dense1.get_weights()[j]) tf.assign(qNet.variables[6+j],qNet.dense2.get_weights()[j]) env.close() ` 我觉得运行慢是因为复制网络参数的方式不对,请看到的兄弟姐妹给个建议。

miss-fang avatar Oct 21 '19 15:10 miss-fang