try to solve smallClassic
This commit is contained in:
@ -5,6 +5,7 @@ import layout
|
||||
import copy
|
||||
import torch
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
class PacmanDeepQAgent(PacmanQAgent):
|
||||
def __init__(self, layout_input="smallGrid", target_update_rate=300, doubleQ=True, **args):
|
||||
@ -15,8 +16,9 @@ class PacmanDeepQAgent(PacmanQAgent):
|
||||
self.update_amount = 0
|
||||
self.epsilon_explore = 1.0
|
||||
self.epsilon0 = 0.4
|
||||
self.minimal_epsilon = 0.01
|
||||
self.epsilon = self.epsilon0
|
||||
self.discount = 0.9
|
||||
self.discount = 0.95
|
||||
self.update_frequency = 3
|
||||
self.counts = None
|
||||
self.replay_memory = ReplayMemory(50000)
|
||||
@ -54,6 +56,27 @@ class PacmanDeepQAgent(PacmanQAgent):
|
||||
import model
|
||||
self.model = model.DeepQNetwork(state_dim, action_dim)
|
||||
self.target_model = model.DeepQNetwork(state_dim, action_dim)
|
||||
if os.path.exists('para.bin'):
|
||||
print("Loading model parameters from para.bin")
|
||||
checkpoint = torch.load('para.bin')
|
||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
self.target_model.load_state_dict(checkpoint['target_model_state_dict'])
|
||||
self.model.optimizer.load_state_dict(checkpoint['model_optimizer_state_dict'])
|
||||
self.target_model.optimizer.load_state_dict(checkpoint['target_model_optimizer_state_dict'])
|
||||
self.replay_memory = checkpoint['memory']
|
||||
print(self.model.state_dict())
|
||||
else:
|
||||
print("Initializing new model parameters")
|
||||
def save_model(self, filename="para.bin"):
|
||||
print(f"Saving model parameters to {filename}")
|
||||
torch.save({
|
||||
'model_state_dict': self.model.state_dict(),
|
||||
'target_model_state_dict': self.target_model.state_dict(),
|
||||
'model_optimizer_state_dict': self.model.optimizer.state_dict(),
|
||||
"target_model_optimizer_state_dict": self.target_model.optimizer.state_dict(),
|
||||
"memory": self.replay_memory
|
||||
}, filename)
|
||||
print(self.model.state_dict())
|
||||
|
||||
def getQValue(self, state, action):
|
||||
"""
|
||||
@ -136,7 +159,7 @@ class PacmanDeepQAgent(PacmanQAgent):
|
||||
if len(self.replay_memory) < self.min_transitions_before_training:
|
||||
self.epsilon = self.epsilon_explore
|
||||
else:
|
||||
self.epsilon = max(self.epsilon0 * (1 - self.update_amount / 20000), 0)
|
||||
self.epsilon = max(self.epsilon0 * (1 - self.update_amount / 20000), self.minimal_epsilon)
|
||||
|
||||
|
||||
if len(self.replay_memory) > self.min_transitions_before_training and self.update_amount % self.update_frequency == 0:
|
||||
|
Reference in New Issue
Block a user