# This script is part of navis (http://www.github.com/navis-org/navis).
# Copyright (C) 2018 Philipp Schlegel
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
"""Module contains functions to plot neurons as flat structures."""
import math
import time
import matplotlib.pyplot as plt
import matplotlib.colors as mcl
import pandas as pd
import numpy as np
import networkx as nx
from typing import Optional, Union, Any, List
from typing_extensions import Literal
from .. import core, config, utils
from ..morpho.mmetrics import parent_dist
from .colors import prepare_connector_cmap
logger = config.get_logger(__name__)
__all__ = ['plot_flat']
_DEFAULTS = dict(origin=(0, 0), # Origin in coordinate system
start_angle=0, # Start angle (0 -> to right)
angle_change=45, # Angle between branch and its child
angle_decrease=0, # Angle decrease with each branch point
syn_marker_size=.5, # Length of orthogonal synapse markers/size of scatter
switch_dist=1, # Distance threshold for inverting angle (i.e. flip branch direction)
syn_linewidth=1.5, # Line width for connectors
syn_highlight_color=(1, 0, 0), # Color for highlighted connectors
force_nx=False, # Force using networkx over igraph
color=(0.1, 0.1, 0.1) # Color for neurites
)
[docs]
def plot_flat(x,
layout: Union[Literal['subway'],
Literal['dot'],
Literal['neato'],
Literal['fpd'],
Literal['sfpd'],
Literal['twopi'],
Literal['circo'],
] = 'subway',
connectors: bool = False,
highlight_connectors: Optional[List[int]] = None,
shade_by_length: bool = False,
normalize_distance: bool = False,
reroot_soma: bool = False,
ax: Optional[Any] = None,
**kwargs):
"""Plot neuron as flat diagrams.
Parameters
----------
x : TreeNeuron
A single neuron to plot.
layout : 'subway' | 'dot' | 'neato' | 'fdp' | 'sfpd' | 'twopi' | 'circo'
Layout to use. All but 'subway' require graphviz to
be installed. For the 'fdp' and 'neato' it is highly
recommended to downsample the neuron first.
connectors : bool
If True and neuron has connectors, will plot
connectors.
highlight_connectors : list of connector IDs, optional
Will highlight these connector IDs.
ax : matplotlib.ax, optional
Ax to plot on. Will create new one if not provided.
shade_by_length : bool, optional
Change shade of branch with length. For layout
"subway" only.
normalize_distance : bool, optional
If True, will normalise all distances to the longest
neurite. For layout "subway" only.
**kwargs
Keyword argument passed on to the respective
plotting functions.
Returns
-------
ax : matplotlib.ax
pos : dict
(X, Y) positions for each node: ``{node_id: (x, y)}``.
Examples
--------
Plot neuron in "subway" layout:
.. plot::
:context: close-figs
>>> import navis
>>> n = navis.example_neurons(1).convert_units('nm')
>>> ax, pos = navis.plot_flat(n, layout='subway',
... figsize=(12, 2),
... connectors=True)
>>> _ = ax.set_xlabel('distance [nm]')
>>> plt.show() # doctest: +SKIP
Plot neuron in "dot" layout (requires pygraphviz and graphviz):
.. plot::
:context: close-figs
>>> # First downsample to speed up processing
>>> ds = navis.downsample_neuron(n, 10, preserve_nodes='connectors')
>>> ax, pos = navis.plot_flat(ds, layout='dot', connectors=True) # doctest: +SKIP
>>> plt.show() # doctest: +SKIP
To close all figures (only for doctests)
>>> plt.close('all')
See the :ref:`plotting tutorial <plot_intro>` for more examples.
"""
if isinstance(x, core.NeuronList) and len(x) == 1:
x = x[0]
utils.eval_param(x, name='x', allowed_types=(core.TreeNeuron,))
utils.eval_param(layout, name='layout',
allowed_values=('subway', 'dot', 'neato', 'fdp', 'sfdp', 'twopi', 'circo'))
# Work on the copy of the neuron
x = x.copy()
# Reroot to soma (if applicable)
if reroot_soma and x.soma:
x.reroot(x.soma, inplace=True)
if layout == 'subway':
return _plot_subway(x,
connectors=connectors,
highlight_connectors=highlight_connectors,
shade_by_length=shade_by_length,
normalize_distance=normalize_distance,
ax=ax, **kwargs)
else:
return _plot_force(x,
prog=layout,
connectors=connectors,
highlight_connectors=highlight_connectors,
ax=ax, **kwargs)
def _plot_subway(x, connectors=False, highlight_connectors=[],
shade_by_length=False, normalize_distance=False, ax=None,
**kwargs):
"""Plot neuron as dendrogram. Preserves distances along branches."""
DEFAULTS = _DEFAULTS.copy()
DEFAULTS.update(kwargs)
if len(x.root) > 1:
raise ValueError('Unable to plot neuron with multiple roots. Use '
'`navis.heal_skeleton` to merge the fragments.')
# Change scale of markers if we normalise to max neurite length
if normalize_distance:
DEFAULTS['syn_marker_len'] /= 1000
DEFAULTS['switch_dist'] /= 1000
if not ax:
fig, ax = plt.subplots(figsize=kwargs.get('figsize', (10, 10)))
# Make background transparent (nicer for dark themes)
fig.patch.set_alpha(0)
ax.patch.set_alpha(0)
# For each node get the distance to its root
if 'parent_dist' not in x.nodes.columns:
x.nodes['parent_dist'] = parent_dist(x, 0)
# First collect leafs, branches and root
leaf_nodes = x.leafs.node_id.values
root_nodes = x.root
branch_nodes = set(x.branch_points.node_id.values)
# Use igraph if possible:
if x.igraph and not DEFAULTS['force_nx']:
# Convert node IDs to igraph vertex indices
leaf_vs = x.igraph.vs.select(node_id_in=leaf_nodes)
root_vs = x.igraph.vs.select(node_id_in=root_nodes)
# Now get paths from all tips to the root
paths = x.igraph.get_shortest_paths(root_vs[0], leaf_vs, mode='ALL')
# Translate indices back into node ids
ids = np.array(x.igraph.vs.get_attribute_values('node_id'))
paths_tn = [ids[p] for p in paths]
else:
# Fall back to networkX
iterator = nx.shortest_path(x.graph, target=root_nodes[0])
paths_tn = [iterator[l][::-1] for l in leaf_nodes]
# Generate DataFrame with all the info
nodes = x.nodes.set_index('node_id')
path_df = pd.DataFrame()
path_df['path'] = paths_tn
pdist = nodes.parent_dist.to_dict()
path_df['distances'] = path_df.path.map(lambda x: np.array([pdist[n] for n in x]))
path_df['cable'] = path_df.distances.map(lambda x: sum(x))
# Sort DataFrame by cable length
path_df.sort_values('cable', inplace=True, ascending=False)
path_df.reset_index(inplace=True)
# Prepare for plotting by finding starts points and defining angles
positions = {x.root[0]: DEFAULTS['origin']}
angles = {x.root[0]: DEFAULTS['start_angle']}
seen = {x.root[0]}
for k, path in enumerate(path_df.path.values):
# Because the paths are always from tip to root, we have to find out
# which of the nodes have already been plotted and at which branch point
# we should add this path to the dendrogram
path = np.asarray(path)
exists = path[np.isin(path, list(seen))]
n_branch_points = len(branch_nodes & set(exists))
# Prune path to the bit that does not yet exist
is_new = ~np.isin(path, list(seen)) # numpy.isin doesn't like sets
path = path[is_new]
start_point = positions[exists[-1]]
last_angle = angles[exists[-1]]
# Get distance of the remaining path
distances = path_df.iloc[k].distances[is_new]
distances[0] = 0
distances = np.cumsum(distances)
# Normalise distances
if normalize_distance:
longest_dist = path_df.iloc[0].cable
distances /= longest_dist
# Make sure the longest neurite goes horizontally
# (or whatever START_ANGLE is)
if k != 0:
angle = DEFAULTS['angle_change'] - (DEFAULTS['angle_decrease'] * n_branch_points)
angle = max(angle, DEFAULTS['angle_decrease'])
else:
angle = DEFAULTS['start_angle']
# Invert angle depending on odd or even branch points
# (only to this if major branch -> SWITCH_DIST)
if n_branch_points % 2 != 0 and distances[-1] >= DEFAULTS['switch_dist']:
angle *= - 1
# Angle to radians
angle *= math.pi/180
# Add to last angle
angle += last_angle
# Calc x/y positions
y_coords = np.array([math.sin(angle) * v for v in distances])
x_coords = np.array([math.cos(angle) * v for v in distances])
# Offset by starting point
y_coords += start_point[1]
x_coords += start_point[0]
# Apply shade
color = DEFAULTS['color']
if shade_by_length:
a = .8 - .8 * distances[-1] / path_df.cable.max()
color = mcl.to_rgba(color, alpha=a)
# Change linewidths with path length
lw = 1 * distances[-1] / path_df.cable.max() + .5
# Plot
ax.plot(x_coords, y_coords, color=color, zorder=path_df.shape[0]-k, linewidth=lw)
# Keep track of positions for each treenode and angle of each path
for i, coords in enumerate(zip(x_coords, y_coords)):
positions[path[i]] = coords
if path[i] not in angles:
angles[path[i]] = angle
seen = seen | set(path)
# Plot connectors
if connectors and x.has_connectors:
# Get centers for each connector
centers = np.vstack(x.connectors.node_id.map(positions))
# Angle of the branch they belong to
angles = (x.connectors.node_id.map(angles) + 90 * (math.pi / 180)).values
# Create lines orthogonal to parent branch
y_coords = np.sin(angles) * DEFAULTS['syn_marker_size']
y_coords = np.dstack((y_coords + centers[:, 1],
-y_coords + centers[:, 1],
[None] * len(y_coords))
).flatten()
x_coords = np.cos(angles) * DEFAULTS['syn_marker_size']
x_coords = np.dstack((x_coords + centers[:, 0],
-x_coords + centers[:, 0],
[None] * len(x_coords))
).flatten()
cn_cmap = prepare_connector_cmap(x)
for ty in x.connectors.type.unique():
is_type = x.connectors.type == ty
is_type = np.dstack((is_type, is_type, is_type)).flatten()
ax.plot(x_coords[is_type],
y_coords[is_type],
color=cn_cmap[ty]['color'],
zorder=path_df.shape[0] + 1,
linewidth=DEFAULTS['syn_linewidth'])
# Plot highlighted connectors
if not isinstance(highlight_connectors, type(None)) and x.has_connectors:
this = x.connectors[x.connectors.connector_id.isin(highlight_connectors)]
# Get centers for each connector
centers = np.vstack(this.node_id.map(positions))
# Angle of the branch they belong to
angles = (this.node_id.map(angles) + 90 * (math.pi / 180)).values
# Create lines orthogonal to parent branch
y_coords = np.sin(angles) * DEFAULTS['syn_marker_size']
y_coords = np.dstack((y_coords + centers[:, 1],
-y_coords + centers[:, 1],
[None] * len(y_coords))
).flatten()
x_coords = np.cos(angles) * DEFAULTS['syn_marker_size']
x_coords = np.dstack((x_coords + centers[:, 0],
-x_coords + centers[:, 0],
[None] * len(x_coords))
).flatten()
ax.plot(x_coords,
y_coords,
color=DEFAULTS['syn_highlight_color'],
zorder=path_df.shape[0] + 2,
linewidth=DEFAULTS['syn_linewidth'])
# Plot soma
if x.has_soma:
soma = utils.make_iterable(x.soma)[0]
soma_pos = positions[soma]
ax.scatter([soma_pos[0]], [soma_pos[1]],
s=40, color=(.1, .1, .1))
# Make sure x/y axis are equal
ax.set_aspect('equal')
# Return axis
return ax, positions
def _plot_force(x, connectors=False, highlight_connectors=None, prog='dot',
ax=None, **kwargs):
"""Plot neurons as dendrograms using graphviz layouts."""
DEFAULTS = _DEFAULTS.copy()
DEFAULTS.update(kwargs)
# Save start time
start = time.time()
# Generate and populate networkX graph representation of the neuron
G = x.graph.copy()
# graphviz needs "len" not "weight"
nx.set_edge_attributes(G, nx.get_edge_attributes(G, 'weight'), name='len')
# Calculate layout
logger.info('Calculating node positions.')
positions = nx.nx_agraph.graphviz_layout(G, prog=prog,
root=utils.make_iterable(x.soma)[0] if x.has_soma else None)
# Plot tree with above layout
logger.info('Plotting tree.')
if not ax:
fig, ax = plt.subplots(figsize=kwargs.get('figsize', (12, 6)))
# Make background transparent (nicer for dark themes)
fig.patch.set_alpha(0)
ax.patch.set_alpha(0)
nx.draw(G, positions,
node_size=0,
arrows=False,
edge_color=DEFAULTS['color'],
ax=ax)
# Add soma
if x.has_soma:
for s in utils.make_iterable(x.soma):
ax.scatter([positions[s][0]], [positions[s][1]],
s=40, color=DEFAULTS['color'],
zorder=1)
if connectors and x.has_connectors:
cn_cmap = prepare_connector_cmap(x)
for ty in x.connectors.type.unique():
this = x.connectors[x.connectors.type == ty]
coords = np.vstack(this.node_id.map(positions))
ax.scatter(coords[:, 0], coords[:, 1],
color=cn_cmap[ty]['color'],
zorder=2,
s=DEFAULTS['syn_marker_size'] * 10)
if not isinstance(highlight_connectors, type(None)) and x.has_connectors:
this = x.connectors[x.connectors.connector_id.isin(highlight_connectors)]
coords = np.vstack(this.node_id.map(positions))
ax.scatter(coords[:, 0], coords[:, 1],
color=DEFAULTS['syn_highlight_color'],
zorder=3,
s=DEFAULTS['syn_marker_size'] * 10)
logger.debug(f'Done in {time.time()-start}s')
return ax, positions