This commit is contained in:
2024-07-06 00:39:56 +08:00
parent 819fe60a6b
commit f10a24eb0a

View File

@ -339,6 +339,11 @@ class LanguageIDModel(Module):
super(LanguageIDModel, self).__init__() super(LanguageIDModel, self).__init__()
"*** YOUR CODE HERE ***" "*** YOUR CODE HERE ***"
# Initialize your model parameters here # Initialize your model parameters here
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
iteration_layer_size=1000
self.fc_initial=Linear(self.num_chars, iteration_layer_size).to(self.device)
self.fc_iteration=Linear(iteration_layer_size, iteration_layer_size).to(self.device)
self.fc_out=Linear(iteration_layer_size, 5,bias=False).to(self.device)
def run(self, xs): def run(self, xs):
@ -371,6 +376,14 @@ class LanguageIDModel(Module):
(also called logits) (also called logits)
""" """
"*** YOUR CODE HERE ***" "*** YOUR CODE HERE ***"
x=xs[0].to(self.device)
x = torch.relu(self.fc_initial(x))
for i in range(1,len(xs)):
y = self.fc_iteration(x)
y += self.fc_initial(xs[i].to(self.device))
x=relu(y)
x = self.fc_out(x)
return x
def get_loss(self, xs, y): def get_loss(self, xs, y):
@ -388,6 +401,7 @@ class LanguageIDModel(Module):
Returns: a loss node Returns: a loss node
""" """
"*** YOUR CODE HERE ***" "*** YOUR CODE HERE ***"
return cross_entropy(self.run(xs), y.to(self.device))
def train(self, dataset): def train(self, dataset):
@ -405,6 +419,28 @@ class LanguageIDModel(Module):
For more information, look at the pytorch documentation of torch.movedim() For more information, look at the pytorch documentation of torch.movedim()
""" """
"*** YOUR CODE HERE ***" "*** YOUR CODE HERE ***"
optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
dataloader = DataLoader(dataset, batch_size=100, shuffle=True)
max_round=30000
required_accuracy=0.9
round_cnt=0
while round_cnt<max_round:
for sample in dataloader:
x = sample['x']
y = sample['label']
x = movedim(x, 0, 1)
loss = self.get_loss(x, y)
if dataset.get_validation_accuracy() > required_accuracy:
break
optimizer.zero_grad()
loss.backward()
optimizer.step()
round_cnt+=1
if round_cnt%100==0:
print(f"round: {round_cnt}, accuracy: {dataset.get_validation_accuracy()}")
if dataset.get_validation_accuracy() > required_accuracy:
break
print(f"total round: {round_cnt}, accuracy: {dataset.get_validation_accuracy()}")