Source code for power_cogs.trainer.mnist_trainer
import attr
import numpy as np
import torch.nn.functional as F
# internal
from power_cogs.base.base_torch_trainer import BaseTorchTrainer
[docs]@attr.s
class MNISTTrainer(BaseTorchTrainer):
_config_name_: str = "mnist"
[docs] def post_dataset_setup(self):
self.model_config["input_dims"] = self.dataset.input_dims
self.model_config["output_dims"] = self.dataset.output_dims
[docs] def train_iter(self, batch_size: int = 32, iteration: int = 0):
losses = []
for batch_ndx, sample in enumerate(self.dataloader):
self.optimizer.zero_grad()
data = sample["data"].float()
targets = sample["targets"]
out = self.model(data)
loss = F.cross_entropy(out, targets)
loss.backward()
self.optimizer.step()
self.scheduler.step()
losses.append(loss.item())
train_dict = {
"out": None,
"metrics": {
"loss": np.mean(losses),
"min_loss": np.min(losses),
"max_loss": np.max(losses),
"mean_loss": np.mean(losses),
"sum_loss": np.sum(losses),
"median_loss": np.median(losses),
},
"loss": np.mean(losses),
}
return train_dict