ml q2
This commit is contained in:
@ -435,7 +435,7 @@ def check_regression(tracker):
|
||||
train_predicted = model(data_x)
|
||||
|
||||
verify_node(train_predicted, 'tensor', (dataset.x.shape[0], 1), "RegressionModel()")
|
||||
error = labels - train_predicted
|
||||
error = labels - train_predicted.cpu()
|
||||
sanity_loss = torch.mean((error.detach())**2)
|
||||
|
||||
assert torch.isclose(torch.tensor(train_loss), sanity_loss), (
|
||||
|
Reference in New Issue
Block a user