--- title: Layers keywords: fastai sidebar: home_sidebar summary: "Custom activations, layers, and layer blocks are contained in this module." description: "Custom activations, layers, and layer blocks are contained in this module." ---
{% raw %}
{% endraw %} {% raw %}
%load_ext autoreload
%autoreload 2
%matplotlib inline
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
{% endraw %} {% raw %}
{% endraw %} {% raw %}
# Used for testing only.
from collections import defaultdict, Counter
from itertools import chain
import matplotlib.pyplot as plt
import pandas as pd
from torch.utils.data import Dataset, DataLoader

from htools import assert_raises, InvalidArgumentError
from incendio.data import probabilistic_hash_item
import pandas_htools
{% endraw %}

Activations

{% raw %}
{% endraw %} {% raw %}

class GRelu[source]

GRelu(leak=0.0, max=inf, sub=0.0) :: Module

Generic ReLU.
{% endraw %} {% raw %}
{% endraw %} {% raw %}
{% endraw %} {% raw %}

class Mish[source]

Mish() :: Module

OOP form of mish activation.

Mish: A Self Regularized Non-Monotonic Neural Activation Function
https://arxiv.org/pdf/1908.08681v1.pdf
{% endraw %} {% raw %}
{% endraw %} {% raw %}

mish[source]

mish(x)

Functional form of mish activation.

Mish: A Self Regularized Non-Monotonic Neural Activation Function
https://arxiv.org/pdf/1908.08681v1.pdf

Parameters
----------
x: torch.Tensor[float]
    Input tensor.
Returns
-------
torch.Tensor[float]: Tensor of same shape as input x.
{% endraw %} {% raw %}
def plot_activations(z, a, mode='scatter', **kwargs):
    """Plot an input tensor and its corresponding activations.  Both tensors
    will be flattened for plotting.
    
    Parameters
    ----------
    z: tf.Tensor
        Tensor containing values to plot on the x axis (we can often think of
        this as the output of a linear layer, where z=f(x) and a=mish(z)).
    a: tf.Tensor
        Tensor containing values to plot on y axis.
    mode: str
        'scatter' for scatter plot or 'plot' for line plot.
    kwargs: Values to be passed to the matplotlib plotting function, such as 
        's' when in 'scatter' mode or 'lw' in 'plot' mode.
        
    Returns
    -------
    None
    """
    plt_func = getattr(plt, mode)
    kwargs = kwargs or {}
    if mode == 'scatter' and not kwargs:
        kwargs = {'s': .75}
    plt_func(z.numpy().flatten(), a.numpy().flatten(), **kwargs)
    plt.axvline(0, lw=.5, alpha=.5)
    plt.axhline(0, lw=.5, alpha=.5)
    plt.show()
{% endraw %} {% raw %}
x = torch.arange(-5, 5, .05)
a = mish(x)
{% endraw %} {% raw %}
plot_activations(x, a, 'plot')
{% endraw %}

Layer Blocks

{% raw %}
{% endraw %} {% raw %}

class ConvBlock[source]

ConvBlock(c_in, c_out, kernel_size=3, norm=True, activation=GReLU(leak=0.1, max=6.0, sub=0.4), **kwargs) :: Module

Create a convolutional block optionally followed by a batch norm layer.
{% endraw %} {% raw %}
conv = ConvBlock(3, 5, norm=False)
conv
ConvBlock(
  (block): Sequential(
    (0): Conv2d(3, 5, kernel_size=(3, 3), stride=(1, 1))
    (1): GReLU(leak=0.1, max=6.0, sub=0.4)
  )
)
{% endraw %} {% raw %}
x = torch.rand(2, 3, 4, 4)
conv(x).shape
torch.Size([2, 5, 2, 2])
{% endraw %} {% raw %}
{% endraw %} {% raw %}

class ResBlock[source]

ResBlock(c_in, activation=GReLU(leak=0.1, max=6.0, sub=0.4), f=3, stride=1, pad=1, skip_size=2, norm=True) :: Module

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in
a tree structure. You can assign the submodules as regular attributes::

    import torch.nn as nn
    import torch.nn.functional as F

    class Model(nn.Module):
        def __init__(self):
            super(Model, self).__init__()
            self.conv1 = nn.Conv2d(1, 20, 5)
            self.conv2 = nn.Conv2d(20, 20, 5)

        def forward(self, x):
            x = F.relu(self.conv1(x))
            return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their
parameters converted too when you call :meth:`to`, etc.
{% endraw %} {% raw %}
ResBlock(4)
ResBlock(
  (layers): ModuleList(
    (0): ConvBlock(
      (block): Sequential(
        (0): Conv2d(4, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): ConvBlock(
      (block): Sequential(
        (0): Conv2d(4, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
  )
  (activation): GReLU(leak=0.1, max=6.0, sub=0.4)
)
{% endraw %} {% raw %}
ResBlock(4, norm=False)
ResBlock(
  (layers): ModuleList(
    (0): ConvBlock(
      (block): Sequential(
        (0): Conv2d(4, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
    )
    (1): ConvBlock(
      (block): Sequential(
        (0): Conv2d(4, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
    )
  )
  (activation): GReLU(leak=0.1, max=6.0, sub=0.4)
)
{% endraw %} {% raw %}
{% endraw %} {% raw %}

ReflectionPaddedConv2d[source]

ReflectionPaddedConv2d(in_channels, out_channels, padding=1, kernel_size=3, **kwargs)

Conv2d only allows padding_mode of `zeros` or `circular`. This
    layer is a quick way for us to use reflection padding.


Applies a 2D convolution over an input signal composed of several input
    planes.

    In the simplest case, the output value of the layer with input size
    :math:`(N, C_{\text{in}}, H, W)` and output :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})`
    can be precisely described as:

    .. math::
        \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) +
        \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k)


    where :math:`\star` is the valid 2D `cross-correlation`_ operator,
    :math:`N` is a batch size, :math:`C` denotes a number of channels,
    :math:`H` is a height of input planes in pixels, and :math:`W` is
    width in pixels.

    * :attr:`stride` controls the stride for the cross-correlation, a single
      number or a tuple.

    * :attr:`padding` controls the amount of implicit zero-paddings on both
      sides for :attr:`padding` number of points for each dimension.

    * :attr:`dilation` controls the spacing between the kernel points; also
      known as the à trous algorithm. It is harder to describe, but this `link`_
      has a nice visualization of what :attr:`dilation` does.

    * :attr:`groups` controls the connections between inputs and outputs.
      :attr:`in_channels` and :attr:`out_channels` must both be divisible by
      :attr:`groups`. For example,

        * At groups=1, all inputs are convolved to all outputs.
        * At groups=2, the operation becomes equivalent to having two conv
          layers side by side, each seeing half the input channels,
          and producing half the output channels, and both subsequently
          concatenated.
        * At groups= :attr:`in_channels`, each input channel is convolved with
          its own set of filters, of size:
          :math:`\left\lfloor\frac{out\_channels}{in\_channels}\right\rfloor`.

    The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be:

        - a single ``int`` -- in which case the same value is used for the height and width dimension
        - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension,
          and the second `int` for the width dimension

    .. note::

         Depending of the size of your kernel, several (of the last)
         columns of the input might be lost, because it is a valid `cross-correlation`_,
         and not a full `cross-correlation`_.
         It is up to the user to add proper padding.

    .. note::

        When `groups == in_channels` and `out_channels == K * in_channels`,
        where `K` is a positive integer, this operation is also termed in
        literature as depthwise convolution.

        In other words, for an input of size :math:`(N, C_{in}, H_{in}, W_{in})`,
        a depthwise convolution with a depthwise multiplier `K`, can be constructed by arguments
        :math:`(in\_channels=C_{in}, out\_channels=C_{in} \times K, ..., groups=C_{in})`.

    .. include:: cudnn_deterministic.rst

    Args:
        in_channels (int): Number of channels in the input image
        out_channels (int): Number of channels produced by the convolution
        kernel_size (int or tuple): Size of the convolving kernel
        stride (int or tuple, optional): Stride of the convolution. Default: 1
        padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
        padding_mode (string, optional). Accepted values `zeros` and `circular` Default: `zeros`
        dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
        groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
        bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``

    Shape:
        - Input: :math:`(N, C_{in}, H_{in}, W_{in})`
        - Output: :math:`(N, C_{out}, H_{out}, W_{out})` where

          .. math::
              H_{out} = \left\lfloor\frac{H_{in}  + 2 \times \text{padding}[0] - \text{dilation}[0]
                        \times (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor

          .. math::
              W_{out} = \left\lfloor\frac{W_{in}  + 2 \times \text{padding}[1] - \text{dilation}[1]
                        \times (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor

    Attributes:
        weight (Tensor): the learnable weights of the module of shape
                         :math:`(\text{out\_channels}, \frac{\text{in\_channels}}{\text{groups}},`
                         :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]})`.
                         The values of these weights are sampled from
                         :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
                         :math:`k = \frac{1}{C_\text{in} * \prod_{i=0}^{1}\text{kernel\_size}[i]}`
        bias (Tensor):   the learnable bias of the module of shape (out_channels). If :attr:`bias` is ``True``,
                         then the values of these weights are
                         sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
                         :math:`k = \frac{1}{C_\text{in} * \prod_{i=0}^{1}\text{kernel\_size}[i]}`

    Examples::

        >>> # With square kernels and equal stride
        >>> m = nn.Conv2d(16, 33, 3, stride=2)
        >>> # non-square kernels and unequal stride and with padding
        >>> m = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
        >>> # non-square kernels and unequal stride and with padding and dilation
        >>> m = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
        >>> input = torch.randn(20, 16, 50, 100)
        >>> output = m(input)

    .. _cross-correlation:
        https://en.wikipedia.org/wiki/Cross-correlation

    .. _link:
        https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
{% endraw %} {% raw %}
def show_img(img):
    plt.imshow(img.permute(1, 2, 0) / 255)
    plt.show()
{% endraw %} {% raw %}
rconv = ReflectionPaddedConv2d(3, 3, kernel_size=1, padding=2)
rconv
ReflectionPaddedConv2d(
  (reflect): ReflectionPad2d((2, 2, 2, 2))
  (conv): Conv2d(3, 3, kernel_size=(1, 1), stride=(1, 1))
)
{% endraw %} {% raw %}
x = torch.randint(255, (1, 3, 3, 3)).float()
show_img(x[0])
{% endraw %} {% raw %}
x2 = rconv.reflect(x)
show_img(x2[0])
{% endraw %} {% raw %}
# Tests
assert nn.Conv2d.__doc__ in ReflectionPaddedConv2d.__doc__

with assert_raises(InvalidArgumentError):
    ReflectionPaddedConv2d(3, 3, padding_mode='zeros')
As expected, got InvalidArgumentError(Remove `padding_mode` from arguments.).
{% endraw %} {% raw %}
{% endraw %} {% raw %}

class Dropin[source]

Dropin(scale=0.5) :: Module

Additive dropout. This injects small amounts of noise into a model
in the form of randomly generated floats from a zero-centered
gaussian distribution (variance can be adjusted). This does nothing
in eval mode. Unlike Dropout, this does not scale weights during
training since it does not bias them in any direction.
{% endraw %} {% raw %}
class Net(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.drop = Dropin()
        
    def forward(self, x):
        return self.drop(x)
{% endraw %} {% raw %}
net = Net()
x = torch.randn(8, 128, 128, 3)
assert np.corrcoef(net(x).flatten(), x.flatten())[0][1] > .9
{% endraw %} {% raw %}
net.eval()
assert torch.eq(net(x), x).all()
assert not net.drop.training
{% endraw %} {% raw %}
def simulate_activation_stats(scale=1.0, trials=10_000):
    act_stats = defaultdict(list)
    noise_stats = defaultdict(list)
    
    drop = Dropin(scale)
    for _ in range(trials):
        x = torch.randn(3, 4, dtype=torch.float)
        z = drop(x)
        noise = drop.noise
        noise_stats['mean'].append(noise.mean())
        noise_stats['std'].append(noise.std())
        noise_stats['act_corr'].append(
            np.corrcoef(z.flatten(), noise.flatten())[0][1]
        )
        
        act_stats['mean'].append(z.mean())
        act_stats['std'].append(z.std())
        act_stats['x_corr'].append(
            np.corrcoef(z.flatten(), x.flatten())[0][1]
        )

    return pd.DataFrame(dict(
        act={k: np.mean(v).round(4) for k, v in act_stats.items()}, 
        noise={k: np.mean(v).round(4) for k, v in noise_stats.items()}
    ))
{% endraw %} {% raw %}
for scale in [10, 1, .75, .5, .25, .1]:
    print('\n', scale)
    simulate_activation_stats(scale, 1_000).pprint()
 10
act noise
mean 0.0132 0.0094
std 1.8189 1.5192
x_corr 0.5324 NaN
act_corr NaN 0.8304
 1
act noise
mean -0.0141 0.0034
std 1.0921 0.4870
x_corr 0.8855 NaN
act_corr NaN 0.4282
 0.75
act noise
mean -0.0015 0.0022
std 1.0633 0.4240
x_corr 0.9100 NaN
act_corr NaN 0.3899
 0.5
act noise
mean 0.0107 0.0008
std 1.0558 0.3442
x_corr 0.9409 NaN
act_corr NaN 0.3235
 0.25
act noise
mean 0.0098 -0.0057
std 1.0013 0.2461
x_corr 0.9667 NaN
act_corr NaN 0.2298
 0.1
act noise
mean -0.0057 -0.0014
std 0.9969 0.1533
x_corr 0.9868 NaN
act_corr NaN 0.1394
{% endraw %} {% raw %}
{% endraw %} {% raw %}

class LinearSkipBlock[source]

LinearSkipBlock(x_dim, layer_dims, op, activation='mish') :: Module

This lets us easily create residual block equivalents with linear
layers.
{% endraw %} {% raw %}
{% endraw %} {% raw %}

class LinearResBlock[source]

LinearResBlock(x_dim, hidden_dims, activation='mish') :: LinearSkipBlock

Equivalent of ResNet block with linear layers.
{% endraw %} {% raw %}
{% endraw %} {% raw %}

class LinearDenseBlock[source]

LinearDenseBlock(x_dim, hidden_dims, activation='mish') :: LinearSkipBlock

Equivalent of DenseNet block with linear layers.
{% endraw %} {% raw %}
{% endraw %} {% raw %}

class WeightedLinearResBlock[source]

WeightedLinearResBlock(x_dim, hidden_dims, weights=(0.25, 0.75), activation='mish') :: LinearSkipBlock

Like a LinearResBlock but takes a weighted average of the input and output
rather than adding them. Addition gives them equal weight and we may want to
weight the output more heavily.
{% endraw %}

Embeddings and Encodings

{% raw %}
{% endraw %} {% raw %}

trunc_normal_[source]

trunc_normal_(x, mean=0.0, std=1.0)

Ported from fastai to remove dependency:

Truncated normal initialization.
From https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/12
{% endraw %} {% raw %}
{% endraw %} {% raw %}

PaddedEmbedding[source]

PaddedEmbedding(num_embeddings, embedding_dim, padding_idx=None, **kwargs)

Patched version of Fastai `embedding` that allows us to specify a row of
    zeros for a padding token.


A simple lookup table that stores embeddings of a fixed dictionary and size.

    This module is often used to store word embeddings and retrieve them using indices.
    The input to the module is a list of indices, and the output is the corresponding
    word embeddings.

    Args:
        num_embeddings (int): size of the dictionary of embeddings
        embedding_dim (int): the size of each embedding vector
        padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx`
                                         (initialized to zeros) whenever it encounters the index.
        max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
                                    is renormalized to have norm :attr:`max_norm`.
        norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
        scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of
                                                the words in the mini-batch. Default ``False``.
        sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor.
                                 See Notes for more details regarding sparse gradients.

    Attributes:
        weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim)
                         initialized from :math:`\mathcal{N}(0, 1)`

    Shape:
        - Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract
        - Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}`

    .. note::
        Keep in mind that only a limited number of optimizers support
        sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`),
        :class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`)

    .. note::
        With :attr:`padding_idx` set, the embedding vector at
        :attr:`padding_idx` is initialized to all zeros. However, note that this
        vector can be modified afterwards, e.g., using a customized
        initialization method, and thus changing the vector used to pad the
        output. The gradient for this vector from :class:`~torch.nn.Embedding`
        is always zero.

    Examples::

        >>> # an Embedding module containing 10 tensors of size 3
        >>> embedding = nn.Embedding(10, 3)
        >>> # a batch of 2 samples of 4 indices each
        >>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
        >>> embedding(input)
        tensor([[[-0.0251, -1.6902,  0.7172],
                 [-0.6431,  0.0748,  0.6969],
                 [ 1.4970,  1.3448, -0.9685],
                 [-0.3677, -2.7265, -0.1685]],

                [[ 1.4970,  1.3448, -0.9685],
                 [ 0.4362, -0.4004,  0.9400],
                 [-0.6431,  0.0748,  0.6969],
                 [ 0.9124, -2.3616,  1.1151]]])


        >>> # example with padding_idx
        >>> embedding = nn.Embedding(10, 3, padding_idx=0)
        >>> input = torch.LongTensor([[0,2,0,5]])
        >>> embedding(input)
        tensor([[[ 0.0000,  0.0000,  0.0000],
                 [ 0.1535, -2.0309,  0.9315],
                 [ 0.0000,  0.0000,  0.0000],
                 [-0.1655,  0.9897,  0.0635]]])
{% endraw %} {% raw %}
PaddedEmbedding(4, 3, 0).weight
Parameter containing:
tensor([[ 0.0000,  0.0000,  0.0000],
        [ 0.0145, -0.0067, -0.0038],
        [-0.0176, -0.0179,  0.0010],
        [-0.0042, -0.0020,  0.0027]], requires_grad=True)
{% endraw %} {% raw %}
{% endraw %} {% raw %}

class BloomEmbedding[source]

BloomEmbedding(n_emb=251, emb_dim=100, n_hashes=4, padding_idx=0, pre_hashed=False) :: Module

Bloom Embedding layer for memory-efficient word representations.
Each word is encoded by a combination of rows of the embedding
matrix. The number of rows can therefore be far lower than the number
of words in our vocabulary while still providing unique representations.
The reduction in rows allows us to use memory in other ways: a larger
embedding dimension, more or larger layers after the embedding,
larger batch sizes, etc.

Note that if hashing is done in the Dataset, we could use a simple
nn.EmbeddingBag to achieve the same thing. Many users have reported
poor performance with this layer though (especially on CPU, but in some
cases on GPU) so I stick with the standard Embedding. We also bake in
the truncated normal intialization provided by fastai, with a slight tweak
to allow a row for padding.
{% endraw %} {% raw %}
class Data(Dataset):
    
    def __init__(self, sentences, labels, seq_len):
        x = [s.split(' ') for s in sentences]
        self.w2i = self.make_w2i(x)
        self.seq_len = seq_len
        self.x = self.encode(x)
        self.y = torch.tensor(labels)
        
    def __getitem__(self, i):
        return self.x[i], self.y[i]
    
    def __len__(self):
        return len(self.y)
    
    def make_w2i(self, tok_rows):
        return {k: i for i, (k, v) in 
                enumerate(Counter(chain(*tok_rows)).most_common(), 1)}
    
    def encode(self, tok_rows):
        enc = np.zeros((len(tok_rows), self.seq_len), dtype=int)
        for i, row in enumerate(tok_rows):
            trunc = [self.w2i.get(w, 0) for w in row[:self.seq_len]]
            enc[i, :len(trunc)] = trunc
        return torch.tensor(enc)
{% endraw %} {% raw %}
sents = [
    'I walked to the store so I hope it is not closed.',
    'The theater is closed today and the sky is grey.',
    'His dog is brown while hers is grey.'
]
labels = [0, 1, 1]
{% endraw %} {% raw %}
ds = Data(sents, labels, 10)
ds[1]
(tensor([13, 14,  1, 15, 16, 17,  3, 18,  1,  4]), tensor(1))
{% endraw %} {% raw %}
dl = DataLoader(ds, batch_size=3)
x, y = next(iter(dl))
x, y
(tensor([[ 2,  5,  6,  3,  7,  8,  2,  9, 10,  1],
         [13, 14,  1, 15, 16, 17,  3, 18,  1,  4],
         [19, 20,  1, 21, 22, 23,  1,  4,  0,  0]]), tensor([0, 1, 1]))
{% endraw %} {% raw %}
x, y = next(iter(dl))
x, y
(tensor([[ 2,  5,  6,  3,  7,  8,  2,  9, 10,  1],
         [13, 14,  1, 15, 16, 17,  3, 18,  1,  4],
         [19, 20,  1, 21, 22, 23,  1,  4,  0,  0]]), tensor([0, 1, 1]))
{% endraw %} {% raw %}
be = BloomEmbedding(11, 4)
be.emb.weight
Parameter containing:
tensor([[ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0054, -0.0142, -0.0027, -0.0052],
        [ 0.0097,  0.0003, -0.0098, -0.0082],
        [ 0.0049,  0.0084,  0.0028, -0.0158],
        [ 0.0035,  0.0056, -0.0106, -0.0135],
        [ 0.0088, -0.0058,  0.0105, -0.0072],
        [-0.0003, -0.0004, -0.0012,  0.0142],
        [ 0.0089, -0.0114, -0.0001,  0.0037],
        [ 0.0089,  0.0007,  0.0076,  0.0034],
        [-0.0039,  0.0119,  0.0090, -0.0077],
        [ 0.0040, -0.0154,  0.0072, -0.0024]], requires_grad=True)
{% endraw %} {% raw %}
x
tensor([[ 2,  5,  6,  3,  7,  8,  2,  9, 10,  1],
        [13, 14,  1, 15, 16, 17,  3, 18,  1,  4],
        [19, 20,  1, 21, 22, 23,  1,  4,  0,  0]])
{% endraw %} {% raw %}
# (bs x seq_len) -> (bs -> seq_len -> emb_size)
y = be(x)
y.shape
torch.Size([3, 10, 4])
{% endraw %} {% raw %}
y[0]
tensor([[ 0.0363, -0.0098,  0.0052,  0.0023],
        [ 0.0337, -0.0130, -0.0147, -0.0183],
        [ 0.0075, -0.0316,  0.0120,  0.0236],
        [ 0.0304, -0.0328,  0.0388, -0.0240],
        [ 0.0144,  0.0004, -0.0022,  0.0019],
        [ 0.0084,  0.0117,  0.0089, -0.0284],
        [ 0.0363, -0.0098,  0.0052,  0.0023],
        [ 0.0177, -0.0087,  0.0344, -0.0139],
        [ 0.0272, -0.0108, -0.0036,  0.0130],
        [ 0.0035, -0.0162,  0.0048,  0.0260]], grad_fn=<SelectBackward>)
{% endraw %}

Below, we show by step how to get from x to y. This is meant to demonstrate the basic mechanism, not to show how PyTorch actually implements this under the hood. Let's look at a single row of x, corresponding to 1 sentence where each word is mapped to its index in the vocabulary.

{% raw %}
x[0]
tensor([ 2,  5,  6,  3,  7,  8,  2,  9, 10,  1])
{% endraw %}

Next, we hash each item.

{% raw %}
hashed = [probabilistic_hash_item(i.item(), 11, int, 4) for i in x[0]]
hashed
[[8, 2, 7, 8],
 [2, 8, 1, 2],
 [6, 6, 10, 10],
 [10, 5, 5, 5],
 [6, 9, 7, 2],
 [5, 9, 4, 0],
 [8, 2, 7, 8],
 [5, 10, 8, 9],
 [7, 8, 6, 2],
 [6, 10, 6, 0]]
{% endraw %}

Then use each row of hashed integers to index into the embedding weight matrix.

{% raw %}
output = []
for row in hashed:
    row_out = be.emb.weight[row]
    output.append(row_out)
output = torch.stack(output)
print(output.shape)
output[:2]
torch.Size([10, 4, 4])
tensor([[[ 0.0089,  0.0007,  0.0076,  0.0034],
         [ 0.0097,  0.0003, -0.0098, -0.0082],
         [ 0.0089, -0.0114, -0.0001,  0.0037],
         [ 0.0089,  0.0007,  0.0076,  0.0034]],

        [[ 0.0097,  0.0003, -0.0098, -0.0082],
         [ 0.0089,  0.0007,  0.0076,  0.0034],
         [ 0.0054, -0.0142, -0.0027, -0.0052],
         [ 0.0097,  0.0003, -0.0098, -0.0082]]], grad_fn=<SliceBackward>)
{% endraw %}

Finally, we sum up the embedding rows. Above, each word is represented by four rows of the embedding matrix. After summing, we get a single vector for each word.

{% raw %}
output = output.sum(-2)
output
tensor([[ 0.0363, -0.0098,  0.0052,  0.0023],
        [ 0.0337, -0.0130, -0.0147, -0.0183],
        [ 0.0075, -0.0316,  0.0120,  0.0236],
        [ 0.0304, -0.0328,  0.0388, -0.0240],
        [ 0.0144,  0.0004, -0.0022,  0.0019],
        [ 0.0084,  0.0117,  0.0089, -0.0284],
        [ 0.0363, -0.0098,  0.0052,  0.0023],
        [ 0.0177, -0.0087,  0.0344, -0.0139],
        [ 0.0272, -0.0108, -0.0036,  0.0130],
        [ 0.0035, -0.0162,  0.0048,  0.0260]], grad_fn=<SumBackward1>)
{% endraw %}

Notice that the values now match the output of our embedding layer.

{% raw %}
assert torch.isclose(output, y[0]).all()
{% endraw %}

Axial encodings are intended to work as positional embeddings for transformer-like architectures. It's possible they could work for word embeddings as well, similar to our use of Bloom embeddings. However, the standard version of axial encodings results in similar vectors for adjacent indices - this makes some sense for positional indices, but for word indices it might require some additional preprocessing. For example, we could compress word embeddings down to 1 dimension and sort them, or simply sort by number of occurrences in our corpus which could be considered to be doing the same thing. Large chunks of the outputs vectors will be shared among different inputs, whereas Bloom embeddings seem like they would have a greater capacity to avoid this issue.

{% raw %}
{% endraw %} {% raw %}

class AxialEncoding[source]

AxialEncoding(vocab_dim, emb_dim, pad_idx=None) :: Module

Axial encodings. These are intended to encode position in a sequence
(e.g. index in a sentence). It's possible we could adapt these for use as
word embeddings but this would likely require some experimentation (for
example, words would likely need to be sorted in a thoughtful manner
(e.g. pre-trained embeddings compressed to 1D?) since adjacent inputs will
share half of their encodings).
{% endraw %} {% raw %}
{% endraw %} {% raw %}

class MultiAxialEncoding[source]

MultiAxialEncoding(vocab_dim, emb_dim, n_blocks=2, pre_hashed=False, pad_idx=None) :: Module

Adapted axial encodings to allow for more than 2 embedding matrices.
These are intended to encode position in a sequence (e.g. index in a
sentence) but might work as word embeddings. This version may be better
suited for that use case because using more blocks results in fewer shared
numbers in the output vectors of adjacent inputs.

Some experimentation is still required for this use case (for
example, words would likely need to be sorted in a thoughtful manner
(e.g. pre-trained embeddings compressed to 1D?) since adjacent inputs will
share half of their encodings).
{% endraw %} {% raw %}
def reduction_ratio(ax, vocab_size, emb_dim):
    """For testing purposes. Lets us compare the number of weights in a
    traditional embedding matrix vs. the number of weights in our axial
    encoding.
    """
    normal_n = vocab_size * emb_dim
    ax_n = sum(e.weight.numel() for e in ax.emb)
    print('Normal embedding weights:', normal_n)
    print('Axial encoding weights:', ax_n)
    print('Difference:', normal_n - ax_n)
    print('Ratio:', normal_n / ax_n)
{% endraw %} {% raw %}
vocab_size = 30_000
emb_dim = 100
bs = 12

ax = AxialEncoding(vocab_size, emb_dim)
x = torch.randint(0, vocab_size, (bs, 2))
print(x.shape)
ax
torch.Size([12, 2])
AxialEncoding(
  (emb): ModuleList(
    (0): Embedding(174, 50)
    (1): Embedding(174, 50)
  )
)
{% endraw %} {% raw %}
res = ax(x)
print(res.shape)
torch.Size([12, 2, 100])
{% endraw %} {% raw %}
reduction_ratio(ax, vocab_size, emb_dim)
Normal embedding weights: 3000000
Axial encoding weights: 17400
Difference: 2982600
Ratio: 172.41379310344828
{% endraw %} {% raw %}
vocab_size = 30_000
emb_dim = 100
bs = 12

ax = MultiAxialEncoding(vocab_size, emb_dim, 4)
x = torch.randint(0, vocab_size, (bs, 2))
print(x.shape)
ax
torch.Size([12, 2])
MultiAxialEncoding(
  (emb): ModuleList(
    (0): Embedding(14, 25)
    (1): Embedding(14, 25)
    (2): Embedding(14, 25)
    (3): Embedding(14, 25)
  )
)
{% endraw %} {% raw %}
res1 = ax(x)
res1.shape
torch.Size([12, 2, 100])
{% endraw %} {% raw %}
vocab_size = 30_000
emb_dim = 100
bs = 12

ax_pre = MultiAxialEncoding(vocab_size, emb_dim, 4, pre_hashed=True)
ax_pre
MultiAxialEncoding(
  (emb): ModuleList(
    (0): Embedding(14, 25)
    (1): Embedding(14, 25)
    (2): Embedding(14, 25)
    (3): Embedding(14, 25)
  )
)
{% endraw %}

By setting the weights of our pre-hashed embedding to the weights of our hashing embedding, we can check that the outputs are ultimately the same.

{% raw %}
for e, e_pre in zip(ax.emb, ax_pre.emb):
    e_pre.weight.data = e.weight.data
{% endraw %} {% raw %}
xhash = probabilistic_hash_tensor(x, 14, 4)
res2 = ax_pre(xhash)
res2.shape
torch.Size([12, 2, 100])
{% endraw %} {% raw %}
(res1 == res2).all()
tensor(True)
{% endraw %} {% raw %}
reduction_ratio(ax_pre, vocab_size, emb_dim)
Normal embedding weights: 3000000
Axial encoding weights: 1400
Difference: 2998600
Ratio: 2142.8571428571427
{% endraw %}

I imagine that as we increase n_blocks, there's likely a point where we simply won't have enough weights to encode the amount of information that's present in the data. It would take some experimentation to find where that line is, however.

{% raw %}
ax_large = MultiAxialEncoding(vocab_size, emb_dim, 8, pre_hashed=True)
ax_large
MultiAxialEncoding(
  (emb): ModuleList(
    (0): Embedding(4, 12)
    (1): Embedding(4, 12)
    (2): Embedding(4, 12)
    (3): Embedding(4, 12)
    (4): Embedding(4, 12)
    (5): Embedding(4, 12)
    (6): Embedding(4, 12)
    (7): Embedding(4, 12)
  )
)
{% endraw %} {% raw %}
reduction_ratio(ax_large, vocab_size, emb_dim)
Normal embedding weights: 3000000
Axial encoding weights: 384
Difference: 2999616
Ratio: 7812.5
{% endraw %}