Supervised Autodiff Predictive Coding with precision weighting of prediction errors

Supervised predictive coding with implicit gradients using tensorflow’s autodifferentiation of the energy with respect to representations and learnable parameters.

precision_modulated_supervised_autodiff_pc.infer(model, data, ir=0.025, T=200, predictions_flow_upward=False, target_shape=None)

Implements the following logic:

Initialize representations
do T times
    E = 0.5 * norm(r - model(r)) ^ 2 + log |P|
    r -= ir * dE/dr
return r
Parameters
  • model (list of tf_utils.Dense or tf_utils.BiasedDense) – description of a sequential network by a list of layers, can be generated e.g. using tf_utils.mlp()

  • data (3d tf.Tensor of float32) – inuput data batch

  • ir (float, optional) – inference rate, defaults to 0.025

  • T (int, optional) – number of inference steps, defaults to 200

  • predictions_flow_upward (bool, optional) – direction of prediction flow, defaults to False

  • target_shape (1d tf.Tensor of int32, optional) – shape of target minibatch, defaults to None

Returns

latent representations

Return type

list of 3d tf.Tensor of float32

precision_modulated_supervised_autodiff_pc.learn(model, data, target, ir=0.1, lr=0.001, pr=0.001, T=40, predictions_flow_upward=False)

Implements the following logic:

Initialize representations
do T times
    E = 0.5 * norm(r - model(r)) ^ 2 + log |P|
    r -= ir * dE/dr
W -= lr * dE/dW
P -= pr * dE/dP
Parameters
  • model (list of tf_utils.Dense or tf_utils.BiasedDense) – description of a sequential network by a list of layers, can be generated e.g. using tf_utils.mlp()

  • data (3d tf.Tensor of float32) – inuput data batch

  • target (3d tf.Tensor of float32) – output target batch

  • ir (float, optional) – inference rate, defaults to 0.1

  • lr (float, optional) – learning rate, defaults to 0.001

  • pr (float, optional) – learning rate for precision, defaults to 0.001

  • T (int, optional) – number of inference steps, defaults to 40

  • predictions_flow_upward (bool, optional) – direction of prediction flow, defaults to False