from __future__ import annotations
from abc import ABC, abstractmethod
from itertools import permutations
import sys
import os
from collections import Counter
from concurrent.futures import ProcessPoolExecutor
from typing import (
Generic,
Hashable,
Iterator,
Mapping,
Optional,
Sequence,
Callable,
List,
Iterable,
Any,
Tuple,
TypeVar,
Union,
)
import logging
from pathlib import Path
from functools import lru_cache
from copy import deepcopy
import operator
import math
from collections import defaultdict
import numpy as np
import pandas as pd
from .. import config
from ..core import Dotprops
logger = logging.getLogger(__name__)
DEFAULT_SEED = 1991
epsilon = sys.float_info.epsilon
cpu_count = max(1, os.cpu_count() - 1)
fp = Path(__file__).resolve().parent
smat_path = fp / "score_mats"
def chunksize(it_len, cpu_count, min_chunk=50):
return max(min_chunk, int(it_len / (cpu_count * 4)))
def yield_not_same(pairs: Iterable[Tuple[Any, Any]]) -> Iterator[Tuple[Any, Any]]:
for a, b in pairs:
if a != b:
yield a, b
def concat_results(results: Iterable[List[np.ndarray]],
total: Optional[int] = None,
desc: str = 'Querying',
progress: bool = True) -> List[np.ndarray]:
"""Helper function to concatenate batches of e.g. [(dist, dots), (dist, dots)]
into single (dist, dot) arrays.
"""
intermediate = defaultdict(list)
with config.tqdm(desc=desc,
total=total,
leave=False,
disable=not progress) as pbar:
for result_lst in results:
for idx, array in enumerate(result_lst):
intermediate[idx].append(array)
pbar.update(1)
return [np.concatenate(arrs) for arrs in intermediate.values()]
def _nblast_v1_scoring(dist : float, dp : float, sigma_scoring : int = 10):
"""NBLAST analytical scoring function following Kohl et al. (2013).
Parameters
----------
dist : float | array thereof
Distance between two points.
dp : float | array thereof
Absolute dot product between points.
sigma_scoring : int
Sigma of the exponential decrease.
It determines how close in space points must be to be
considered similar. Defaults to 10.
Returns
-------
scores : float or array thereof
Score value(s).
References
----------
Kohl J, Ostrovsky AD, Frechter S, Jefferis GS. A bidirectional circuit switch
reroutes pheromone signals in male and female brains.
Cell. 2013 Dec;155(7) 1610-1623. doi: 10.1016/j.cell.2013.11.025.
"""
return np.sqrt(np.abs(dp) * np.exp(-(dist ** 2)/(2 * sigma_scoring ** 2)))
NeuronKey = Hashable
T = TypeVar("T")
class LookupNdBuilder:
def __init__(
self,
neurons: Union[List[T], Mapping[NeuronKey, T]],
matching_lists: List[List[NeuronKey]],
match_fn: Callable[[T, T], List[np.ndarray]],
nonmatching_list: Optional[List[NeuronKey]] = None,
draw_strat: str = 'batched',
seed: int = DEFAULT_SEED,
) -> None:
f"""Class for building an N-dimensional score lookup (for e.g. NBLAST).
Once instantiated, the axes of the lookup table must be defined.
Call ``.with_digitizers()`` to manually define them, or
``.with_bin_counts()`` to learn them from the matched-pair data.
Then call ``.build()`` to build the lookup table.
Parameters
----------
neurons : dict or list of objects (e.g. Dotprops)
An indexable, consistently-ordered sequence of all
objects (typically neurons) which will be used as the
training set. Importantly: each object must have a
``len()``!
matching_sets : list of lists of index into ``neurons``
Lists of neurons, as indices into ``neurons``, which
should be considered matches:
[[0, 1, 2, 4], [5, 6], [9, 10, 11]]
match_fn : Callable[[object, object], List[np.ndarray[float]]]
Function taking 2 arguments, both instances of type
``neurons``, and returning a list of 1D
``numpy.ndarray``s of floats. The length of the list
must be the same as the length of ``boundaries``.
The length of the ``array``s must be the same as the
number of points in the first argument. This function
returns values describing the quality of point matches
from a query to a target neuron.
nonmatching : list of index into ``neurons``, optional
List of objects, as indices into ``neurons``, which
should be be considered NON-matches. If not given,
all ``neurons`` will be used (on the assumption that
matches are a small subset of possible pairs).
draw_strat : "batched" | "greedy"
Strategy for randomly drawing non-matching pairs. Only
relevant if ``nonmatching`` is not provided.
"batched" should be the right choice in most scenarios.
"greedy" can be better if your pool of neurons is very
small.
seed : int, optional
Non-matching pairs are drawn at random using this seed,
by default {DEFAULT_SEED}.
"""
self.objects = neurons
self.matching_lists = matching_lists
self._nonmatching_list = nonmatching_list
self.match_fn = match_fn
self.nonmatching_draw = draw_strat
self.digitizers: Optional[List[Digitizer]] = None
self.bin_counts: Optional[List[int]] = None
self.seed = seed
self._ndim: Optional[int] = None
@property
def ndim(self) -> int:
if self._ndim is None:
idx1, idx2 = self._object_keys()[:2]
self._ndim = len(self._query(idx1, idx2))
return self._ndim
def with_digitizers(self, digitizers: List[Digitizer]):
"""Specify the axes of the output lookup table directly.
Parameters
----------
digitizers : List[Digitizer]
Returns
-------
self
For chaining convenience.
"""
if len(digitizers) != self.ndim:
raise ValueError(
f"Match function returns {self.ndim} values "
f"but provided {len(digitizers)} digitizers"
)
self.digitizers = digitizers
self.bin_counts = None
return self
def with_bin_counts(self, bin_counts: List[int], method='quantile'):
"""Specify the number of bins on each axis of the output lookup table.
The bin boundaries will be determined by evenly partitioning the data
from the matched pairs into quantiles, in each dimension.
Parameters
----------
bin_counts : List[int]
method : 'quantile' | 'geometric' | 'linear'
Method used to tile the data space.
Returns
-------
self
For chaining convenience.
"""
if len(bin_counts) != self.ndim:
raise ValueError(
f"Match function returns {self.ndim} values "
f"but provided {len(bin_counts)} bin counts"
)
self.bin_counts = bin_counts
self.digitizers = None
self.bin_method = method
return self
def _object_keys(self) -> Sequence[NeuronKey]:
"""Get all indices into objects instance member."""
try:
return self.objects.keys()
except AttributeError:
return range(len(self.objects))
@property
def nonmatching(self) -> List[NeuronKey]:
"""Indices of nonmatching set of neurons."""
if self._nonmatching_list is None:
return list(self._object_keys())
return self._nonmatching_list
def _yield_matching_pairs(self) -> Iterator[Tuple[NeuronKey, NeuronKey]]:
"""Yield all index pairs within all matching pairs."""
for ms in self.matching_lists:
yield from yield_not_same(permutations(ms, 2))
def _yield_nonmatching_pairs(self) -> Iterator[Tuple[NeuronKey, NeuronKey]]:
"""Yield all index pairs within all non-matching pairs."""
if self._nonmatching_list is None:
raise ValueError('Must provide non-matching pairs explicitly.')
yield from yield_not_same(permutations(self._nonmatching_list, 2))
def _yield_nonmatching_pairs_greedy(self, rng=None) -> Iterator[Tuple[NeuronKey, NeuronKey]]:
"""Yield all index pairs within nonmatching list."""
return yield_not_same(permutations(self.nonmatching, 2))
def _yield_nonmatching_pairs_batched(self) -> Iterator[Tuple[NeuronKey, NeuronKey]]:
"""Yield all index pairs within nonmatching list.
This function tries to generate truely random draws of all possible
non-matching pairs without actually having to generate all pairs.
Instead, we generate new randomly permutated pairs in batches from which
we then remove previously seen pairs.
This works reasonable well as long as we only need a small subset
of all possible non-matches. Otherwise this becomes inefficient.
"""
nonmatching = np.array(self.nonmatching)
seen = []
rng = np.random.default_rng(self.seed)
# Generate random pairs
pairs = np.vstack((rng.permutation(nonmatching),
rng.permutation(nonmatching))).T
pairs = pairs[pairs[:, 0] != pairs[:, 1]] # drop self hits
seen = set([tuple(p) for p in pairs]) # track already seen pairs
i = 0
while True:
# If exhausted, generate a new batch of random permutation
if i >= len(pairs):
pairs = np.vstack((rng.permutation(nonmatching),
rng.permutation(nonmatching))).T
pairs = pairs[pairs[:, 0] != pairs[:, 1]] # drop self hits
pairs = set([tuple(p) for p in pairs])
pairs = pairs - seen
seen = seen | pairs
pairs = list(pairs)
i = 0
# Pick a pair
ix1, ix2 = pairs[i]
i += 1
yield (ix1, ix2)
def _empty_counts(self) -> np.ndarray:
"""Create an empty array in which to store counts; shape determined by digitizer sizes."""
shape = [len(b) for b in self.digitizers]
return np.zeros(shape, int)
def _query(self, q_idx, t_idx) -> List[np.ndarray]:
"""Get the results of applying the match function to objects specified by indices."""
return self.match_fn(self.objects[q_idx], self.objects[t_idx])
def _query_many(self, idx_pairs, threads=None) -> Iterator[List[np.ndarray]]:
"""Yield results from querying many pairs of neuron indices."""
if threads is None or (threads == 0 and cpu_count == 1):
for q_idx, t_idx in idx_pairs:
yield self._query(q_idx, t_idx)
return
threads = threads or cpu_count
idx_pairs = np.asarray(idx_pairs)
chunks = chunksize(len(idx_pairs), threads)
with ProcessPoolExecutor(threads) as exe:
yield from exe.map(self.match_fn,
[self.objects[ix] for ix in idx_pairs[:, 0]],
[self.objects[ix] for ix in idx_pairs[:, 1]],
chunksize=chunks)
def _query_to_idxs(self, q_idx, t_idx, counts=None):
"""Produce a digitized counts array from a given query-target pair."""
return self._count_results(self._query(q_idx, t_idx), counts)
def _count_results(self, results: List[np.ndarray], counts=None):
"""Convert raw match function ouput into a digitized counts array.
Requires digitizers.
"""
# Digitize
idxs = [dig(r) for dig, r in zip(self.digitizers, results)]
# Make a stack
stack = np.vstack(idxs).T
# Create empty matrix if necessary
if counts is None:
counts = self._empty_counts()
# Get counts per cell -> this is the actual bottleneck of this function
cells, cnt = np.unique(stack, axis=0, return_counts=True)
# Fill matrix
counts[tuple(cells[:, i] for i in range(cells.shape[1]))] += cnt
return counts
def _counts_array(self,
idx_pairs,
threads=None,
progress=True,
desc=None,
):
"""Convert index pairs into a digitized counts array.
Requires digitizers.
"""
counts = self._empty_counts()
if threads is None or (threads == 0 and cpu_count == 1):
for q_idx, t_idx in config.tqdm(idx_pairs,
leave=False,
desc=desc,
disable=not progress):
counts = self._query_to_idxs(q_idx, t_idx, counts)
return counts
threads = threads or cpu_count
idx_pairs = np.asarray(idx_pairs, dtype=int)
chunks = chunksize(len(idx_pairs), threads)
# because digitizing is not necessarily free,
# keep this parallelisation separate to that in _query_many
with ProcessPoolExecutor(threads) as exe:
# This is the progress bar
with config.tqdm(desc=desc,
total=len(idx_pairs),
leave=False,
disable=not progress) as pbar:
for distdots in exe.map(
self.match_fn,
[self.objects[ix] for ix in idx_pairs[:, 0]],
[self.objects[ix] for ix in idx_pairs[:, 1]],
chunksize=chunks,
):
counts = self._count_results(distdots, counts)
pbar.update(1)
return counts
def _pick_nonmatching_pairs(self, n_matching_qual_vals, progress=True):
"""Using the seeded RNG, pick which non-matching pairs to use."""
# pre-calculating which pairs we're going to use,
# rather than drawing them as we need them,
# means that we can parallelise the later step more effectively.
# Slowdowns here are practically meaningless
# because of how long distdot calculation will take
nonmatching_pairs = []
n_nonmatching_qual_vals = 0
if self.nonmatching_draw == 'batched':
# This is a generator that tries to generate random pairs in
# batches to avoid having to calculate all possible pairs
gen = self._yield_nonmatching_pairs_batched()
with config.tqdm(desc='Drawing non-matching pairs',
total=n_matching_qual_vals,
leave=False,
disable=not progress) as pbar:
# Draw non-matching pairs until we have enough data
for nonmatching_pair in gen:
nonmatching_pairs.append(nonmatching_pair)
new_vals = len(self.objects[nonmatching_pair[0]])
n_nonmatching_qual_vals += new_vals
pbar.update(new_vals)
if n_nonmatching_qual_vals >= n_matching_qual_vals:
break
elif self.nonmatching_draw == 'greedy':
# Generate all possible non-matching pairs
possible_pairs = len(self.nonmatching) ** 2 - len(self.nonmatching)
all_nonmatching_pairs = [p for p in config.tqdm(self._yield_nonmatching_pairs_greedy(),
total=possible_pairs,
desc='Generating non-matching pairs')]
# Randomly pick non-matching pairs until we have enough data
rng = np.random.default_rng(self.seed)
with config.tqdm(desc='Drawing non-matching pairs',
total=n_matching_qual_vals,
leave=False,
disable=not progress) as pbar:
while n_nonmatching_qual_vals < n_matching_qual_vals:
idx = rng.integers(0, len(all_nonmatching_pairs))
nonmatching_pair = all_nonmatching_pairs.pop(idx)
nonmatching_pairs.append(nonmatching_pair)
new_vals = len(self.objects[nonmatching_pair[0]])
n_nonmatching_qual_vals += new_vals
pbar.update(new_vals)
else:
raise ValueError('Unknown strategy for non-matching pair draw:'
f'{self.nonmatching_draw}')
return nonmatching_pairs
def _get_pairs(self):
matching_pairs = list(set(self._yield_matching_pairs()))
# If no explicit non-matches provided, pick them from the entire pool
if self._nonmatching_list is None:
# need to know the eventual distdot count
# so we know how many non-matching pairs to draw
q_idx_count = Counter(p[0] for p in matching_pairs)
n_matching_qual_vals = sum(
len(self.objects[q_idx]) * n_reps for q_idx, n_reps in q_idx_count.items()
)
nonmatching_pairs = self._pick_nonmatching_pairs(n_matching_qual_vals)
else:
nonmatching_pairs = list(set(self._yield_nonmatching_pairs()))
return matching_pairs, nonmatching_pairs
def _build(self, threads, progress=True) -> Tuple[List[Digitizer], np.ndarray]:
# Asking for more threads than available CPUs seems to crash on Github
# actions
if threads and threads >= cpu_count:
threads = cpu_count
if self.digitizers is None and self.bin_counts is None:
raise ValueError("Builder needs either digitizers or bin_counts - "
"see with_* methods.")
self.matching_pairs, self.nonmatching_pairs = self._get_pairs()
logger.info('Comparing matching pairs')
if self.digitizers:
self.match_counts_ = self._counts_array(self.matching_pairs,
threads=threads,
progress=progress,
desc='Comparing matching pairs')
else:
match_results = concat_results(self._query_many(self.matching_pairs, threads),
progress=progress,
desc='Comparing matching pairs',
total=len(self.matching_pairs))
self.digitizers = []
for i, (data, nbins) in enumerate(zip(match_results, self.bin_counts)):
if not isinstance(nbins, Digitizer):
try:
self.digitizers.append(Digitizer.from_data(data, nbins,
method=self.bin_method))
except BaseException as e:
logger.error(f'Error creating digitizers for axes {i + 1}')
raise e
else:
self.digitizers.append(nbins)
logger.info('Counting results (this may take a while)')
self.match_counts_ = self._count_results(match_results)
logger.info('Comparing non-matching pairs')
self.nonmatch_counts_ = self._counts_array(self.nonmatching_pairs,
threads=threads,
progress=progress,
desc='Comparing non-matching pairs')
# Account for there being different total numbers of datapoints for
# matches and nonmatches
self.matching_factor_ = self.nonmatch_counts_.sum() / self.match_counts_.sum()
if np.any(self.match_counts_ + self.nonmatch_counts_ == 0):
logger.warning("Some lookup cells have no data in them")
self.cells_ = np.log2(
(self.match_counts_ * self.matching_factor_ + epsilon) / (self.nonmatch_counts_ + epsilon)
)
return self.digitizers, self.cells_
def build(self, threads=None) -> LookupNd:
"""Build the score matrix.
All non-identical neuron pairs within all matching sets are selected,
and the scoring function is evaluated for those pairs.
Then, the minimum number of non-matching pairs are randomly drawn
so that at least as many data points can be calculated for non-matching
pairs.
In each bin of the score matrix, the log2 odds ratio of a score
in that bin belonging to a match vs. non-match is calculated.
Parameters
----------
threads : int, optional
If None, act in serial.
If 0, use cpu_count - 1.
Otherwise, use the given value.
Will be clipped at number of available cores - 1.
Note that with the currently implementation a large number
of threads might (and somewhat counterintuitively) actually
be slower than running building the scoring function in serial.
Returns
-------
LookupNd
"""
dig, cells = self._build(threads)
return LookupNd(dig, cells)
def dist_dot(q: Dotprops, t: Dotprops):
return list(q.dist_dots(t))
def dist_dot_alpha(q: Dotprops, t: Dotprops):
dist, dot, alpha = q.dist_dots(t, alpha=True)
return [dist, dot * np.sqrt(alpha)]
[docs]class LookupDistDotBuilder(LookupNdBuilder):
[docs] def __init__(
self,
dotprops: Union[List[Dotprops], Mapping[NeuronKey, Dotprops]],
matching_lists: List[List[NeuronKey]],
nonmatching_list: Optional[List[NeuronKey]] = None,
use_alpha: bool = False,
draw_strat: str = 'batched',
seed: int = DEFAULT_SEED,
):
f"""Class for building a 2-dimensional score lookup for NBLAST.
The scores are
1. The distances between best-matching points
2. The dot products of direction vectors around those points,
optionally scaled by the colinearity ``alpha``.
Parameters
----------
dotprops : dict or list of Dotprops
An indexable sequence of all neurons which will be
used as the training set, as Dotprops objects.
matching_lists : list of lists of indices into dotprops
List of neurons, as indices into ``dotprops``, which
should be considered matches.
nonmatching_list : list of indices into dotprops, optional
List of neurons, as indices into ``dotprops``,
which should not be considered matches.
If not given, all ``dotprops`` will be used
(on the assumption that matches are a small subset
of possible pairs).
use_alpha : bool, optional
If true, multiply the dot product by the geometric
mean of the matched points' alpha values
(i.e. ``sqrt(alpha1 * alpha2)``).
draw_strat : "batched" | "greedy"
Strategy for randomly drawing non-matching pairs.
"batched" should be the right choice in most scenarios.
"greedy" can be better if your pool of neurons is very
small.
seed : int, optional
Non-matching pairs are drawn at random using this
seed, by default {DEFAULT_SEED}.
"""
match_fn = dist_dot_alpha if use_alpha else dist_dot
super().__init__(
dotprops,
matching_lists,
match_fn,
nonmatching_list,
draw_strat=draw_strat,
seed=seed,
)
self._ndim = 2
def build(self, threads=None) -> Lookup2d:
(dig0, dig1), cells = self._build(threads)
return Lookup2d(dig0, dig1, cells)
def is_monotonically_increasing(lst):
for prev_idx, item in enumerate(lst[1:]):
if item <= lst[prev_idx]:
return False
return True
def parse_boundary(item: str):
explicit_interval = item[0] + item[-1]
if explicit_interval == "[)":
right = False
elif explicit_interval == "(]":
right = True
else:
raise ValueError(
f"Enclosing characters '{explicit_interval}' do not match a half-open interval"
)
return tuple(float(i) for i in item[1:-1].split(",")), right
class LookupAxis(ABC, Generic[T]):
"""Class converting some data into a linear index."""
@abstractmethod
def __len__(self) -> int:
"""Number of bins represented by this instance."""
pass
@abstractmethod
def __call__(self, value: Union[T, Sequence[T]]) -> Union[int, Sequence[int]]:
"""Convert some data into a linear index.
Parameters
----------
value : Union[T, Sequence[T]]
Value to convert into an index
Returns
-------
Union[int, Sequence[int]]
If a scalar was given, return a scalar; otherwise, a numpy array of ints.
"""
pass
class SimpleLookup(LookupAxis[Hashable]):
def __init__(self, items: List[Hashable]):
"""Look up in a list of items and return their index.
Parameters
----------
items : List[Hashable]
The item's position in the list is the index which will be returned.
Raises
------
ValueError
items are non-unique.
"""
self.items = {item: idx for idx, item in enumerate(items)}
if len(self.items) != len(items):
raise ValueError("Items are not unique")
def __len__(self) -> int:
return len(self.items)
def __call__(self, value: Union[Hashable, Sequence[Hashable]]) -> Union[int, Sequence[int]]:
if np.isscalar(value):
return self.items[value]
else:
return np.array([self.items[v] for v in value], int)
[docs]class Digitizer(LookupAxis[float]):
[docs] def __init__(
self,
boundaries: Sequence[float],
clip: Tuple[bool, bool] = (True, True),
right=False,
):
"""Class converting continuous values into discrete indices.
Parameters
----------
boundaries : Sequence[float]
N boundaries specifying N-1 bins.
Must be monotonically increasing.
clip : Tuple[bool, bool], optional
Whether to set the bottom and top boundaries to -infinity and
infinity respectively, effectively clipping incoming values: by
default (True, True).
False means "add a new bin for out-of-range values".
right : bool, optional
Whether bins should include their right (rather than left) boundary,
by default False.
"""
self.right = right
boundaries = list(boundaries)
self._min = -math.inf
if clip[0]:
self._min = boundaries[0]
boundaries[0] = -math.inf
elif boundaries[0] != -math.inf:
self._min = -math.inf
boundaries.insert(0, -math.inf)
self._max = math.inf
if clip[1]:
self._max = boundaries[-1]
boundaries[-1] = math.inf
elif boundaries[-1] != math.inf:
boundaries.append(math.inf)
if not is_monotonically_increasing(boundaries):
raise ValueError("Boundaries are not monotonically increasing: "
f"{boundaries}")
self.boundaries = np.asarray(boundaries)
def __len__(self):
return len(self.boundaries) - 1
def __call__(self, value: float):
# searchsorted is marginally faster than digitize as it skips monotonicity checks
return (
np.searchsorted(
self.boundaries, value, side="left" if self.right else "right"
)
- 1
)
def to_strings(self, round=None) -> List[str]:
"""Turn boundaries into list of labels.
Parameters
----------
round : int, optional
Use to round bounds to the Nth decimal.
"""
if self.right:
lb = "("
rb = "]"
else:
lb = "["
rb = ")"
b = self.boundaries.copy()
b[0] = self._min
b[-1] = self._max
if round:
b = [np.round(x, round) for x in b]
return [
f"{lb}{lower},{upper}{rb}"
for lower, upper in zip(b[:-1], b[1:])
]
@classmethod
def from_strings(cls, interval_strs: Sequence[str]):
"""Set digitizer boundaries based on a sequence of interval expressions.
e.g. ``["(0, 1]", "(1, 5]", "(5, 10]"]``
The lowermost and uppermost boundaries are converted to -infinity and
infinity respectively.
Parameters
----------
bound_strs : Sequence[str]
Strings representing intervals, which must abut and have open/closed
boundaries specified by brackets.
Returns
-------
Digitizer
"""
bounds: List[float] = []
last_upper = None
last_right = None
for item in interval_strs:
(lower, upper), right = parse_boundary(item)
bounds.append(float(lower))
if last_right is not None:
if right != last_right:
raise ValueError("Inconsistent half-open interval")
else:
last_right = right
if last_upper is not None:
if lower != last_upper:
raise ValueError("Half-open intervals do not abut")
last_upper = upper
bounds.append(float(last_upper))
return cls(bounds, right=last_right)
@classmethod
def from_linear(cls, lower: float, upper: float, nbins: int, right=False):
"""Choose digitizer boundaries spaced linearly between two values.
Input values will be clipped to fit within the given interval.
Parameters
----------
lower : float
Lowest value
upper : float
Highest value
nbins : int
Number of bins
right : bool, optional
Whether bins should include their right (rather than left) boundary,
by default False
Returns
-------
Digitizer
"""
arr = np.linspace(lower, upper, nbins + 1, endpoint=True)
return cls(arr, right=right)
@classmethod
def from_geom(cls, lowest_upper: float, highest_lower: float, nbins: int, right=False):
"""Choose digitizer boundaries in a geometric sequence.
Additional bins will be added above and below the given values.
Parameters
----------
lowest_upper : float
Upper bound of the lowest bin. The lower bound of the lowest bin is
often 0, which cannot be represented in a nontrivial geometric
sequence.
highest_lower : float
Lower bound of the highest bin.
nbins : int
Number of bins
right : bool, optional
Whether bins should include their right (rather than left) boundary,
by default False
Returns
-------
Digitizer
"""
arr = np.geomspace(lowest_upper, highest_lower, nbins - 1, True)
return cls(arr, clip=(False, False), right=right)
@classmethod
def from_data(cls,
data: Sequence[float],
nbins: int,
right=False,
method='quantile'):
"""Choose digitizer boundaries to evenly partition the given values.
Parameters
----------
data : Sequence[float]
Data which should be partitioned by the resulting digitizer.
nbins : int
Number of bins
right : bool, optional
Whether bins should include their right (rather than left) boundary,
by default False
method : "quantile" | "linear" | "geometric"
Method to use for partitioning the data space:
- 'quantile' (default) will partition the data such that each bin
contains the same number of data points. This is usually the
method of choice because it is robust against outlier and because
we are guaranteed to not have empty bin.
- 'linear' will partition the data into evenly spaced bins.
- 'geometric' will produce a log scale partition. This will not work
if data has negative values.
Returns
-------
Digitizer
"""
assert method in ('quantile', 'linear', 'geometric')
if method == 'quantile':
arr = np.quantile(data, np.linspace(0, 1, nbins + 1, True))
elif method == 'linear':
arr = np.linspace(min(data), max(data), nbins + 1, True)
elif method == 'geometric':
if min(data) <= 0:
raise ValueError('Data must not have values <= 0 for creating '
'geometric (logarithmic) bins.')
arr = np.geomspace(min(data), max(data), nbins + 1, True)
return cls(arr, right=right)
def __eq__(self, other: object) -> bool:
if not isinstance(other, Digitizer):
return NotImplemented
return self.right == other.right and np.allclose(
self.boundaries, other.boundaries
)
class LookupNd:
def __init__(self, axes: List[LookupAxis], cells: np.ndarray):
if [len(b) for b in axes] != list(cells.shape):
raise ValueError("boundaries and cells have inconsistent bin counts")
self.axes = axes
self.cells = cells
def __call__(self, *args):
if len(args) != len(self.axes):
raise TypeError(
f"Lookup takes {len(self.axes)} arguments but {len(args)} were given"
)
idxs = tuple(d(arg) for d, arg in zip(self.axes, args))
out = self.cells[idxs]
return out
[docs]class Lookup2d(LookupNd):
"""Convenience class inheriting from LookupNd for the common 2D float case.
Provides IO with pandas DataFrames.
"""
[docs] def __init__(self, axis0: Digitizer, axis1: Digitizer, cells: np.ndarray):
"""2D lookup table for convert NBLAST matches to scores.
Commonly read from a ``pandas.DataFrame``
or trained on data using a ``LookupDistDotBuilder``.
Parameters
----------
digitizer0 : Digitizer
How to convert continuous values into an index for the first axis.
digitizer1 : Digitizer
How to convert continuous values into an index for the second axis.
cells : np.ndarray
Values to look up in the table.
"""
super().__init__([axis0, axis1], cells)
def to_dataframe(self) -> pd.DataFrame:
"""Convert the lookup table into a ``pandas.DataFrame``.
From there, it can be shared, saved, and so on.
The index and column labels describe the intervals represented by that axis.
Returns
-------
pd.DataFrame
"""
return pd.DataFrame(
self.cells,
self.axes[0].to_strings(),
self.axes[1].to_strings(),
)
@classmethod
def from_dataframe(cls, df: pd.DataFrame):
f"""Parse score matrix from a dataframe with string index and column labels.
Expects the index and column labels to specify an interval
like ``f"[{{lower}},{{upper}})"``.
Will replace the lowermost and uppermost bound with -inf and inf
if they are not already.
"""
return cls(
Digitizer.from_strings(df.index),
Digitizer.from_strings(df.columns),
df.to_numpy(),
)
@lru_cache(maxsize=None)
def _smat_fcwb(alpha=False):
# cached private function defers construction
# until needed (speeding startup),
# but avoids repeated reads (speeding later uses)
fname = ("smat_fcwb.csv", "smat_alpha_fcwb.csv")[alpha]
fpath = smat_path / fname
return Lookup2d.from_dataframe(pd.read_csv(fpath, index_col=0))
def smat_fcwb(alpha=False):
# deepcopied so that mutations do not propagate to cache
return deepcopy(_smat_fcwb(alpha))
def check_score_fn(fn: Callable, nargs=2, scalar=True, array=True):
"""Checks functionally that the callable can be used as a score function.
Parameters
----------
nargs : optional int, default 2
How many positional arguments the score function should have.
scalar : optional bool, default True
Check that the function can be used on ``nargs`` scalars.
array : optional bool, default True
Check that the function can be used on ``nargs`` 1D ``numpy.ndarray``s.
Raises
------
ValueError
If the score function is not appropriate.
"""
if scalar:
scalars = [0.5] * nargs
if not isinstance(fn(*scalars), float):
raise ValueError("smat does not take 2 floats and return a float")
if array:
test_arr = np.array([0.5] * 3)
arrs = [test_arr] * nargs
try:
out = fn(*arrs)
except Exception as e:
raise ValueError(f"Failed to use smat with numpy arrays: {e}")
if out.shape != test_arr.shape:
raise ValueError(
f"smat produced inconsistent shape: input {test_arr.shape}; output {out.shape}"
)
SCORE_FN_DESCR = """
NBLAST score functions take 2 floats or N-length numpy arrays of floats
(for matched dotprop points/tangents, distance and dot product;
the latter possibly scaled by the geometric mean of the alpha colinearity values)
and returns a float or N-length numpy array of floats.
""".strip().replace(
"\n", " "
)
def parse_score_fn(smat, alpha=False):
f"""Interpret ``smat`` as a score function.
Primarily for backwards compatibility.
{SCORE_FN_DESCR}
Parameters
----------
smat : None | "auto" | str | os.PathLike | pandas.DataFrame | Callable[[float, float], float]
If ``None``, use ``operator.mul``.
If ``"auto"``, use ``navis.nbl.smat.smat_fcwb(alpha)``.
If a dataframe, use ``navis.nbl.smat.Lookup2d.from_dataframe(smat)``.
If another string or path-like, load from CSV in a dataframe and uses as above.
Also checks the signature of the callable.
Raises an error, probably a ValueError, if it can't be interpreted.
alpha : optional bool, default False
If ``smat`` is ``"auto"``, choose whether to use the FCWB matrices
with or without alpha.
Returns
-------
Callable
Raises
------
ValueError
If score function cannot be interpreted.
"""
if smat is None:
smat = operator.mul
elif smat == "auto":
smat = smat_fcwb(alpha)
if isinstance(smat, (str, os.PathLike)):
smat = pd.read_csv(smat, index_col=0)
if isinstance(smat, pd.DataFrame):
smat = Lookup2d.from_dataframe(smat)
if not callable(smat):
raise ValueError(
"smat should be a callable, a path, a pandas.DataFrame, or 'auto'"
)
check_score_fn(smat)
return smat