accelerate ml q5 using GPU
This commit is contained in:
@ -457,27 +457,33 @@ def Convolve(input: tensor, weight: tensor):
|
||||
|
||||
This returns a subtensor who's first element is tensor[y,x] and has height 'height, and width 'width'
|
||||
"""
|
||||
input_tensor_height, input_tensor_width = input.shape
|
||||
weight_height, weight_width = weight.shape
|
||||
|
||||
# Calculate output dimensions
|
||||
output_height = input_tensor_height - weight_height + 1
|
||||
output_width = input_tensor_width - weight_width + 1
|
||||
|
||||
# Initialize output tensor
|
||||
if input.device.type!=Convolve.device.type:
|
||||
input=input.to(Convolve.device)
|
||||
if weight.device.type!=Convolve.device.type:
|
||||
weight=weight.to(Convolve.device)
|
||||
output = torch.zeros((output_height, output_width),device=Convolve.device)
|
||||
input_4d = input.unsqueeze(0).unsqueeze(0) # Make it shape (1, 1, H, W)
|
||||
weight_4d = weight.unsqueeze(0).unsqueeze(0) # Make it shape (1, 1, kH, kW)
|
||||
|
||||
# Perform convolution
|
||||
for i in range(output_height):
|
||||
for j in range(output_width):
|
||||
output[i, j] = torch.tensordot(input[i:i+weight_height, j:j+weight_width], weight, dims=2)
|
||||
output_4d = torch.nn.functional.conv2d(input_4d, weight_4d)
|
||||
|
||||
# Remove the extra dimensions
|
||||
output = output_4d.squeeze(0).squeeze(0)
|
||||
|
||||
# input_tensor_height, input_tensor_width = input.shape
|
||||
# weight_height, weight_width = weight.shape
|
||||
|
||||
# # Calculate output dimensions
|
||||
# output_height = input_tensor_height - weight_height + 1
|
||||
# output_width = input_tensor_width - weight_width + 1
|
||||
# output = torch.zeros((output_height, output_width),device=Convolve.device)
|
||||
|
||||
# # Perform convolution
|
||||
# for i in range(output_height):
|
||||
# for j in range(output_width):
|
||||
# output[i, j] = torch.tensordot(input[i:i+weight_height, j:j+weight_width], weight, dims=2)
|
||||
return output
|
||||
|
||||
Convolve.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user