bettermdptools.algorithms.planner

Author: Miguel Morales BSD 3-Clause License

Copyright (c) 2018, Miguel Morales All rights reserved. https://github.com/mimoralea/gdrl/blob/master/LICENSE

  1"""
  2Author: Miguel Morales
  3BSD 3-Clause License
  4
  5Copyright (c) 2018, Miguel Morales
  6All rights reserved.
  7https://github.com/mimoralea/gdrl/blob/master/LICENSE
  8"""
  9
 10"""
 11modified by: John Mansfield
 12
 13documentation added by: Gagandeep Randhawa
 14"""
 15
 16"""
 17Class that contains functions related to planning algorithms (Value Iteration, Policy Iteration). 
 18Planner init expects a reward and transitions matrix P, which is nested dictionary gym style discrete environment 
 19where P[state][action] is a list of tuples (probability, next state, reward, terminal).
 20
 21Model-based learning algorithms: Value Iteration and Policy Iteration
 22"""
 23
 24import warnings
 25
 26import numpy as np
 27from tqdm.auto import tqdm
 28
 29
 30class Planner:
 31    def __init__(self, P):
 32        """
 33        Initializes the Planner class.
 34
 35        Parameters
 36        ----------
 37        P : dict
 38            Transition probability matrix where P[state][action] is a list of tuples
 39            (probability, next state, reward, terminal).
 40        """
 41        self.P = P
 42
 43    def value_iteration(self, gamma=1.0, n_iters=1000, theta=1e-10, dtype=np.float32):
 44        """
 45        Value Iteration algorithm.
 46
 47        Parameters
 48        ----------
 49        gamma : float, optional
 50            Discount factor, by default 1.0.
 51        n_iters : int, optional
 52            Number of iterations, by default 1000.
 53        theta : float, optional
 54            Convergence criterion for value iteration, by default 1e-10.
 55
 56        Returns
 57        -------
 58        tuple
 59            V : np.ndarray
 60                State values array.
 61            V_track : np.ndarray
 62                Log of V(s) for each iteration.
 63            pi : dict
 64                Policy mapping states to actions.
 65        """
 66        V = np.zeros(len(self.P), dtype=dtype)
 67        V_track = np.zeros((n_iters, len(self.P)), dtype=dtype)
 68        converged = False
 69        for i in tqdm(range(n_iters), leave=False):
 70            Q = np.zeros((len(self.P), len(self.P[0])), dtype=dtype)
 71            for s in range(len(self.P)):
 72                for a in range(len(self.P[s])):
 73                    for prob, next_state, reward, done in self.P[s][a]:
 74                        Q[s][a] += prob * (reward + gamma * V[next_state] * (not done))
 75            if np.max(np.abs(V - np.max(Q, axis=1))) < theta:
 76                converged = True
 77            V = np.max(Q, axis=1)
 78            V_track[i] = V
 79            if converged:
 80                break
 81        if not converged:
 82            warnings.warn("Max iterations reached before convergence.  Check n_iters.")
 83
 84        pi = {s: a for s, a in enumerate(np.argmax(Q, axis=1))}
 85        return V, V_track, pi
 86
 87    def value_iteration_vectorized(
 88        self, gamma=1.0, n_iters=1000, theta=1e-10, dtype=np.float32
 89    ):
 90        """
 91        Vectorized Value Iteration algorithm.
 92
 93        Parameters
 94        ----------
 95        gamma : float
 96            Discount factor
 97
 98        n_iters : int
 99            Number of iterations
100
101        theta : float
102            Convergence criterion for value iteration.
103            State values are considered to be converged when the maximum difference between new and previous state values is less than theta.
104            Stops at n_iters or theta convergence - whichever comes first.
105
106        Returns
107        -------
108        tuple
109            V : np.ndarray
110                State values array.
111            V_track : np.ndarray
112                Log of V(s) for each iteration.
113            pi : dict
114                Policy mapping states to actions.
115        """
116        S = len(self.P)
117        A = len(self.P[0])
118
119        max_K = max(len(self.P[s][a]) for s in range(S) for a in range(A))
120
121        prob_array = np.zeros((S, A, max_K), dtype=dtype)
122        next_state_array = np.zeros((S, A, max_K), dtype=np.int32)
123        reward_array = np.zeros((S, A, max_K), dtype=dtype)
124        done_array = np.zeros((S, A, max_K), dtype=bool)
125        mask_array = np.zeros((S, A, max_K), dtype=bool)
126
127        for s in range(S):
128            for a in range(A):
129                transitions = self.P[s][a]
130                for k, (prob, next_state, reward, done) in enumerate(transitions):
131                    prob_array[s, a, k] = prob
132                    next_state_array[s, a, k] = next_state
133                    reward_array[s, a, k] = reward
134                    done_array[s, a, k] = done
135                    mask_array[s, a, k] = True
136
137        V = np.zeros(S, dtype=dtype)
138        V_track = np.zeros((n_iters, S), dtype=dtype)
139        converged = False
140        # Simpler way to handle done states
141        not_done_array = 1 - done_array
142
143        for i in tqdm(range(n_iters), leave=False):
144            Q = np.sum(
145                prob_array
146                * (reward_array + gamma * V[next_state_array] * not_done_array)
147                * mask_array,
148                axis=2,
149            )
150            V_new = np.max(Q, axis=1)
151
152            if np.max(np.abs(V - V_new)) < theta:
153                converged = True
154
155            V = V_new
156            V_track[i] = V
157
158            if converged:
159                break
160
161        if not converged:
162            warnings.warn("Max iterations reached before convergence. Check n_iters.")
163
164        return V, V_track, dict(enumerate(np.argmax(Q, axis=1)))
165
166    def policy_iteration(self, gamma=1.0, n_iters=50, theta=1e-10, dtype=np.float32):
167        """
168        Policy Iteration algorithm.
169
170        Parameters
171        ----------
172        gamma : float, optional
173            Discount factor, by default 1.0.
174        n_iters : int, optional
175            Number of iterations, by default 50.
176        theta : float, optional
177            Convergence criterion for policy evaluation, by default 1e-10.
178
179        Returns
180        -------
181        tuple
182            V : np.ndarray
183                State values array.
184            V_track : np.ndarray
185                Log of V(s) for each iteration.
186            pi : dict
187                Policy mapping states to actions.
188        """
189        random_actions = np.random.choice(tuple(self.P[0].keys()), len(self.P))
190
191        pi = {s: a for s, a in enumerate(random_actions)}
192        # initial V to give to `policy_evaluation` for the first time
193        V = np.zeros(len(self.P), dtype=dtype)
194        V_track = np.zeros((n_iters, len(self.P)), dtype=dtype)
195        converged = False
196        for i in tqdm(range(n_iters), leave=False):
197            old_pi = pi
198            V = self.policy_evaluation(pi, V, gamma=gamma, theta=theta, dtype=dtype)
199            V_track[i] = V
200            pi = self.policy_improvement(V, gamma=gamma, dtype=dtype)
201            if old_pi == pi:
202                converged = True
203                break
204        if not converged:
205            warnings.warn("Max iterations reached before convergence.  Check n_iters.")
206        return V, V_track, pi
207
208    def policy_evaluation(self, pi, prev_V, gamma=1.0, theta=1e-10, dtype=np.float32):
209        """
210        Policy Evaluation algorithm.
211
212        Parameters
213        ----------
214        pi : dict
215            Policy mapping states to actions.
216        prev_V : np.ndarray
217            Previous state values array.
218        gamma : float, optional
219            Discount factor, by default 1.0.
220        theta : float, optional
221            Convergence criterion, by default 1e-10.
222
223        Returns
224        -------
225        np.ndarray
226            State values array.
227        """
228        while True:
229            V = np.zeros(len(self.P), dtype=dtype)
230            for s in range(len(self.P)):
231                for prob, next_state, reward, done in self.P[s][pi[s]]:
232                    V[s] += prob * (reward + gamma * prev_V[next_state] * (not done))
233            if np.max(np.abs(prev_V - V)) < theta:
234                break
235            prev_V = V.copy()
236        return V
237
238    def policy_improvement(self, V, gamma=1.0, dtype=np.float32):
239        """
240        Policy Improvement algorithm.
241
242        Parameters
243        ----------
244        V : np.ndarray
245            State values array.
246        gamma : float, optional
247            Discount factor, by default 1.0.
248
249        Returns
250        -------
251        dict
252            Policy mapping states to actions.
253        """
254        Q = np.zeros((len(self.P), len(self.P[0])), dtype=dtype)
255        for s in range(len(self.P)):
256            for a in range(len(self.P[s])):
257                for prob, next_state, reward, done in self.P[s][a]:
258                    Q[s][a] += prob * (reward + gamma * V[next_state] * (not done))
259
260        return dict(enumerate(np.argmax(Q, axis=1)))
class Planner:
 31class Planner:
 32    def __init__(self, P):
 33        """
 34        Initializes the Planner class.
 35
 36        Parameters
 37        ----------
 38        P : dict
 39            Transition probability matrix where P[state][action] is a list of tuples
 40            (probability, next state, reward, terminal).
 41        """
 42        self.P = P
 43
 44    def value_iteration(self, gamma=1.0, n_iters=1000, theta=1e-10, dtype=np.float32):
 45        """
 46        Value Iteration algorithm.
 47
 48        Parameters
 49        ----------
 50        gamma : float, optional
 51            Discount factor, by default 1.0.
 52        n_iters : int, optional
 53            Number of iterations, by default 1000.
 54        theta : float, optional
 55            Convergence criterion for value iteration, by default 1e-10.
 56
 57        Returns
 58        -------
 59        tuple
 60            V : np.ndarray
 61                State values array.
 62            V_track : np.ndarray
 63                Log of V(s) for each iteration.
 64            pi : dict
 65                Policy mapping states to actions.
 66        """
 67        V = np.zeros(len(self.P), dtype=dtype)
 68        V_track = np.zeros((n_iters, len(self.P)), dtype=dtype)
 69        converged = False
 70        for i in tqdm(range(n_iters), leave=False):
 71            Q = np.zeros((len(self.P), len(self.P[0])), dtype=dtype)
 72            for s in range(len(self.P)):
 73                for a in range(len(self.P[s])):
 74                    for prob, next_state, reward, done in self.P[s][a]:
 75                        Q[s][a] += prob * (reward + gamma * V[next_state] * (not done))
 76            if np.max(np.abs(V - np.max(Q, axis=1))) < theta:
 77                converged = True
 78            V = np.max(Q, axis=1)
 79            V_track[i] = V
 80            if converged:
 81                break
 82        if not converged:
 83            warnings.warn("Max iterations reached before convergence.  Check n_iters.")
 84
 85        pi = {s: a for s, a in enumerate(np.argmax(Q, axis=1))}
 86        return V, V_track, pi
 87
 88    def value_iteration_vectorized(
 89        self, gamma=1.0, n_iters=1000, theta=1e-10, dtype=np.float32
 90    ):
 91        """
 92        Vectorized Value Iteration algorithm.
 93
 94        Parameters
 95        ----------
 96        gamma : float
 97            Discount factor
 98
 99        n_iters : int
100            Number of iterations
101
102        theta : float
103            Convergence criterion for value iteration.
104            State values are considered to be converged when the maximum difference between new and previous state values is less than theta.
105            Stops at n_iters or theta convergence - whichever comes first.
106
107        Returns
108        -------
109        tuple
110            V : np.ndarray
111                State values array.
112            V_track : np.ndarray
113                Log of V(s) for each iteration.
114            pi : dict
115                Policy mapping states to actions.
116        """
117        S = len(self.P)
118        A = len(self.P[0])
119
120        max_K = max(len(self.P[s][a]) for s in range(S) for a in range(A))
121
122        prob_array = np.zeros((S, A, max_K), dtype=dtype)
123        next_state_array = np.zeros((S, A, max_K), dtype=np.int32)
124        reward_array = np.zeros((S, A, max_K), dtype=dtype)
125        done_array = np.zeros((S, A, max_K), dtype=bool)
126        mask_array = np.zeros((S, A, max_K), dtype=bool)
127
128        for s in range(S):
129            for a in range(A):
130                transitions = self.P[s][a]
131                for k, (prob, next_state, reward, done) in enumerate(transitions):
132                    prob_array[s, a, k] = prob
133                    next_state_array[s, a, k] = next_state
134                    reward_array[s, a, k] = reward
135                    done_array[s, a, k] = done
136                    mask_array[s, a, k] = True
137
138        V = np.zeros(S, dtype=dtype)
139        V_track = np.zeros((n_iters, S), dtype=dtype)
140        converged = False
141        # Simpler way to handle done states
142        not_done_array = 1 - done_array
143
144        for i in tqdm(range(n_iters), leave=False):
145            Q = np.sum(
146                prob_array
147                * (reward_array + gamma * V[next_state_array] * not_done_array)
148                * mask_array,
149                axis=2,
150            )
151            V_new = np.max(Q, axis=1)
152
153            if np.max(np.abs(V - V_new)) < theta:
154                converged = True
155
156            V = V_new
157            V_track[i] = V
158
159            if converged:
160                break
161
162        if not converged:
163            warnings.warn("Max iterations reached before convergence. Check n_iters.")
164
165        return V, V_track, dict(enumerate(np.argmax(Q, axis=1)))
166
167    def policy_iteration(self, gamma=1.0, n_iters=50, theta=1e-10, dtype=np.float32):
168        """
169        Policy Iteration algorithm.
170
171        Parameters
172        ----------
173        gamma : float, optional
174            Discount factor, by default 1.0.
175        n_iters : int, optional
176            Number of iterations, by default 50.
177        theta : float, optional
178            Convergence criterion for policy evaluation, by default 1e-10.
179
180        Returns
181        -------
182        tuple
183            V : np.ndarray
184                State values array.
185            V_track : np.ndarray
186                Log of V(s) for each iteration.
187            pi : dict
188                Policy mapping states to actions.
189        """
190        random_actions = np.random.choice(tuple(self.P[0].keys()), len(self.P))
191
192        pi = {s: a for s, a in enumerate(random_actions)}
193        # initial V to give to `policy_evaluation` for the first time
194        V = np.zeros(len(self.P), dtype=dtype)
195        V_track = np.zeros((n_iters, len(self.P)), dtype=dtype)
196        converged = False
197        for i in tqdm(range(n_iters), leave=False):
198            old_pi = pi
199            V = self.policy_evaluation(pi, V, gamma=gamma, theta=theta, dtype=dtype)
200            V_track[i] = V
201            pi = self.policy_improvement(V, gamma=gamma, dtype=dtype)
202            if old_pi == pi:
203                converged = True
204                break
205        if not converged:
206            warnings.warn("Max iterations reached before convergence.  Check n_iters.")
207        return V, V_track, pi
208
209    def policy_evaluation(self, pi, prev_V, gamma=1.0, theta=1e-10, dtype=np.float32):
210        """
211        Policy Evaluation algorithm.
212
213        Parameters
214        ----------
215        pi : dict
216            Policy mapping states to actions.
217        prev_V : np.ndarray
218            Previous state values array.
219        gamma : float, optional
220            Discount factor, by default 1.0.
221        theta : float, optional
222            Convergence criterion, by default 1e-10.
223
224        Returns
225        -------
226        np.ndarray
227            State values array.
228        """
229        while True:
230            V = np.zeros(len(self.P), dtype=dtype)
231            for s in range(len(self.P)):
232                for prob, next_state, reward, done in self.P[s][pi[s]]:
233                    V[s] += prob * (reward + gamma * prev_V[next_state] * (not done))
234            if np.max(np.abs(prev_V - V)) < theta:
235                break
236            prev_V = V.copy()
237        return V
238
239    def policy_improvement(self, V, gamma=1.0, dtype=np.float32):
240        """
241        Policy Improvement algorithm.
242
243        Parameters
244        ----------
245        V : np.ndarray
246            State values array.
247        gamma : float, optional
248            Discount factor, by default 1.0.
249
250        Returns
251        -------
252        dict
253            Policy mapping states to actions.
254        """
255        Q = np.zeros((len(self.P), len(self.P[0])), dtype=dtype)
256        for s in range(len(self.P)):
257            for a in range(len(self.P[s])):
258                for prob, next_state, reward, done in self.P[s][a]:
259                    Q[s][a] += prob * (reward + gamma * V[next_state] * (not done))
260
261        return dict(enumerate(np.argmax(Q, axis=1)))
Planner(P)
32    def __init__(self, P):
33        """
34        Initializes the Planner class.
35
36        Parameters
37        ----------
38        P : dict
39            Transition probability matrix where P[state][action] is a list of tuples
40            (probability, next state, reward, terminal).
41        """
42        self.P = P

Initializes the Planner class.

Parameters
  • P (dict): Transition probability matrix where P[state][action] is a list of tuples (probability, next state, reward, terminal).
P
def value_iteration( self, gamma=1.0, n_iters=1000, theta=1e-10, dtype=<class 'numpy.float32'>):
44    def value_iteration(self, gamma=1.0, n_iters=1000, theta=1e-10, dtype=np.float32):
45        """
46        Value Iteration algorithm.
47
48        Parameters
49        ----------
50        gamma : float, optional
51            Discount factor, by default 1.0.
52        n_iters : int, optional
53            Number of iterations, by default 1000.
54        theta : float, optional
55            Convergence criterion for value iteration, by default 1e-10.
56
57        Returns
58        -------
59        tuple
60            V : np.ndarray
61                State values array.
62            V_track : np.ndarray
63                Log of V(s) for each iteration.
64            pi : dict
65                Policy mapping states to actions.
66        """
67        V = np.zeros(len(self.P), dtype=dtype)
68        V_track = np.zeros((n_iters, len(self.P)), dtype=dtype)
69        converged = False
70        for i in tqdm(range(n_iters), leave=False):
71            Q = np.zeros((len(self.P), len(self.P[0])), dtype=dtype)
72            for s in range(len(self.P)):
73                for a in range(len(self.P[s])):
74                    for prob, next_state, reward, done in self.P[s][a]:
75                        Q[s][a] += prob * (reward + gamma * V[next_state] * (not done))
76            if np.max(np.abs(V - np.max(Q, axis=1))) < theta:
77                converged = True
78            V = np.max(Q, axis=1)
79            V_track[i] = V
80            if converged:
81                break
82        if not converged:
83            warnings.warn("Max iterations reached before convergence.  Check n_iters.")
84
85        pi = {s: a for s, a in enumerate(np.argmax(Q, axis=1))}
86        return V, V_track, pi

Value Iteration algorithm.

Parameters
  • gamma (float, optional): Discount factor, by default 1.0.
  • n_iters (int, optional): Number of iterations, by default 1000.
  • theta (float, optional): Convergence criterion for value iteration, by default 1e-10.
Returns
  • tuple: V : np.ndarray State values array. V_track : np.ndarray Log of V(s) for each iteration. pi : dict Policy mapping states to actions.
def value_iteration_vectorized( self, gamma=1.0, n_iters=1000, theta=1e-10, dtype=<class 'numpy.float32'>):
 88    def value_iteration_vectorized(
 89        self, gamma=1.0, n_iters=1000, theta=1e-10, dtype=np.float32
 90    ):
 91        """
 92        Vectorized Value Iteration algorithm.
 93
 94        Parameters
 95        ----------
 96        gamma : float
 97            Discount factor
 98
 99        n_iters : int
100            Number of iterations
101
102        theta : float
103            Convergence criterion for value iteration.
104            State values are considered to be converged when the maximum difference between new and previous state values is less than theta.
105            Stops at n_iters or theta convergence - whichever comes first.
106
107        Returns
108        -------
109        tuple
110            V : np.ndarray
111                State values array.
112            V_track : np.ndarray
113                Log of V(s) for each iteration.
114            pi : dict
115                Policy mapping states to actions.
116        """
117        S = len(self.P)
118        A = len(self.P[0])
119
120        max_K = max(len(self.P[s][a]) for s in range(S) for a in range(A))
121
122        prob_array = np.zeros((S, A, max_K), dtype=dtype)
123        next_state_array = np.zeros((S, A, max_K), dtype=np.int32)
124        reward_array = np.zeros((S, A, max_K), dtype=dtype)
125        done_array = np.zeros((S, A, max_K), dtype=bool)
126        mask_array = np.zeros((S, A, max_K), dtype=bool)
127
128        for s in range(S):
129            for a in range(A):
130                transitions = self.P[s][a]
131                for k, (prob, next_state, reward, done) in enumerate(transitions):
132                    prob_array[s, a, k] = prob
133                    next_state_array[s, a, k] = next_state
134                    reward_array[s, a, k] = reward
135                    done_array[s, a, k] = done
136                    mask_array[s, a, k] = True
137
138        V = np.zeros(S, dtype=dtype)
139        V_track = np.zeros((n_iters, S), dtype=dtype)
140        converged = False
141        # Simpler way to handle done states
142        not_done_array = 1 - done_array
143
144        for i in tqdm(range(n_iters), leave=False):
145            Q = np.sum(
146                prob_array
147                * (reward_array + gamma * V[next_state_array] * not_done_array)
148                * mask_array,
149                axis=2,
150            )
151            V_new = np.max(Q, axis=1)
152
153            if np.max(np.abs(V - V_new)) < theta:
154                converged = True
155
156            V = V_new
157            V_track[i] = V
158
159            if converged:
160                break
161
162        if not converged:
163            warnings.warn("Max iterations reached before convergence. Check n_iters.")
164
165        return V, V_track, dict(enumerate(np.argmax(Q, axis=1)))

Vectorized Value Iteration algorithm.

Parameters
  • gamma (float): Discount factor
  • n_iters (int): Number of iterations
  • theta (float): Convergence criterion for value iteration. State values are considered to be converged when the maximum difference between new and previous state values is less than theta. Stops at n_iters or theta convergence - whichever comes first.
Returns
  • tuple: V : np.ndarray State values array. V_track : np.ndarray Log of V(s) for each iteration. pi : dict Policy mapping states to actions.
def policy_iteration( self, gamma=1.0, n_iters=50, theta=1e-10, dtype=<class 'numpy.float32'>):
167    def policy_iteration(self, gamma=1.0, n_iters=50, theta=1e-10, dtype=np.float32):
168        """
169        Policy Iteration algorithm.
170
171        Parameters
172        ----------
173        gamma : float, optional
174            Discount factor, by default 1.0.
175        n_iters : int, optional
176            Number of iterations, by default 50.
177        theta : float, optional
178            Convergence criterion for policy evaluation, by default 1e-10.
179
180        Returns
181        -------
182        tuple
183            V : np.ndarray
184                State values array.
185            V_track : np.ndarray
186                Log of V(s) for each iteration.
187            pi : dict
188                Policy mapping states to actions.
189        """
190        random_actions = np.random.choice(tuple(self.P[0].keys()), len(self.P))
191
192        pi = {s: a for s, a in enumerate(random_actions)}
193        # initial V to give to `policy_evaluation` for the first time
194        V = np.zeros(len(self.P), dtype=dtype)
195        V_track = np.zeros((n_iters, len(self.P)), dtype=dtype)
196        converged = False
197        for i in tqdm(range(n_iters), leave=False):
198            old_pi = pi
199            V = self.policy_evaluation(pi, V, gamma=gamma, theta=theta, dtype=dtype)
200            V_track[i] = V
201            pi = self.policy_improvement(V, gamma=gamma, dtype=dtype)
202            if old_pi == pi:
203                converged = True
204                break
205        if not converged:
206            warnings.warn("Max iterations reached before convergence.  Check n_iters.")
207        return V, V_track, pi

Policy Iteration algorithm.

Parameters
  • gamma (float, optional): Discount factor, by default 1.0.
  • n_iters (int, optional): Number of iterations, by default 50.
  • theta (float, optional): Convergence criterion for policy evaluation, by default 1e-10.
Returns
  • tuple: V : np.ndarray State values array. V_track : np.ndarray Log of V(s) for each iteration. pi : dict Policy mapping states to actions.
def policy_evaluation( self, pi, prev_V, gamma=1.0, theta=1e-10, dtype=<class 'numpy.float32'>):
209    def policy_evaluation(self, pi, prev_V, gamma=1.0, theta=1e-10, dtype=np.float32):
210        """
211        Policy Evaluation algorithm.
212
213        Parameters
214        ----------
215        pi : dict
216            Policy mapping states to actions.
217        prev_V : np.ndarray
218            Previous state values array.
219        gamma : float, optional
220            Discount factor, by default 1.0.
221        theta : float, optional
222            Convergence criterion, by default 1e-10.
223
224        Returns
225        -------
226        np.ndarray
227            State values array.
228        """
229        while True:
230            V = np.zeros(len(self.P), dtype=dtype)
231            for s in range(len(self.P)):
232                for prob, next_state, reward, done in self.P[s][pi[s]]:
233                    V[s] += prob * (reward + gamma * prev_V[next_state] * (not done))
234            if np.max(np.abs(prev_V - V)) < theta:
235                break
236            prev_V = V.copy()
237        return V

Policy Evaluation algorithm.

Parameters
  • pi (dict): Policy mapping states to actions.
  • prev_V (np.ndarray): Previous state values array.
  • gamma (float, optional): Discount factor, by default 1.0.
  • theta (float, optional): Convergence criterion, by default 1e-10.
Returns
  • np.ndarray: State values array.
def policy_improvement(self, V, gamma=1.0, dtype=<class 'numpy.float32'>):
239    def policy_improvement(self, V, gamma=1.0, dtype=np.float32):
240        """
241        Policy Improvement algorithm.
242
243        Parameters
244        ----------
245        V : np.ndarray
246            State values array.
247        gamma : float, optional
248            Discount factor, by default 1.0.
249
250        Returns
251        -------
252        dict
253            Policy mapping states to actions.
254        """
255        Q = np.zeros((len(self.P), len(self.P[0])), dtype=dtype)
256        for s in range(len(self.P)):
257            for a in range(len(self.P[s])):
258                for prob, next_state, reward, done in self.P[s][a]:
259                    Q[s][a] += prob * (reward + gamma * V[next_state] * (not done))
260
261        return dict(enumerate(np.argmax(Q, axis=1)))

Policy Improvement algorithm.

Parameters
  • V (np.ndarray): State values array.
  • gamma (float, optional): Discount factor, by default 1.0.
Returns
  • dict: Policy mapping states to actions.