This commit is contained in:
2024-07-05 19:21:02 +08:00
parent 422beb4a9b
commit 99f53727ba
2 changed files with 57 additions and 1 deletions

View File

@ -265,7 +265,12 @@ class DigitClassificationDataset(CustomDataset):
def get_validation_accuracy(self):
dev_logits = self.model.run(torch.tensor(self.dev_images, dtype=torch.float32)).data
dev_predicted = torch.argmax(dev_logits, axis=1).detach()
dev_accuracy = (dev_predicted == self.dev_labels).mean()
# print(f"dev_predicted:{dev_predicted}")
# print(f"self.dev_labels: {self.dev_labels}")
total = len(dev_predicted)
correct = torch.sum(torch.eq(dev_predicted.cpu(), torch.tensor(self.dev_labels))).float()
# dev_accuracy = (dev_predicted == self.dev_labels).mean()
dev_accuracy = correct / total
return dev_accuracy
class LanguageIDDataset(CustomDataset):