From f10a24eb0a32144f761c64bb500a39fe2f4707c8 Mon Sep 17 00:00:00 2001 From: ZhuangYumin Date: Sat, 6 Jul 2024 00:39:56 +0800 Subject: [PATCH] ml q4 --- machinelearning/models.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/machinelearning/models.py b/machinelearning/models.py index 4824e81..8607d85 100644 --- a/machinelearning/models.py +++ b/machinelearning/models.py @@ -339,6 +339,11 @@ class LanguageIDModel(Module): super(LanguageIDModel, self).__init__() "*** YOUR CODE HERE ***" # Initialize your model parameters here + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + iteration_layer_size=1000 + self.fc_initial=Linear(self.num_chars, iteration_layer_size).to(self.device) + self.fc_iteration=Linear(iteration_layer_size, iteration_layer_size).to(self.device) + self.fc_out=Linear(iteration_layer_size, 5,bias=False).to(self.device) def run(self, xs): @@ -371,6 +376,14 @@ class LanguageIDModel(Module): (also called logits) """ "*** YOUR CODE HERE ***" + x=xs[0].to(self.device) + x = torch.relu(self.fc_initial(x)) + for i in range(1,len(xs)): + y = self.fc_iteration(x) + y += self.fc_initial(xs[i].to(self.device)) + x=relu(y) + x = self.fc_out(x) + return x def get_loss(self, xs, y): @@ -388,6 +401,7 @@ class LanguageIDModel(Module): Returns: a loss node """ "*** YOUR CODE HERE ***" + return cross_entropy(self.run(xs), y.to(self.device)) def train(self, dataset): @@ -405,6 +419,28 @@ class LanguageIDModel(Module): For more information, look at the pytorch documentation of torch.movedim() """ "*** YOUR CODE HERE ***" + optimizer = torch.optim.Adam(self.parameters(), lr=0.001) + dataloader = DataLoader(dataset, batch_size=100, shuffle=True) + max_round=30000 + required_accuracy=0.9 + round_cnt=0 + while round_cnt required_accuracy: + break + optimizer.zero_grad() + loss.backward() + optimizer.step() + round_cnt+=1 + if round_cnt%100==0: + print(f"round: {round_cnt}, accuracy: {dataset.get_validation_accuracy()}") + if dataset.get_validation_accuracy() > required_accuracy: + break + print(f"total round: {round_cnt}, accuracy: {dataset.get_validation_accuracy()}")