try to solve smallClassic

This commit is contained in:
2024-07-18 19:18:55 +08:00
parent 1bf4cc1efe
commit ceae34ea86
6 changed files with 58 additions and 13 deletions

View File

@ -5,6 +5,7 @@ import layout
import copy
import torch
import numpy as np
import os
class PacmanDeepQAgent(PacmanQAgent):
def __init__(self, layout_input="smallGrid", target_update_rate=300, doubleQ=True, **args):
@ -15,8 +16,9 @@ class PacmanDeepQAgent(PacmanQAgent):
self.update_amount = 0
self.epsilon_explore = 1.0
self.epsilon0 = 0.4
self.minimal_epsilon = 0.01
self.epsilon = self.epsilon0
self.discount = 0.9
self.discount = 0.95
self.update_frequency = 3
self.counts = None
self.replay_memory = ReplayMemory(50000)
@ -54,6 +56,27 @@ class PacmanDeepQAgent(PacmanQAgent):
import model
self.model = model.DeepQNetwork(state_dim, action_dim)
self.target_model = model.DeepQNetwork(state_dim, action_dim)
if os.path.exists('para.bin'):
print("Loading model parameters from para.bin")
checkpoint = torch.load('para.bin')
self.model.load_state_dict(checkpoint['model_state_dict'])
self.target_model.load_state_dict(checkpoint['target_model_state_dict'])
self.model.optimizer.load_state_dict(checkpoint['model_optimizer_state_dict'])
self.target_model.optimizer.load_state_dict(checkpoint['target_model_optimizer_state_dict'])
self.replay_memory = checkpoint['memory']
print(self.model.state_dict())
else:
print("Initializing new model parameters")
def save_model(self, filename="para.bin"):
print(f"Saving model parameters to {filename}")
torch.save({
'model_state_dict': self.model.state_dict(),
'target_model_state_dict': self.target_model.state_dict(),
'model_optimizer_state_dict': self.model.optimizer.state_dict(),
"target_model_optimizer_state_dict": self.target_model.optimizer.state_dict(),
"memory": self.replay_memory
}, filename)
print(self.model.state_dict())
def getQValue(self, state, action):
"""
@ -136,7 +159,7 @@ class PacmanDeepQAgent(PacmanQAgent):
if len(self.replay_memory) < self.min_transitions_before_training:
self.epsilon = self.epsilon_explore
else:
self.epsilon = max(self.epsilon0 * (1 - self.update_amount / 20000), 0)
self.epsilon = max(self.epsilon0 * (1 - self.update_amount / 20000), self.minimal_epsilon)
if len(self.replay_memory) > self.min_transitions_before_training and self.update_amount % self.update_frequency == 0: