reach +0!

This commit is contained in:
2024-07-19 05:55:58 +08:00
parent ceae34ea86
commit 25f46c9a13
2 changed files with 15 additions and 4 deletions

View File

@ -17,6 +17,11 @@ class PacmanDeepQAgent(PacmanQAgent):
self.epsilon_explore = 1.0
self.epsilon0 = 0.4
self.minimal_epsilon = 0.01
if model.kProductionMode:
self.epsilon_explore=0.01
self.epsilon0=0.01
self.minimal_epsilon=0.01
print("in production mode, epsilon set to 0.01")
self.epsilon = self.epsilon0
self.discount = 0.95
self.update_frequency = 3

View File

@ -11,7 +11,7 @@ from torch import tensor, double, optim
from torch.nn.functional import relu, mse_loss
import torch
kProductionMode=True
class DeepQNetwork(Module):
"""
A model that uses a Deep Q-value Network (DQN) to approximate Q(s,a) as part
@ -26,10 +26,12 @@ class DeepQNetwork(Module):
"*** YOUR CODE HERE ***"
# Initialize layers
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
layer0_size=256
layer1_size=512
layer2_size=128
layer3_size=64
self.fc1 = Linear(state_dim, layer1_size).to(self.device)
self.fc0 = Linear(state_dim, layer0_size).to(self.device)
self.fc1 = Linear(layer0_size, layer1_size).to(self.device)
self.fc2 = Linear(layer1_size, layer2_size).to(self.device)
self.fc3 = Linear(layer2_size, layer3_size).to(self.device)
self.fc_out= Linear(layer3_size, action_dim).to(self.device)
@ -82,7 +84,8 @@ class DeepQNetwork(Module):
"*** YOUR CODE HERE ***"
if states.device.type != self.device.type:
states = states.to(self.device)
x = relu(self.fc1(states))
x = relu(self.fc0(states))
x = relu(self.fc1(x))
x = relu(self.fc2(x))
x = relu(self.fc3(x))
Q_values = self.fc_out(x)
@ -107,6 +110,9 @@ class DeepQNetwork(Module):
None
"""
"*** YOUR CODE HERE ***"
if kProductionMode:
print("in production mode, no update")
return
self.optimizer.zero_grad()
loss = self.get_loss(states, Q_target)
loss.backward()
@ -115,4 +121,4 @@ class DeepQNetwork(Module):
# self.scheduler2.step()
self.output_cnt+=1
if self.output_cnt%self.output_step==0:
print("now lr is: ", self.optimizer.param_groups[0]['lr'],"update count", self.output_cnt)
print("now lr is: ", self.optimizer.param_groups[0]['lr'],"update count", self.output_cnt)