diff --git a/machinelearning/models.py b/machinelearning/models.py index af9965f..4824e81 100644 --- a/machinelearning/models.py +++ b/machinelearning/models.py @@ -298,8 +298,8 @@ class DigitClassificationModel(Module): """ """ YOUR CODE HERE """ optimizer = torch.optim.Adam(self.parameters(), lr=0.0005) - dataloader = DataLoader(dataset, batch_size=20, shuffle=True) - max_round=15000 + dataloader = DataLoader(dataset, batch_size=100, shuffle=True) + max_round=30000 required_accuracy=0.99 round_cnt=0 while round_cnt