Source code for navis.plotting.dd

#    This script is part of navis (
#    Copyright (C) 2018 Philipp Schlegel
#    This program is free software: you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation, either version 3 of the License, or
#    (at your option) any later version.
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    GNU General Public License for more details.

""" Module contains functions to plot neurons in 2D/2.5D.
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import matplotlib.patches as mpatches
import matplotlib.colors as mcl
import mpl_toolkits
from mpl_toolkits.mplot3d.art3d import (Line3DCollection, Poly3DCollection,
                                        Path3DCollection, Patch3DCollection)
from matplotlib.collections import LineCollection, PatchCollection
from import ScalarMappable

import numpy as np

import pint
import warnings

from typing import Union, List, Tuple
import copy
from typing_extensions import Literal

from .. import utils, config, core, conversion
from .colors import prepare_colormap, vertex_colors
from .plot_utils import segments_to_coords, tn_pairs_to_coords

__all__ = ['plot2d']

logger = config.get_logger(__name__)

with warnings.catch_warnings():

[docs]def plot2d(x: Union[core.NeuronObject, core.Volume, np.ndarray, List[Union[core.NeuronObject, np.ndarray, core.Volume]] ], method: Union[Literal['2d'], Literal['3d'], Literal['3d_complex']] = '3d', **kwargs) -> Tuple[mpl.figure.Figure, mpl.axes.Axes]: """Generate 2D plots of neurons and neuropils. The main advantage of this is that you can save plot as vector graphics. Important --------- This function uses matplotlib which "fakes" 3D as it has only very limited control over layering objects in 3D. Therefore neurites are not necessarily plotted in the right Z order. This becomes especially troublesome when plotting a complex scene with lots of neurons criss-crossing. See the ``method`` parameter for details. All methods use orthogonal projection. Parameters ---------- x : TreeNeuron | MeshNeuron | NeuronList | Volume | Dotprops | np.ndarray Objects to plot: - multiple objects can be passed as list (see examples) - numpy array of shape (n,3) is intepreted as points for scatter plots method : '2d' | '3d' (default) | '3d_complex' Method used to generate plot. Comes in three flavours: 1. '2d' uses normal matplotlib. Neurons are plotted on top of one another in the order their are passed to the function. Use the ``view`` parameter (below) to set the view (default = xy). 2. '3d' uses matplotlib's 3D axis. Here, matplotlib decide the depth order (zorder) of plotting. Can change perspective either interacively or by code (see examples). 3. '3d_complex' same as 3d but each neuron segment is added individually. This allows for more complex zorders to be rendered correctly. Slows down rendering though. soma : bool, default=True Plot soma if one exists. Size of the soma is determined by the neuron's ``.soma_radius`` property which defaults to the "radius" column for ``TreeNeurons``. connectors : bool, default=True Plot connectors. connectors_only : boolean, default=False Plot only connectors, not the neuron. cn_size : int | float, default = 1 Size of connectors. linewidth : int | float, default=.5 Width of neurites. Also accepts alias ``lw``. linestyle : str, default='-' Line style of neurites. Also accepts alias ``ls``. autoscale : bool, default=True If True, will scale the axes to fit the data. scalebar : int | float | str | pint.Quantity, default=False Adds scale bar. Provide integer, float or str to set size of scalebar. Int|float are assumed to be in same units as data. You can specify units in as string: e.g. "1 um". For methods '3d' and '3d_complex', this will create an axis object. ax : matplotlib ax, default=None Pass an ax object if you want to plot on an existing canvas. Must match ``method`` - i.e. 2D or 3D axis. figsize : tuple, default=(8, 8) Size of figure. color : None | str | tuple | list | dict, default=None Use single str (e.g. ``'red'``) or ``(r, g, b)`` tuple to give all neurons the same color. Use ``list`` of colors to assign colors: ``['red', (1, 0, 1), ...]. Use ``dict`` to map colors to neuron IDs: ``{id: (r, g, b), ...}``. palette : str | array | list of arrays, default=None Name of a matplotlib or seaborn palette. If ``color`` is not specified will pick colors from this palette. color_by : str | array | list of arrays, default = None Can be the name of a column in the node table of ``TreeNeurons`` or an array of (numerical or categorical) values for each node. Numerical values will be normalized. You can control the normalization by passing a ``vmin`` and/or ``vmax`` parameter. shade_by : str | array | list of arrays, default=None Similar to ``color_by`` but will affect only the alpha channel of the color. If ``shade_by='strahler'`` will compute Strahler order if not already part of the node table (TreeNeurons only). Numerical values will be normalized. You can control the normalization by passing a ``smin`` and/or ``smax`` parameter. alpha : float [0-1], default=.9 Alpha value for neurons. Overriden if alpha is provided as fourth value in ``color`` (rgb*a*). You can override alpha value for connectors by using ``cn_alpha``. clusters : list, default=None A list assigning a cluster to each neuron (e.g. ``[0, 0, 0, 1, 1]``). Overrides ``color`` and uses ``palette`` to generate colors according to clusters. depth_coloring : bool, default=False If True, will color encode depth (Z). Overrides ``color``. Does not work with ``method = '3d_complex'``. depth_scale : bool, default=True If True and ``depth_coloring=True`` will plot a scale. cn_mesh_colors : bool, default=False If True, will use the neuron's color for its connectors. group_neurons : bool, default=False If True, neurons will be grouped. Works with SVG export (not PDF). Does NOT work with ``method='3d_complex'``. scatter_kws : dict, default={} Parameters to be used when plotting points. Accepted keywords are: ``size`` and ``color``. view : tuple, default = ("x", "y") Sets view for ``method='2d'``. orthogonal : bool, default=True Whether to use orthogonal or perspective view for methods '3d' and '3d_complex'. volume_outlines : bool | "both", default=True If True will plot volume outline with no fill. Only works with `method="2d"`. dps_scale_vec : float Scale vector for dotprops. rasterize : bool, default=False Neurons produce rather complex vector graphics which can lead to large files when saving to SVG, PDF or PS. Use this parameter to rasterize neurons and meshes/volumes (but not axes or labels) to reduce file size. Examples -------- .. plot:: :context: close-figs >>> import navis >>> import matplotlib.pyplot as plt Plot list of neurons as simple 2d >>> nl = navis.example_neurons() >>> fig, ax = navis.plot2d(nl, method='2d', view=('x', '-y')) >>> # doctest: +SKIP Add a volume .. plot:: :context: close-figs >>> vol = navis.example_volume('LH') >>> fig, ax = navis.plot2d([nl, vol], method='2d', view=('x', '-y')) >>> # doctest: +SKIP Change neuron colors .. plot:: :context: close-figs >>> fig, ax = navis.plot2d(nl, ... method='2d', ... view=('x', '-y'), ... color=['r', 'g', 'b', 'm', 'c', 'y']) >>> # doctest: +SKIP Plot in "fake" 3D .. plot:: :context: close-figs >>> fig, ax = navis.plot2d(nl, method='3d') >>> # doctest: +SKIP >>> # In an interactive window you can dragging the plot to rotate Plot in "fake" 3D and change perspective .. plot:: :context: close-figs >>> fig, ax = navis.plot2d(nl, method='3d') >>> # Change view to frontal (for example neurons) >>> ax.azim = ax.elev = 90 >>> # Change view to lateral >>> ax.azim, ax.elev = 180, 180 >>> ax.elev = 0 >>> # Change view to top >>> ax.azim, ax.elev = 90, 180 >>> # Tilted top view >>> ax.azim, ax.elev = -130, -150 >>> # Move camera >>> ax.dist = 6 >>> # doctest: +SKIP Plot using depth-coloring .. plot:: :context: close-figs >>> fig, ax = navis.plot2d(nl, method='3d', depth_coloring=True) >>> # doctest: +SKIP To close all figures >>> plt.close('all') See the :ref:`plotting tutorial <plot_intro>` for more examples. Returns ------- fig, ax : matplotlib figure and axis object See Also -------- :func:`navis.plot3d` Use this if you want interactive, perspectively correct renders and if you don't need vector graphics as outputs. :func:`navis.plot1d` A nifty way to visualise neurons in a single dimension. :func:`navis.plot_flat` Plot neurons as flat structures (e.g. dendrograms). """ # Filter kwargs _ACCEPTED_KWARGS = ['soma', 'connectors', 'connectors_only', 'ax', 'color', 'colors', 'c', 'view', 'scalebar', 'cn_mesh_colors', 'linewidth', 'cn_size', 'cn_alpha', 'orthogonal', 'group_neurons', 'scatter_kws', 'figsize', 'linestyle', 'rasterize', 'clusters', 'synapse_layout', 'alpha', 'depth_coloring', 'autoscale', 'depth_scale', 'ls', 'lw', 'volume_outlines', 'radius', 'dps_scale_vec', 'palette', 'color_by', 'shade_by', 'vmin', 'vmax', 'smin', 'smax', 'norm_global'] wrong_kwargs = [a for a in kwargs if a not in _ACCEPTED_KWARGS] if wrong_kwargs: raise KeyError(f'Unknown kwarg(s): {", ".join(wrong_kwargs)}. ' f'Currently accepted: {", ".join(_ACCEPTED_KWARGS)}') _METHOD_OPTIONS = ['2d', '3d', '3d_complex'] if method not in _METHOD_OPTIONS: raise ValueError(f'Unknown method "{method}". Please use either: ' f'{",".join(_METHOD_OPTIONS)}') connectors = kwargs.get('connectors', False) connectors_only = kwargs.get('connectors_only', False) ax = kwargs.pop('ax', None) scalebar = kwargs.get('scalebar', None) # Depth coloring depth_coloring = kwargs.get('depth_coloring', False) depth_scale = kwargs.get('depth_scale', True) scatter_kws = kwargs.get('scatter_kws', {}) autoscale = kwargs.get('autoscale', True) # Parse objects (neurons, volumes, points, _) = utils.parse_objects(x) # Generate colors colors = kwargs.pop('color', kwargs.pop('c', kwargs.pop('colors', None))) palette = kwargs.get('palette', None) color_by = kwargs.get('color_by', None) shade_by = kwargs.get('shade_by', None) # Generate the colormaps (neuron_cmap, volumes_cmap) = prepare_colormap(colors, neurons=neurons, volumes=volumes, palette=palette, clusters=kwargs.get('clusters', None), alpha=kwargs.get('alpha', None), color_range=1) if not isinstance(color_by, type(None)): if not palette: raise ValueError('Must provide `palette` (e.g. "viridis") argument ' 'if using `color_by`') neuron_cmap = vertex_colors(neurons, by=color_by, use_alpha=False, palette=palette, norm_global=kwargs.get('norm_global', True), vmin=kwargs.get('vmin', None), vmax=kwargs.get('vmax', None), na=kwargs.get('na', 'raise'), color_range=1) if not isinstance(shade_by, type(None)): alphamap = vertex_colors(neurons, by=shade_by, use_alpha=True, palette='viridis', # palette is irrelevant here norm_global=kwargs.get('norm_global', True), vmin=kwargs.get('smin', None), vmax=kwargs.get('smax', None), na=kwargs.get('na', 'raise'), color_range=1) new_colormap = [] for c, a in zip(neuron_cmap, alphamap): if not (isinstance(c, np.ndarray) and c.ndim == 2): c = np.tile(c, (a.shape[0], 1)) if c.shape[1] == 4: c[:, 3] = a[:, 3] else: c = np.insert(c, 3, a[:, 3], axis=1) new_colormap.append(c) neuron_cmap = new_colormap # Set axis projection if method in ['3d', '3d_complex']: if kwargs.get('orthogonal', True): mpl_toolkits.mplot3d.proj3d.persp_transformation = _orthogonal_proj else: mpl_toolkits.mplot3d.proj3d.persp_transformation = _perspective_proj # Generate axes if not ax: if method == '2d': fig, ax = plt.subplots(figsize=kwargs.get('figsize', (8, 8))) ax.set_aspect('equal') elif method in ['3d', '3d_complex']: fig = plt.figure(figsize=kwargs.get('figsize', plt.figaspect(1) * 1.5)) ax = fig.add_subplot(111, projection='3d') # This sets front view ax.azim = -90 ax.elev = 0 ax.dist = 7 # Disallowed for 3D in matplotlib 3.1.0 # ax.set_aspect('equal') # Make background transparent (nicer for dark themes) fig.patch.set_alpha(0) ax.patch.set_alpha(0) # Check if correct axis were provided else: if not isinstance(ax, mpl.axes.Axes): raise TypeError('Ax must be of type "mpl.axes.Axes", ' f'not "{type(ax)}"') fig = ax.get_figure() if method in ['3d', '3d_complex']: if != '3d': raise TypeError('Axis must be 3d.') elif method == '2d': if == '3d': raise TypeError('Axis must be 2d.') ax.had_data = ax.has_data() # Prepare some stuff for depth coloring if depth_coloring and not neurons.empty: if method == '3d_complex': raise Exception(f'Depth coloring unavailable for method "{method}"') elif method == '2d': bbox = neurons.bbox # Add to kwargs xy = [v.replace('-', '').replace('+', '') for v in kwargs.get('view', ('x', 'y'))] z_ix = [v[1] for v in [('x', 0), ('y', 1), ('z', 2)] if v[0] not in xy] kwargs['norm'] = plt.Normalize(vmin=bbox[z_ix, 0], vmax=bbox[z_ix, 1]) # Plot volumes first if volumes: for i, v in enumerate(volumes): _ = _plot_volume(v, volumes_cmap[i], method, ax, **kwargs) # Create lines from segments visuals = {} for i, neuron in enumerate(config.tqdm(neurons, desc='Plot neurons', leave=False, disable=config.pbar_hide | len(neurons) < 2)): if not connectors_only: if isinstance(neuron, core.TreeNeuron) and kwargs.get('radius', False): _neuron = conversion.tree2meshneuron(neuron) _neuron.connectors = neuron.connectors neuron = _neuron if isinstance(neuron, core.TreeNeuron) and neuron.nodes.empty: logger.warning(f'Skipping TreeNeuron w/o nodes: {neuron.label}') if isinstance(neuron, core.TreeNeuron) and neuron.nodes.shape[0] == 1: logger.warning(f'Skipping single-node TreeNeuron: {neuron.label}') elif isinstance(neuron, core.MeshNeuron) and neuron.faces.size == 0: logger.warning(f'Skipping MeshNeuron w/o faces: {neuron.label}') elif isinstance(neuron, core.Dotprops) and neuron.points.size == 0: logger.warning(f'Skipping Dotprops w/o points: {neuron.label}') elif isinstance(neuron, core.TreeNeuron): lc, sc = _plot_skeleton(neuron, neuron_cmap[i], method, ax, **kwargs) # Keep track of visuals related to this neuron visuals[neuron] = {'skeleton': lc, 'somata': sc} elif isinstance(neuron, core.MeshNeuron): m = _plot_mesh(neuron, neuron_cmap[i], method, ax, **kwargs) visuals[neuron] = {'mesh': m} elif isinstance(neuron, core.Dotprops): dp = _plot_dotprops(neuron, neuron_cmap[i], method, ax, **kwargs) visuals[neuron] = {'dotprop': dp} elif isinstance(neuron, core.VoxelNeuron): dp = _plot_voxels(neuron, neuron_cmap[i], method, ax, kwargs, **scatter_kws) visuals[neuron] = {'dotprop': dp} else: raise TypeError(f"Don't know how to plot neuron of type '{type(neuron)}' ") if (connectors or connectors_only) and neuron.has_connectors: _ = _plot_connectors(neuron, neuron_cmap[i], method, ax, **kwargs) for p in points: _ = _plot_scatter(p, method, ax, kwargs, **scatter_kws) if autoscale: if method == '2d': ax.autoscale(tight=True) elif method in ['3d', '3d_complex']: # Make sure data lims are set correctly update_axes3d_bounds(ax) # Rezie to have equal aspect set_axes3d_equal(ax) if scalebar is not None: _ = _add_scalebar(scalebar, neurons, method, ax) def set_depth(): """Set depth information for neurons according to camera position.""" # Get projected coordinates proj_co = mpl_toolkits.mplot3d.proj3d.proj_points(all_co, ax.get_proj()) # Get min and max of z coordinates z_min, z_max = min(proj_co[:, 2]), max(proj_co[:, 2]) # Generate a new normaliser norm = plt.Normalize(vmin=z_min, vmax=z_max) # Go over all neurons and update Z information for neuron in visuals: # Get this neurons colletion and coordinates if 'skeleton' in visuals[neuron]: c = visuals[neuron]['skeleton'] this_co = c._segments3d[:, 0, :] elif 'mesh' in visuals[neuron]: c = visuals[neuron]['mesh'] # Note that we only get every third position -> that's because # these vectors actually represent faces, i.e. each vertex this_co = c._vec.T[::3, [0, 1, 2]] else: raise ValueError(f'Neither mesh nor skeleton found for neuron {}') # Get projected coordinates this_proj = mpl_toolkits.mplot3d.proj3d.proj_points(this_co, ax.get_proj()) # Normalise z coordinates ns = norm(this_proj[:, 2]).data # Set array c.set_array(ns) # No need for normaliser - already happened c.set_norm(None) if (isinstance(neuron, core.TreeNeuron) and not isinstance(getattr(neuron, 'soma', None), type(None))): # Get depth of soma(s) soma = utils.make_iterable(neuron.soma) soma_co = neuron.nodes.set_index('node_id').loc[soma][['x', 'y', 'z']].values soma_proj = mpl_toolkits.mplot3d.proj3d.proj_points(soma_co, ax.get_proj()) soma_cs = norm(soma_proj[:, 2]).data # Set soma color for cs, s in zip(soma_cs, visuals[neuron]['somata']): s.set_color(cmap(cs)) def Update(event): set_depth() if depth_coloring: cmap = if method == '2d' and depth_scale: sm = ScalarMappable(norm=kwargs['norm'], cmap=cmap) fig.colorbar(sm, ax=ax, fraction=.075, shrink=.5, label='Depth') elif method == '3d': # Collect all coordinates all_co = [] for n in visuals: if 'skeleton' in visuals[n]: all_co.append(visuals[n]['skeleton']._segments3d[:, 0, :]) if 'mesh' in visuals[n]: all_co.append(visuals[n]['mesh']._vec.T[:, [0, 1, 2]]) all_co = np.concatenate(all_co, axis=0) fig.canvas.mpl_connect('draw_event', Update) set_depth() plt.axis('off') return fig, ax
def _add_scalebar(scalebar, neurons, method, ax): """Add scalebar.""" if isinstance(scalebar, bool): scalebar = '1 um' if isinstance(scalebar, str): scalebar = config.ureg(scalebar) if isinstance(scalebar, pint.Quantity): # If we have neurons as points of reference convert if neurons: scalebar =[0].units).magnitude # If no reference, use assume it's the same units else: scalebar = scalebar.magnitude # Hard-coded offset from figure boundaries ax_offset = (ax.get_xlim()[1] - ax.get_xlim()[0]) / 100 * 5 if method == '2d': xlim = ax.get_xlim() ylim = ax.get_ylim() coords = np.array([[xlim[0] + ax_offset, ylim[0] + ax_offset], [xlim[0] + ax_offset + scalebar, ylim[0] + ax_offset] ]) sbar = mlines.Line2D( coords[:, 0], coords[:, 1], lw=3, alpha=.9, color='black') sbar.set_gid(f'{scalebar}_scalebar') ax.add_line(sbar) elif method in ['3d', '3d_complex']: xlim = ax.get_xlim() ylim = ax.get_ylim() zlim = ax.get_zlim() left = xlim[0] + ax_offset bottom = zlim[0] + ax_offset front = ylim[0] + ax_offset sbar = [np.array([[left, front, bottom], [left, front, bottom]]), np.array([[left, front, bottom], [left, front, bottom]]), np.array([[left, front, bottom], [left, front, bottom]])] sbar[0][1][0] += scalebar sbar[1][1][1] += scalebar sbar[2][1][2] += scalebar lc = Line3DCollection(sbar, color='black', lw=1) lc.set_gid(f'{scalebar}_scalebar') ax.add_collection3d(lc) def _plot_scatter(points, method, ax, kwargs, **scatter_kws): """Plot dotprops.""" if method == '2d': default_settings = dict( c='black', zorder=4, edgecolor='none', s=1 ) default_settings.update(scatter_kws) default_settings = _fix_default_dict(default_settings) view = kwargs.get('view', ('x', 'y')) x, y = _parse_view2d(points, view) ax.scatter(x, y, **default_settings) elif method in ['3d', '3d_complex']: default_settings = dict( c='black', s=1, depthshade=False, edgecolor='none' ) default_settings.update(scatter_kws) default_settings = _fix_default_dict(default_settings) ax.scatter(points[:, 0], points[:, 1], points[:, 2], **default_settings ) def _plot_voxels(vx, color, method, ax, kwargs, **scatter_kws): """Plot VoxelNeuron as scatter plot.""" # Use only the top N voxels assert isinstance(vx, core.VoxelNeuron) n_pts = 1000000 v = vx.values pts = vx.voxels srt = np.argsort(v)[::-1] pts = pts[srt][: n_pts] v = v[srt][: n_pts] # Scale points by units pts = pts * vx.units_xyz.magnitude + vx.offset # Calculate colors cmap = color_to_cmap(color) colors = cmap(v / v.max()) if method == '2d': view = kwargs.get('view', ('x', 'y')) x, y = _parse_view2d(pts, view) ax.scatter(x, y, c=colors, s=scatter_kws.get('size', 20)) elif method in ['3d', '3d_complex']: ax.scatter(pts[:, 0], pts[:, 1], pts[:, 2], c=colors, marker=scatter_kws.get('marker', 'o'), s=scatter_kws.get('size', .1)) def color_to_cmap(color): """Convert single color to color palette.""" color = mcl.to_rgb(color) colors = [[color[0], color[1], color[2], 0], [color[0], color[1], color[2], 1]] return mcl.LinearSegmentedColormap.from_list('Palette', colors, N=256) def _plot_dotprops(dp, color, method, ax, **kwargs): """Plot dotprops.""" # Here, we will effectively cheat and turn the dotprops into a skeleton # which we can then pass to _plot_skeleton tn = dp.to_skeleton(scale_vec=kwargs.get('dps_scale_vec', 'auto')) return _plot_skeleton(tn, color, method, ax, **kwargs) def _plot_connectors(neuron, color, method, ax, **kwargs): """Plot connectors.""" view = kwargs.get('view', ('x', 'y')) if not kwargs.get('cn_mesh_colors', False): cn_layout = config.default_connector_colors.copy() else: cn_layout = copy.deepcopy(config.default_connector_colors) # change all of the colors to color for inner_dict in cn_layout.values(): if not isinstance(inner_dict, dict): continue inner_dict["color"] = color cn_layout.update(kwargs.get('synapse_layout', {})) if method == '2d': for c in neuron.connectors.type.unique(): this_cn = neuron.connectors[neuron.connectors.type == c] x, y = _parse_view2d(this_cn[['x', 'y', 'z']].values, view) ax.scatter(x, y, color=cn_layout[c]['color'], edgecolor='none', s=kwargs.get('cn_size', cn_layout['size'])) ax.get_children()[-1].set_gid(f'CN_{}') elif method in ['3d', '3d_complex']: all_cn = neuron.connectors c = [cn_layout[i]['color'] for i in all_cn.type.values] ax.scatter(all_cn.x.values, all_cn.y.values, all_cn.z.values, color=c, s=kwargs.get('cn_size', cn_layout['size']), depthshade=cn_layout.get('depthshade', False), edgecolor='none') ax.get_children()[-1].set_gid(f'CN_{}') def _plot_mesh(neuron, color, method, ax, **kwargs): """Plot mesh (i.e. MeshNeuron).""" name = getattr(neuron, 'name') depth_coloring = kwargs.get('depth_coloring', False) alpha = kwargs.get('alpha', None) group_neurons = kwargs.get('group_neurons', False) view = kwargs.get('view', ('x', 'y')) rasterize = kwargs.get('rasterize', False) # Add alpha if alpha: color = (color[0], color[1], color[2], alpha) ts = None if method == '2d': # Generate 2d representation xy = np.dstack(_parse_view2d(neuron.vertices, view))[0] # Map vertex colors to faces if isinstance(color, np.ndarray) and color.ndim == 2: if len(color) != len(neuron.faces) and len(color) == len(neuron.vertices): color = [color[f].mean(axis=0)[:3].tolist() for f in neuron.faces] # Generate a patch for each face patches = [] for i, f in enumerate(neuron.faces): p = mpatches.Polygon(xy[f], closed=True) patches.append(p) pc = PatchCollection(patches, linewidth=0, facecolor=color, rasterized=rasterize, edgecolor='none', alpha=alpha) ax.add_collection(pc) else: ts = ax.plot_trisurf(neuron.vertices[:, 0], neuron.vertices[:, 1], neuron.faces, neuron.vertices[:, 2], label=name, rasterized=rasterize, if depth_coloring else None, color=color) if group_neurons: ts.set_gid( return ts def _get_depth_axis(view): """Return index of axis which is not used for x/y.""" view = [v.replace('-', '').replace('+', '') for v in view] depth = [ax for ax in ['x', 'y', 'z']][0] map = {'x': 0, 'y': 1, 'z': 2} return map[depth] def _parse_view2d(co, view): """Parse view parameter and returns x/y parameter.""" if not isinstance(co, np.ndarray): co = np.array(co) map = {'x': 0, 'y': 1, 'z': 2} x_ix = map[view[0].replace('-', '').replace('+', '')] y_ix = map[view[1].replace('-', '').replace('+', '')] x_mod = -1 if '-' in view[0] else 1 y_mod = -1 if '-' in view[1] else 1 if co.ndim == 2: x = co[:, x_ix] y = co[:, y_ix] # Multiply only where co is not None x = np.multiply(x, x_mod, where=x != None, subok=False) y = np.multiply(y, y_mod, where=y != None, subok=False) # Do NOT remove the list() here - for some reason the multiplication # above causes issues in matplotlib return (list(x), list(y)) elif co.ndim == 3: xy = co[:, :, [x_ix, y_ix]] * [x_mod, y_mod] return xy else: raise ValueError(f'Expect coordinates to have 2 or 3 dimensions, got {co.ndim}') def _plot_skeleton(neuron, color, method, ax, **kwargs): """Plot skeleton.""" depth_coloring = kwargs.get('depth_coloring', False) linewidth = kwargs.get('linewidth', kwargs.get('lw', .5)) linestyle = kwargs.get('linestyle', kwargs.get('ls', '-')) alpha = kwargs.get('alpha', None) norm = kwargs.get('norm') plot_soma = kwargs.get('soma', True) group_neurons = kwargs.get('group_neurons', False) view = kwargs.get('view', ('x', 'y')) rasterize = kwargs.get('rasterize', False) if method == '2d': if not depth_coloring and not (isinstance(color, np.ndarray) and color.ndim == 2): # Generate by-segment coordinates coords = segments_to_coords(neuron, neuron.segments, modifier=(1, 1, 1)) # We have to add (None, None, None) to the end of each # slab to make that line discontinuous there coords = np.vstack([np.append(t, [[None] * 3], axis=0) for t in coords]) x, y = _parse_view2d(coords, view) this_line = mlines.Line2D(x, y, lw=linewidth, ls=linestyle, alpha=alpha, color=color, rasterized=rasterize, label=f'{getattr(neuron, "name", "NA")} - #{}') ax.add_line(this_line) else: coords = tn_pairs_to_coords(neuron, modifier=(1, 1, 1)) xy = _parse_view2d(coords, view) lc = LineCollection(xy, cmap='jet' if depth_coloring else None, norm=norm if depth_coloring else None, rasterized=rasterize, joinstyle='round') lc.set_linewidth(linewidth) lc.set_linestyle(linestyle) lc.set_label(f'{getattr(neuron, "name", "NA")} - #{}') if depth_coloring: lc.set_alpha(alpha) lc.set_array(neuron.nodes.loc[neuron.nodes.parent_id >= 0, 'z'].values) elif (isinstance(color, np.ndarray) and color.ndim == 2): # If we have a color for each node, we need to drop the roots if color.shape[1] != coords.shape[0]: lc.set_color(color[neuron.nodes.parent_id.values >= 0]) else: lc.set_color(color) ax.add_collection(lc) if plot_soma and np.any(neuron.soma): soma = utils.make_iterable(neuron.soma) # If soma detection is messed up we might end up producing # dozens of soma which will freeze the kernel if len(soma) >= 10: logger.warning(f'{} - {len(soma)} somas found.') for s in soma: if isinstance(color, np.ndarray) and color.ndim > 1: s_ix = np.where(neuron.nodes.node_id == s)[0][0] soma_color = color[s_ix] else: soma_color = color n = neuron.nodes.set_index('node_id').loc[s] r = getattr(n, neuron.soma_radius) if isinstance(neuron.soma_radius, str) else neuron.soma_radius if depth_coloring: d = [n.x, n.y, n.z][_get_depth_axis(view)] soma_color = sx, sy = _parse_view2d(np.array([[n.x, n.y, n.z]]), view) c = mpatches.Circle((sx[0], sy[0]), radius=r, alpha=alpha, fill=True, fc=soma_color, rasterized=rasterize, zorder=4, edgecolor='none') ax.add_patch(c) return None, None elif method in ['3d', '3d_complex']: # For simple scenes, add whole neurons at a time to speed up rendering if method == '3d': if (isinstance(color, np.ndarray) and color.ndim == 2) or depth_coloring: coords = tn_pairs_to_coords(neuron, modifier=(1, 1, 1)) # If we have a color for each node, we need to drop the roots if isinstance(color, np.ndarray) and color.shape[1] != coords.shape[0]: line_color = color[neuron.nodes.parent_id.values >= 0] else: line_color = color else: # Generate by-segment coordinates coords = segments_to_coords(neuron, neuron.segments, modifier=(1, 1, 1)) line_color = color lc = Line3DCollection(coords, color=line_color if not depth_coloring else None,, alpha=alpha, cmap=None if not depth_coloring else, lw=linewidth, joinstyle='round', rasterized=rasterize, linestyle=linestyle) if group_neurons: lc.set_gid( # Need to get this before adding data line3D_collection = lc ax.add_collection3d(lc) # For complex scenes, add each segment as a single collection # -> helps reducing Z-order errors elif method == '3d_complex': # Generate by-segment coordinates coords = segments_to_coords(neuron, neuron.segments, modifier=(1, 1, 1)) for c in coords: lc = Line3DCollection([c], color=color, lw=linewidth, alpha=alpha, rasterized=rasterize, linestyle=linestyle) if group_neurons: lc.set_gid( ax.add_collection3d(lc) line3D_collection = None surf3D_collections = [] if plot_soma and not isinstance(getattr(neuron, 'soma', None), type(None)): soma = utils.make_iterable(neuron.soma) # If soma detection is messed up we might end up producing # dozens of soma which will freeze the kernel if len(soma) >= 5: logger.warning(f'Neuron {} appears to have {len(soma)}' ' somas. Skipping plotting its somas.') else: for s in soma: if isinstance(color, np.ndarray) and color.ndim > 1: s_ix = np.where(neuron.nodes.node_id == s)[0][0] soma_color = color[s_ix] else: soma_color = color n = neuron.nodes.set_index('node_id').loc[s] r = getattr(n, neuron.soma_radius) if isinstance(neuron.soma_radius, str) else neuron.soma_radius resolution = 20 u = np.linspace(0, 2 * np.pi, resolution) v = np.linspace(0, np.pi, resolution) x = r * np.outer(np.cos(u), np.sin(v)) + n.x y = r * np.outer(np.sin(u), np.sin(v)) + n.y z = r * np.outer(np.ones(np.size(u)), np.cos(v)) + n.z surf = ax.plot_surface(x, y, z, color=soma_color, shade=False, rasterized=rasterize, alpha=alpha) if group_neurons: surf.set_gid( surf3D_collections.append(surf) return line3D_collection, surf3D_collections def _plot_volume(volume, color, method, ax, **kwargs): """Plot volume.""" name = getattr(volume, 'name') rasterize = kwargs.get('rasterize', False) if len(color) == 4: this_alpha = color[3] else: this_alpha = 1 if kwargs.get('volume_outlines', False): fill, lw, fc, ec = False, 1, 'none', color else: fill, lw, fc, ec = True, 0, color, 'none' if method == '2d': view = kwargs.get('view', ('x', 'y')) volume_outlines = kwargs.get('volume_outlines', False) if volume_outlines in (False, 'both'): # Generate 2d representation xy = np.dstack(_parse_view2d(volume.verts, view))[0] # Generate a patch for each face patches = [] for f in volume.faces: p = mpatches.Polygon(xy[f], closed=True, fill=fill) patches.append(p) pc = PatchCollection(patches, linewidth=lw, facecolor=fc, rasterized=rasterize, edgecolor=ec, alpha=this_alpha, zorder=0) ax.add_collection(pc) if volume_outlines in (True, 'both'): verts = volume.to_2d(view=view, alpha=0.001) vpatch = mpatches.Polygon(verts, closed=True, lw=lw, fill=fill, rasterized=rasterize, fc=fc, ec=ec, zorder=0, alpha=1 if volume_outlines == 'both' else this_alpha) ax.add_patch(vpatch) elif method in ['3d', '3d_complex']: verts = np.vstack(volume.vertices) # Add alpha if len(color) == 3: color = (color[0], color[1], color[2], .1) ts = ax.plot_trisurf(verts[:, 0], verts[:, 1], volume.faces, verts[:, 2], label=name, rasterized=rasterize, color=color) ts.set_gid(name) def update_axes3d_bounds(ax): """Update axis bounds and remove default points (0,0,0) and (1,1,1).""" # Collect data points present in the figure points = [] for c in ax.collections: if isinstance(c, Line3DCollection): for s in c._segments3d: points.append(s) elif isinstance(c, Poly3DCollection): points.append(c._vec[:3, :].T) elif isinstance(c, (Path3DCollection, Patch3DCollection)): points.append(np.array(c._offsets3d).T) if not len(points): return points = np.vstack(points) # If this is the first set of points, we need to overwrite the defaults # That should happen automatically but for some reason doesn't for 3d axes if not getattr(ax, 'had_data', False): mn = points.min(axis=0) mx = points.max(axis=0) new_xybounds = np.array([[mn[0], mn[1]], [mx[0], mx[1]]]) new_zzbounds = np.array([[mn[2], mn[2]], [mx[2], mx[2]]]) ax.xy_dataLim.set_points(new_xybounds) ax.zz_dataLim.set_points(new_zzbounds) ax.xy_viewLim.set_points(new_xybounds) ax.zz_viewLim.set_points(new_zzbounds) ax.had_data = True else: ax.auto_scale_xyz(points[:, 0].tolist(), points[:, 1].tolist(), points[:, 2].tolist(), had_data=True) def set_axes3d_equal(ax): """Make axes of 3D plot have equal scale. This requires the viewLim to be set correctly: see `update_axes3d_bounds()`. Parameters ---------- ax : a matplotlib axis, e.g., as output from plt.gca(). """ x_limits = ax.get_xlim3d() y_limits = ax.get_ylim3d() z_limits = ax.get_zlim3d() x_range = abs(x_limits[1] - x_limits[0]) x_middle = np.mean(x_limits) y_range = abs(y_limits[1] - y_limits[0]) y_middle = np.mean(y_limits) z_range = abs(z_limits[1] - z_limits[0]) z_middle = np.mean(z_limits) # The plot bounding box is a sphere in the sense of the infinity # norm, hence I call half the max range the plot radius. plot_radius = 0.5*max([x_range, y_range, z_range]) ax.set_xlim3d([x_middle - plot_radius, x_middle + plot_radius]) ax.set_ylim3d([y_middle - plot_radius, y_middle + plot_radius]) ax.set_zlim3d([z_middle - plot_radius, z_middle + plot_radius]) # Set 1:1:1 box ratio ax.set_box_aspect((1, 1, 1)) def __old__update_axes3d_bounds(ax, points): """Update axis bounds and remove default points (0,0,0) and (1,1,1).""" if not isinstance(points, np.ndarray): points = np.ndarray(points) # If this is the first set of points, we need to overwrite the defaults # That should happen automatically but for some reason doesn't for 3d axes if not getattr(ax, 'had_data', False): mn = points.min(axis=0) mx = points.max(axis=0) new_xybounds = np.array([[mn[0], mn[1]], [mx[0], mx[1]]]) new_zzbounds = np.array([[mn[2], mn[2]], [mx[2], mx[2]]]) ax.xy_dataLim.set_points(new_xybounds) ax.zz_dataLim.set_points(new_zzbounds) ax.had_data = True else: ax.auto_scale_xyz(points[:, 0].tolist(), points[:, 1].tolist(), points[:, 2].tolist(), had_data=True) def _fix_default_dict(x): """Consolidate duplicate settings. E.g. scatter kwargs when 'c' and 'color' is provided. """ # The first entry is the "survivor" duplicates = [['color', 'c'], ['size', 's'], ['alpha', 'a']] for dupl in duplicates: if sum([v in x for v in dupl]) > 1: to_delete = [v for v in dupl if v in x][1:] _ = [x.pop(v) for v in to_delete] return x def _perspective_proj(zfront, zback, focal_length=1): """Copy of the original matplotlib projection matrix. Notably, we set a default value for focal_length because this was only added with version 3.6 of matplotlib. """ e = focal_length a = 1 # aspect ratio b = (zfront+zback)/(zfront-zback) c = -2*(zfront*zback)/(zfront-zback) proj_matrix = np.array([[e, 0, 0, 0], [0, e/a, 0, 0], [0, 0, b, c], [0, 0, -1, 0]]) return proj_matrix def _orthogonal_proj(zfront, zback, focal_length=None): """Get matplotlib to use orthogonal instead of perspective view. Usage: proj3d.persp_transformation = _orthogonal_proj """ a = (zfront + zback) / (zfront - zback) b = -2 * (zfront * zback) / (zfront - zback) # -0.0001 added for numerical stability as suggested in: # return np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, a, b], [0, 0, -0.0001, zback]])