This commit is contained in:
2024-07-05 19:21:02 +08:00
parent 422beb4a9b
commit 99f53727ba
2 changed files with 57 additions and 1 deletions

View File

@ -265,7 +265,12 @@ class DigitClassificationDataset(CustomDataset):
def get_validation_accuracy(self): def get_validation_accuracy(self):
dev_logits = self.model.run(torch.tensor(self.dev_images, dtype=torch.float32)).data dev_logits = self.model.run(torch.tensor(self.dev_images, dtype=torch.float32)).data
dev_predicted = torch.argmax(dev_logits, axis=1).detach() dev_predicted = torch.argmax(dev_logits, axis=1).detach()
dev_accuracy = (dev_predicted == self.dev_labels).mean() # print(f"dev_predicted:{dev_predicted}")
# print(f"self.dev_labels: {self.dev_labels}")
total = len(dev_predicted)
correct = torch.sum(torch.eq(dev_predicted.cpu(), torch.tensor(self.dev_labels))).float()
# dev_accuracy = (dev_predicted == self.dev_labels).mean()
dev_accuracy = correct / total
return dev_accuracy return dev_accuracy
class LanguageIDDataset(CustomDataset): class LanguageIDDataset(CustomDataset):

View File

@ -225,8 +225,36 @@ class DigitClassificationModel(Module):
input_size = 28 * 28 input_size = 28 * 28
output_size = 10 output_size = 10
"*** YOUR CODE HERE ***" "*** YOUR CODE HERE ***"
hidden_layer1_size=300
hidden_layer2_size=300
hidden_layer3_size=300
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.fc1 = Linear(input_size, hidden_layer1_size).to(self.device)
self.fc2 = Linear(hidden_layer1_size, hidden_layer2_size).to(self.device)
self.fc3 = Linear(hidden_layer2_size, hidden_layer3_size).to(self.device)
self.fc_out = Linear(hidden_layer3_size, output_size,bias=False).to(self.device)
def forward(self, x):
"""
Runs the model for a batch of examples.
Your model should predict a node with shape (batch_size x 10),
containing scores. Higher scores correspond to greater probability of
the image belonging to a particular class.
Inputs:
x: a tensor with shape (batch_size x 784)
Output:
A node with shape (batch_size x 10) containing predicted scores
(also called logits)
"""
x=x.to(self.device)
x = relu(self.fc1(x))
x = relu(self.fc2(x))
x = relu(self.fc3(x))
x = self.fc_out(x)
return x
def run(self, x): def run(self, x):
""" """
@ -243,6 +271,7 @@ class DigitClassificationModel(Module):
(also called logits) (also called logits)
""" """
""" YOUR CODE HERE """ """ YOUR CODE HERE """
return self.forward(x)
def get_loss(self, x, y): def get_loss(self, x, y):
@ -259,6 +288,7 @@ class DigitClassificationModel(Module):
Returns: a loss tensor Returns: a loss tensor
""" """
""" YOUR CODE HERE """ """ YOUR CODE HERE """
return cross_entropy(self.forward(x.to(self.device)), y.to(self.device))
@ -267,6 +297,27 @@ class DigitClassificationModel(Module):
Trains the model. Trains the model.
""" """
""" YOUR CODE HERE """ """ YOUR CODE HERE """
optimizer = torch.optim.Adam(self.parameters(), lr=0.0005)
dataloader = DataLoader(dataset, batch_size=20, shuffle=True)
max_round=15000
required_accuracy=0.99
round_cnt=0
while round_cnt<max_round:
for sample in dataloader:
x = sample['x'].to(self.device)
y = sample['label'].to(self.device)
loss = self.get_loss(x, y)
if dataset.get_validation_accuracy() > required_accuracy:
break
optimizer.zero_grad()
loss.backward()
optimizer.step()
round_cnt+=1
if round_cnt%100==0:
print(f"round: {round_cnt}, accuracy: {dataset.get_validation_accuracy()}")
if dataset.get_validation_accuracy() > required_accuracy:
break
print(f"round: {round_cnt}, accuracy: {dataset.get_validation_accuracy()}")