In this notebook, we will study how homeostasis (cooperation) may be an essential ingredient to this algorithm working on a winner-take-all basis (competition). This extension has been published as Perrinet, Neural Computation (2010) (see http://invibe.net/LaurentPerrinet/Publications/Perrinet10shl ). Compared to the previous post, we optimize the code to be faster.

See also the other posts on unsupervised learning,

This is joint work with Victor Boutin.

Summary: using fast Pcum functions works with approx 80 times speed-up, and one needs to learn the non-linear functions

In [1]:
import matplotlib.pyplot as plt
import matplotlib
%matplotlib inline
import numpy as np
np.set_printoptions(formatter = dict( float = lambda x: "%.3g" % x ), precision=3, suppress=True, threshold=np.inf)
from shl_scripts.shl_experiments import SHL
#from nengo.utils.ipython import hide_input
import time
%load_ext autoreload
%autoreload 2
In [2]:
matname = '2017-05-31_Testing_COMPs'
DEBUG_DOWNSCALE = 1
#matname = '2017-05-31_Testing_COMPs-DEBUG'
#DEBUG_DOWNSCALE = 10

seed = 42
nb_quant = 256
C = 4.
do_sym = False

from shl_scripts.shl_experiments import SHL
shl = SHL(DEBUG_DOWNSCALE=DEBUG_DOWNSCALE, 
           eta=0.05, verbose=2, record_each=50, n_iter=1000, eta_homeo=0., alpha_homeo=1., 
          do_sym=do_sym, nb_quant=nb_quant, C=C)
data = shl.get_data(seed=seed, matname=matname)
loading the data called : /tmp/data_cache/2017-05-31_Testing_COMPs_data
In [3]:
test_size = data.shape[0]//2
data_training = data[:test_size, :]
data_test = data[test_size:,:]   
#DEBUG
test_size = data.shape[0]//20
data_training = data[:(data.shape[0]-test_size),:].copy()
data_test = data[:test_size, :].copy()
In [4]:
dico_partial_learning = shl.learn_dico(data=data_training, matname=matname)
loading the dico called : 2017-05-31_Testing_COMPs

We start off by using a short learning with no homeostasis such that we end up with a unbalanced dictionary:

In [5]:
fig, ax = shl.show_dico(dico_partial_learning, data=data, title=matname)
fig.show()
fig, ax = shl.time_plot(dico_partial_learning, variable='prob_active');
fig.show()

MP classique

In [6]:
n_samples, n_pixels = data_test.shape
n_dictionary, n_pixels = dico_partial_learning.dictionary.shape
norm_each_filter = np.sqrt(np.sum(dico_partial_learning.dictionary**2, axis=1))
dico_partial_learning.dictionary /= norm_each_filter[:,np.newaxis]

sparse_code_mp = shl.code(data_test, dico_partial_learning, matname=matname)


def plot_proba_histogram(coding, verbose=False):
    n_dictionary=coding.shape[1]

    p = np.count_nonzero(coding, axis=0)/coding.shape[1]
    p /= p.sum()

    rel_ent = np.sum( -p * np.log(p)) / np.log(n_dictionary)
    if verbose: print('Entropy / Entropy_max=', rel_ent )

    fig = plt.figure(figsize=(16, 4))
    ax = fig.add_subplot(111)
    ax.bar(np.arange(n_dictionary), p*n_dictionary)
    ax.set_title('distribution of the selection probability - entropy= ' + str(rel_ent)  )
    ax.set_ylabel('pdf')
    ax.set_xlim(0)
    ax.axis('tight')
    return fig, ax

fig, ax = plot_proba_histogram(sparse_code_mp)
loading the code called : /tmp/data_cache/2017-05-31_Testing_COMPs_coding.npy

COMP : learning modulations

computing histograms is rather fast:

In [7]:
def prior(code, C=5., do_sym=False):
    if do_sym:
        return 1.-np.exp(-np.abs(code)/C)
    else:
        return (1.-np.exp(-code/C))*(code>0)

def get_P_cum(code, nb_quant=100, C=5., do_sym=True):
    from shl_scripts.shl_encode import prior
    n_samples, nb_filter = code.shape
    code_bins = np.linspace(0., 1, nb_quant, endpoint=True)
    P_cum = np.zeros((nb_filter, nb_quant))
    for i in range(nb_filter):
        p, bins = np.histogram(prior(code[:, i], C=C, do_sym=do_sym), bins=code_bins, density=True)
        p /= p.sum()
        P_cum[i, :] = np.hstack((0, np.cumsum(p)))
    return P_cum
In [8]:
%%time
P_cum = get_P_cum(sparse_code_mp, shl.nb_quant)
CPU times: user 736 ms, sys: 8.89 ms, total: 744 ms
Wall time: 751 ms
In [42]:
P_cum = get_P_cum(sparse_code_mp, nb_quant=shl.nb_quant, C=C, do_sym=do_sym)
from shl_scripts.shl_tools import plot_P_cum
fig, ax = plot_P_cum(P_cum, verbose=False, alpha=.05)
ax.set_ylim(0.92, 1.01);

COMP : optimizing the z_score function

but the z_score function is used at each iteration of comp and should be optimized:

In [10]:
corr = (data_test @ dico_partial_learning.dictionary.T)[0, :]
print('correlation=', corr)
print('transformed correlation=', prior(corr, C=C, do_sym=do_sym))
correlation= [0.473 0.262 -0.179 0.0131 -0.411 -0.156 -0.72 1.04 -0.193 -0.0292 -0.133
 0.508 -0.0219 -0.218 1.32 0.203 0.0923 0.155 0.322 -0.101 -0.128 0.256
 0.478 -0.116 0.0748 0.255 0.739 0.733 0.195 0.0246 0.316 0.517 -0.107
 -0.0646 0.35 0.549 0.243 -0.116 -0.207 0.0741 0.775 0.332 -0.101 0.22
 -0.447 1.31 -1.72 -0.0894 -0.092 0.00393 -0.255 0.504 -0.315 0.3 -0.444
 -0.327 0.0701 0.187 -0.898 -0.443 0.355 -0.698 -0.48 0.386 -0.259 0.138
 0.334 -0.146 0.119 0.951 0.429 -0.379 -0.441 0.253 0.342 -0.169 -0.467
 -0.203 -1.11 0.567 -0.055 -0.601 0.428 0.153 -0.428 0.712 0.528 0.136
 -0.332 -0.603 0.0228 0.0413 -0.956 -0.477 -0.212 0.382 0.486 0.16 -0.271
 -0.278 0.784 0.321 -0.658 0.214 -0.0523 0.311 0.06 -0.148 0.151 0.32 0.491
 0.16 0.831 -0.992 -0.187 0.0966 -0.125 -0.685 -0.625 0.0762 0.00117 0.556
 1.11 -0.737 0.425 0.42 0.118 -0.987 -0.119 1.26 -0.613 0.944 0.269 0.2
 0.107 0.054 0.721 -0.682 0.489 0.488 -0.195 -0.717 0.726 0.152 -0.153
 -0.0952 0.427 -1.06 0.177 -0.347 0.112 0.473 -0.00344 -0.426 0.675 -0.318
 0.218 0.0969 -0.377 -0.891 -0.16 -0.0186 0.0333 -0.299 -0.893 -0.181 1.99
 0.626 -0.292 0.136 0.36 0.471 0.147 0.24 0.533 0.493 0.244 0.189 -0.0662
 -0.605 0.591 -0.123 -0.299 0.295 0.0535 0.179 0.22 -0.288 -0.523 1.05 0.47
 0.238 -0.221 -0.0599 -0.83 0.527 0.301 -0.111 0.0515 0.694 -0.317 -0.014
 0.0442 0.0848 -0.481 0.232 0.697 0.0962 0.25 -0.112 0.178 0.295 -0.383
 0.0509 -0.276 -0.191 0.391 -0.0278 0.645 -0.176 -0.123 0.124 0.0391 0.112
 -0.263 -0.45 0.298 -0.509 -0.452 0.567 -0.489 -1.02 -0.24 -0.235 0.148
 0.426 -0.273 -0.35 -0.143 -0.57 -0.426 -0.157 -0.332 0.717 -0.628 0.0103
 -0.825 0.24 -0.997 0.339 -0.569 -0.649 0.244 -0.365 0.443 -0.882 0.00136
 0.498 -0.482 -0.335 0.756 -0.448 0.469 1.13 -0.164 -0.108 -0.0952 0.743
 -0.149 -0.706 -0.397 -0.0483 -0.678 0.274 -0.658 -0.345 0.355 0.0348
 -0.617 -0.109 1.05 0.574 -0.55 0.415 -0.201 -0.907 -1.17 -1.12 -1.46
 -0.785 0.0877 -1.08 0.223 0.0791 0.0797 -0.28 0.264 0.962 -0.24 -0.484
 -0.0307 0.805 0.467 -0.218 0.0526 0.102 -0.449 -0.252 -0.115 0.575 0.722
 0.147 -0.0878 0.372 -0.315 -0.321 0.0828 0.322 -0.414 -0.547 -0.134 1.04
 0.208 -0.852]
transformed correlation= [0.112 0.0635 -0 0.00326 -0 -0 -0 0.23 -0 -0 -0 0.119 -0 -0 0.282 0.0495
 0.0228 0.0379 0.0772 -0 -0 0.0619 0.113 -0 0.0185 0.0618 0.169 0.167
 0.0477 0.00613 0.0761 0.121 -0 -0 0.0837 0.128 0.0589 -0 -0 0.0183 0.176
 0.0796 -0 0.0534 -0 0.279 -0 -0 -0 0.000982 -0 0.118 -0 0.0722 -0 -0
 0.0174 0.0458 -0 -0 0.0848 -0 -0 0.0921 -0 0.0339 0.0801 -0 0.0292 0.212
 0.102 -0 -0 0.0613 0.082 -0 -0 -0 -0 0.132 -0 -0 0.102 0.0375 -0 0.163
 0.124 0.0334 -0 -0 0.00568 0.0103 -0 -0 -0 0.0911 0.114 0.0393 -0 -0 0.178
 0.0771 -0 0.052 -0 0.0748 0.0149 -0 0.0369 0.0768 0.116 0.0391 0.188 -0 -0
 0.0239 -0 -0 -0 0.0189 0.000292 0.13 0.241 -0 0.101 0.0997 0.029 -0 -0
 0.27 -0 0.21 0.065 0.0487 0.0263 0.0134 0.165 -0 0.115 0.115 -0 -0 0.166
 0.0373 -0 -0 0.101 -0 0.0434 -0 0.0277 0.112 -0 -0 0.155 -0 0.053 0.0239
 -0 -0 -0 -0 0.00829 -0 -0 -0 0.391 0.145 -0 0.0335 0.086 0.111 0.0361
 0.0583 0.125 0.116 0.0592 0.0461 -0 -0 0.137 -0 -0 0.0712 0.0133 0.0438
 0.0534 -0 -0 0.231 0.111 0.0577 -0 -0 -0 0.123 0.0724 -0 0.0128 0.159 -0
 -0 0.011 0.021 -0 0.0564 0.16 0.0238 0.0605 -0 0.0436 0.0711 -0 0.0127 -0
 -0 0.0931 -0 0.149 -0 -0 0.0304 0.00973 0.0275 -0 -0 0.0718 -0 -0 0.132 -0
 -0 -0 -0 0.0364 0.101 -0 -0 -0 -0 -0 -0 -0 0.164 -0 0.00257 -0 0.0583 -0
 0.0812 -0 -0 0.0592 -0 0.105 -0 0.000341 0.117 -0 -0 0.172 -0 0.111 0.245
 -0 -0 -0 0.17 -0 -0 -0 -0 -0 0.0662 -0 -0 0.0849 0.00866 -0 -0 0.23 0.134
 -0 0.0986 -0 -0 -0 -0 -0 -0 0.0217 -0 0.0542 0.0196 0.0197 -0 0.0639 0.214
 -0 -0 -0 0.182 0.11 -0 0.0131 0.0252 -0 -0 -0 0.134 0.165 0.0361 -0 0.0888
 -0 -0 0.0205 0.0774 -0 -0 -0 0.229 0.0506 -0]
In [11]:
%%time
code_bins = np.linspace(0., 1, shl.nb_quant, endpoint=True)
def z_score(P_cum, c):
    z_res = np.zeros_like(c)
    for i in range(P_cum.shape[0]):
        z_res[i] = np.interp(c[i], code_bins, P_cum[i, :])
    return z_res

z_vanilla = z_score(P_cum, prior(corr, C=C, do_sym=do_sym))
CPU times: user 2.6 ms, sys: 913 µs, total: 3.51 ms
Wall time: 3.29 ms
In [12]:
%%timeit
z = z_score(P_cum, prior(corr, C=C, do_sym=do_sym))
2.98 ms ± 497 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
In [13]:
print('Z-scores=', z_vanilla)
Z-scores= [0.966 0.952 0 0.793 0 0 0 0.969 0 0 0 0.963 0 0 0.974 0.959 0.963 0.965
 0.965 0 0 0.968 0.969 0 0.956 0.964 0.973 0.972 0.961 0.952 0.955 0.965 0
 0 0.966 0.968 0.953 0 0 0.95 0.973 0.96 0 0.969 0 0.972 0 0 0 0.233 0
 0.968 0 0.966 0 0 0.947 0.958 0 0 0.961 0 0 0.959 0 0.961 0.967 0 0.966
 0.972 0.976 0 0 0.956 0.962 0 0 0 0 0.974 0 0 0.962 0.967 0 0.969 0.975
 0.966 0 0 0.966 0.961 0 0 0 0.96 0.969 0.959 0 0 0.963 0.969 0 0.966 0
 0.971 0.972 0 0.961 0.967 0.963 0.97 0.968 0 0 0.957 0 0 0 0.958 0.0716
 0.969 0.972 0 0.965 0.964 0.951 0 0 0.967 0 0.969 0.958 0.971 0.969 0.957
 0.973 0 0.965 0.962 0 0 0.971 0.958 0 0 0.964 0 0.959 0 0.961 0.969 0 0
 0.97 0 0.958 0.963 0 0 0 0 0.949 0 0 0 0.983 0.968 0 0.957 0.964 0.957
 0.971 0.968 0.956 0.968 0.957 0.959 0 0 0.969 0 0 0.965 0.932 0.964 0.953
 0 0 0.975 0.958 0.962 0 0 0 0.97 0.961 0 0.962 0.971 0 0 0.954 0.953 0
 0.962 0.965 0.958 0.963 0 0.964 0.949 0 0.955 0 0 0.962 0 0.965 0 0 0.959
 0.955 0.957 0 0 0.968 0 0 0.968 0 0 0 0 0.956 0.962 0 0 0 0 0 0 0 0.965 0
 0.606 0 0.967 0 0.966 0 0 0.95 0 0.966 0 0.0838 0.965 0 0 0.975 0 0.962
 0.966 0 0 0 0.97 0 0 0 0 0 0.967 0 0 0.968 0.962 0 0 0.969 0.963 0 0.972 0
 0 0 0 0 0 0.955 0 0.957 0.959 0.942 0 0.956 0.968 0 0 0 0.968 0.955 0
 0.961 0.966 0 0 0 0.957 0.968 0.964 0 0.963 0 0 0.95 0.962 0 0 0 0.967
 0.967 0]

une interpolation linéaire rapide car c'est sur des points linéaires (les bins):

In [14]:
%%time
def z_score(P_cum, c):
    z_res = np.zeros_like(c)
    for i in range(P_cum.shape[0]):
        z_res[i] = P_cum[i, int(c[i]*shl.nb_quant)]
    return z_res


z_ind = z_score(P_cum, prior(corr, C=C, do_sym=do_sym))
CPU times: user 499 µs, sys: 36 µs, total: 535 µs
Wall time: 529 µs
In [15]:
%%timeit
z = z_score(P_cum, prior(corr, C=C, do_sym=do_sym))
421 µs ± 41 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
In [16]:
print('Z-scores=', z_ind)
Z-scores= [0.966 0.952 0 0 0 0 0 0.969 0 0 0 0.963 0 0 0.974 0.959 0.963 0.965 0.965
 0 0 0.968 0.969 0 0.955 0.963 0.973 0.972 0.961 0.952 0.955 0.965 0 0
 0.966 0.968 0.953 0 0 0.945 0.973 0.959 0 0.969 0 0.972 0 0 0 0 0 0.968 0
 0.966 0 0 0.947 0.958 0 0 0.961 0 0 0.959 0 0.961 0.967 0 0.966 0.972
 0.976 0 0 0.956 0.962 0 0 0 0 0.974 0 0 0.962 0.967 0 0.968 0.975 0.965 0
 0 0.966 0.961 0 0 0 0.96 0.969 0.959 0 0 0.963 0.969 0 0.966 0 0.971 0.971
 0 0.961 0.967 0.963 0.97 0.969 0 0 0.957 0 0 0 0.957 0 0.969 0.972 0 0.965
 0.964 0.951 0 0 0.967 0 0.969 0.958 0.971 0.969 0.957 0.973 0 0.965 0.962
 0 0 0.971 0.958 0 0 0.964 0 0.958 0 0.961 0.969 0 0 0.97 0 0.958 0.963 0 0
 0 0 0.949 0 0 0 0.983 0.968 0 0.957 0.964 0.957 0.971 0.968 0.955 0.968
 0.957 0.959 0 0 0.969 0 0 0.965 0.932 0.963 0.953 0 0 0.975 0.958 0.962 0
 0 0 0.97 0.961 0 0.962 0.971 0 0 0.954 0.953 0 0.962 0.964 0.958 0.963 0
 0.964 0.948 0 0.955 0 0 0.962 0 0.965 0 0 0.959 0.955 0.957 0 0 0.968 0 0
 0.968 0 0 0 0 0.956 0.962 0 0 0 0 0 0 0 0.965 0 0 0 0.966 0 0.966 0 0 0.95
 0 0.966 0 0 0.965 0 0 0.975 0 0.962 0.966 0 0 0 0.97 0 0 0 0 0 0.967 0 0
 0.968 0.962 0 0 0.969 0.963 0 0.972 0 0 0 0 0 0 0.955 0 0.956 0.959 0.942
 0 0.956 0.968 0 0 0 0.968 0.954 0 0.961 0.966 0 0 0 0.957 0.968 0.964 0
 0.963 0 0 0.95 0.962 0 0 0 0.967 0.967 0]
In [17]:
print('différence=', (z_vanilla-z_ind).std() )
différence= 0.0570146945833

La différence est normale car on est passé d'une interpolation linéaire à une interpolation "nearest neighbor". On va approx 8 fois plus vite...

Procédons à la vectorisation de cette fonction:

In [18]:
(prior(corr, C=C, do_sym=do_sym)*shl.nb_quant).astype(np.int)
Out[18]:
array([ 28,  16,   0,   0,   0,   0,   0,  58,   0,   0,   0,  30,   0,
         0,  72,  12,   5,   9,  19,   0,   0,  15,  28,   0,   4,  15,
        43,  42,  12,   1,  19,  31,   0,   0,  21,  32,  15,   0,   0,
         4,  45,  20,   0,  13,   0,  71,   0,   0,   0,   0,   0,  30,
         0,  18,   0,   0,   4,  11,   0,   0,  21,   0,   0,  23,   0,
         8,  20,   0,   7,  54,  26,   0,   0,  15,  20,   0,   0,   0,
         0,  33,   0,   0,  25,   9,   0,  41,  31,   8,   0,   0,   1,
         2,   0,   0,   0,  23,  29,  10,   0,   0,  45,  19,   0,  13,
         0,  19,   3,   0,   9,  19,  29,  10,  48,   0,   0,   6,   0,
         0,   0,   4,   0,  33,  61,   0,  25,  25,   7,   0,   0,  69,
         0,  53,  16,  12,   6,   3,  42,   0,  29,  29,   0,   0,  42,
         9,   0,   0,  25,   0,  11,   0,   7,  28,   0,   0,  39,   0,
        13,   6,   0,   0,   0,   0,   2,   0,   0,   0, 100,  37,   0,
         8,  22,  28,   9,  14,  31,  29,  15,  11,   0,   0,  35,   0,
         0,  18,   3,  11,  13,   0,   0,  59,  28,  14,   0,   0,   0,
        31,  18,   0,   3,  40,   0,   0,   2,   5,   0,  14,  40,   6,
        15,   0,  11,  18,   0,   3,   0,   0,  23,   0,  38,   0,   0,
         7,   2,   7,   0,   0,  18,   0,   0,  33,   0,   0,   0,   0,
         9,  25,   0,   0,   0,   0,   0,   0,   0,  42,   0,   0,   0,
        14,   0,  20,   0,   0,  15,   0,  26,   0,   0,  29,   0,   0,
        44,   0,  28,  62,   0,   0,   0,  43,   0,   0,   0,   0,   0,
        16,   0,   0,  21,   2,   0,   0,  58,  34,   0,  25,   0,   0,
         0,   0,   0,   0,   5,   0,  13,   5,   5,   0,  16,  54,   0,
         0,   0,  46,  28,   0,   3,   6,   0,   0,   0,  34,  42,   9,
         0,  22,   0,   0,   5,  19,   0,   0,   0,  58,  12,   0])
help(np.unravel_index)np.unravel_index(range(nb_filter), v.astype(np.int))
In [19]:
print(P_cum[0, :], P_cum.ravel()[:shl.nb_quant])
[0 0.958 0.959 0.959 0.959 0.96 0.96 0.96 0.961 0.961 0.962 0.962 0.962
 0.962 0.962 0.963 0.963 0.963 0.964 0.964 0.964 0.964 0.965 0.965 0.965
 0.966 0.966 0.966 0.966 0.966 0.966 0.967 0.967 0.967 0.967 0.967 0.967
 0.968 0.968 0.968 0.968 0.968 0.969 0.969 0.969 0.969 0.969 0.969 0.97
 0.97 0.97 0.97 0.971 0.971 0.971 0.971 0.971 0.971 0.971 0.972 0.972 0.972
 0.972 0.973 0.973 0.973 0.973 0.973 0.974 0.974 0.974 0.974 0.974 0.974
 0.975 0.975 0.975 0.975 0.976 0.976 0.976 0.976 0.976 0.977 0.977 0.977
 0.977 0.978 0.978 0.978 0.978 0.978 0.979 0.979 0.979 0.979 0.979 0.98
 0.98 0.98 0.98 0.981 0.981 0.981 0.981 0.982 0.982 0.982 0.982 0.982 0.983
 0.983 0.983 0.983 0.983 0.984 0.984 0.984 0.985 0.985 0.985 0.985 0.985
 0.986 0.986 0.986 0.986 0.986 0.986 0.987 0.987 0.987 0.987 0.987 0.988
 0.988 0.988 0.988 0.989 0.989 0.989 0.989 0.989 0.99 0.99 0.99 0.99 0.99
 0.99 0.991 0.991 0.991 0.991 0.991 0.991 0.991 0.992 0.992 0.992 0.992
 0.993 0.993 0.993 0.993 0.993 0.993 0.994 0.994 0.994 0.994 0.994 0.994
 0.994 0.994 0.995 0.995 0.995 0.995 0.995 0.995 0.995 0.996 0.996 0.996
 0.996 0.996 0.996 0.996 0.996 0.997 0.997 0.997 0.997 0.997 0.997 0.997
 0.997 0.997 0.997 0.998 0.998 0.998 0.998 0.998 0.998 0.998 0.998 0.998
 0.998 0.998 0.998 0.998 0.998 0.998 0.998 0.999 0.999 0.999 0.999 0.999
 0.999 0.999 0.999 0.999 0.999 0.999 0.999 0.999 0.999 0.999 0.999 0.999
 0.999 0.999 0.999 0.999 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1] [0 0.958 0.959 0.959 0.959 0.96 0.96 0.96 0.961 0.961 0.962 0.962 0.962
 0.962 0.962 0.963 0.963 0.963 0.964 0.964 0.964 0.964 0.965 0.965 0.965
 0.966 0.966 0.966 0.966 0.966 0.966 0.967 0.967 0.967 0.967 0.967 0.967
 0.968 0.968 0.968 0.968 0.968 0.969 0.969 0.969 0.969 0.969 0.969 0.97
 0.97 0.97 0.97 0.971 0.971 0.971 0.971 0.971 0.971 0.971 0.972 0.972 0.972
 0.972 0.973 0.973 0.973 0.973 0.973 0.974 0.974 0.974 0.974 0.974 0.974
 0.975 0.975 0.975 0.975 0.976 0.976 0.976 0.976 0.976 0.977 0.977 0.977
 0.977 0.978 0.978 0.978 0.978 0.978 0.979 0.979 0.979 0.979 0.979 0.98
 0.98 0.98 0.98 0.981 0.981 0.981 0.981 0.982 0.982 0.982 0.982 0.982 0.983
 0.983 0.983 0.983 0.983 0.984 0.984 0.984 0.985 0.985 0.985 0.985 0.985
 0.986 0.986 0.986 0.986 0.986 0.986 0.987 0.987 0.987 0.987 0.987 0.988
 0.988 0.988 0.988 0.989 0.989 0.989 0.989 0.989 0.99 0.99 0.99 0.99 0.99
 0.99 0.991 0.991 0.991 0.991 0.991 0.991 0.991 0.992 0.992 0.992 0.992
 0.993 0.993 0.993 0.993 0.993 0.993 0.994 0.994 0.994 0.994 0.994 0.994
 0.994 0.994 0.995 0.995 0.995 0.995 0.995 0.995 0.995 0.996 0.996 0.996
 0.996 0.996 0.996 0.996 0.996 0.997 0.997 0.997 0.997 0.997 0.997 0.997
 0.997 0.997 0.997 0.998 0.998 0.998 0.998 0.998 0.998 0.998 0.998 0.998
 0.998 0.998 0.998 0.998 0.998 0.998 0.998 0.999 0.999 0.999 0.999 0.999
 0.999 0.999 0.999 0.999 0.999 0.999 0.999 0.999 0.999 0.999 0.999 0.999
 0.999 0.999 0.999 0.999 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]

adding the index to scan the P_cum matrix

In [20]:
np.arange(P_cum.shape[0])*shl.nb_quant
Out[20]:
array([    0,   256,   512,   768,  1024,  1280,  1536,  1792,  2048,
        2304,  2560,  2816,  3072,  3328,  3584,  3840,  4096,  4352,
        4608,  4864,  5120,  5376,  5632,  5888,  6144,  6400,  6656,
        6912,  7168,  7424,  7680,  7936,  8192,  8448,  8704,  8960,
        9216,  9472,  9728,  9984, 10240, 10496, 10752, 11008, 11264,
       11520, 11776, 12032, 12288, 12544, 12800, 13056, 13312, 13568,
       13824, 14080, 14336, 14592, 14848, 15104, 15360, 15616, 15872,
       16128, 16384, 16640, 16896, 17152, 17408, 17664, 17920, 18176,
       18432, 18688, 18944, 19200, 19456, 19712, 19968, 20224, 20480,
       20736, 20992, 21248, 21504, 21760, 22016, 22272, 22528, 22784,
       23040, 23296, 23552, 23808, 24064, 24320, 24576, 24832, 25088,
       25344, 25600, 25856, 26112, 26368, 26624, 26880, 27136, 27392,
       27648, 27904, 28160, 28416, 28672, 28928, 29184, 29440, 29696,
       29952, 30208, 30464, 30720, 30976, 31232, 31488, 31744, 32000,
       32256, 32512, 32768, 33024, 33280, 33536, 33792, 34048, 34304,
       34560, 34816, 35072, 35328, 35584, 35840, 36096, 36352, 36608,
       36864, 37120, 37376, 37632, 37888, 38144, 38400, 38656, 38912,
       39168, 39424, 39680, 39936, 40192, 40448, 40704, 40960, 41216,
       41472, 41728, 41984, 42240, 42496, 42752, 43008, 43264, 43520,
       43776, 44032, 44288, 44544, 44800, 45056, 45312, 45568, 45824,
       46080, 46336, 46592, 46848, 47104, 47360, 47616, 47872, 48128,
       48384, 48640, 48896, 49152, 49408, 49664, 49920, 50176, 50432,
       50688, 50944, 51200, 51456, 51712, 51968, 52224, 52480, 52736,
       52992, 53248, 53504, 53760, 54016, 54272, 54528, 54784, 55040,
       55296, 55552, 55808, 56064, 56320, 56576, 56832, 57088, 57344,
       57600, 57856, 58112, 58368, 58624, 58880, 59136, 59392, 59648,
       59904, 60160, 60416, 60672, 60928, 61184, 61440, 61696, 61952,
       62208, 62464, 62720, 62976, 63232, 63488, 63744, 64000, 64256,
       64512, 64768, 65024, 65280, 65536, 65792, 66048, 66304, 66560,
       66816, 67072, 67328, 67584, 67840, 68096, 68352, 68608, 68864,
       69120, 69376, 69632, 69888, 70144, 70400, 70656, 70912, 71168,
       71424, 71680, 71936, 72192, 72448, 72704, 72960, 73216, 73472,
       73728, 73984, 74240, 74496, 74752, 75008, 75264, 75520, 75776,
       76032, 76288, 76544, 76800, 77056, 77312, 77568, 77824, 78080,
       78336, 78592, 78848, 79104, 79360, 79616, 79872, 80128, 80384,
       80640, 80896, 81152, 81408, 81664, 81920, 82176, 82432, 82688])

such that the indices int the raveled matrix is:

In [21]:
(prior(corr, C=C, do_sym=do_sym)*shl.nb_quant).astype(np.int) + np.arange(P_cum.shape[0])*shl.nb_quant
Out[21]:
array([   28,   272,   512,   768,  1024,  1280,  1536,  1850,  2048,
        2304,  2560,  2846,  3072,  3328,  3656,  3852,  4101,  4361,
        4627,  4864,  5120,  5391,  5660,  5888,  6148,  6415,  6699,
        6954,  7180,  7425,  7699,  7967,  8192,  8448,  8725,  8992,
        9231,  9472,  9728,  9988, 10285, 10516, 10752, 11021, 11264,
       11591, 11776, 12032, 12288, 12544, 12800, 13086, 13312, 13586,
       13824, 14080, 14340, 14603, 14848, 15104, 15381, 15616, 15872,
       16151, 16384, 16648, 16916, 17152, 17415, 17718, 17946, 18176,
       18432, 18703, 18964, 19200, 19456, 19712, 19968, 20257, 20480,
       20736, 21017, 21257, 21504, 21801, 22047, 22280, 22528, 22784,
       23041, 23298, 23552, 23808, 24064, 24343, 24605, 24842, 25088,
       25344, 25645, 25875, 26112, 26381, 26624, 26899, 27139, 27392,
       27657, 27923, 28189, 28426, 28720, 28928, 29184, 29446, 29696,
       29952, 30208, 30468, 30720, 31009, 31293, 31488, 31769, 32025,
       32263, 32512, 32768, 33093, 33280, 33589, 33808, 34060, 34310,
       34563, 34858, 35072, 35357, 35613, 35840, 36096, 36394, 36617,
       36864, 37120, 37401, 37632, 37899, 38144, 38407, 38684, 38912,
       39168, 39463, 39680, 39949, 40198, 40448, 40704, 40960, 41216,
       41474, 41728, 41984, 42240, 42596, 42789, 43008, 43272, 43542,
       43804, 44041, 44302, 44575, 44829, 45071, 45323, 45568, 45824,
       46115, 46336, 46592, 46866, 47107, 47371, 47629, 47872, 48128,
       48443, 48668, 48910, 49152, 49408, 49664, 49951, 50194, 50432,
       50691, 50984, 51200, 51456, 51714, 51973, 52224, 52494, 52776,
       52998, 53263, 53504, 53771, 54034, 54272, 54531, 54784, 55040,
       55319, 55552, 55846, 56064, 56320, 56583, 56834, 57095, 57344,
       57600, 57874, 58112, 58368, 58657, 58880, 59136, 59392, 59648,
       59913, 60185, 60416, 60672, 60928, 61184, 61440, 61696, 61952,
       62250, 62464, 62720, 62976, 63246, 63488, 63764, 64000, 64256,
       64527, 64768, 65050, 65280, 65536, 65821, 66048, 66304, 66604,
       66816, 67100, 67390, 67584, 67840, 68096, 68395, 68608, 68864,
       69120, 69376, 69632, 69904, 70144, 70400, 70677, 70914, 71168,
       71424, 71738, 71970, 72192, 72473, 72704, 72960, 73216, 73472,
       73728, 73984, 74245, 74496, 74765, 75013, 75269, 75520, 75792,
       76086, 76288, 76544, 76800, 77102, 77340, 77568, 77827, 78086,
       78336, 78592, 78848, 79138, 79402, 79625, 79872, 80150, 80384,
       80640, 80901, 81171, 81408, 81664, 81920, 82234, 82444, 82688])

Par contre, il faut faire attention au cas où on atteint $p_c=1$ pour lesquels on selectionne le filtre juste apres:

In [22]:
p_c = np.ones_like(corr)
print('Index avant correction', (p_c*shl.nb_quant ).astype(np.int) + np.arange(P_cum.shape[0])*shl.nb_quant)
print('Index après correction', (p_c*shl.nb_quant - (p_c==1)).astype(np.int) + np.arange(P_cum.shape[0])*shl.nb_quant)
Index avant correction [  256   512   768  1024  1280  1536  1792  2048  2304  2560  2816  3072
  3328  3584  3840  4096  4352  4608  4864  5120  5376  5632  5888  6144
  6400  6656  6912  7168  7424  7680  7936  8192  8448  8704  8960  9216
  9472  9728  9984 10240 10496 10752 11008 11264 11520 11776 12032 12288
 12544 12800 13056 13312 13568 13824 14080 14336 14592 14848 15104 15360
 15616 15872 16128 16384 16640 16896 17152 17408 17664 17920 18176 18432
 18688 18944 19200 19456 19712 19968 20224 20480 20736 20992 21248 21504
 21760 22016 22272 22528 22784 23040 23296 23552 23808 24064 24320 24576
 24832 25088 25344 25600 25856 26112 26368 26624 26880 27136 27392 27648
 27904 28160 28416 28672 28928 29184 29440 29696 29952 30208 30464 30720
 30976 31232 31488 31744 32000 32256 32512 32768 33024 33280 33536 33792
 34048 34304 34560 34816 35072 35328 35584 35840 36096 36352 36608 36864
 37120 37376 37632 37888 38144 38400 38656 38912 39168 39424 39680 39936
 40192 40448 40704 40960 41216 41472 41728 41984 42240 42496 42752 43008
 43264 43520 43776 44032 44288 44544 44800 45056 45312 45568 45824 46080
 46336 46592 46848 47104 47360 47616 47872 48128 48384 48640 48896 49152
 49408 49664 49920 50176 50432 50688 50944 51200 51456 51712 51968 52224
 52480 52736 52992 53248 53504 53760 54016 54272 54528 54784 55040 55296
 55552 55808 56064 56320 56576 56832 57088 57344 57600 57856 58112 58368
 58624 58880 59136 59392 59648 59904 60160 60416 60672 60928 61184 61440
 61696 61952 62208 62464 62720 62976 63232 63488 63744 64000 64256 64512
 64768 65024 65280 65536 65792 66048 66304 66560 66816 67072 67328 67584
 67840 68096 68352 68608 68864 69120 69376 69632 69888 70144 70400 70656
 70912 71168 71424 71680 71936 72192 72448 72704 72960 73216 73472 73728
 73984 74240 74496 74752 75008 75264 75520 75776 76032 76288 76544 76800
 77056 77312 77568 77824 78080 78336 78592 78848 79104 79360 79616 79872
 80128 80384 80640 80896 81152 81408 81664 81920 82176 82432 82688 82944]
Index après correction [  255   511   767  1023  1279  1535  1791  2047  2303  2559  2815  3071
  3327  3583  3839  4095  4351  4607  4863  5119  5375  5631  5887  6143
  6399  6655  6911  7167  7423  7679  7935  8191  8447  8703  8959  9215
  9471  9727  9983 10239 10495 10751 11007 11263 11519 11775 12031 12287
 12543 12799 13055 13311 13567 13823 14079 14335 14591 14847 15103 15359
 15615 15871 16127 16383 16639 16895 17151 17407 17663 17919 18175 18431
 18687 18943 19199 19455 19711 19967 20223 20479 20735 20991 21247 21503
 21759 22015 22271 22527 22783 23039 23295 23551 23807 24063 24319 24575
 24831 25087 25343 25599 25855 26111 26367 26623 26879 27135 27391 27647
 27903 28159 28415 28671 28927 29183 29439 29695 29951 30207 30463 30719
 30975 31231 31487 31743 31999 32255 32511 32767 33023 33279 33535 33791
 34047 34303 34559 34815 35071 35327 35583 35839 36095 36351 36607 36863
 37119 37375 37631 37887 38143 38399 38655 38911 39167 39423 39679 39935
 40191 40447 40703 40959 41215 41471 41727 41983 42239 42495 42751 43007
 43263 43519 43775 44031 44287 44543 44799 45055 45311 45567 45823 46079
 46335 46591 46847 47103 47359 47615 47871 48127 48383 48639 48895 49151
 49407 49663 49919 50175 50431 50687 50943 51199 51455 51711 51967 52223
 52479 52735 52991 53247 53503 53759 54015 54271 54527 54783 55039 55295
 55551 55807 56063 56319 56575 56831 57087 57343 57599 57855 58111 58367
 58623 58879 59135 59391 59647 59903 60159 60415 60671 60927 61183 61439
 61695 61951 62207 62463 62719 62975 63231 63487 63743 63999 64255 64511
 64767 65023 65279 65535 65791 66047 66303 66559 66815 67071 67327 67583
 67839 68095 68351 68607 68863 69119 69375 69631 69887 70143 70399 70655
 70911 71167 71423 71679 71935 72191 72447 72703 72959 73215 73471 73727
 73983 74239 74495 74751 75007 75263 75519 75775 76031 76287 76543 76799
 77055 77311 77567 77823 78079 78335 78591 78847 79103 79359 79615 79871
 80127 80383 80639 80895 81151 81407 81663 81919 82175 82431 82687 82943]
In [23]:
p_c = prior(corr, C=C, do_sym=do_sym)
In [24]:
%%time
def z_score(P_cum, p_c):
    return P_cum.ravel()[(p_c*shl.nb_quant - (p_c==1)).astype(np.int) + np.arange(P_cum.shape[0])*shl.nb_quant]

z_vec = z_score(P_cum, p_c)
CPU times: user 78 µs, sys: 42 µs, total: 120 µs
Wall time: 98 µs
In [25]:
%%timeit
z = z_score(P_cum, p_c)
23.3 µs ± 2.09 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

on gratte encore un peu:

In [26]:
stick = np.arange(P_cum.shape[0])*shl.nb_quant
In [27]:
print("shape of stick is ", stick.shape)
print("shape of vector ", (p_c*P_cum.shape[1]).astype(np.int).shape)
shape of stick is  (324,)
shape of vector  (324,)
In [28]:
%%time
def z_score(P_cum, p_c, stick):
    return P_cum.ravel()[(p_c*shl.nb_quant - (p_c==1)).astype(np.int) + stick]

z_vec = z_score(P_cum, p_c, stick)
CPU times: user 62 µs, sys: 19 µs, total: 81 µs
Wall time: 72 µs
In [29]:
%%timeit
z = z_score(P_cum, p_c, stick)
17.7 µs ± 2.5 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)
In [30]:
print('différence=', (z_vanilla-z_ind).std() )
différence= 0.0570146945833
In [31]:
print('différence=', (z_vec-z_ind).std() )
différence= 0.0

On va approx $10\times 8$ fois plus vite grace à la vectorisation.

Note: Par contre, il faut faire attention au cas où on atteint $p_c=1$ ou $p_c=0$ pour lesquels on selectionne le filtre juste apres:

In [32]:
from shl_scripts.shl_encode import z_score, prior

stick = np.arange(shl.n_dictionary)*shl.nb_quant
print('Value for ones = ', z_score(P_cum, prior(np.inf*np.ones(shl.n_dictionary), C=C), stick))
print('Value for zeros = ', z_score(P_cum, prior(np.zeros(shl.n_dictionary), C=C), stick))
Value for ones =  [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
Value for zeros =  [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]

COMP : using modulations

let's use this new z_score function

In [33]:
l0_sparseness = shl.l0_sparseness
def comp(data, dico, P_cum, C=C, do_sym=do_sym, verbose=0):
    if verbose!=0: t0 = time.time()
    n_samples, n_dictionary = data.shape[0], dico.shape[0]
    sparse_code = np.zeros((n_samples, n_dictionary))
    corr = (data @ dico.T)
    Xcorr = (dico @ dico.T)
    nb_quant = P_cum.shape[1]
    stick = np.arange(n_dictionary)*nb_quant
    
    for i_sample in range(n_samples):
        c = corr[i_sample, :].copy()
        if verbose!=0: ind_list=list()
        for i_l0 in range(int(l0_sparseness)):
            zi = z_score(P_cum, prior(c, C=C, do_sym=do_sym), stick)
            ind  = np.argmax(zi)
            if verbose!=0: ind_list.append(ind)

            c_ind = c[ind] / Xcorr[ind, ind]
            sparse_code[i_sample, ind] += c_ind
            c -= c_ind * Xcorr[ind, :]

        if verbose!=0 and i_sample in range(2):
            zi = z_score(P_cum, prior(c, C=C, do_sym=do_sym), stick)
            print(ind_list, [zi[i] for i in ind_list], np.median(zi), zi.max(), [c[i] for i in ind_list], c.min(), c.max())
    if verbose!=0:
        duration = time.time()-t0
        print('coding duration : {0}s'.format(duration))
    return sparse_code

#sparse_code = comp(data_test, dico_partial_learning.dictionary, code_bins, P_cum, verbose=1)
In [34]:
%%timeit
sparse_code = comp(data_test, dico_partial_learning.dictionary, P_cum, C=C, do_sym=do_sym, verbose=0)
28.5 s ± 1.62 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [35]:
def plot_scatter_MpVsComp(sparse_vector, my_sparse_code):
    fig = plt.figure(figsize=(16, 16))
    ax = fig.add_subplot(111)
    a_min = np.min((sparse_vector.min(), my_sparse_code.min()))
    a_max = np.max((sparse_vector.max(), my_sparse_code.max()))
    ax.plot(np.array([a_min, a_max]), np.array([a_min, a_max]), 'k--', lw=2)
    print(sparse_vector.shape, my_sparse_code.shape)
    ax.scatter(sparse_vector.ravel(), my_sparse_code.ravel(), alpha=0.01)
    ax.set_title('MP')
    ax.set_ylabel('COMP')
    #ax.set_xlim(0)
    #ax.set_ylim(0)
    ax.axis('equal')
    return fig, ax


#fig, ax = plot_scatter_MpVsComp(sparse_code_mp, sparse_code)
#fig.show()

testing that COMP with fixed Pcum is equivalent to MP

In [36]:
print(dico_partial_learning.P_cum)
None
In [37]:
n_samples, nb_filter = sparse_code_mp.shape

P_cum = np.linspace(0, 1, nb_quant, endpoint=True)[np.newaxis, :] * np.ones((nb_filter, 1))

sparse_code_comp = comp(data_test, dico_partial_learning.dictionary, P_cum, C=C, do_sym=do_sym, verbose=1)

fig, ax = plot_proba_histogram(sparse_code_mp)
fig.show()
fig, ax = plot_proba_histogram(sparse_code_comp)

fig, ax = plot_scatter_MpVsComp(sparse_code_mp, sparse_code_comp)
[166, 122, 129, 14, 131, 199, 263, 297, 313, 69, 310, 95, 40, 167, 7] [0.0, 0.0, 0.0, 0.027450980392156862, 0.058823529411764705, 0.047058823529411764, 0.015686274509803921, 0.0, 0.011764705882352941, 0.0078431372549019607, 0.0078431372549019607, 0.0078431372549019607, 0.0, 0.0039215686274509803, 0.0] 0.00392156862745 0.152941176471 [-0.014144489199123081, -0.29764948677990238, -0.22589313762029484, 0.12454694665524793, 0.24573651026061044, 0.19922381923942906, 0.075139139142544137, -0.041528525401208057, 0.050108691526164269, 0.03632414578129424, 0.043871921218744878, 0.044485218588332848, 0.012594896004748071, 0.019488163017774635, 0.0] -0.978132955591 0.669435349149
[85, 149, 211, 310, 146, 162, 200, 236, 140, 97, 41, 216, 312, 241, 93] [0.0, 0.0, 0.0, 0.0, 0.0, 0.054901960784313725, 0.062745098039215685, 0.12549019607843137, 0.0, 0.0, 0.0, 0.12549019607843137, 0.15686274509803921, 0.0, 0.0] 0.0 0.498039215686 [-0.16016636223710271, -0.92849054720520463, -0.32991236180700345, -0.63060880676904096, -0.76948077001650084, 0.24106229657345207, 0.26276735755475855, 0.5441138008472135, -0.11677406956091178, -0.15219006796748036, -0.025554513076037559, 0.54287319002786116, 0.68963322131096594, -0.021338539642562732, 0.0] -3.37701399069 2.75299303249
coding duration : 28.301383018493652s
(40960, 324) (40960, 324)

gradient descent

In [38]:
P_cum = np.linspace(0, 1, nb_quant, endpoint=True)[np.newaxis, :] * np.ones((nb_filter, 1))
print('Shape of modulation function', P_cum.shape)

eta_homeo = .01

for i in range(1000):
    sparse_code = comp(data_test, dico_partial_learning.dictionary, P_cum, C=C, do_sym=do_sym, verbose=0)
    P_cum_ = get_P_cum(sparse_code, nb_quant=nb_quant)
    P_cum = (1-eta_homeo) * P_cum + eta_homeo  * P_cum_
    if i % 100 == 0:
        print('Learning step', i)
        fig, ax = plot_proba_histogram(sparse_code)
Shape of modulation function (324, 256)
Learning step 0
Learning step 100
Learning step 200
Learning step 300
Learning step 400
Learning step 500
Learning step 600
Learning step 700
Learning step 800
Learning step 900
In [39]:
from shl_scripts.shl_tools import plot_P_cum
fig, ax = plot_P_cum(P_cum, verbose=False);