Source code for power_cogs.model.mnist_model

from typing import List, Optional

import torch

from power_cogs.base import BaseTorchModel

# internal
from power_cogs.utils.torch_utils import create_linear_network


[docs]class MNISTModel(BaseTorchModel): def __init__( self, input_dims: int = 64, hidden_dims: List[int] = [32], output_dims: int = 10, output_activation: Optional[str] = None, use_normal_init: bool = True, normal_std: float = 0.01, zero_bias: bool = False, ): super(MNISTModel, self).__init__() self.input_shape = input_dims self.hidden_dims = hidden_dims self.output_dims = output_dims if output_activation is not None: self.output_activation = eval(output_activation) self.net = create_linear_network(input_dims, hidden_dims, output_dims) def init_weights(m): if isinstance(m, torch.nn.Conv3d): torch.nn.init.normal_(m.weight, std=normal_std) if getattr(m, "bias", None) is not None: if zero_bias: torch.nn.init.zeros_(m.bias) else: torch.nn.init.normal_(m.bias, std=normal_std) if use_normal_init: with torch.no_grad(): self.apply(init_weights)
[docs] def forward(self, x): x = self.net(x) return self.output_activation(x)