ml q5
This commit is contained in:
@ -576,7 +576,7 @@ def check_convolution(tracker):
|
||||
input = torch.rand(matrix_size, matrix_size)
|
||||
student_output = models.Convolve(input, weights)
|
||||
actual_output = conv2d(input,weights)
|
||||
assert torch.isclose(student_output, actual_output).all(), "The convolution returned by Convolve() does not match expected output"
|
||||
assert torch.isclose(student_output.cpu(), actual_output).all(), "The convolution returned by Convolve() does not match expected output"
|
||||
|
||||
tracker.add_points(1/2) # Partial credit for testing whether convolution function works
|
||||
|
||||
|
Reference in New Issue
Block a user