# This script is part of navis (http://www.github.com/navis-org/navis).
# Copyright (C) 2018 Philipp Schlegel
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
import warnings
import trimesh as tm
import pandas as pd
import numpy as np
import scipy.spatial
import scipy.interpolate
from typing import Union, Optional, List, overload
from typing_extensions import Literal
from .. import config, core, utils, graph
# Set up logging
logger = config.get_logger(__name__)
__all__ = ['resample_skeleton', 'resample_along_axis']
@overload
def resample_skeleton(x: 'core.TreeNeuron',
resample_to: int,
inplace: bool = False,
method: str = 'linear',
skip_errors: bool = True
) -> 'core.TreeNeuron': ...
@overload
def resample_skeleton(x: 'core.NeuronList',
resample_to: int,
inplace: bool = False,
method: str = 'linear',
skip_errors: bool = True
) -> 'core.NeuronList': ...
[docs]
@utils.map_neuronlist(desc='Resampling', allow_parallel=True)
def resample_skeleton(x: 'core.NeuronObject',
resample_to: Union[int, str],
inplace: bool = False,
method: str = 'linear',
skip_errors: bool = True
) -> Optional['core.NeuronObject']:
"""Resample skeleton(s) to given resolution.
Preserves root, leafs and branchpoints. Soma, connectors and node tags
(if present) are mapped onto the closest node in the resampled neuron.
Important
---------
A few things to keep in mind:
- This generates an entirely new set of node IDs! They will be unique
within a neuron, but you may encounter duplicates across neurons.
- Any non-standard node table columns (e.g. "labels") will be lost.
- Soma(s) will be pinned to the closest node in the resampled neuron.
Also: be aware that high-resolution neurons will use A LOT of memory.
Parameters
----------
x : TreeNeuron | NeuronList
Neuron(s) to resample.
resample_to : int | float | str
Target sampling resolution, i.e. one node every
N units of cable. Note that hitting the exact
sampling resolution might not be possible e.g. if
a branch is shorter than the target resolution. If
neuron(s) have their `.units` parameter, you can also
pass a string such as "1 micron".
method : str, optional
See ``scipy.interpolate.interp1d`` for possible
options. By default, we're using linear interpolation.
inplace : bool, optional
If True, will modify original neuron. If False, a
resampled copy is returned.
skip_errors : bool, optional
If True, will skip errors during interpolation and
only print summary.
Returns
-------
TreeNeuron/List
Downsampled neuron(s).
Examples
--------
>>> import navis
>>> n = navis.example_neurons(1)
>>> # Check sampling resolution (nodes/cable)
>>> round(n.sampling_resolution)
60
>>> # Resample to 1 micron (example neurons are in 8x8x8nm)
>>> n_rs = navis.resample_skeleton(n,
... resample_to=1000 / 8,
... inplace=False)
>>> round(n_rs.sampling_resolution)
134
See Also
--------
:func:`navis.downsample_neuron`
This function reduces the number of nodes instead of
resample to certain resolution. Useful if you are
just after some simplification - e.g. for speeding up
your calculations or you want to preserve node IDs.
:func:`navis.resample_along_axis`
Resample neuron along a single axis such that nodes
align with given 1-dimensional grid.
"""
if not isinstance(x, core.TreeNeuron):
raise TypeError(f'Unable to resample data of type "{type(x)}"')
# Map units (non-str are just passed through)
resample_to = x.map_units(resample_to, on_error='raise')
if not inplace:
x = x.copy()
# Collect some information for later
locs = dict(zip(x.nodes.node_id.values, x.nodes[['x', 'y', 'z']].values))
radii = dict(zip(x.nodes.node_id.values, x.nodes.radius.values))
new_nodes: List = []
max_tn_id = x.nodes.node_id.max() + 1
errors = 0
# Iterate over segments
for i, seg in enumerate(x.small_segments):
# Get coordinates
coords = np.vstack([locs[n] for n in seg])
# Get radii
rad = [radii[tn] for tn in seg]
# Vecs between subsequently measured points
vecs = np.diff(coords.T)
# path: cum distance along points (norm from first to Nth point)
dist = np.cumsum(np.linalg.norm(vecs, axis=0))
dist = np.insert(dist, 0, 0)
# If path is too short, just keep the first and last node
if dist[-1] < resample_to or (method == 'cubic' and len(seg) <= 3):
new_nodes += [[seg[0], seg[-1],
coords[0][0], coords[0][1], coords[0][2],
radii[seg[0]]]]
continue
# Distances (i.e. resolution) of interpolation
n_nodes = np.round(dist[-1] / resample_to)
new_dist = np.linspace(dist[0], dist[-1], int(n_nodes))
try:
sampleX = scipy.interpolate.interp1d(dist, coords[:, 0],
kind=method)
sampleY = scipy.interpolate.interp1d(dist, coords[:, 1],
kind=method)
sampleZ = scipy.interpolate.interp1d(dist, coords[:, 2],
kind=method)
sampleR = scipy.interpolate.interp1d(dist, rad,
kind=method)
except ValueError as e:
if skip_errors:
errors += 1
new_nodes += x.nodes.loc[x.nodes.node_id.isin(seg[:-1]),
['node_id', 'parent_id',
'x', 'y', 'z',
'radius']].values.tolist()
continue
else:
raise e
# Sample each dim
xnew = sampleX(new_dist)
ynew = sampleY(new_dist)
znew = sampleZ(new_dist)
rnew = sampleR(new_dist)
# Generate new coordinates
new_coords = np.array([xnew, ynew, znew]).T
# Generate new ids (start and end node IDs of this segment are kept)
new_ids = seg[:1] + [max_tn_id +
i for i in range(len(new_coords) - 2)] + seg[-1:]
# Increase max index
max_tn_id += len(new_ids)
# Keep track of new nodes
new_nodes += [[tn, pn, co[0], co[1], co[2], r]
for tn, pn, co, r in zip(new_ids[:-1],
new_ids[1:],
new_coords,
rnew)]
if errors:
logger.warning(f'{errors} ({errors/i:.0%}) segments skipped due to '
'errors')
# Add root node(s)
root = x.nodes.loc[x.nodes.node_id.isin(utils.make_iterable(x.root)),
['node_id', 'parent_id', 'x', 'y', 'z', 'radius']]
new_nodes += [list(r) for r in root.values]
# Generate new nodes dataframe
new_nodes = pd.DataFrame(data=new_nodes,
columns=['node_id', 'parent_id',
'x', 'y', 'z', 'radius'])
# Convert columns to appropriate dtypes
dtypes = {k: x.nodes[k].dtype for k in ['node_id', 'parent_id', 'x', 'y', 'z', 'radius']}
for cols in new_nodes.columns:
new_nodes = new_nodes.astype(dtypes, errors='ignore')
# Remove duplicate nodes (branch points)
new_nodes = new_nodes[~new_nodes.node_id.duplicated()]
# Generate KDTree
tree = scipy.spatial.cKDTree(new_nodes[['x', 'y', 'z']].values)
# Map soma onto new nodes if required
# Note that if `._soma` is a soma detection function we can't tell
# how to deal with it. Ideally the new soma node will
# be automatically detected but it is possible, for example, that
# the radii of nodes have changed due to interpolation such that more
# than one soma is detected now. Also a "label" column in the node
# table would be lost at this point.
# We will go for the easy option which is to pin the soma at this point.
nodes = x.nodes.set_index('node_id', inplace=False)
if np.any(getattr(x, 'soma')):
soma_nodes = utils.make_iterable(x.soma)
old_pos = nodes.loc[soma_nodes, ['x', 'y', 'z']].values
# Get nearest neighbours
dist, ix = tree.query(old_pos)
node_map = dict(zip(soma_nodes, new_nodes.node_id.values[ix]))
# Map back onto neuron
if utils.is_iterable(x.soma):
# Use _soma to avoid checks - the new nodes have not yet been
# assigned to the neuron!
x._soma = [node_map[n] for n in x.soma]
else:
x._soma = node_map[x.soma]
else:
# If `._soma` was (read: is) a function but it didn't detect anything in
# the original neurons, this makes sure that the resampled neuron
# doesn't have a soma either:
x.soma = None
# Map connectors back if necessary
if x.has_connectors:
# Get position of old synapse-bearing nodes
old_tn_position = nodes.loc[x.connectors.node_id, ['x', 'y', 'z']].values
# Get nearest neighbours
dist, ix = tree.query(old_tn_position)
# Map back onto neuron
x.connectors['node_id'] = new_nodes.node_id.values[ix]
# Map tags back if necessary
# Expects `tags` to be a dictionary {'tag': [node_id1, node_id2, ...]}
if x.has_tags and isinstance(x.tags, dict):
# Get nodes that need remapping
nodes_to_remap = list({n for l in x.tags.values() for n in l})
# Get position of old tag-bearing nodes
old_tn_position = nodes.loc[nodes_to_remap, ['x', 'y', 'z']].values
# Get nearest neighbours
dist, ix = tree.query(old_tn_position)
# Map back onto tags
node_map = dict(zip(nodes_to_remap, new_nodes.node_id.values[ix]))
x.tags = {k: [node_map[n] for n in v] for k, v in x.tags.items()}
# Set nodes
x.nodes = new_nodes
# Clear and regenerate temporary attributes
x._clear_temp_attr()
return x
[docs]
@utils.map_neuronlist(desc='Binning', allow_parallel=True)
def resample_along_axis(x: 'core.TreeNeuron',
interval: Union[int, float, str],
axis: int = 2,
old_nodes: Union[Literal['remove'],
Literal['keep'],
Literal['snap']] = 'remove',
inplace: bool = False
) -> Optional['core.TreeNeuron']:
"""Resample neuron such that nodes lie exactly on given 1d grid.
This function does not simply snap nodes to the closest grid line but
instead adds new nodes where edges between existing nodes intersect
with the planes defined by the grid.
Parameters
----------
x : TreeNeuron | NeuronList
Neuron(s) to resample.
interval : float | int | str
Intervals defining a 1-dimensional grid along given axes
(see examples). If neuron(s) have `.units` set, you can also
pass a string such as "50 nm".
axis : 0 | 1 | 2
Along which axes (x/y/z) to resample.
old_nodes : "remove" | "keep" | "snap"
Existing nodes are unlikely to intersect with the planes as
defined by the grid interval. There are three possible ways
to deal with them:
- "remove" (default) will simply drop old nodes: this
guarantees all remaining nodes will lie on a plane
- "keep" will keep old nodes without changing them
- "snap" will snap those nodes to the closest coordinate
on the grid without interpolation
inplace : bool
If False, will resample and return a copy of the original. If
True, will resample input neuron in place.
Returns
-------
TreeNeuron/List
The resampled neuron(s).
See Also
--------
:func:`navis.resample_skeleton`
Resample neuron such that edges between nodes have a
given length.
:func:`navis.downsample_neuron`
This function reduces the number of nodes instead of
resample to certain resolution. Useful if you are
just after some simplification e.g. for speeding up
your calculations or you want to preserve node IDs.
Examples
--------
Resample neuron such that we have one node in every 40nm slice along z axis
>>> import navis
>>> n = navis.example_neurons(1)
>>> n.n_nodes
4465
>>> res = navis.resample_along_axis(n, interval='40 nm',
... axis=2, old_nodes='remove')
>>> res.n_nodes < n.n_nodes
True
"""
utils.eval_param(axis, name='axis', allowed_values=(0, 1, 2))
utils.eval_param(old_nodes, name='old_nodes',
allowed_values=("remove", "keep", "snap"))
utils.eval_param(x, name='x', allowed_types=(core.TreeNeuron, ))
interval = x.map_units(interval, on_error='raise')
if not inplace:
x = x.copy()
# Collect coordinates of nodes and their parents
nodes = x.nodes
not_root = nodes.loc[nodes.parent_id >= 0]
node_locs = not_root[['x', 'y', 'z']].values
parent_locs = nodes.set_index('node_id').loc[not_root.parent_id.values,
['x', 'y', 'z']].values
# Get all vectors
vecs = parent_locs - node_locs
# Get coordinates along this axis
loc1 = node_locs[:, axis]
loc2 = parent_locs[:, axis]
# This prevents runtime warnings e.g. from division by zero
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# Find out which grid interval these are on
int1 = (loc1 / interval).astype(int)
int2 = (loc2 / interval).astype(int)
# Difference in bin between both locs
diff = int2 - int1
sign = diff / np.abs(diff)
# Figure out by how far we are from the gridline
dist = np.zeros(diff.shape[0])
dist[diff < 0] = loc1[diff < 0] % interval
dist[diff > 0] = -loc1[diff > 0] % interval
# Now we need to calculate the new position
# Get other axes
other_axes = list({0, 1, 2} - {axis})
# Normalize other vectors by this vector
other_vecs_norm = vecs[:, other_axes] / vecs[:, [axis]]
# Get offset for other axis
other_offset = other_vecs_norm * dist.reshape(dist.shape[0], 1)
# Offset for this axis
this_offset = dist * sign
# Apply offsets
new_coords = node_locs.copy()
new_coords[:, other_axes] += other_offset * sign.reshape(sign.shape[0], 1)
new_coords[:, [axis]] += this_offset.reshape(this_offset.shape[0], 1)
# Now extract nodes that need to be inserted
insert_between = not_root.loc[diff != 0, ['node_id', 'parent_id']].values
new_coords = new_coords[diff != 0]
# Insert nodes
graph.insert_nodes(x, where=insert_between, coords=new_coords, inplace=True)
# Figure out what to do with nodes that are not on the grid
if old_nodes == 'remove':
mod = x.nodes[['x', 'y', 'z'][axis]].values % interval
not_lined_up = mod != 0
to_remove = x.nodes.loc[not_lined_up, 'node_id'].values
elif old_nodes == 'keep':
to_remove = insert_between[:, 0]
elif old_nodes == 'snap':
not_lined_up = x.nodes[['x', 'y', 'z']].values[:, axis] % interval != 0
to_snap = x.nodes.loc[not_lined_up, ['x', 'y', 'z'][axis]].values
snapped = (to_snap / interval).round() * interval
x.nodes.loc[not_lined_up, ['x', 'y', 'z'][axis]] = snapped
to_remove = []
if np.any(to_remove):
graph.remove_nodes(x, which=to_remove, inplace=True)
return x
def _make_grid(interval, axis, neuron):
"""Generate Volume visualizing 1d grid."""
assert axis in (0, 1, 2)
bounds = neuron.bbox
# Generate a box for each plane - just a face won't render properly
b = tm.primitives.Box()
box_verts = np.array(b.vertices)
box_faces = np.array(b.faces)
for i in range(3):
is_low = box_verts[:, i] < 0
box_verts[is_low, i] = bounds[i][0]
box_verts[~is_low, i] = bounds[i][1]
is_low = b.vertices[:, axis] < 0
start = (bounds[axis][0] / interval).astype(int) * interval
end = ((bounds[axis][1] / interval).astype(int) + 1) * interval
depth = np.arange(start, end + interval, interval)
faces = []
vertices = []
for i, d in enumerate(depth):
this_verts = box_verts.copy()
this_faces = box_faces.copy()
this_verts[is_low, axis] = d - 0.01 * interval
this_verts[~is_low, axis] = d + 0.01 * interval
this_faces += this_verts.shape[0] * i
vertices.append(this_verts)
faces.append(this_faces)
faces = np.vstack(faces)
vertices = np.vstack(vertices)
return core.Volume(vertices=vertices, faces=faces, color=(1, 1, 1, .1))