From c48da75a0036783698689a1377bfcf032c3e1bf4 Mon Sep 17 00:00:00 2001 From: ZhuangYumin Date: Fri, 5 Jul 2024 00:56:40 +0800 Subject: [PATCH] ml q1 --- machinelearning/models.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/machinelearning/models.py b/machinelearning/models.py index c87d5bd..9904ddd 100644 --- a/machinelearning/models.py +++ b/machinelearning/models.py @@ -37,7 +37,11 @@ class PerceptronModel(Module): super(PerceptronModel, self).__init__() "*** YOUR CODE HERE ***" - self.w = None #Initialize your weights here + weight_vector = ones((1, dimensions)) + self.w = Parameter(weight_vector) + # print(f"self.w: {self.w}") + "*** END YOUR CODE HERE ***" + def get_weights(self): """ @@ -56,6 +60,9 @@ class PerceptronModel(Module): The pytorch function `tensordot` may be helpful here. """ "*** YOUR CODE HERE ***" + # Calculate the dot product of the weight vector and the input vector + # print(f"the dot product of the weight vector and the input vector: {tensordot(x, self.w, dims=1)}") + return tensordot(x, self.w, dims=([1],[1])) def get_prediction(self, x): @@ -65,6 +72,7 @@ class PerceptronModel(Module): Returns: 1 or -1 """ "*** YOUR CODE HERE ***" + return 1 if self.run(x) >= 0 else -1 @@ -79,7 +87,17 @@ class PerceptronModel(Module): """ with no_grad(): dataloader = DataLoader(dataset, batch_size=1, shuffle=True) - "*** YOUR CODE HERE ***" + all_correct = False + + while not all_correct: + all_correct = True + for sample in dataloader: + x = sample['x'] + y = sample['label'] + prediction = self.get_prediction(x) + if prediction != y: + all_correct = False + self.w += y * x