bettermdptools.algorithms.planner
Author: Miguel Morales BSD 3-Clause License
Copyright (c) 2018, Miguel Morales All rights reserved. https://github.com/mimoralea/gdrl/blob/master/LICENSE
1""" 2Author: Miguel Morales 3BSD 3-Clause License 4 5Copyright (c) 2018, Miguel Morales 6All rights reserved. 7https://github.com/mimoralea/gdrl/blob/master/LICENSE 8""" 9 10""" 11modified by: John Mansfield 12 13documentation added by: Gagandeep Randhawa 14""" 15 16""" 17Class that contains functions related to planning algorithms (Value Iteration, Policy Iteration). 18Planner init expects a reward and transitions matrix P, which is nested dictionary gym style discrete environment 19where P[state][action] is a list of tuples (probability, next state, reward, terminal). 20 21Model-based learning algorithms: Value Iteration and Policy Iteration 22""" 23 24import warnings 25 26import numpy as np 27from tqdm.auto import tqdm 28 29 30class Planner: 31 def __init__(self, P): 32 """ 33 Initializes the Planner class. 34 35 Parameters 36 ---------- 37 P : dict 38 Transition probability matrix where P[state][action] is a list of tuples 39 (probability, next state, reward, terminal). 40 """ 41 self.P = P 42 43 def value_iteration(self, gamma=1.0, n_iters=1000, theta=1e-10, dtype=np.float32): 44 """ 45 Value Iteration algorithm. 46 47 Parameters 48 ---------- 49 gamma : float, optional 50 Discount factor, by default 1.0. 51 n_iters : int, optional 52 Number of iterations, by default 1000. 53 theta : float, optional 54 Convergence criterion for value iteration, by default 1e-10. 55 56 Returns 57 ------- 58 tuple 59 V : np.ndarray 60 State values array. 61 V_track : np.ndarray 62 Log of V(s) for each iteration. 63 pi : dict 64 Policy mapping states to actions. 65 """ 66 V = np.zeros(len(self.P), dtype=dtype) 67 V_track = np.zeros((n_iters, len(self.P)), dtype=dtype) 68 converged = False 69 for i in tqdm(range(n_iters), leave=False): 70 Q = np.zeros((len(self.P), len(self.P[0])), dtype=dtype) 71 for s in range(len(self.P)): 72 for a in range(len(self.P[s])): 73 for prob, next_state, reward, done in self.P[s][a]: 74 Q[s][a] += prob * (reward + gamma * V[next_state] * (not done)) 75 if np.max(np.abs(V - np.max(Q, axis=1))) < theta: 76 converged = True 77 V = np.max(Q, axis=1) 78 V_track[i] = V 79 if converged: 80 break 81 if not converged: 82 warnings.warn("Max iterations reached before convergence. Check n_iters.") 83 84 pi = {s: a for s, a in enumerate(np.argmax(Q, axis=1))} 85 return V, V_track, pi 86 87 def value_iteration_vectorized( 88 self, gamma=1.0, n_iters=1000, theta=1e-10, dtype=np.float32 89 ): 90 """ 91 Vectorized Value Iteration algorithm. 92 93 Parameters 94 ---------- 95 gamma : float 96 Discount factor 97 98 n_iters : int 99 Number of iterations 100 101 theta : float 102 Convergence criterion for value iteration. 103 State values are considered to be converged when the maximum difference between new and previous state values is less than theta. 104 Stops at n_iters or theta convergence - whichever comes first. 105 106 Returns 107 ------- 108 tuple 109 V : np.ndarray 110 State values array. 111 V_track : np.ndarray 112 Log of V(s) for each iteration. 113 pi : dict 114 Policy mapping states to actions. 115 """ 116 S = len(self.P) 117 A = len(self.P[0]) 118 119 max_K = max(len(self.P[s][a]) for s in range(S) for a in range(A)) 120 121 prob_array = np.zeros((S, A, max_K), dtype=dtype) 122 next_state_array = np.zeros((S, A, max_K), dtype=np.int32) 123 reward_array = np.zeros((S, A, max_K), dtype=dtype) 124 done_array = np.zeros((S, A, max_K), dtype=bool) 125 mask_array = np.zeros((S, A, max_K), dtype=bool) 126 127 for s in range(S): 128 for a in range(A): 129 transitions = self.P[s][a] 130 for k, (prob, next_state, reward, done) in enumerate(transitions): 131 prob_array[s, a, k] = prob 132 next_state_array[s, a, k] = next_state 133 reward_array[s, a, k] = reward 134 done_array[s, a, k] = done 135 mask_array[s, a, k] = True 136 137 V = np.zeros(S, dtype=dtype) 138 V_track = np.zeros((n_iters, S), dtype=dtype) 139 converged = False 140 # Simpler way to handle done states 141 not_done_array = 1 - done_array 142 143 for i in tqdm(range(n_iters), leave=False): 144 Q = np.sum( 145 prob_array 146 * (reward_array + gamma * V[next_state_array] * not_done_array) 147 * mask_array, 148 axis=2, 149 ) 150 V_new = np.max(Q, axis=1) 151 152 if np.max(np.abs(V - V_new)) < theta: 153 converged = True 154 155 V = V_new 156 V_track[i] = V 157 158 if converged: 159 break 160 161 if not converged: 162 warnings.warn("Max iterations reached before convergence. Check n_iters.") 163 164 return V, V_track, dict(enumerate(np.argmax(Q, axis=1))) 165 166 def policy_iteration(self, gamma=1.0, n_iters=50, theta=1e-10, dtype=np.float32): 167 """ 168 Policy Iteration algorithm. 169 170 Parameters 171 ---------- 172 gamma : float, optional 173 Discount factor, by default 1.0. 174 n_iters : int, optional 175 Number of iterations, by default 50. 176 theta : float, optional 177 Convergence criterion for policy evaluation, by default 1e-10. 178 179 Returns 180 ------- 181 tuple 182 V : np.ndarray 183 State values array. 184 V_track : np.ndarray 185 Log of V(s) for each iteration. 186 pi : dict 187 Policy mapping states to actions. 188 """ 189 random_actions = np.random.choice(tuple(self.P[0].keys()), len(self.P)) 190 191 pi = {s: a for s, a in enumerate(random_actions)} 192 # initial V to give to `policy_evaluation` for the first time 193 V = np.zeros(len(self.P), dtype=dtype) 194 V_track = np.zeros((n_iters, len(self.P)), dtype=dtype) 195 converged = False 196 for i in tqdm(range(n_iters), leave=False): 197 old_pi = pi 198 V = self.policy_evaluation(pi, V, gamma=gamma, theta=theta, dtype=dtype) 199 V_track[i] = V 200 pi = self.policy_improvement(V, gamma=gamma, dtype=dtype) 201 if old_pi == pi: 202 converged = True 203 break 204 if not converged: 205 warnings.warn("Max iterations reached before convergence. Check n_iters.") 206 return V, V_track, pi 207 208 def policy_evaluation(self, pi, prev_V, gamma=1.0, theta=1e-10, dtype=np.float32): 209 """ 210 Policy Evaluation algorithm. 211 212 Parameters 213 ---------- 214 pi : dict 215 Policy mapping states to actions. 216 prev_V : np.ndarray 217 Previous state values array. 218 gamma : float, optional 219 Discount factor, by default 1.0. 220 theta : float, optional 221 Convergence criterion, by default 1e-10. 222 223 Returns 224 ------- 225 np.ndarray 226 State values array. 227 """ 228 while True: 229 V = np.zeros(len(self.P), dtype=dtype) 230 for s in range(len(self.P)): 231 for prob, next_state, reward, done in self.P[s][pi[s]]: 232 V[s] += prob * (reward + gamma * prev_V[next_state] * (not done)) 233 if np.max(np.abs(prev_V - V)) < theta: 234 break 235 prev_V = V.copy() 236 return V 237 238 def policy_improvement(self, V, gamma=1.0, dtype=np.float32): 239 """ 240 Policy Improvement algorithm. 241 242 Parameters 243 ---------- 244 V : np.ndarray 245 State values array. 246 gamma : float, optional 247 Discount factor, by default 1.0. 248 249 Returns 250 ------- 251 dict 252 Policy mapping states to actions. 253 """ 254 Q = np.zeros((len(self.P), len(self.P[0])), dtype=dtype) 255 for s in range(len(self.P)): 256 for a in range(len(self.P[s])): 257 for prob, next_state, reward, done in self.P[s][a]: 258 Q[s][a] += prob * (reward + gamma * V[next_state] * (not done)) 259 260 return dict(enumerate(np.argmax(Q, axis=1)))
class
Planner:
31class Planner: 32 def __init__(self, P): 33 """ 34 Initializes the Planner class. 35 36 Parameters 37 ---------- 38 P : dict 39 Transition probability matrix where P[state][action] is a list of tuples 40 (probability, next state, reward, terminal). 41 """ 42 self.P = P 43 44 def value_iteration(self, gamma=1.0, n_iters=1000, theta=1e-10, dtype=np.float32): 45 """ 46 Value Iteration algorithm. 47 48 Parameters 49 ---------- 50 gamma : float, optional 51 Discount factor, by default 1.0. 52 n_iters : int, optional 53 Number of iterations, by default 1000. 54 theta : float, optional 55 Convergence criterion for value iteration, by default 1e-10. 56 57 Returns 58 ------- 59 tuple 60 V : np.ndarray 61 State values array. 62 V_track : np.ndarray 63 Log of V(s) for each iteration. 64 pi : dict 65 Policy mapping states to actions. 66 """ 67 V = np.zeros(len(self.P), dtype=dtype) 68 V_track = np.zeros((n_iters, len(self.P)), dtype=dtype) 69 converged = False 70 for i in tqdm(range(n_iters), leave=False): 71 Q = np.zeros((len(self.P), len(self.P[0])), dtype=dtype) 72 for s in range(len(self.P)): 73 for a in range(len(self.P[s])): 74 for prob, next_state, reward, done in self.P[s][a]: 75 Q[s][a] += prob * (reward + gamma * V[next_state] * (not done)) 76 if np.max(np.abs(V - np.max(Q, axis=1))) < theta: 77 converged = True 78 V = np.max(Q, axis=1) 79 V_track[i] = V 80 if converged: 81 break 82 if not converged: 83 warnings.warn("Max iterations reached before convergence. Check n_iters.") 84 85 pi = {s: a for s, a in enumerate(np.argmax(Q, axis=1))} 86 return V, V_track, pi 87 88 def value_iteration_vectorized( 89 self, gamma=1.0, n_iters=1000, theta=1e-10, dtype=np.float32 90 ): 91 """ 92 Vectorized Value Iteration algorithm. 93 94 Parameters 95 ---------- 96 gamma : float 97 Discount factor 98 99 n_iters : int 100 Number of iterations 101 102 theta : float 103 Convergence criterion for value iteration. 104 State values are considered to be converged when the maximum difference between new and previous state values is less than theta. 105 Stops at n_iters or theta convergence - whichever comes first. 106 107 Returns 108 ------- 109 tuple 110 V : np.ndarray 111 State values array. 112 V_track : np.ndarray 113 Log of V(s) for each iteration. 114 pi : dict 115 Policy mapping states to actions. 116 """ 117 S = len(self.P) 118 A = len(self.P[0]) 119 120 max_K = max(len(self.P[s][a]) for s in range(S) for a in range(A)) 121 122 prob_array = np.zeros((S, A, max_K), dtype=dtype) 123 next_state_array = np.zeros((S, A, max_K), dtype=np.int32) 124 reward_array = np.zeros((S, A, max_K), dtype=dtype) 125 done_array = np.zeros((S, A, max_K), dtype=bool) 126 mask_array = np.zeros((S, A, max_K), dtype=bool) 127 128 for s in range(S): 129 for a in range(A): 130 transitions = self.P[s][a] 131 for k, (prob, next_state, reward, done) in enumerate(transitions): 132 prob_array[s, a, k] = prob 133 next_state_array[s, a, k] = next_state 134 reward_array[s, a, k] = reward 135 done_array[s, a, k] = done 136 mask_array[s, a, k] = True 137 138 V = np.zeros(S, dtype=dtype) 139 V_track = np.zeros((n_iters, S), dtype=dtype) 140 converged = False 141 # Simpler way to handle done states 142 not_done_array = 1 - done_array 143 144 for i in tqdm(range(n_iters), leave=False): 145 Q = np.sum( 146 prob_array 147 * (reward_array + gamma * V[next_state_array] * not_done_array) 148 * mask_array, 149 axis=2, 150 ) 151 V_new = np.max(Q, axis=1) 152 153 if np.max(np.abs(V - V_new)) < theta: 154 converged = True 155 156 V = V_new 157 V_track[i] = V 158 159 if converged: 160 break 161 162 if not converged: 163 warnings.warn("Max iterations reached before convergence. Check n_iters.") 164 165 return V, V_track, dict(enumerate(np.argmax(Q, axis=1))) 166 167 def policy_iteration(self, gamma=1.0, n_iters=50, theta=1e-10, dtype=np.float32): 168 """ 169 Policy Iteration algorithm. 170 171 Parameters 172 ---------- 173 gamma : float, optional 174 Discount factor, by default 1.0. 175 n_iters : int, optional 176 Number of iterations, by default 50. 177 theta : float, optional 178 Convergence criterion for policy evaluation, by default 1e-10. 179 180 Returns 181 ------- 182 tuple 183 V : np.ndarray 184 State values array. 185 V_track : np.ndarray 186 Log of V(s) for each iteration. 187 pi : dict 188 Policy mapping states to actions. 189 """ 190 random_actions = np.random.choice(tuple(self.P[0].keys()), len(self.P)) 191 192 pi = {s: a for s, a in enumerate(random_actions)} 193 # initial V to give to `policy_evaluation` for the first time 194 V = np.zeros(len(self.P), dtype=dtype) 195 V_track = np.zeros((n_iters, len(self.P)), dtype=dtype) 196 converged = False 197 for i in tqdm(range(n_iters), leave=False): 198 old_pi = pi 199 V = self.policy_evaluation(pi, V, gamma=gamma, theta=theta, dtype=dtype) 200 V_track[i] = V 201 pi = self.policy_improvement(V, gamma=gamma, dtype=dtype) 202 if old_pi == pi: 203 converged = True 204 break 205 if not converged: 206 warnings.warn("Max iterations reached before convergence. Check n_iters.") 207 return V, V_track, pi 208 209 def policy_evaluation(self, pi, prev_V, gamma=1.0, theta=1e-10, dtype=np.float32): 210 """ 211 Policy Evaluation algorithm. 212 213 Parameters 214 ---------- 215 pi : dict 216 Policy mapping states to actions. 217 prev_V : np.ndarray 218 Previous state values array. 219 gamma : float, optional 220 Discount factor, by default 1.0. 221 theta : float, optional 222 Convergence criterion, by default 1e-10. 223 224 Returns 225 ------- 226 np.ndarray 227 State values array. 228 """ 229 while True: 230 V = np.zeros(len(self.P), dtype=dtype) 231 for s in range(len(self.P)): 232 for prob, next_state, reward, done in self.P[s][pi[s]]: 233 V[s] += prob * (reward + gamma * prev_V[next_state] * (not done)) 234 if np.max(np.abs(prev_V - V)) < theta: 235 break 236 prev_V = V.copy() 237 return V 238 239 def policy_improvement(self, V, gamma=1.0, dtype=np.float32): 240 """ 241 Policy Improvement algorithm. 242 243 Parameters 244 ---------- 245 V : np.ndarray 246 State values array. 247 gamma : float, optional 248 Discount factor, by default 1.0. 249 250 Returns 251 ------- 252 dict 253 Policy mapping states to actions. 254 """ 255 Q = np.zeros((len(self.P), len(self.P[0])), dtype=dtype) 256 for s in range(len(self.P)): 257 for a in range(len(self.P[s])): 258 for prob, next_state, reward, done in self.P[s][a]: 259 Q[s][a] += prob * (reward + gamma * V[next_state] * (not done)) 260 261 return dict(enumerate(np.argmax(Q, axis=1)))
Planner(P)
32 def __init__(self, P): 33 """ 34 Initializes the Planner class. 35 36 Parameters 37 ---------- 38 P : dict 39 Transition probability matrix where P[state][action] is a list of tuples 40 (probability, next state, reward, terminal). 41 """ 42 self.P = P
Initializes the Planner class.
Parameters
- P (dict): Transition probability matrix where P[state][action] is a list of tuples (probability, next state, reward, terminal).
def
value_iteration( self, gamma=1.0, n_iters=1000, theta=1e-10, dtype=<class 'numpy.float32'>):
44 def value_iteration(self, gamma=1.0, n_iters=1000, theta=1e-10, dtype=np.float32): 45 """ 46 Value Iteration algorithm. 47 48 Parameters 49 ---------- 50 gamma : float, optional 51 Discount factor, by default 1.0. 52 n_iters : int, optional 53 Number of iterations, by default 1000. 54 theta : float, optional 55 Convergence criterion for value iteration, by default 1e-10. 56 57 Returns 58 ------- 59 tuple 60 V : np.ndarray 61 State values array. 62 V_track : np.ndarray 63 Log of V(s) for each iteration. 64 pi : dict 65 Policy mapping states to actions. 66 """ 67 V = np.zeros(len(self.P), dtype=dtype) 68 V_track = np.zeros((n_iters, len(self.P)), dtype=dtype) 69 converged = False 70 for i in tqdm(range(n_iters), leave=False): 71 Q = np.zeros((len(self.P), len(self.P[0])), dtype=dtype) 72 for s in range(len(self.P)): 73 for a in range(len(self.P[s])): 74 for prob, next_state, reward, done in self.P[s][a]: 75 Q[s][a] += prob * (reward + gamma * V[next_state] * (not done)) 76 if np.max(np.abs(V - np.max(Q, axis=1))) < theta: 77 converged = True 78 V = np.max(Q, axis=1) 79 V_track[i] = V 80 if converged: 81 break 82 if not converged: 83 warnings.warn("Max iterations reached before convergence. Check n_iters.") 84 85 pi = {s: a for s, a in enumerate(np.argmax(Q, axis=1))} 86 return V, V_track, pi
Value Iteration algorithm.
Parameters
- gamma (float, optional): Discount factor, by default 1.0.
- n_iters (int, optional): Number of iterations, by default 1000.
- theta (float, optional): Convergence criterion for value iteration, by default 1e-10.
Returns
- tuple: V : np.ndarray State values array. V_track : np.ndarray Log of V(s) for each iteration. pi : dict Policy mapping states to actions.
def
value_iteration_vectorized( self, gamma=1.0, n_iters=1000, theta=1e-10, dtype=<class 'numpy.float32'>):
88 def value_iteration_vectorized( 89 self, gamma=1.0, n_iters=1000, theta=1e-10, dtype=np.float32 90 ): 91 """ 92 Vectorized Value Iteration algorithm. 93 94 Parameters 95 ---------- 96 gamma : float 97 Discount factor 98 99 n_iters : int 100 Number of iterations 101 102 theta : float 103 Convergence criterion for value iteration. 104 State values are considered to be converged when the maximum difference between new and previous state values is less than theta. 105 Stops at n_iters or theta convergence - whichever comes first. 106 107 Returns 108 ------- 109 tuple 110 V : np.ndarray 111 State values array. 112 V_track : np.ndarray 113 Log of V(s) for each iteration. 114 pi : dict 115 Policy mapping states to actions. 116 """ 117 S = len(self.P) 118 A = len(self.P[0]) 119 120 max_K = max(len(self.P[s][a]) for s in range(S) for a in range(A)) 121 122 prob_array = np.zeros((S, A, max_K), dtype=dtype) 123 next_state_array = np.zeros((S, A, max_K), dtype=np.int32) 124 reward_array = np.zeros((S, A, max_K), dtype=dtype) 125 done_array = np.zeros((S, A, max_K), dtype=bool) 126 mask_array = np.zeros((S, A, max_K), dtype=bool) 127 128 for s in range(S): 129 for a in range(A): 130 transitions = self.P[s][a] 131 for k, (prob, next_state, reward, done) in enumerate(transitions): 132 prob_array[s, a, k] = prob 133 next_state_array[s, a, k] = next_state 134 reward_array[s, a, k] = reward 135 done_array[s, a, k] = done 136 mask_array[s, a, k] = True 137 138 V = np.zeros(S, dtype=dtype) 139 V_track = np.zeros((n_iters, S), dtype=dtype) 140 converged = False 141 # Simpler way to handle done states 142 not_done_array = 1 - done_array 143 144 for i in tqdm(range(n_iters), leave=False): 145 Q = np.sum( 146 prob_array 147 * (reward_array + gamma * V[next_state_array] * not_done_array) 148 * mask_array, 149 axis=2, 150 ) 151 V_new = np.max(Q, axis=1) 152 153 if np.max(np.abs(V - V_new)) < theta: 154 converged = True 155 156 V = V_new 157 V_track[i] = V 158 159 if converged: 160 break 161 162 if not converged: 163 warnings.warn("Max iterations reached before convergence. Check n_iters.") 164 165 return V, V_track, dict(enumerate(np.argmax(Q, axis=1)))
Vectorized Value Iteration algorithm.
Parameters
- gamma (float): Discount factor
- n_iters (int): Number of iterations
- theta (float): Convergence criterion for value iteration. State values are considered to be converged when the maximum difference between new and previous state values is less than theta. Stops at n_iters or theta convergence - whichever comes first.
Returns
- tuple: V : np.ndarray State values array. V_track : np.ndarray Log of V(s) for each iteration. pi : dict Policy mapping states to actions.
def
policy_iteration( self, gamma=1.0, n_iters=50, theta=1e-10, dtype=<class 'numpy.float32'>):
167 def policy_iteration(self, gamma=1.0, n_iters=50, theta=1e-10, dtype=np.float32): 168 """ 169 Policy Iteration algorithm. 170 171 Parameters 172 ---------- 173 gamma : float, optional 174 Discount factor, by default 1.0. 175 n_iters : int, optional 176 Number of iterations, by default 50. 177 theta : float, optional 178 Convergence criterion for policy evaluation, by default 1e-10. 179 180 Returns 181 ------- 182 tuple 183 V : np.ndarray 184 State values array. 185 V_track : np.ndarray 186 Log of V(s) for each iteration. 187 pi : dict 188 Policy mapping states to actions. 189 """ 190 random_actions = np.random.choice(tuple(self.P[0].keys()), len(self.P)) 191 192 pi = {s: a for s, a in enumerate(random_actions)} 193 # initial V to give to `policy_evaluation` for the first time 194 V = np.zeros(len(self.P), dtype=dtype) 195 V_track = np.zeros((n_iters, len(self.P)), dtype=dtype) 196 converged = False 197 for i in tqdm(range(n_iters), leave=False): 198 old_pi = pi 199 V = self.policy_evaluation(pi, V, gamma=gamma, theta=theta, dtype=dtype) 200 V_track[i] = V 201 pi = self.policy_improvement(V, gamma=gamma, dtype=dtype) 202 if old_pi == pi: 203 converged = True 204 break 205 if not converged: 206 warnings.warn("Max iterations reached before convergence. Check n_iters.") 207 return V, V_track, pi
Policy Iteration algorithm.
Parameters
- gamma (float, optional): Discount factor, by default 1.0.
- n_iters (int, optional): Number of iterations, by default 50.
- theta (float, optional): Convergence criterion for policy evaluation, by default 1e-10.
Returns
- tuple: V : np.ndarray State values array. V_track : np.ndarray Log of V(s) for each iteration. pi : dict Policy mapping states to actions.
def
policy_evaluation( self, pi, prev_V, gamma=1.0, theta=1e-10, dtype=<class 'numpy.float32'>):
209 def policy_evaluation(self, pi, prev_V, gamma=1.0, theta=1e-10, dtype=np.float32): 210 """ 211 Policy Evaluation algorithm. 212 213 Parameters 214 ---------- 215 pi : dict 216 Policy mapping states to actions. 217 prev_V : np.ndarray 218 Previous state values array. 219 gamma : float, optional 220 Discount factor, by default 1.0. 221 theta : float, optional 222 Convergence criterion, by default 1e-10. 223 224 Returns 225 ------- 226 np.ndarray 227 State values array. 228 """ 229 while True: 230 V = np.zeros(len(self.P), dtype=dtype) 231 for s in range(len(self.P)): 232 for prob, next_state, reward, done in self.P[s][pi[s]]: 233 V[s] += prob * (reward + gamma * prev_V[next_state] * (not done)) 234 if np.max(np.abs(prev_V - V)) < theta: 235 break 236 prev_V = V.copy() 237 return V
Policy Evaluation algorithm.
Parameters
- pi (dict): Policy mapping states to actions.
- prev_V (np.ndarray): Previous state values array.
- gamma (float, optional): Discount factor, by default 1.0.
- theta (float, optional): Convergence criterion, by default 1e-10.
Returns
- np.ndarray: State values array.
def
policy_improvement(self, V, gamma=1.0, dtype=<class 'numpy.float32'>):
239 def policy_improvement(self, V, gamma=1.0, dtype=np.float32): 240 """ 241 Policy Improvement algorithm. 242 243 Parameters 244 ---------- 245 V : np.ndarray 246 State values array. 247 gamma : float, optional 248 Discount factor, by default 1.0. 249 250 Returns 251 ------- 252 dict 253 Policy mapping states to actions. 254 """ 255 Q = np.zeros((len(self.P), len(self.P[0])), dtype=dtype) 256 for s in range(len(self.P)): 257 for a in range(len(self.P[s])): 258 for prob, next_state, reward, done in self.P[s][a]: 259 Q[s][a] += prob * (reward + gamma * V[next_state] * (not done)) 260 261 return dict(enumerate(np.argmax(Q, axis=1)))
Policy Improvement algorithm.
Parameters
- V (np.ndarray): State values array.
- gamma (float, optional): Discount factor, by default 1.0.
Returns
- dict: Policy mapping states to actions.