This commit is contained in:
2024-07-06 01:24:14 +08:00
parent f10a24eb0a
commit f105ba0150
3 changed files with 68 additions and 11 deletions

View File

@ -492,7 +492,7 @@ class DigitClassificationDataset2(CustomDataset):
def get_validation_accuracy(self):
dev_logits = self.model.run(torch.tensor(self.dev_images, dtype=torch.float32)).data
dev_predicted = torch.argmax(dev_logits, axis=1).detach()
dev_accuracy = torch.mean(torch.eq(dev_predicted, torch.tensor(self.dev_labels)).float())
dev_accuracy = torch.mean(torch.eq(dev_predicted.cpu(), torch.tensor(self.dev_labels)).float())
return dev_accuracy
def main():