ml q3
This commit is contained in:
@ -265,7 +265,12 @@ class DigitClassificationDataset(CustomDataset):
|
|||||||
def get_validation_accuracy(self):
|
def get_validation_accuracy(self):
|
||||||
dev_logits = self.model.run(torch.tensor(self.dev_images, dtype=torch.float32)).data
|
dev_logits = self.model.run(torch.tensor(self.dev_images, dtype=torch.float32)).data
|
||||||
dev_predicted = torch.argmax(dev_logits, axis=1).detach()
|
dev_predicted = torch.argmax(dev_logits, axis=1).detach()
|
||||||
dev_accuracy = (dev_predicted == self.dev_labels).mean()
|
# print(f"dev_predicted:{dev_predicted}")
|
||||||
|
# print(f"self.dev_labels: {self.dev_labels}")
|
||||||
|
total = len(dev_predicted)
|
||||||
|
correct = torch.sum(torch.eq(dev_predicted.cpu(), torch.tensor(self.dev_labels))).float()
|
||||||
|
# dev_accuracy = (dev_predicted == self.dev_labels).mean()
|
||||||
|
dev_accuracy = correct / total
|
||||||
return dev_accuracy
|
return dev_accuracy
|
||||||
|
|
||||||
class LanguageIDDataset(CustomDataset):
|
class LanguageIDDataset(CustomDataset):
|
||||||
|
@ -225,8 +225,36 @@ class DigitClassificationModel(Module):
|
|||||||
input_size = 28 * 28
|
input_size = 28 * 28
|
||||||
output_size = 10
|
output_size = 10
|
||||||
"*** YOUR CODE HERE ***"
|
"*** YOUR CODE HERE ***"
|
||||||
|
hidden_layer1_size=300
|
||||||
|
hidden_layer2_size=300
|
||||||
|
hidden_layer3_size=300
|
||||||
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
self.fc1 = Linear(input_size, hidden_layer1_size).to(self.device)
|
||||||
|
self.fc2 = Linear(hidden_layer1_size, hidden_layer2_size).to(self.device)
|
||||||
|
self.fc3 = Linear(hidden_layer2_size, hidden_layer3_size).to(self.device)
|
||||||
|
self.fc_out = Linear(hidden_layer3_size, output_size,bias=False).to(self.device)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Runs the model for a batch of examples.
|
||||||
|
|
||||||
|
Your model should predict a node with shape (batch_size x 10),
|
||||||
|
containing scores. Higher scores correspond to greater probability of
|
||||||
|
the image belonging to a particular class.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
x: a tensor with shape (batch_size x 784)
|
||||||
|
Output:
|
||||||
|
A node with shape (batch_size x 10) containing predicted scores
|
||||||
|
(also called logits)
|
||||||
|
"""
|
||||||
|
x=x.to(self.device)
|
||||||
|
x = relu(self.fc1(x))
|
||||||
|
x = relu(self.fc2(x))
|
||||||
|
x = relu(self.fc3(x))
|
||||||
|
x = self.fc_out(x)
|
||||||
|
return x
|
||||||
|
|
||||||
def run(self, x):
|
def run(self, x):
|
||||||
"""
|
"""
|
||||||
@ -243,6 +271,7 @@ class DigitClassificationModel(Module):
|
|||||||
(also called logits)
|
(also called logits)
|
||||||
"""
|
"""
|
||||||
""" YOUR CODE HERE """
|
""" YOUR CODE HERE """
|
||||||
|
return self.forward(x)
|
||||||
|
|
||||||
|
|
||||||
def get_loss(self, x, y):
|
def get_loss(self, x, y):
|
||||||
@ -259,6 +288,7 @@ class DigitClassificationModel(Module):
|
|||||||
Returns: a loss tensor
|
Returns: a loss tensor
|
||||||
"""
|
"""
|
||||||
""" YOUR CODE HERE """
|
""" YOUR CODE HERE """
|
||||||
|
return cross_entropy(self.forward(x.to(self.device)), y.to(self.device))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -267,6 +297,27 @@ class DigitClassificationModel(Module):
|
|||||||
Trains the model.
|
Trains the model.
|
||||||
"""
|
"""
|
||||||
""" YOUR CODE HERE """
|
""" YOUR CODE HERE """
|
||||||
|
optimizer = torch.optim.Adam(self.parameters(), lr=0.0005)
|
||||||
|
dataloader = DataLoader(dataset, batch_size=20, shuffle=True)
|
||||||
|
max_round=15000
|
||||||
|
required_accuracy=0.99
|
||||||
|
round_cnt=0
|
||||||
|
while round_cnt<max_round:
|
||||||
|
for sample in dataloader:
|
||||||
|
x = sample['x'].to(self.device)
|
||||||
|
y = sample['label'].to(self.device)
|
||||||
|
loss = self.get_loss(x, y)
|
||||||
|
if dataset.get_validation_accuracy() > required_accuracy:
|
||||||
|
break
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
round_cnt+=1
|
||||||
|
if round_cnt%100==0:
|
||||||
|
print(f"round: {round_cnt}, accuracy: {dataset.get_validation_accuracy()}")
|
||||||
|
if dataset.get_validation_accuracy() > required_accuracy:
|
||||||
|
break
|
||||||
|
print(f"round: {round_cnt}, accuracy: {dataset.get_validation_accuracy()}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user