import numpy as np
import matplotlib.pyplot as plt
State Action Value Function
The state-action value function, commonly denoted as Q(s,a), represents the expected cumulative rewards of taking action a in state s and following a particular policy thereafter. It is a fundamental concept in reinforcement learning, helping agents evaluate and select actions based on their potential long-term outcomes in a given environment.
Import Libraries
Data
= 8
num_states = 2
num_actions
= 100
terminal_left_reward = 40
terminal_right_reward = 0
each_step_reward
= 0.5 # Discount factor
gamma
= 0 # Probability of going in the wrong direction misstep_prob
Q Values
def generate_rewards(num_states, each_step_reward, terminal_left_reward, terminal_right_reward):
= [each_step_reward] * num_states
rewards 0] = terminal_left_reward
rewards[-1] = terminal_right_reward
rewards[
return rewards
def generate_transition_prob(num_states, num_actions, misstep_prob = 0):
= np.zeros((num_states, num_actions, num_states))
p
for i in range(num_states):
if i != 0:
0, i-1] = 1 - misstep_prob
p[i, 1, i-1] = misstep_prob
p[i,
if i != num_states - 1:
1, i+1] = 1 - misstep_prob
p[i, 0, i+1] = misstep_prob
p[i,
# Terminal States
0] = np.zeros((num_actions, num_states))
p[-1] = np.zeros((num_actions, num_states))
p[
return p
def calculate_Q_value(num_states, rewards, transition_prob, gamma, V_states, state, action):
= rewards[state] + gamma * sum([transition_prob[state, action, sp] * V_states[sp] for sp in range(num_states)])
q_sa return q_sa
def evaluate_policy(num_states, rewards, transition_prob, gamma, policy):
= 10000
max_policy_eval = 1e-10
threshold
= np.zeros(num_states)
V
for i in range(max_policy_eval):
= 0
delta for s in range(num_states):
= V[s]
v = calculate_Q_value(num_states, rewards, transition_prob, gamma, V, s, policy[s])
V[s] = max(delta, abs(v - V[s]))
delta
if delta < threshold:
break
return V
def calculate_Q_values(num_states, rewards, transition_prob, gamma, optimal_policy):
# Left and then optimal policy
= np.zeros(num_states)
q_left_star
# Right and optimal policy
= np.zeros(num_states)
q_right_star
= evaluate_policy(num_states, rewards, transition_prob, gamma, optimal_policy)
V_star
for s in range(num_states):
= calculate_Q_value(num_states, rewards, transition_prob, gamma, V_star, s, 0)
q_left_star[s] = calculate_Q_value(num_states, rewards, transition_prob, gamma, V_star, s, 1)
q_right_star[s]
return q_left_star, q_right_star
Optimize Policy
def improve_policy(num_states, num_actions, rewards, transition_prob, gamma, V, policy):
= True
policy_stable
for s in range(num_states):
= V[s]
q_best for a in range(num_actions):
= calculate_Q_value(num_states, rewards, transition_prob, gamma, V, s, a)
q_sa if q_sa > q_best and policy[s] != a:
= a
policy[s] = q_sa
q_best = False
policy_stable
return policy, policy_stable
def get_optimal_policy(num_states, num_actions, rewards, transition_prob, gamma):
= np.zeros(num_states, dtype=int)
optimal_policy = 10000
max_policy_iter
for i in range(max_policy_iter):
= True
policy_stable
= evaluate_policy(num_states, rewards, transition_prob, gamma, optimal_policy)
V = improve_policy(num_states, num_actions, rewards, transition_prob, gamma, V, optimal_policy)
optimal_policy, policy_stable
if policy_stable:
break
return optimal_policy, V
Visualization
def plot_optimal_policy_return(num_states, optimal_policy, rewards, V):
= [r"$\leftarrow$" if a == 0 else r"$\rightarrow$" for a in optimal_policy]
actions 0] = ""
actions[-1] = ""
actions[
= plt.subplots(figsize=(2*num_states,2))
fig, ax
for i in range(num_states):
+0.5, 0.5, actions[i], fontsize=32, ha="center", va="center", color="orange")
ax.text(i+0.5, 0.25, rewards[i], fontsize=16, ha="center", va="center", color="black")
ax.text(i+0.5, 0.75, round(V[i],2), fontsize=16, ha="center", va="center", color="firebrick")
ax.text(i="black")
ax.axvline(i, color0, num_states])
ax.set_xlim([0, 1])
ax.set_ylim([
ax.set_xticklabels([])
ax.set_yticklabels([])='both', which='both', length=0)
ax.tick_params(axis"Optimal policy",fontsize = 16) ax.set_title(
def plot_q_values(num_states, q_left_star, q_right_star, rewards):
= plt.subplots(figsize=(3*num_states,2))
fig, ax
for i in range(num_states):
+0.2, 0.6, round(q_left_star[i],2), fontsize=16, ha="center", va="center", color="firebrick")
ax.text(i+0.8, 0.6, round(q_right_star[i],2), fontsize=16, ha="center", va="center", color="firebrick")
ax.text(i
+0.5, 0.25, rewards[i], fontsize=20, ha="center", va="center", color="black")
ax.text(i="black")
ax.axvline(i, color0, num_states])
ax.set_xlim([0, 1])
ax.set_ylim([
ax.set_xticklabels([])
ax.set_yticklabels([])='both', which='both', length=0)
ax.tick_params(axis"Q(s,a)",fontsize = 16) ax.set_title(
def generate_visualization(num_states,num_actions,terminal_left_reward, terminal_right_reward, each_step_reward, gamma, misstep_prob):
= generate_rewards(num_states, each_step_reward, terminal_left_reward, terminal_right_reward)
rewards = generate_transition_prob(num_states, num_actions, misstep_prob)
transition_prob
= get_optimal_policy(num_states, num_actions, rewards, transition_prob, gamma)
optimal_policy, V = calculate_Q_values(num_states, rewards, transition_prob, gamma, optimal_policy)
q_left_star, q_right_star
plot_optimal_policy_return(num_states, optimal_policy, rewards, V) plot_q_values(num_states, q_left_star, q_right_star, rewards)
generate_visualization(num_states,num_actions,terminal_left_reward, terminal_right_reward, each_step_reward, gamma, misstep_prob)