This commit is contained in:
2024-07-05 00:56:40 +08:00
parent ef4f011fe2
commit c48da75a00

View File

@ -37,7 +37,11 @@ class PerceptronModel(Module):
super(PerceptronModel, self).__init__()
"*** YOUR CODE HERE ***"
self.w = None #Initialize your weights here
weight_vector = ones((1, dimensions))
self.w = Parameter(weight_vector)
# print(f"self.w: {self.w}")
"*** END YOUR CODE HERE ***"
def get_weights(self):
"""
@ -56,6 +60,9 @@ class PerceptronModel(Module):
The pytorch function `tensordot` may be helpful here.
"""
"*** YOUR CODE HERE ***"
# Calculate the dot product of the weight vector and the input vector
# print(f"the dot product of the weight vector and the input vector: {tensordot(x, self.w, dims=1)}")
return tensordot(x, self.w, dims=([1],[1]))
def get_prediction(self, x):
@ -65,6 +72,7 @@ class PerceptronModel(Module):
Returns: 1 or -1
"""
"*** YOUR CODE HERE ***"
return 1 if self.run(x) >= 0 else -1
@ -79,7 +87,17 @@ class PerceptronModel(Module):
"""
with no_grad():
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
"*** YOUR CODE HERE ***"
all_correct = False
while not all_correct:
all_correct = True
for sample in dataloader:
x = sample['x']
y = sample['label']
prediction = self.get_prediction(x)
if prediction != y:
all_correct = False
self.w += y * x