ml q4
This commit is contained in:
@ -339,6 +339,11 @@ class LanguageIDModel(Module):
|
||||
super(LanguageIDModel, self).__init__()
|
||||
"*** YOUR CODE HERE ***"
|
||||
# Initialize your model parameters here
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
iteration_layer_size=1000
|
||||
self.fc_initial=Linear(self.num_chars, iteration_layer_size).to(self.device)
|
||||
self.fc_iteration=Linear(iteration_layer_size, iteration_layer_size).to(self.device)
|
||||
self.fc_out=Linear(iteration_layer_size, 5,bias=False).to(self.device)
|
||||
|
||||
|
||||
def run(self, xs):
|
||||
@ -371,6 +376,14 @@ class LanguageIDModel(Module):
|
||||
(also called logits)
|
||||
"""
|
||||
"*** YOUR CODE HERE ***"
|
||||
x=xs[0].to(self.device)
|
||||
x = torch.relu(self.fc_initial(x))
|
||||
for i in range(1,len(xs)):
|
||||
y = self.fc_iteration(x)
|
||||
y += self.fc_initial(xs[i].to(self.device))
|
||||
x=relu(y)
|
||||
x = self.fc_out(x)
|
||||
return x
|
||||
|
||||
|
||||
def get_loss(self, xs, y):
|
||||
@ -388,6 +401,7 @@ class LanguageIDModel(Module):
|
||||
Returns: a loss node
|
||||
"""
|
||||
"*** YOUR CODE HERE ***"
|
||||
return cross_entropy(self.run(xs), y.to(self.device))
|
||||
|
||||
|
||||
def train(self, dataset):
|
||||
@ -405,6 +419,28 @@ class LanguageIDModel(Module):
|
||||
For more information, look at the pytorch documentation of torch.movedim()
|
||||
"""
|
||||
"*** YOUR CODE HERE ***"
|
||||
optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
|
||||
dataloader = DataLoader(dataset, batch_size=100, shuffle=True)
|
||||
max_round=30000
|
||||
required_accuracy=0.9
|
||||
round_cnt=0
|
||||
while round_cnt<max_round:
|
||||
for sample in dataloader:
|
||||
x = sample['x']
|
||||
y = sample['label']
|
||||
x = movedim(x, 0, 1)
|
||||
loss = self.get_loss(x, y)
|
||||
if dataset.get_validation_accuracy() > required_accuracy:
|
||||
break
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
round_cnt+=1
|
||||
if round_cnt%100==0:
|
||||
print(f"round: {round_cnt}, accuracy: {dataset.get_validation_accuracy()}")
|
||||
if dataset.get_validation_accuracy() > required_accuracy:
|
||||
break
|
||||
print(f"total round: {round_cnt}, accuracy: {dataset.get_validation_accuracy()}")
|
||||
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user