# This script is part of navis (http://www.github.com/navis-org/navis).
# Copyright (C) 2018 Philipp Schlegel
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
"""Module containing functions and classes to build `NEURON` compartment models.
Useful resources
----------------
- http://www.inf.ed.ac.uk/teaching/courses/nc/NClab1.pdf
ToDo
----
- connect neurons
- use neuron ID as GID
- [x] add spike recorder
- [x] make a subplot for each recording type (V, current, spikes)
- consider adding 3d points to more accurately represent the neuron
Examples
--------
Initialize and run a simple model. For debugging/testing only
>>> import navis
>>> import navis.interfaces.neuron as nrn
>>> import neuron
>>> # Set finer time steps
>>> neuron.h.dt = 0.025 # .01 ms
>>> # Set the temperature - how much does this matter?
>>> # Default is 6.3 (from HH model)
>>> # neuron.h.celsius = 24
>>> # This is a DA1 PN from the hemibrain dataset
>>> # It's in 8x8x8 nm voxels so we need to convert to convert
>>> n = navis.example_neurons(1) / 125
>>> n.reroot(n.soma, inplace=True)
>>> navis.smooth_skeleton(n, to_smooth='radius', inplace=True, window=3)
>>> # Get dendritic postsynapses
>>> post = n.connectors[n.connectors.type == 'post']
>>> post = post[post.y >= 250]
>>> # Initialize as a DrosophilaPN which automatically assigns a couple
>>> # properties known from the literature.
>>> cmp = nrn.DrosophilaPN(n, res=10)
>>> # Simulate some synaptic inputs on the first 10 input synapse
>>> cmp.add_synaptic_current(post.node_id.unique()[0:10], max_syn_cond=.1,
rev_pot=-10)
>>> # Add voltage recording at the soma and some of the synapses
>>> cmp.add_voltage_record(n.soma, label='soma')
>>> cmp.add_voltage_record(post.node_id.unique()[0:3])
>>> # Let's also check out the synaptic current at one of the synapses
>>> cmp.add_current_record(post.node_id.unique()[0])
>>> # Initialize Run for 200ms
>>> print('Running model')
>>> cmp.run_simulation(200, v_init=-60)
>>> print('Done')
>>> # Plot
>>> cmp.plot_results()
Simulate some presynaptic spikes
>>> cmp = nrn.DrosophilaPN(n, res=1)
>>> cmp.add_voltage_record(n.soma, label='soma')
>>> cmp.add_voltage_record(post.node_id.unique()[0:10])
>>> cmp.add_synaptic_input(post.node_id.unique()[0:10], spike_no=5,
spike_int=50, spike_noise=1, syn_tau2=1.1,
syn_rev_pot=-10, cn_weight=0.04)
>>> cmp.run_simulation(200, v_init=-60)
>>> cmp.plot_results()
"""
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from numbers import Number
from ... import config, core, utils, graph
from .utils import is_NEURON_object, is_section, is_segment
# We will belay any import error
try:
import neuron
except ImportError:
raise ImportError('This interface requires the `neuron` libary to be '
'installed:\n pip3 install neuron\n'
'See also https://neuron.yale.edu/neuron/')
from neuron.units import ms, mV
neuron.h.load_file('stdrun.hoc')
# Set up logging
logger = config.get_logger(__name__)
__all__ = []
# It looks like there can only ever be one reference to the time
# If we have multiple models, we will each reference them to this variable
main_t = None
[docs]class CompartmentModel:
"""Compartment model representing a single neuron in NEURON.
Parameters
----------
x : navis.TreeNeuron
Neuron to generate model for. Has to be in microns!
res : int
Approximate length [um] of segments. This guarantees that
no section has any segment that is longer than `res` but for
small branches (i.e. "sections") the segments might be smaller.
Lower ``res`` = more detailed simulation.
"""
[docs] def __init__(self, x: 'core.TreeNeuron', res=10):
"""Initialize Neuron."""
utils.eval_param(x, name='x', allowed_types=(core.TreeNeuron, ))
# Note that we make a copy to make sure that the data underlying the
# model will not accidentally be changed
self.skeleton = x.copy()
# Max section resolution per segment
self.res = res
# Some placeholders
self._sections = []
self._stimuli = {}
self._records = {}
self._synapses = {}
# Generate the actual model
self._validate_skeleton()
self._generate_sections()
def __repr__(self):
s = (f'CompartmentModel<id={self.skeleton.label},'
f'sections={self.n_sections};'
f'stimuli={self.n_stimuli};'
f'records={self.n_records}>'
)
return s
@property
def label(self):
"""Name/label of the neuron."""
return f'CompartmentModel[{self.skeleton.label}]'
@property
def n_records(self):
"""Number of records (across all types) active on this model."""
return len([r for t in self.records.values() for r in t])
@property
def n_sections(self):
"""Number of sections in this model."""
return len(self.sections)
@property
def n_stimuli(self):
"""Number of stimuli active on this model."""
return len(self.stimuli)
@property
def nodes(self) -> pd.DataFrame:
"""Node table of the skeleton."""
return self.skeleton.nodes
@property
def cm(self) -> float:
"""Membran capacity [micro Farads / cm^2] of all sections."""
return np.array([s.cm for s in self.sections])
@cm.setter
def cm(self, value: float):
"""Membran capacity [micro Farads / cm^2] for all sections."""
for s in self.sections:
s.cm = value
@property
def Ra(self) -> float:
"""Axial resistance [Ohm * cm] of all sections."""
return np.array([s.Ra for s in self.sections])
@Ra.setter
def Ra(self, value: float):
"""Set axial resistance [Ohm * cm] for all sections."""
for s in self.sections:
s.Ra = value
@property
def records(self) -> dict:
"""Return mapping of node ID(s) to recordings."""
return self._records
@property
def sections(self) -> np.ndarray:
"""List of sections making up this model."""
return self._sections
@property
def stimuli(self) -> dict:
"""Return mapping of node ID(s) to stimuli."""
return self._stimuli
@property
def synapses(self) -> dict:
"""Return mapping of node ID(s) to synapses."""
return self._synapses
@property
def t(self) -> np.ndarray:
"""The global time. Should be the same for all neurons."""
return main_t
def _generate_sections(self):
"""Generate sections from the neuron.
This will automatically be called at initialization and should not be
called again.
"""
# First generate sections
self._sections = []
nodes = self.skeleton.nodes.set_index('node_id')
roots = self.skeleton.root
bp = self.skeleton.branch_points.node_id.values
G = self.skeleton.graph
node2sec = {}
node2pos = {}
for i, seg in enumerate(self.skeleton.small_segments):
# Get child -> parent distances in this segment
dists = np.array([G.edges[(c, p)]['weight']
for c, p in zip(seg[:-1], seg[1:])])
# Invert the sections
# That's because in navis sections go from tip -> root (i.e.
# child -> parent) but in neuron section(0) is the base and
# section(1) is the tip.
seg = np.asarray(seg)[::-1]
dists = dists[::-1]
# Grab the coordinates and radii
seg_nodes = nodes.loc[seg]
locs = seg_nodes[['x', 'y', 'z']].values
radii = seg_nodes.radius.values
# Generate section
sec = neuron.h.Section(name=f'segment_{i}')
# Set 3D points -> this automatically sets length L
xvec = neuron.h.Vector(locs[:, 0])
yvec = neuron.h.Vector(locs[:, 1])
zvec = neuron.h.Vector(locs[:, 2])
dvec = neuron.h.Vector(radii * 2)
neuron.h.pt3dadd(xvec, yvec, zvec, dvec, sec=sec)
# Set number of segments for this section
# We also will make sure that each section has an odd
# number of segments
sec.nseg = 1 + 2 * int(sec.L / (self.res * 2))
# Keep track of section
self.sections.append(sec)
# While we're at it: for each point (except the root of this
# section) find the relative position within the section
# Get normalized positions within this segment
norm_pos = dists.cumsum() / dists.sum()
# Update positional dictionaries (required for connecting the
# segments in the next step)
node2pos.update(dict(zip(seg[1:], norm_pos)))
node2sec.update({n: i for n in seg[1:]})
# If this happens to be the segment with the skeleton's root, keep
# track of it too
if seg[0] in roots:
node2pos[seg[0]] = 0
node2sec[seg[0]] = i
self._sections = np.array(self.sections)
self.skeleton.nodes['sec_ix'] = self.skeleton.nodes.node_id.map(node2sec)
self.skeleton.nodes['sec_pos'] = self.skeleton.nodes.node_id.map(node2pos)
# Need to grab nodes again after adding `sec_ix` and `sec_pos`
nodes = self.skeleton.nodes.set_index('node_id')
# Connect segments
for i, seg in enumerate(self.skeleton.small_segments):
# Root is special in that it only needs to be connected if it's also
# a branch point
if seg[-1] in roots:
# Skip if root is not a branch point
if seg[-1] not in bp:
continue
# If root is also a branch point, it will be part of more than
# one section but in the positional dicts we will have kept track
# of only one of them. That's the one we pick as base segment
if node2sec[seg[-1]] == i:
continue
parent = nodes.loc[seg[-1]]
parent_sec = self.sections[parent.sec_ix]
self.sections[i].connect(parent_sec(1))
def _validate_skeleton(self):
"""Validate skeleton."""
if self.skeleton.units and not self.skeleton.units.dimensionless:
not_um = self.skeleton.units.units != config.ureg.Unit('um')
not_microns = self.skeleton.units.units != config.ureg.Unit('microns')
if not_um and not_microns:
logger.warning('Model expects coordinates in microns but '
f'neuron has units "{self.skeleton.units}"!')
if len(self.skeleton.root) > 1:
logger.warning('Neuron has multiple roots and hence consists of '
'multiple disconnected fragments!')
if 'radius' not in self.skeleton.nodes.columns:
raise ValueError('Neuron node table must have `radius` column')
if np.any(self.skeleton.nodes.radius.values <= 0):
raise ValueError('Neuron node table contains radii <= 0.')
[docs] def inject_current_pulse(self, where, start=5,
duration=1, current=0.1):
"""Add current injection (IClamp) stimulation to model.
Parameters
----------
where : int | list of int
Node ID(s) at which to stimulate.
start : int
Onset (delay) [ms] from beginning of simulation.
duration : int
Duration (dur) [ms] of injection.
current : float
Amount (i) [nA] of injected current.
"""
self._add_stimulus('IClamp', where=where, delay=start,
dur=duration, amp=current)
[docs] def add_synaptic_current(self, where, start=5, tau=0.1, rev_pot=0,
max_syn_cond=0.1):
"""Add synaptic current(s) (AlphaSynapse) to model.
Parameters
----------
where : int | list of int
Node ID(s) at which to stimulate.
start : int
Onset [ms] from beginning of simulation.
tau : int
Decay time constant [ms].
rev_pot : int
Reverse potential (e) [mV].
max_syn_cond : float
Max synaptic conductance (gmax) [uS].
"""
self._add_stimulus('AlphaSynapse', where=where, onset=start,
tau=tau, e=rev_pot, gmax=max_syn_cond)
def _add_stimulus(self, stimulus, where, **kwargs):
"""Add generic stimulus."""
if not callable(stimulus):
stimulus = getattr(neuron.h, stimulus)
where = utils.make_iterable(where)
nodes = self.nodes.set_index('node_id')
for node in nodes.loc[where].itertuples():
sec = self.sections[node.sec_ix](node.sec_pos)
stim = stimulus(sec)
for k, v in kwargs.items():
setattr(stim, k, v)
self.stimuli[node.Index] = self.stimuli.get(node.Index, []) + [stim]
[docs] def add_voltage_record(self, where, label=None):
"""Add voltage recording to model.
Parameters
----------
where : int | list of int
Node ID(s) at which to record.
label : str, optional
If label is given, this recording will be added as
``self.records['v'][label]`` else ``self.records['v'][node_id]``.
"""
self._add_record(where, what='v', label=label)
[docs] def add_current_record(self, where, label=None):
"""Add current recording to model.
This only works if nodes map to sections that have point processes.
Parameters
----------
where : int | list of int
Node ID(s) at which to record.
label : str, optional
If label is given, this recording will be added as
``self.records['i'][label]`` else ``self.records['i'][node_id]``.
"""
nodes = utils.make_iterable(where)
# Map nodes to point processes
secs = self.get_node_segment(nodes)
where = []
for n, sec in zip(nodes, secs):
pp = sec.point_processes()
if not pp:
raise TypeError(f'Section for node {n} has no point process '
'- unable to add current record')
elif len(pp) > 1:
logger.warning(f'Section for node {n} has more than on point '
'process. Recording current at first.')
pp = pp[:1]
where += pp
self._add_record(where, what='i', label=label)
[docs] def add_spike_detector(self, where, threshold=20, label=None):
"""Add a spike detector at given node(s).
Parameters
----------
where : int | list of int
Node ID(s) at which to record.
threshold : float
Threshold in mV for a spike to be counted.
label : str, optional
If label is given, this recording will be added as
``self.records[label]`` else ``self.records[node_id]``.
"""
where = utils.make_iterable(where)
self.records['spikes'] = self.records.get('spikes', {})
self._spike_det = getattr(self, '_spike_det', [])
segments = self.get_node_segment(where)
sections = self.get_node_section(where)
for n, sec, seg in zip(where, sections, segments):
# Generate a NetCon object that has no target
sp_det = neuron.h.NetCon(seg._ref_v, None, sec=sec)
# Set threshold
if threshold:
sp_det.threshold = threshold
# Keeping track of this to save it from garbage collector
self._spike_det.append(sp_det)
# Create a vector for the spike timings
vec = neuron.h.Vector()
# Tell the NetCon object to record into that vector
sp_det.record(vec)
if label:
self.records['spikes'][label] = vec
else:
self.records['spikes'][n] = vec
def _add_record(self, where, what, label=None):
"""Add a recording to given node.
Parameters
----------
where : int | list of int | point process | section
Node ID(s) (or a section) at which to record.
what : str
What to record. Can be e.g. `v` or `_ref_v` for Voltage.
label : str, optional
If label is given, this recording will be added as
``self.records[label]`` else ``self.records[node_id]``.
"""
where = utils.make_iterable(where)
if not isinstance(what, str):
raise TypeError(f'Required str e.g. "v", got {type(what)}')
if not what.startswith('_ref_'):
what = f'_ref_{what}'
rec_type = what.split('_')[-1]
if rec_type not in self.records:
self.records[rec_type] = {}
for w in where:
# If this is a neuron object (e.g. segment, section or point
# process) we assume this does not need mapping
if is_NEURON_object(w):
seg = w
else:
seg = self.get_node_segment(w)
rec = neuron.h.Vector().record(getattr(seg, what))
if label:
self.records[rec_type][label] = rec
else:
self.records[rec_type][w] = rec
[docs] def connect(self, pre, where, syn_tau1=.1 * ms, syn_tau2=10 * ms,
syn_rev_pot=0, cn_thresh=10, cn_delay=1 * ms, cn_weight=0):
"""Connect object to model.
This uses the Exp2Syn synapse and treats `pre` as the presynaptic
object.
Parameters
----------
pre : NetStim | section
The presynaptic object to connect to this neuron.
where : int | list of int
Node IDs at which to simulate synaptic input.
Synapse properties:
syn_tau1 : int
Rise time constant [ms].
syn_tau2 : int
Decay time constant [ms].
syn_rev_pot : int
Reversal potential (e) [mV].
Connection properties:
cn_thresh : int
Presynaptic membrane potential [mV] at which synaptic
event is triggered.
cn_delay : int
Delay [ms] between presynaptic trigger and postsynaptic
event.
cn_weight : int
Weight variable. This bundles a couple of synaptic
properties such as e.g. how much transmitter is released
or binding affinity at postsynaptic receptors.
"""
where = utils.make_iterable(where)
if not is_NEURON_object(pre):
raise ValueError(f'Expected NEURON object, got {type(pre)}')
# Turn section into segment
if isinstance(pre, neuron.nrn.Section):
pre = pre()
# Go over the nodes
nodes = self.nodes.set_index('node_id')
for node in nodes.loc[where].itertuples():
# Generate synapses for the nodes in question
# Note that we are not reusing existing synapses
# in case the properties are different
sec = self.sections[node.sec_ix](node.sec_pos)
syn = neuron.h.Exp2Syn(sec)
syn.tau1 = syn_tau1
syn.tau2 = syn_tau2
syn.e = syn_rev_pot
self.synapses[node.Index] = self.synapses.get(node.Index, []) + [syn]
# Connect spike stimulus and synapse
if isinstance(pre, neuron.nrn.Segment):
nc = neuron.h.NetCon(pre._ref_v, syn, sec=pre.sec)
else:
nc = neuron.h.NetCon(pre, syn)
# Set connection parameters
nc.threshold = cn_thresh
nc.delay = cn_delay
nc.weight[0] = cn_weight
self.stimuli[node.Index] = self.stimuli.get(node.Index, []) + [nc, pre]
[docs] def clear_records(self):
"""Clear records."""
self._records = {}
[docs] def clear_stimuli(self):
"""Clear stimuli."""
self._stimuli = {}
def clear_synapses(self):
"""Clear synapses."""
self._synapses = {}
def clear(self):
"""Attempt to remove model from NEURON space.
This is not guaranteed to work. Check `neuron.h.topology()` to inspect.
"""
# Basically we have to bring the reference count to zero
self.clear_records()
self.clear_stimuli()
self.clear_synapses()
for s in self._sections:
del s
self._sections = []
[docs] def get_node_section(self, node_ids):
"""Return section(s) for given node(s).
Parameters
----------
node_ids : int | list of int
Node IDs.
Returns
-------
section(s) : segment or list of segments
Depends on input.
"""
nodes = self.nodes.set_index('node_id')
if not utils.is_iterable(node_ids):
n = nodes.loc[node_ids]
return self.sections[n.sec_ix]
else:
segs = []
for node in nodes.loc[node_ids].itertuples():
segs.append(self.sections[node.sec_ix])
return segs
[docs] def get_node_segment(self, node_ids):
"""Return segment(s) for given node(s).
Parameters
----------
node_ids : int | list of int
Node IDs.
Returns
-------
segment(s) : segment or list of segments
Depends on input.
"""
nodes = self.nodes.set_index('node_id')
if not utils.is_iterable(node_ids):
n = nodes.loc[node_ids]
return self.sections[n.sec_ix](n.sec_pos)
else:
segs = []
for node in nodes.loc[node_ids].itertuples():
segs.append(self.sections[node.sec_ix](node.sec_pos))
return segs
[docs] def insert(self, mechanism, subset=None, **kwargs):
"""Insert biophysical mechanism for model.
Parameters
----------
mechanism : str
Mechanism to insert - e.g. "hh" for Hodgkin-Huxley kinetics.
subset : list of sections | list of int
Sections (or indices thereof) to set mechanism for.
If ``None`` will add mechanism to all sections.
**kwargs
Use to set properties for mechanism.
"""
if isinstance(subset, type(None)):
sections = self.sections
else:
subset = utils.make_iterable(subset)
if all([is_section(s) for s in subset]):
sections = subset
elif all([isinstance(s, Number) for s in subset]):
sections = self.sections[subset]
else:
raise TypeError('`subset` must be None, a list of sections or '
'a list of section indices')
for sec in np.unique(sections):
_ = sec.insert(mechanism)
for seg in sec:
mech = getattr(seg, mechanism)
for k, v in kwargs.items():
setattr(mech, k, v)
[docs] def uninsert(self, mechanism, subset=None):
"""Remove biophysical mechanism from model.
Parameters
----------
mechanism : str
Mechanism to remove - e.g. "hh" for Hodgkin-Huxley kinetics.
subset : list of sections | list of int
Sections (or indices thereof) to set mechanism for.
If ``None`` will add mechanism to all sections.
"""
if isinstance(subset, type(None)):
sections = self.sections
else:
subset = utils.make_iterable(subset)
if all([is_section(s) for s in subset]):
sections = subset
elif all([isinstance(s, Number) for s in subset]):
sections = self.sections[subset]
else:
raise TypeError('`subset` must be None, a list of sections or '
'a list of section indices')
for sec in np.unique(sections):
if hasattr(sec, mechanism):
_ = sec.uninsert(mechanism)
def plot_structure(self):
"""Visualize structure in 3D using matplotlib."""
_ = neuron.h.PlotShape().plot(plt)
def run_simulation(self, duration=25 * ms, v_init=-65 * mV):
"""Run the simulation."""
# Add recording of time
global main_t
main_t = neuron.h.Vector().record(neuron.h._ref_t)
# This resets the entire model space not just this neuron!
neuron.h.finitialize(v_init)
neuron.h.continuerun(duration)
[docs] def plot_results(self, axes=None):
"""Plot results.
Parameters
----------
axes : matplotlib axes
Axes to plot onto. Must have one ax for each recording
type (mV, spike count, etc) in `self.records`.
Returns
-------
axes
"""
if isinstance(self.t, type(None)) or not len(self.t):
logger.warning('Looks like the simulation has not yet been run.')
return
if not self.records:
logger.warning('Nothing to plot: no recordings found.')
return
if not axes:
fig, axes = plt.subplots(len(self.records), sharex=True)
# Make sure that even a single ax is a list
if not isinstance(axes, (np.ndarray, list)):
axes = [axes] * len(self.records)
for t, ax in zip(self.records, axes):
for i, (k, v) in enumerate(self.records[t].items()):
if not len(v):
continue
v = v.as_numpy()
# For spikes the vector contains the times
if t == 'spikes':
# Calculate spike rate
bins = np.linspace(0, max(self.t), 10)
hist, _ = np.histogram(v, bins=bins)
width = bins[1] - bins[0]
rate = hist * (1000 / width)
ax.plot(bins[:-1] + (width / 2), rate, label=k)
ax.scatter(v, [-i] * len(v), marker='|', s=100)
else:
ax.plot(self.t, v, label=k)
ax.set_xlabel('time [ms]')
ax.set_ylabel(f'{t}')
ax.legend()
return axes
[docs]class DrosophilaPN(CompartmentModel):
"""Compartment model of an olfactory projection neuron in Drosophila.
This is a ``CompartmentModel`` that uses passive membrane properties
from Tobin et al. (2017) as presets:
- specific axial resistivity (``Ra``) of 266.1 Ohm / cm
- specific membrane capacitance (``cm``) of 0.8 mF / cm**2
- specific leakage conductance (``g``) of 1/Rm
- Rm = specific membran resistance of 20800 Ohm cm**2
- leakage reverse potential of -60 mV
Parameters
----------
x : navis.TreeNeuron
Neuron to generate model for. Has to be in microns!
res : int
Approximate length [um] of segments. This guarantees that
no section has any segment that is longer than `res` but for
small branches (i.e. "sections") the segments might be smaller.
Lower ``res`` = more detailed simulation.
"""
[docs] def __init__(self, x, res=10):
super().__init__(x, res=res)
self.Ra = 266.1 # specific axial resistivity in Ohm cm
self.cm = 0.8 # specific membrane capacitance in mF / cm**2
# Add passive membran properties
self.insert('pas',
g=1/20800, # specific leakage conductance = 1/Rm; Rm = specific membran resistance in Ohm cm**2
e=-60, # leakage reverse potential
)