Source code for power_cogs.config.mnist_config
from dataclasses import field
from typing import Any, Dict, List, Optional
from omegaconf import MISSING # Do not confuse with dataclass.MISSING
from pydantic.dataclasses import dataclass
from power_cogs.config.base import (
BaseDatasetConfig,
BaseModelConfig,
BaseTrainerConfig,
make_trainer_defaults,
)
from power_cogs.config.config_utils import add_configs
[docs]@dataclass
class MNISTModelConfig(BaseModelConfig):
_target_: str = "power_cogs.model.mnist_model.MNISTModel"
input_dims: Optional[int] = None
hidden_dims: List[int] = field(default_factory=lambda: [32])
output_dims: Optional[int] = None
output_activation: str = "torch.nn.functional.relu"
use_normal_init: bool = True
normal_std: float = 0.01
zero_bias: bool = False
[docs]@dataclass
class MNISTDatasetConfig(BaseDatasetConfig):
_target_: str = "power_cogs.dataset.mnist_dataset.MNISTDataset"
trainer_defaults = [
{"model_config": "mnist"},
{"dataset_config": "mnist"},
]
[docs]@dataclass
class MNISTTrainerConfig(BaseTrainerConfig):
_target_: str = "power_cogs.trainer.mnist_trainer.MNISTTrainer"
defaults: List[Any] = field(
default_factory=lambda: make_trainer_defaults(overrides=trainer_defaults)
)
config_defaults = [{"trainer": "mnist"}]
[docs]@dataclass
class MNISTConfig:
defaults: List[Any] = field(default_factory=lambda: config_defaults)
trainer: Any = MISSING
config_dicts: List[Dict[str, Any]] = [
dict(group="trainer/model_config", name="mnist", node=MNISTModelConfig),
dict(group="trainer/dataset_config", name="mnist", node=MNISTDatasetConfig),
dict(group="trainer", name="mnist", node=MNISTTrainerConfig),
dict(name="mnist", node=MNISTConfig),
]
add_configs(config_dicts)