from rl_opts.analytics import pdf_powerlaw, pdf_discrete_sample, get_policy_from_dist
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
Imitation learning
PS_imitation
PS_imitation (num_states:int, eta:float, gamma:float)
Constructs a PS agent with two actions (continue and rotate) that performs imitation learning in the search scenario. Instead of following a full trajectory of action-state tuples, the agent is directly given the reward state (the step length in this case). The agent updates all previous continue actions and the current rotate action.
Type | Details | |
---|---|---|
num_states | int | Number of states |
eta | float | Glow parameter of PS |
gamma | float | Damping parameter of PS |
PS_imitation.update
PS_imitation.update (length:int, reward:int=1)
Updates the policy based on the imitation scheme (see paper for detailes)
NOTE: state is length-1 because counter starts in 0 (but in 0, agent has already performed a step of length 1 – from the previous action “rotate”).
Type | Default | Details | |
---|---|---|---|
length | int | Step length rewarded | |
reward | int | 1 | Value of the reward |
Example
We showcase how to imitate the policy based on a given step length distribution, an in particular of a Lévy distribution. For further examples, see the Tutorials section.
= 100 # size of the state space
NUM_STATES = 100 # number of epochs
EPOCHS = 1000 # number of learning steps per episode
NUM_STEPS
= pdf_discrete_sample(pdf_func = pdf_powerlaw,
steps = 1,
beta = np.arange(1, NUM_STATES),
L = (EPOCHS, NUM_STEPS))
num_samples
= PS_imitation(num_states = NUM_STATES,
imitator = int(1e-7),
eta = 0)
gamma
for e in tqdm(range(EPOCHS)):
imitator.reset()for s in steps[e]:
= s) imitator.update(length
100%|██████████| 100/100 [00:01<00:00, 86.11it/s]
= get_policy_from_dist(n_max = NUM_STATES,
policy_theory = pdf_powerlaw,
func = 1)
beta = imitator.h_matrix[0,:]/imitator.h_matrix.sum(0) policy_imitat
= plt.subplots(figsize = (5,3))
_ , ax 'o')
ax.plot(policy_imitat ,1, NUM_STATES), policy_theory[1:])
ax.plot(np.arange(
plt.setp(ax, = 'log', xlim = (0.9, NUM_STATES/2), xlabel = r'Counter $n$',
xscale = (0.5, 1.1), ylabel = 'Policy'); ylim