This commit is contained in:
2024-07-05 18:51:22 +08:00
parent c48da75a00
commit 422beb4a9b
3 changed files with 47 additions and 7 deletions

View File

@ -1,6 +1,7 @@
from torch import no_grad, stack
from torch.utils.data import DataLoader
from torch.nn import Module
import torch
"""
@ -111,6 +112,12 @@ class RegressionModel(Module):
# Initialize your model parameters here
"*** YOUR CODE HERE ***"
super().__init__()
hidden_size1=400
hidden_size2=400
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.fc1 = Linear(1, hidden_size1).to(self.device)
self.fc2 = Linear(hidden_size1, hidden_size2).to(self.device)
self.fc_out = Linear(hidden_size2, 1,bias=False).to(self.device)
@ -123,7 +130,11 @@ class RegressionModel(Module):
Returns:
A node with shape (batch_size x 1) containing predicted y-values
"""
"*** YOUR CODE HERE ***"
x=x.to(self.device)
x = relu(self.fc1(x))
x = relu(self.fc2(x))
x = self.fc_out(x)
return x
def get_loss(self, x, y):
@ -136,7 +147,18 @@ class RegressionModel(Module):
to be used for training
Returns: a tensor of size 1 containing the loss
"""
"*** YOUR CODE HERE ***"
return mse_loss(self.forward(x.to(self.device)), y.to(self.device))
def run(self, x):
"""
Runs the model for a batch of examples.
Inputs:
x: a node with shape (batch_size x 1)
Returns:
A node with shape (batch_size x 1) containing predicted y-values
"""
return self.forward(x)
@ -154,7 +176,25 @@ class RegressionModel(Module):
dataset: a PyTorch dataset object containing data to be trained on
"""
"*** YOUR CODE HERE ***"
optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)
max_round=10000
allow_loss=1e-4
for round_cnt in range(max_round):
for sample in dataloader:
x = sample['x'].to(self.device)
y = sample['label'].to(self.device)
loss = self.get_loss(x, y)
if loss < allow_loss:
break
optimizer.zero_grad()
loss.backward()
optimizer.step()
if loss < allow_loss:
break
if round_cnt%100==0:
print(f"round: {round_cnt}, loss: {loss.item()}")
print(f"round: {round_cnt}, loss: {loss.item()}")