This commit is contained in:
2024-07-06 01:24:14 +08:00
parent f10a24eb0a
commit f105ba0150
3 changed files with 68 additions and 11 deletions

View File

@ -576,7 +576,7 @@ def check_convolution(tracker):
input = torch.rand(matrix_size, matrix_size)
student_output = models.Convolve(input, weights)
actual_output = conv2d(input,weights)
assert torch.isclose(student_output, actual_output).all(), "The convolution returned by Convolve() does not match expected output"
assert torch.isclose(student_output.cpu(), actual_output).all(), "The convolution returned by Convolve() does not match expected output"
tracker.add_points(1/2) # Partial credit for testing whether convolution function works