2017-01-15 Bogacz (2017) A tutorial on free-energy

I enjoyed reading "A tutorial on the free-energy framework for modelling perception and learning" by Rafal Bogacz, which is freely available here. In particular, the author encourages to replicate the results in the paper. He is himself giving solutions in matlab, so I had to do the same in python all within a notebook...

Let's first initialize the notebook:

In [1]:
from __future__ import division, print_function
import numpy as np
np.set_printoptions(precision=6, suppress=True)
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
import matplotlib.pyplot as plt
phi = (np.sqrt(5)+1)/2
fig_width = 10
figsize = (fig_width, fig_width/phi)

exercise 1 : defining probabilities

In [2]:
u_obs = 2 # observation
var_u = 1 # noise in the 
v_p = 3 # prior expectation
var_p = 1 # variance of prior
In [3]:
def gauss(x, mean, variance):
    return 1 / np.sqrt(2* np.pi) / variance * np.exp(- (x - mean)**2 / variance )

g = lambda v: v**2

sizes = np.linspace(0.01, 5, 100)
fig, axs = plt.subplots(1, 3, figsize=figsize)

prior = gauss(sizes, v_p, var_p)
axs[0].plot(sizes, prior, 'k')
axs[0].set_title('Prior')


for var_u_ in np.logspace(-1, 1, 7, base=10)*var_u:
    likelihood = gauss(u_obs, g(sizes), var_u_)
    axs[1].plot(sizes, likelihood/likelihood.sum())
    axs[1].set_title('Likelihood')


    posterior = prior * likelihood
    posterior /= posterior.sum()
    axs[2].plot(sizes, posterior, label=r'$\sigma_u^2$ ={0:.2f}'.format(var_u_))
    axs[2].set_title('Posterior')

axs[2].legend()
for ax in axs:
    ax.set_xlabel('Size')
    ax.set_ylabel('Probability')

plt.tight_layout()

exercise 2 : an online solution

Let's define $F = \log( p(u |v) )$

$$ \frac{\partial F}{\partial v} = \frac{v - v_p}{\Sigma_p} + \dot{g}(v) \cdot \frac{u - g(v)}{\Sigma_u} $$
In [4]:
dg = lambda v: 2*v

T, dt = 5, 0.01
times = np.linspace(0., T, int(T/dt))

v = np.zeros_like(times)
for i_time, time in enumerate(times):
    if time == 0 :
        v[i_time] = v_p
    else:
        v[i_time] = v[i_time-1] + dt * ( (v[i_time-1] -v_p) / var_p + dg(v[i_time-1]) * (u_obs-g(v[i_time-1] )) / var_u )


fig, [ax1, ax2] = plt.subplots(1, 2, figsize=figsize)
ax1.plot(times, v)
ax1.set_xlabel('Time (s)')
ax1.set_ylabel('Size')
ax1.set_ylim(0, 5)

# now going online
v = np.zeros_like(times)
eps_u = np.zeros_like(times)
eps_p = np.zeros_like(times)

for i_time, time in enumerate(times):
    if time == 0 :
        v[i_time], eps_u[i_time], eps_p[i_time] = v_p, 0., 0.
    else:
        v[i_time] = v[i_time-1] + dt * ( - eps_p[i_time-1] + dg(v[i_time-1]) * eps_u[i_time-1] )

        eps_p[i_time] = eps_p[i_time-1] + dt * ( (v[i_time] -v_p) - var_p * eps_p[i_time-1] )
        eps_u[i_time] = eps_u[i_time-1] + dt * ( (u_obs-g(v[i_time])) - var_u * eps_u[i_time-1]  )




ax2.plot(times, v, label=r'$\phi$')
ax2.set_xlabel('Time (s)')
ax2.set_ylabel('Size')
ax2.plot(times, eps_p, 'g--', label=r'$\epsilon_p$')
ax2.plot(times, eps_u, 'r--', label=r'$\epsilon_u$')
ax2.legend();

exercise 5 : estimating variance

From the paper:

Simulate learning of variance $Σ_i$ over trials. For simplicity, only simulate the network described by Eqs. (59)– (60), and assume that variables ϕ are constant. On each trial generate input $ϕ_i$ from a normal distribution with mean 5 and variance 2, while set $g_i(ϕ_i+1)=5$ (so that the upper level correctly predicts the mean of $ϕ_i$). Simulate the network for 20 time units, and then update weight $Σ_i$ with learning rate $α=0.01$. Simulate 1000 trials and plot how $Σ_i$ changes across trials.

In [5]:
mean_u_obs = 5 # observation
var_u = 2 # noise in the observation
v_p = 5 # prior expectation (from node above)
var_p = 1 # variance of prior (from node above)

eta = .01
N_trials = 1000

T, dt = 50, 0.01
times = np.linspace(0., T, int(T/dt))


v = np.zeros_like(times)
e = np.zeros_like(times)
error = np.zeros_like(times)

var_u_ = 1. * np.ones(N_trials)

for i_trial in range(1, N_trials):
    # making an observation
    u_obs = mean_u_obs + np.sqrt(var_u) * np.random.randn()

    for i_time, time in enumerate(times):
        if time == 0 :
            e[i_time], error[i_time] = 0., 0.
        else:
            error[i_time] = error[i_time-1] + dt * ( (u_obs - v_p) - var_p * e[i_time-1] )
            e[i_time] = e[i_time-1] + dt * (var_u_[i_trial-1] * error[i_time-1] - e[i_time-1])

    var_u_[i_trial] = var_u_[i_trial-1] + eta * (error[-1]*e[-1] - 1)

fig, ax = plt.subplots(1, 1, figsize=figsize)
ax.plot(var_u_)
ax.set_ylim(0, 2.5)
ax.set_xlabel('trials')
ax.set_ylabel(r'$\Sigma$');

Comments

Comments powered by Disqus