Source code for sconce.models.layers.fully_connected_layer

from torch import nn


[docs]class FullyConnectedLayer(nn.Module): def __init__(self, in_size, out_size, activation, with_batchnorm=True, dropout=0.0): super().__init__() self.with_batchnorm = with_batchnorm if with_batchnorm: self.bn = nn.BatchNorm1d(in_size) self.fc = nn.Linear(in_size, out_size) self._dropout_value = dropout if dropout > 0.0: self.dropout = nn.Dropout(dropout) self.activation = activation
[docs] def forward(self, x_in): if self.with_batchnorm: x = self.bn(x_in) else: x = x_in x = self.fc(x) if self._dropout_value > 0.0: x = self.dropout(x) x = self.activation(x) return x