reach +0!

This commit is contained in:
2024-07-19 05:55:58 +08:00
parent ceae34ea86
commit 25f46c9a13
2 changed files with 15 additions and 4 deletions

View File

@ -17,6 +17,11 @@ class PacmanDeepQAgent(PacmanQAgent):
self.epsilon_explore = 1.0 self.epsilon_explore = 1.0
self.epsilon0 = 0.4 self.epsilon0 = 0.4
self.minimal_epsilon = 0.01 self.minimal_epsilon = 0.01
if model.kProductionMode:
self.epsilon_explore=0.01
self.epsilon0=0.01
self.minimal_epsilon=0.01
print("in production mode, epsilon set to 0.01")
self.epsilon = self.epsilon0 self.epsilon = self.epsilon0
self.discount = 0.95 self.discount = 0.95
self.update_frequency = 3 self.update_frequency = 3

View File

@ -11,7 +11,7 @@ from torch import tensor, double, optim
from torch.nn.functional import relu, mse_loss from torch.nn.functional import relu, mse_loss
import torch import torch
kProductionMode=True
class DeepQNetwork(Module): class DeepQNetwork(Module):
""" """
A model that uses a Deep Q-value Network (DQN) to approximate Q(s,a) as part A model that uses a Deep Q-value Network (DQN) to approximate Q(s,a) as part
@ -26,10 +26,12 @@ class DeepQNetwork(Module):
"*** YOUR CODE HERE ***" "*** YOUR CODE HERE ***"
# Initialize layers # Initialize layers
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
layer0_size=256
layer1_size=512 layer1_size=512
layer2_size=128 layer2_size=128
layer3_size=64 layer3_size=64
self.fc1 = Linear(state_dim, layer1_size).to(self.device) self.fc0 = Linear(state_dim, layer0_size).to(self.device)
self.fc1 = Linear(layer0_size, layer1_size).to(self.device)
self.fc2 = Linear(layer1_size, layer2_size).to(self.device) self.fc2 = Linear(layer1_size, layer2_size).to(self.device)
self.fc3 = Linear(layer2_size, layer3_size).to(self.device) self.fc3 = Linear(layer2_size, layer3_size).to(self.device)
self.fc_out= Linear(layer3_size, action_dim).to(self.device) self.fc_out= Linear(layer3_size, action_dim).to(self.device)
@ -82,7 +84,8 @@ class DeepQNetwork(Module):
"*** YOUR CODE HERE ***" "*** YOUR CODE HERE ***"
if states.device.type != self.device.type: if states.device.type != self.device.type:
states = states.to(self.device) states = states.to(self.device)
x = relu(self.fc1(states)) x = relu(self.fc0(states))
x = relu(self.fc1(x))
x = relu(self.fc2(x)) x = relu(self.fc2(x))
x = relu(self.fc3(x)) x = relu(self.fc3(x))
Q_values = self.fc_out(x) Q_values = self.fc_out(x)
@ -107,6 +110,9 @@ class DeepQNetwork(Module):
None None
""" """
"*** YOUR CODE HERE ***" "*** YOUR CODE HERE ***"
if kProductionMode:
print("in production mode, no update")
return
self.optimizer.zero_grad() self.optimizer.zero_grad()
loss = self.get_loss(states, Q_target) loss = self.get_loss(states, Q_target)
loss.backward() loss.backward()