# -*- coding: utf-8 -*-
"""INI simulator with temporal pattern code.
@author: rbodo
"""
import tensorflow as tf
import numpy as np
from snntoolbox.simulation.target_simulators.\
INI_temporal_mean_rate_target_sim import SNN as SNN_
from snntoolbox.simulation.utils import get_layer_synaptic_operations, \
remove_name_counter
[docs]class SNN(SNN_):
"""
The compiled spiking neural network, using layers derived from
Keras base classes (see
`snntoolbox.simulation.backends.inisim.temporal_pattern`).
Aims at simulating the network on a self-implemented Integrate-and-Fire
simulator using a timestepped approach.
Attributes
----------
snn: keras.models.Model
Keras model. This is the output format of the compiled spiking model
because INI simulator runs networks of layers that are derived from
Keras layer base classes.
"""
def __init__(self, config, queue=None):
SNN_.__init__(self, config, queue)
self.num_bits = self.config.getint('conversion', 'num_bits')
[docs] def compile(self):
self.snn = tf.keras.models.Model(
self._input_images,
self._spiking_layers[self.parsed_model.layers[-1].name])
self.snn.compile('sgd', 'categorical_crossentropy', ['accuracy'])
# Tensorflow 2 lists all variables as weights, including our state
# variables (membrane potential etc). So a simple
# snn.set_weights(parsed_model.get_weights()) does not work any more.
# Need to extract the actual weights here:
parameter_map = {remove_name_counter(p.name): p for p in
self.parsed_model.weights}
count = 0
for p in self.snn.weights:
name = remove_name_counter(p.name)
if name in parameter_map:
p.assign(parameter_map[name])
count += 1
assert count == len(parameter_map), "Not all weights have been " \
"transferred from ANN to SNN."
for layer in self.snn.layers:
if hasattr(layer, 'bias'):
# Adjust biases to time resolution of simulator.
layer.bias.assign(layer.bias / self._num_timesteps)
# @tf.function
[docs] def simulate(self, **kwargs):
from snntoolbox.utils.utils import echo
input_b_l = kwargs[str('x_b_l')] * self._dt
output_b_l_t = np.zeros((self.batch_size, self.num_classes,
self._num_timesteps))
self._input_spikecount = 0
self.set_time(self._dt)
# Main step: Propagate input through network and record output spikes.
out_spikes = self.snn.predict_on_batch(input_b_l)
# Broadcast the raw output (softmax) across time axis.
output_b_l_t[:, :, :] = np.expand_dims(out_spikes, -1)
# Record neuron variables.
i = 0
for layer in self.snn.layers:
# Excludes Input, Flatten, Concatenate, etc:
if hasattr(layer, 'spikerates') and layer.spikerates is not None:
spikerates_b_l = layer.spikerates.numpy()
spiketrains_b_l_t = to_binary_numpy(spikerates_b_l,
self.num_bits)
self.set_spikerates(spikerates_b_l, i)
self.set_spiketrains(spiketrains_b_l_t, i)
if self.synaptic_operations_b_t is not None:
self.set_synaptic_operations(spiketrains_b_l_t, i)
if self.neuron_operations_b_t is not None:
self.set_neuron_operations(i)
i += 1
if 'input_b_l_t' in self._log_keys:
self.input_b_l_t[Ellipsis, 0] = input_b_l
if self.neuron_operations_b_t is not None:
self.neuron_operations_b_t[:, 0] += self.fanin[1] * \
self.num_neurons[1] * np.ones(self.batch_size) * 2
print("Current accuracy of batch:")
if self.config.getint('output', 'verbose') > 0:
guesses_b = np.argmax(np.sum(output_b_l_t, 2), 1)
echo('{:.2%}_'.format(np.mean(kwargs[str('truth_b')] ==
guesses_b)))
return np.cumsum(output_b_l_t, 2)
[docs] def load(self, path, filename):
SNN_.load(self, path, filename)
[docs] def set_spiketrains(self, spiketrains_b_l_t, i):
if self.spiketrains_n_b_l_t is not None:
self.spiketrains_n_b_l_t[i][0][:] = spiketrains_b_l_t
[docs] def set_spikerates(self, spikerates_b_l, i):
if self.spikerates_n_b_l is not None:
self.spikerates_n_b_l[i][0][:] = spikerates_b_l
[docs] def set_neuron_operations(self, i):
self.neuron_operations_b_t += self.num_neurons_with_bias[i + 1]
[docs] def set_synaptic_operations(self, spiketrains_b_l_t, i):
for t in range(self.synaptic_operations_b_t.shape[-1]):
ops = get_layer_synaptic_operations(spiketrains_b_l_t[Ellipsis, t],
self.fanout[i + 1])
self.synaptic_operations_b_t[:, t] += 2 * ops
[docs]def to_binary_numpy(x, num_bits):
"""Transform an array of floats into binary representation.
Parameters
----------
x: ndarray
Input array containing float values. The first dimension has to be of
length 1.
num_bits: int
The fixed point precision to be used when converting to binary.
Returns
-------
y: ndarray
Output array with same shape as ``x`` except that an axis is added to
the last dimension with size ``num_bits``. The binary representation of
each value in ``x`` is distributed across the last dimension of ``y``.
"""
n = 2 ** num_bits - 1
a = np.round(x * n) / n
y = np.zeros(list(x.shape) + [num_bits])
for i in range(num_bits):
p = 2 ** -(i + 1)
b = np.greater(a, p) * p
y[Ellipsis, i] = b
a -= b
return y