ml q3
This commit is contained in:
@ -265,7 +265,12 @@ class DigitClassificationDataset(CustomDataset):
|
||||
def get_validation_accuracy(self):
|
||||
dev_logits = self.model.run(torch.tensor(self.dev_images, dtype=torch.float32)).data
|
||||
dev_predicted = torch.argmax(dev_logits, axis=1).detach()
|
||||
dev_accuracy = (dev_predicted == self.dev_labels).mean()
|
||||
# print(f"dev_predicted:{dev_predicted}")
|
||||
# print(f"self.dev_labels: {self.dev_labels}")
|
||||
total = len(dev_predicted)
|
||||
correct = torch.sum(torch.eq(dev_predicted.cpu(), torch.tensor(self.dev_labels))).float()
|
||||
# dev_accuracy = (dev_predicted == self.dev_labels).mean()
|
||||
dev_accuracy = correct / total
|
||||
return dev_accuracy
|
||||
|
||||
class LanguageIDDataset(CustomDataset):
|
||||
|
Reference in New Issue
Block a user