ml q1
This commit is contained in:
@ -37,7 +37,11 @@ class PerceptronModel(Module):
|
||||
super(PerceptronModel, self).__init__()
|
||||
|
||||
"*** YOUR CODE HERE ***"
|
||||
self.w = None #Initialize your weights here
|
||||
weight_vector = ones((1, dimensions))
|
||||
self.w = Parameter(weight_vector)
|
||||
# print(f"self.w: {self.w}")
|
||||
"*** END YOUR CODE HERE ***"
|
||||
|
||||
|
||||
def get_weights(self):
|
||||
"""
|
||||
@ -56,6 +60,9 @@ class PerceptronModel(Module):
|
||||
The pytorch function `tensordot` may be helpful here.
|
||||
"""
|
||||
"*** YOUR CODE HERE ***"
|
||||
# Calculate the dot product of the weight vector and the input vector
|
||||
# print(f"the dot product of the weight vector and the input vector: {tensordot(x, self.w, dims=1)}")
|
||||
return tensordot(x, self.w, dims=([1],[1]))
|
||||
|
||||
|
||||
def get_prediction(self, x):
|
||||
@ -65,6 +72,7 @@ class PerceptronModel(Module):
|
||||
Returns: 1 or -1
|
||||
"""
|
||||
"*** YOUR CODE HERE ***"
|
||||
return 1 if self.run(x) >= 0 else -1
|
||||
|
||||
|
||||
|
||||
@ -79,7 +87,17 @@ class PerceptronModel(Module):
|
||||
"""
|
||||
with no_grad():
|
||||
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
|
||||
"*** YOUR CODE HERE ***"
|
||||
all_correct = False
|
||||
|
||||
while not all_correct:
|
||||
all_correct = True
|
||||
for sample in dataloader:
|
||||
x = sample['x']
|
||||
y = sample['label']
|
||||
prediction = self.get_prediction(x)
|
||||
if prediction != y:
|
||||
all_correct = False
|
||||
self.w += y * x
|
||||
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user