# -*- coding: utf-8 -*-
"""
This module bundles all the tools of the SNN conversion toolbox.
Important functions:
.. autosummary::
:nosignatures:
run_pipeline
update_setup
@author: rbodo
"""
import os
from importlib import import_module
from snntoolbox.parsing.model_libs.keras_input_lib import load
[docs]def run_pipeline(config, queue=None):
"""Convert an analog network to a spiking network and simulate it.
Complete pipeline of
1. loading and testing a pretrained ANN,
2. normalizing parameters
3. converting it to SNN,
4. running it on a simulator,
5. given a specified hyperparameter range ``params``,
repeat simulations with modified parameters.
Parameters
----------
config: configparser.ConfigParser
ConfigParser containing the user settings.
queue: Optional[Queue.Queue]
Results are added to the queue to be displayed in the GUI.
Returns
-------
results: list
List of the accuracies obtained after simulating with each parameter
value in config.get('parameter_sweep', 'param_values').
"""
from snntoolbox.datasets.utils import get_dataset
from snntoolbox.conversion.utils import normalize_parameters
num_to_test = config.getint('simulation', 'num_to_test')
# Instantiate an empty spiking network
target_sim = import_target_sim(config)
spiking_model = target_sim.SNN(config, queue)
# ___________________________ LOAD DATASET ______________________________ #
normset, testset = get_dataset(config)
results = None
parsed_model = None
if config.getboolean('tools', 'parse') and not is_stop(queue):
# __________________________ LOAD MODEL _____________________________ #
model_lib = import_module('snntoolbox.parsing.model_libs.' +
config.get('input', 'model_lib') +
'_input_lib')
input_model = model_lib.load(config.get('paths', 'path_wd'),
config.get('paths', 'filename_ann'))
# Evaluate input model.
if config.getboolean('tools', 'evaluate_ann') and not is_stop(queue):
print("Evaluating input model on {} samples...".format(
num_to_test))
acc = model_lib.evaluate(input_model['val_fn'],
config.getint('simulation', 'batch_size'),
num_to_test, **testset)
results = [acc]
# ____________________________ PARSE ________________________________ #
print("Parsing input model...")
model_parser = model_lib.ModelParser(input_model['model'], config)
model_parser.parse()
parsed_model = model_parser.build_parsed_model()
# ___________________________ NORMALIZE _____________________________ #
if config.getboolean('tools', 'normalize') and not is_stop(queue):
normalize_parameters(parsed_model, config, **normset)
# Evaluate parsed model.
if config.getboolean('tools', 'evaluate_ann') and not is_stop(queue):
print("Evaluating parsed model on {} samples...".format(
num_to_test))
score = model_parser.evaluate(
config.getint('simulation', 'batch_size'),
num_to_test, **testset)
results = [score[1]]
# Write parsed model to disk
parsed_model.save(str(
os.path.join(config.get('paths', 'path_wd'),
config.get('paths', 'filename_parsed_model') +
'.h5')))
# _____________________________ CONVERT _________________________________ #
if config.getboolean('tools', 'convert') and not is_stop(queue):
if parsed_model is None:
try:
parsed_model = load(
config.get('paths', 'path_wd'),
config.get('paths', 'filename_parsed_model'),
filepath_custom_objects=config.get(
'paths', 'filepath_custom_objects'))['model']
except FileNotFoundError:
print("Could not find parsed model {} in path {}. Consider "
"setting `parse = True` in your config file.".format(
config.get('paths', 'path_wd'),
config.get('paths', 'filename_parsed_model')))
spiking_model.build(parsed_model, **testset)
# Export network in a format specific to the simulator with which it
# will be tested later.
spiking_model.save(config.get('paths', 'path_wd'),
config.get('paths', 'filename_snn'))
# ______________________________ SIMULATE _______________________________ #
if config.getboolean('tools', 'simulate') and not is_stop(queue):
# Decorate the 'run' function of the spiking model with a parameter
# sweep function.
@run_parameter_sweep(config, queue)
def run(snn, **test_set):
return snn.run(**test_set)
# Simulate network
results = run(spiking_model, **testset)
# Clean up
spiking_model.end_sim()
# Add results to queue to be displayed in GUI.
if queue:
queue.put(results)
return results
[docs]def is_stop(queue):
"""Determine if the user pressed 'stop' in the GUI.
Parameters
----------
queue: Queue.Queue
Event queue.
Returns
-------
: bool
``True`` if user pressed 'stop' in GUI, ``False`` otherwise.
"""
if not queue:
return False
if queue.empty():
return False
elif queue.get_nowait() == 'stop':
print("Skipped step after user interrupt")
queue.put('stop')
return True
[docs]def run_parameter_sweep(config, queue):
"""
Decorator to perform a parameter sweep using the ``run_single`` function.
Need an aditional wrapping layer to be able to pass decorator arguments.
"""
def decorator(run_single):
from functools import wraps
@wraps(run_single)
def wrapper(snn, **testset):
results = []
param_values = eval(config.get('parameter_sweep', 'param_values'))
param_name = config.get('parameter_sweep', 'param_name')
param_logscale = config.getboolean('parameter_sweep',
'param_logscale')
if len(param_values) > 1:
print("Testing SNN for parameter values {} = ".format(
param_name))
print(['{:.2f}'.format(i) for i in param_values])
print('\n')
elif len(param_values) == 0:
param_values.append(eval(config.get('cell', param_name)))
# Loop over parameter to sweep
for p in param_values:
if is_stop(queue):
break
# Display current parameter value
config.set('cell', param_name, str(p))
if len(param_values) > 1:
print("\nCurrent value of parameter to sweep: " +
"{} = {:.2f}\n".format(param_name, p))
results.append(run_single(snn, **testset))
# Plot and return results of parameter sweep.
try:
from snntoolbox.simulation.plotting import plot_param_sweep
except ImportError:
plot_param_sweep = None
if plot_param_sweep is not None:
plot_param_sweep(
results, config.getint('simulation', 'num_to_test'),
param_values, param_name, param_logscale)
return results
return wrapper
return decorator
[docs]def import_target_sim(config):
sim_str = config.get('simulation', 'simulator')
code_str = '_' + config.get('conversion', 'spike_code') \
if sim_str == 'INI' else ''
return import_module('snntoolbox.simulation.target_simulators.'
+ sim_str + code_str + '_target_sim')
[docs]def load_config(filepath):
"""
Load a config file from ``filepath``.
"""
from snntoolbox.utils.utils import import_configparser
configparser = import_configparser()
assert os.path.isfile(filepath), \
"Configuration file not found at {}.".format(filepath)
config = configparser.ConfigParser()
config.optionxform = str
config.read(filepath)
return config
[docs]def update_setup(config_filepath):
"""Update default settings with user settings and check they are valid.
Load settings from configuration file at ``config_filepath``, and check
that parameter choices are valid. Non-specified settings are filled in with
defaults.
"""
from textwrap import dedent
# config.read will not thow an error if the filepath does not exist, and
# user values will not override defaults. So check here:
assert os.path.isfile(config_filepath), \
"Config filepath {} does not exist.".format(config_filepath)
# Load defaults.
config = load_config(os.path.abspath(os.path.join(
os.path.dirname(__file__), '..', 'config_defaults')))
# Overwrite with user settings.
config.read(config_filepath)
keras_backend = config.get('simulation', 'keras_backend')
keras_backends = config_string_to_set_of_strings(
config.get('restrictions', 'keras_backends'))
assert keras_backend in keras_backends, \
"Keras backend {} not supported. Choose from {}.".format(
keras_backend, keras_backends)
os.environ['KERAS_BACKEND'] = keras_backend
# The keras import has to happen after setting the backend environment
# variable!
import tensorflow.keras.backend as k
assert k.backend() == keras_backend, \
"Keras backend set to {} in snntoolbox config file, but has already " \
"been set to {} by a previous keras import. Set backend " \
"appropriately in the keras config file.".format(keras_backend,
k.backend())
# Name of input file must be given.
filename_ann = config.get('paths', 'filename_ann')
assert filename_ann != '', "Filename of input model not specified."
# Check that simulator choice is valid.
simulator = config.get('simulation', 'simulator')
simulators = config_string_to_set_of_strings(config.get('restrictions',
'simulators'))
assert simulator in simulators, \
"Simulator '{}' not supported. Choose from {}".format(simulator,
simulators)
# Warn user that it is not possible to use Brian2 simulator by loading a
# pre-converted network from disk.
if simulator == 'brian2' and not config.getboolean('tools', 'convert'):
print(dedent("""\n
SNN toolbox Warning: When using Brian 2 simulator, you need to
convert the network each time you start a new session. (No
saving/reloading methods implemented.) Setting convert = True.
\n"""))
config.set('tools', 'convert', str(True))
elif simulator in config_string_to_set_of_strings(
config.get('restrictions', 'simulators_pyNN')):
delay = config.getfloat('cell', 'delay')
tau_refrac = config.getfloat('cell', 'tau_refrac')
v_thresh = config.getfloat('cell', 'v_thresh')
dt = config.getfloat('simulation', 'dt')
# We found that in some cases the refractory period can actually be
# smaller than the time step.
scale = 1e1 if dt == 0.1 else 1e3
if tau_refrac < dt / scale and tau_refrac != 0:
print("\nSNN toolbox WARNING: Refractory period ({}) must be at "
"least one time step / {} ({}). Setting tau_refrac = dt / "
"{}.".format(tau_refrac, scale, dt / scale, scale))
config.set('cell', 'tau_refrac', str(dt / scale))
elif tau_refrac > dt / scale:
print("\nSNN toolbox WARNING: We recommend to set the refractory "
"period ({}) to be as small as possible (one time step / {}"
", {}).".format(tau_refrac, scale, dt / scale))
if delay < dt:
print("\nSNN toolbox WARNING: Delay ({}) must be at least one "
"time step ({}). Setting delay = dt.".format(delay, dt))
config.set('cell', 'delay', str(dt))
elif delay > dt:
print("\nSNN toolbox WARNING: We recommend to set the delay ({}) "
"to be as small as possible (one time step, {})."
"".format(delay, dt))
if v_thresh != 0.01:
print("\nSNN toolbox WARNING: For optimal correspondence between "
"the original ANN and the converted SNN simulated on pyNN, "
"the threshold should be 0.01. Current value: {}."
"".format(v_thresh))
# Set default path if user did not specify it.
if config.get('paths', 'path_wd') == '':
config.set('paths', 'path_wd', os.path.dirname(config_filepath))
# Check specified working directory exists.
path_wd = config.get('paths', 'path_wd')
assert os.path.exists(path_wd), \
"Working directory {} does not exist.".format(path_wd)
# Check that choice of input model library is valid.
model_lib = config.get('input', 'model_lib')
model_libs = config_string_to_set_of_strings(config.get('restrictions',
'model_libs'))
assert model_lib in model_libs, "ERROR: Input model library '{}' ".format(
model_lib) + "not supported yet. Possible values: {}".format(
model_libs)
# Check input model is found and has the right format for the specified
# model library.
if config.getboolean('tools', 'evaluate_ann') \
or config.getboolean('tools', 'parse'):
if model_lib == 'caffe':
caffemodel_filepath = os.path.join(path_wd,
filename_ann + '.caffemodel')
caffemodel_h5_filepath = os.path.join(path_wd, filename_ann +
'.caffemodel.h5')
assert os.path.isfile(caffemodel_filepath) or os.path.isfile(
caffemodel_h5_filepath), "File {} or {} not found.".format(
caffemodel_filepath, caffemodel_h5_filepath)
prototxt_filepath = os.path.join(path_wd, filename_ann +
'.prototxt')
assert os.path.isfile(prototxt_filepath), \
"File {} not found.".format(prototxt_filepath)
elif model_lib == 'keras':
h5_filepath = str(os.path.join(path_wd, filename_ann + '.h5'))
assert os.path.isfile(h5_filepath), \
"File {} not found.".format(h5_filepath)
elif model_lib == 'lasagne':
h5_filepath = os.path.join(path_wd, filename_ann + '.h5')
pkl_filepath = os.path.join(path_wd, filename_ann + '.pkl')
assert os.path.isfile(h5_filepath) or \
os.path.isfile(pkl_filepath), \
"File {} not found.".format('.h5 or .pkl')
py_filepath = os.path.join(path_wd, filename_ann + '.py')
assert os.path.isfile(py_filepath), \
"File {} not found.".format(py_filepath)
else:
pass
# print("For the specified input model library {}, no test is "
# "implemented to check if input model files exist in the "
# "specified working directory!".format(model_lib))
# Set default path if user did not specify it.
if config.get('paths', 'dataset_path') == '':
config.set('paths', 'dataset_path', os.path.dirname(__file__))
# Check that the data set path is valid.
dataset_path = os.path.abspath(config.get('paths', 'dataset_path'))
config.set('paths', 'dataset_path', dataset_path)
assert os.path.exists(dataset_path), "Path to data set does not exist: " \
"{}".format(dataset_path)
# Check that data set path contains the data in the specified format.
assert os.listdir(dataset_path), "Data set directory is empty."
normalize = config.getboolean('tools', 'normalize')
dataset_format = config.get('input', 'dataset_format')
if dataset_format == 'npz' and normalize and not os.path.exists(
os.path.join(dataset_path, 'x_norm.npz')):
raise RuntimeWarning(
"No data set file 'x_norm.npz' found in specified data set path " +
"{}. Add it, or disable normalization.".format(dataset_path))
if dataset_format == 'npz' and not (os.path.exists(os.path.join(
dataset_path, 'x_test.npz')) and os.path.exists(os.path.join(
dataset_path, 'y_test.npz'))):
raise RuntimeWarning(
"Data set file 'x_test.npz' or 'y_test.npz' was not found in "
"specified data set path {}.".format(dataset_path))
sample_idxs_to_test = eval(config.get('simulation', 'sample_idxs_to_test'))
num_to_test = config.getint('simulation', 'num_to_test')
if len(sample_idxs_to_test):
num_required = max(sample_idxs_to_test) + 1
if num_required > num_to_test:
print(dedent("""
SNN toolbox warning: Settings mismatch. Adjusting 'num_to_test' to
include all 'sample_idxs_to_test'."""))
config.set('simulation', 'num_to_test', str(num_required))
# Create log directory if it does not exist.
if config.get('paths', 'log_dir_of_current_run') == '':
config.set('paths', 'log_dir_of_current_run', os.path.join(
path_wd, 'log', 'gui', config.get('paths', 'runlabel')))
log_dir_of_current_run = config.get('paths', 'log_dir_of_current_run')
if not os.path.isdir(log_dir_of_current_run):
os.makedirs(log_dir_of_current_run)
# Specify filenames for models at different stages of the conversion.
if config.get('paths', 'filename_parsed_model') == '':
config.set('paths', 'filename_parsed_model', filename_ann + '_parsed')
if config.get('paths', 'filename_snn') == '':
config.set('paths', 'filename_snn', '{}_{}'.format(filename_ann,
simulator))
# Make sure the number of samples to test is not lower than the batch size.
batch_size = config.getint('simulation', 'batch_size')
if config.getint('simulation', 'num_to_test') < batch_size:
print(dedent("""\
SNN toolbox Warning: 'num_to_test' set lower than 'batch_size'.
In simulators that test samples batch-wise (e.g. INIsim), this
can lead to undesired behavior. Setting 'num_to_test' equal to
'batch_size'."""))
config.set('simulation', 'num_to_test', str(batch_size))
plot_var = get_plot_keys(config)
plot_vars = config_string_to_set_of_strings(config.get('restrictions',
'plot_vars'))
assert all([v in plot_vars for v in plot_var]), \
"Plot variable(s) {} not understood.".format(
[v for v in plot_var if v not in plot_vars])
if 'all' in plot_var:
plot_vars_all = plot_vars.copy()
plot_vars_all.remove('all')
config.set('output', 'plot_vars', str(plot_vars_all))
log_var = get_log_keys(config)
log_vars = config_string_to_set_of_strings(config.get('restrictions',
'log_vars'))
assert all([v in log_vars for v in log_var]), \
"Log variable(s) {} not understood.".format(
[v for v in log_var if v not in log_vars])
if 'all' in log_var:
log_vars_all = log_vars.copy()
log_vars_all.remove('all')
config.set('output', 'log_vars', str(log_vars_all))
# Change matplotlib plot properties, e.g. label font size
try:
import matplotlib
except ImportError:
matplotlib = None
if len(plot_vars) > 0:
import warnings
warnings.warn("Package 'matplotlib' not installed; disabling "
"plotting. Run 'pip install matplotlib' to enable "
"plotting.", ImportWarning)
config.set('output', 'plot_vars', str({}))
if matplotlib is not None:
matplotlib.rcParams.update(eval(config.get('output',
'plotproperties')))
# Check settings for parameter sweep
param_name = config.get('parameter_sweep', 'param_name')
try:
config.get('cell', param_name)
except KeyError:
print("Unkown parameter name {} to sweep.".format(param_name))
raise RuntimeError
spike_code = config.get('conversion', 'spike_code')
spike_codes = config_string_to_set_of_strings(config.get('restrictions',
'spike_codes'))
assert spike_code in spike_codes, \
"Unknown spike code {} selected. Choose from {}.".format(spike_code,
spike_codes)
if spike_code == 'temporal_pattern':
num_bits = str(config.getint('conversion', 'num_bits'))
config.set('simulation', 'duration', num_bits)
config.set('simulation', 'batch_size', '1')
elif 'ttfs' in spike_code:
config.set('cell', 'tau_refrac',
str(config.getint('simulation', 'duration')))
assert keras_backend != 'theano' or spike_code == 'temporal_mean_rate', \
"Keras backend 'theano' only works when the 'spike_code' parameter " \
"is set to 'temporal_mean_rate' in snntoolbox config."
with open(os.path.join(log_dir_of_current_run, '.config'), str('w')) as f:
config.write(f)
return config
[docs]def initialize_simulator(config):
"""Import a module that contains utility functions of spiking simulator."""
simulator = config.get('simulation', 'simulator')
print("Initializing {} simulator...\n".format(simulator))
if simulator in config_string_to_set_of_strings(
config.get('restrictions', 'simulators_pyNN')):
if simulator == 'spiNNaker':
try:
sim = import_module('pyNN.' + simulator)
except ImportError:
sim = import_module('spynnaker8')
else:
sim = import_module('pyNN.' + simulator)
# From the pyNN documentation:
# "Before using any other functions or classes from PyNN, the user
# must call the setup() function. Calling setup() a second time
# resets the simulator entirely, destroying any network that may
# have been created in the meantime."
sim.setup(timestep=config.getfloat('simulation', 'dt'))
return sim
if simulator == 'brian2':
return import_module('brian2')
if simulator == 'loihi':
import nxsdk.api.n2a as sim
return sim
sim_module_str = None
if simulator == 'INI':
spike_code = config.get('conversion', 'spike_code')
sim_module_str = 'inisim.' + spike_code
if spike_code == 'temporal_mean_rate':
sim_module_str += '_' + config.get('simulation', 'keras_backend')
elif simulator == 'MegaSim':
sim_module_str = 'megasim.megasim'
if sim_module_str is None:
sim_module_str = 'inisim.temporal_mean_rate_theano'
sim = import_module('snntoolbox.simulation.backends.' + sim_module_str)
assert sim, "Simulator {} could not be initialized.".format(simulator)
return sim
[docs]def get_log_keys(config):
return config_string_to_set_of_strings(config.get('output', 'log_vars'))
[docs]def get_plot_keys(config):
return config_string_to_set_of_strings(config.get('output', 'plot_vars'))
[docs]def config_string_to_set_of_strings(string):
set_unicode = set(eval(string))
return {str(s) for s in set_unicode}