diff --git a/reinforcement/deepQLearningAgents.py b/reinforcement/deepQLearningAgents.py index 1e8331b..5172364 100644 --- a/reinforcement/deepQLearningAgents.py +++ b/reinforcement/deepQLearningAgents.py @@ -73,6 +73,9 @@ class PacmanDeepQAgent(PacmanQAgent): else: print("Initializing new model parameters") def save_model(self, filename="para.bin"): + if model.kProductionMode: + print("in production mode, not saving model") + return print(f"Saving model parameters to {filename}") torch.save({ 'model_state_dict': self.model.state_dict(), diff --git a/reinforcement/model.py b/reinforcement/model.py index e9b82ae..e3405c9 100644 --- a/reinforcement/model.py +++ b/reinforcement/model.py @@ -11,7 +11,7 @@ from torch import tensor, double, optim from torch.nn.functional import relu, mse_loss import torch -kProductionMode=True +kProductionMode=False class DeepQNetwork(Module): """ A model that uses a Deep Q-value Network (DQN) to approximate Q(s,a) as part