This commit is contained in:
2024-07-05 18:51:22 +08:00
parent c48da75a00
commit 422beb4a9b
3 changed files with 47 additions and 7 deletions

View File

@ -435,7 +435,7 @@ def check_regression(tracker):
train_predicted = model(data_x)
verify_node(train_predicted, 'tensor', (dataset.x.shape[0], 1), "RegressionModel()")
error = labels - train_predicted
error = labels - train_predicted.cpu()
sanity_loss = torch.mean((error.detach())**2)
assert torch.isclose(torch.tensor(train_loss), sanity_loss), (