reach +0!
This commit is contained in:
@ -17,6 +17,11 @@ class PacmanDeepQAgent(PacmanQAgent):
|
|||||||
self.epsilon_explore = 1.0
|
self.epsilon_explore = 1.0
|
||||||
self.epsilon0 = 0.4
|
self.epsilon0 = 0.4
|
||||||
self.minimal_epsilon = 0.01
|
self.minimal_epsilon = 0.01
|
||||||
|
if model.kProductionMode:
|
||||||
|
self.epsilon_explore=0.01
|
||||||
|
self.epsilon0=0.01
|
||||||
|
self.minimal_epsilon=0.01
|
||||||
|
print("in production mode, epsilon set to 0.01")
|
||||||
self.epsilon = self.epsilon0
|
self.epsilon = self.epsilon0
|
||||||
self.discount = 0.95
|
self.discount = 0.95
|
||||||
self.update_frequency = 3
|
self.update_frequency = 3
|
||||||
|
@ -11,7 +11,7 @@ from torch import tensor, double, optim
|
|||||||
from torch.nn.functional import relu, mse_loss
|
from torch.nn.functional import relu, mse_loss
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
kProductionMode=True
|
||||||
class DeepQNetwork(Module):
|
class DeepQNetwork(Module):
|
||||||
"""
|
"""
|
||||||
A model that uses a Deep Q-value Network (DQN) to approximate Q(s,a) as part
|
A model that uses a Deep Q-value Network (DQN) to approximate Q(s,a) as part
|
||||||
@ -26,10 +26,12 @@ class DeepQNetwork(Module):
|
|||||||
"*** YOUR CODE HERE ***"
|
"*** YOUR CODE HERE ***"
|
||||||
# Initialize layers
|
# Initialize layers
|
||||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
layer0_size=256
|
||||||
layer1_size=512
|
layer1_size=512
|
||||||
layer2_size=128
|
layer2_size=128
|
||||||
layer3_size=64
|
layer3_size=64
|
||||||
self.fc1 = Linear(state_dim, layer1_size).to(self.device)
|
self.fc0 = Linear(state_dim, layer0_size).to(self.device)
|
||||||
|
self.fc1 = Linear(layer0_size, layer1_size).to(self.device)
|
||||||
self.fc2 = Linear(layer1_size, layer2_size).to(self.device)
|
self.fc2 = Linear(layer1_size, layer2_size).to(self.device)
|
||||||
self.fc3 = Linear(layer2_size, layer3_size).to(self.device)
|
self.fc3 = Linear(layer2_size, layer3_size).to(self.device)
|
||||||
self.fc_out= Linear(layer3_size, action_dim).to(self.device)
|
self.fc_out= Linear(layer3_size, action_dim).to(self.device)
|
||||||
@ -82,7 +84,8 @@ class DeepQNetwork(Module):
|
|||||||
"*** YOUR CODE HERE ***"
|
"*** YOUR CODE HERE ***"
|
||||||
if states.device.type != self.device.type:
|
if states.device.type != self.device.type:
|
||||||
states = states.to(self.device)
|
states = states.to(self.device)
|
||||||
x = relu(self.fc1(states))
|
x = relu(self.fc0(states))
|
||||||
|
x = relu(self.fc1(x))
|
||||||
x = relu(self.fc2(x))
|
x = relu(self.fc2(x))
|
||||||
x = relu(self.fc3(x))
|
x = relu(self.fc3(x))
|
||||||
Q_values = self.fc_out(x)
|
Q_values = self.fc_out(x)
|
||||||
@ -107,6 +110,9 @@ class DeepQNetwork(Module):
|
|||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
"*** YOUR CODE HERE ***"
|
"*** YOUR CODE HERE ***"
|
||||||
|
if kProductionMode:
|
||||||
|
print("in production mode, no update")
|
||||||
|
return
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
loss = self.get_loss(states, Q_target)
|
loss = self.get_loss(states, Q_target)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
Reference in New Issue
Block a user