rein q7
This commit is contained in:
@ -97,15 +97,15 @@ class PacmanDeepQAgent(PacmanQAgent):
|
||||
next_states = torch.tensor(next_states)
|
||||
done = np.array([x.done for x in minibatch])
|
||||
|
||||
Q_predict = network.run(states).data.detach().numpy()
|
||||
Q_predict = network.run(states).data.detach().cpu().numpy()
|
||||
Q_target = np.copy(Q_predict )
|
||||
state_indices = states.int().detach().numpy()
|
||||
state_indices = (state_indices[:, 0], state_indices[:, 1])
|
||||
exploration_bonus = 1 / (2 * np.sqrt((self.counts[state_indices] / 100)))
|
||||
|
||||
replace_indices = np.arange(actions.shape[0])
|
||||
action_indices = np.argmax(network.run(next_states).data, axis=1)
|
||||
target = rewards + exploration_bonus + (1 - done) * self.discount * target_network.run(next_states).data[replace_indices, action_indices].detach().numpy()
|
||||
action_indices = np.argmax(network.run(next_states).data.cpu(), axis=1)
|
||||
target = rewards + exploration_bonus + (1 - done) * self.discount * target_network.run(next_states).data[replace_indices, action_indices].detach().cpu().numpy()
|
||||
|
||||
Q_target[replace_indices, actions] = target
|
||||
|
||||
|
Reference in New Issue
Block a user