--- title: Actor-critic Model keywords: fastai sidebar: home_sidebar summary: "RL Actor-critic model framework." description: "RL Actor-critic model framework." nb_path: "nbs/models/actor_critic.ipynb" ---
{% raw %}
{% endraw %} {% raw %}
{% endraw %} {% raw %}

class Actor[source]

Actor(embedded_state_size:int, action_weight_size:int, hidden_sizes:Tuple[int]) :: Module

Actor Network

{% endraw %} {% raw %}
{% endraw %} {% raw %}

class Critic[source]

Critic(embedded_state_size:int, embedded_action_size:int, hidden_sizes:Tuple[int]) :: Module

Critic Network

{% endraw %} {% raw %}
{% endraw %} {% raw %}
embedded_state_size = 64
embedded_action_size = 5
actor_hidden_sizes = (128, 64)
critic_hidden_sizes = (32, 16)

actor = Actor(embedded_state_size=embedded_state_size,
              action_weight_size=embedded_action_size,
              hidden_sizes=actor_hidden_sizes)
critic = Critic(embedded_state_size=embedded_state_size,
                embedded_action_size=embedded_action_size,
                hidden_sizes=critic_hidden_sizes)
{% endraw %} {% raw %}
torch.manual_seed(0)
actor.forward(torch.rand(2,embedded_state_size))
tensor([[-0.1129, -0.0019,  0.0999, -0.0621,  0.0551],
        [-0.0445, -0.0425,  0.0677, -0.0397,  0.0003]],
       grad_fn=<AddmmBackward0>)
{% endraw %} {% raw %}
torch.manual_seed(0)
critic.forward(torch.rand(2,embedded_state_size),torch.rand(2,embedded_action_size))
tensor([[-0.1659],
        [-0.1432]], grad_fn=<AddmmBackward0>)
{% endraw %}