ml q5
This commit is contained in:
@ -492,7 +492,7 @@ class DigitClassificationDataset2(CustomDataset):
|
||||
def get_validation_accuracy(self):
|
||||
dev_logits = self.model.run(torch.tensor(self.dev_images, dtype=torch.float32)).data
|
||||
dev_predicted = torch.argmax(dev_logits, axis=1).detach()
|
||||
dev_accuracy = torch.mean(torch.eq(dev_predicted, torch.tensor(self.dev_labels)).float())
|
||||
dev_accuracy = torch.mean(torch.eq(dev_predicted.cpu(), torch.tensor(self.dev_labels)).float())
|
||||
return dev_accuracy
|
||||
|
||||
def main():
|
||||
|
Reference in New Issue
Block a user