This commit is contained in:
2024-07-05 18:51:22 +08:00
parent c48da75a00
commit 422beb4a9b
3 changed files with 47 additions and 7 deletions

View File

@ -149,9 +149,9 @@ class RegressionDataset(CustomDataset):
self.processed += 1
if use_graphics and time.time() - self.last_update > 0.1:
predicted = self.model(torch.tensor(self.x, dtype=torch.float32)).data
loss = self.model.get_loss(x, y).data
if use_graphics and time.time() - self.last_update > 1:
predicted = self.model(torch.tensor(self.x, dtype=torch.float32).to(self.model.device)).data.cpu().numpy()
loss = self.model.get_loss(x, y).data.cpu().item()
self.learned.set_data(self.x[self.argsort_x], predicted[self.argsort_x])
self.text.set_text("processed: {:,}\nloss: {:.6f}".format(
self.processed, loss))