Source code for navis.transforms.cmtk
# This script is part of navis (http://www.github.com/navis-org/navis).
# Copyright (C) 2018 Philipp Schlegel
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
"""Functions to use CMTK transforms."""
import os
import re
import nrrd
import copy
import pathlib
import tempfile
import platform
import functools
import subprocess
import numpy as np
import pandas as pd
from subprocess import check_call
from .. import utils, config
from .base import BaseTransform, TransformSequence
__all__ = ['xform_cmtk']
_search_path = os.environ['PATH']
_search_path = [i for i in _search_path.split(os.pathsep) if len(i) > 0]
_search_path += ['~/bin',
'/usr/lib/cmtk/bin/',
'/usr/local/lib/cmtk/bin',
'/usr/local/bin',
'/opt/local/bin',
'/opt/local/lib/cmtk/bin/',
'/Applications/IGSRegistrationTools/bin']
if platform.system() == 'Windows':
_search_path += [r'C:\cygwin64\usr\local\lib\cmtk\bin',
r'C:\Program Files\CMTK-3.3\CMTK\lib\cmtk\bin']
def find_cmtkbin(tool: str = 'streamxform') -> str:
"""Find directory with CMTK binaries."""
for path in _search_path:
path = pathlib.Path(path)
if not path.is_dir():
continue
try:
return next(path.glob(tool)).resolve().parent
except StopIteration:
continue
except BaseException:
raise
_cmtkbin = find_cmtkbin()
def requires_cmtk(func):
"""Check if CMTK is available."""
@functools.wraps(func)
def wrapper(*args, **kwargs):
if not _cmtkbin:
raise ValueError("Cannot find CMTK. Please install from "
"http://www.nitrc.org/projects/cmtk and "
"make sure that it is your path!")
return func(*args, **kwargs)
return wrapper
@requires_cmtk
def cmtk_version(as_string=False):
"""Get CMTK version."""
p = subprocess.run([_cmtkbin / 'streamxform', '--version'],
capture_output=True)
version = p.stdout.decode('utf-8').rstrip()
if as_string:
return version
else:
return tuple(int(v) for v in version.split('.'))
@requires_cmtk
def xform_cmtk(points: np.ndarray, transforms, inverse: bool = False,
affine_fallback: bool = False, **kwargs) -> np.ndarray:
"""Xform 3d coordinates.
Parameters
----------
points : (N, 3) array | pandas.DataFrame
Points to transform. DataFrame must have x/y/z columns.
transforms : filepath(s) | CMTKtransform(s)
Either filepath to CMTK transform or ``CMTKtransform``.
Multiple regs must be given as list and will be applied
sequentially in the order provided.
inverse : bool | list thereof
Whether to invert transforms. If single boolean will
apply to all transforms. Can also provide ``inverse` as
list of booleans.
affine_fallback : bool
If True, points that failed to transform during warping
transform will be transformed using only the affine
transform.
Returns
-------
pointsxf : (N, 3) numpy.ndarray
Transformed points. Will contain `np.nan` for points
that did not transform.
"""
transforms = list(utils.make_iterable(transforms))
if isinstance(inverse, bool):
inverse = [inverse] * len(transforms)
directions = ['forward' if not i else 'inverse' for i in inverse]
for i, r in enumerate(transforms):
if not isinstance(r, CMTKtransform):
if not isinstance(r, (str, pathlib.Path)):
raise TypeError('`reg` must be filepath or CMTKtransform')
transforms[i] = CMTKtransform(r, directions=directions[i])
# Combine all transforms into a sequence of transforms
seq = TransformSequence(*transforms)
# Transform points
xf = seq.xform(points)
# If requested, try again with affine only for points that failed to xform
if affine_fallback:
isnan = np.any(np.isnan(xf), axis=1)
if np.any(isnan):
xf[isnan] = seq.xform(points[isnan], affine_only=True)
return xf
[docs]
class CMTKtransform(BaseTransform):
"""CMTK transforms of 3D spatial data.
Requires `CMTK <https://www.nitrc.org/projects/cmtk/>`_ to be installed.
Parameters
----------
regs : str | list of str
Path(s) to CMTK transformations(s).
directions : "forward" | "inverse" | list thereof
Direction of transformation. Must provide one direction per
``reg``.
threads : int, optional
Number of threads to use.
Examples
--------
>>> from navis import transforms
>>> tr = transforms.cmtk.CMTKtransform('/path/to/CMTK_directory.list')
>>> tr.xform(points) # doctest: +SKIP
"""
[docs]
def __init__(self, regs: list, directions: str = 'forward', threads: int = None):
self.directions = list(utils.make_iterable(directions))
for d in self.directions:
assert d in ('forward', 'inverse'), ('`direction` must be "foward"'
f'or "inverse", not "{d}"')
self.regs = list(utils.make_iterable(regs))
self.command = 'streamxform'
self.threads = threads
if len(directions) == 1 and len(regs) >= 1:
directions = directions * len(regs)
if len(self.regs) != len(self.directions):
raise ValueError('Must provide one direction per regs')
def __eq__(self, other: 'CMTKtransform') -> bool:
"""Implement equality comparison."""
if isinstance(other, CMTKtransform):
if len(self) == len(other):
if all([self.regs[i] == other.regs[i] for i in range(len(self))]):
if all([self.directions[i] == other.directions[i] for i in range(len(self))]):
return True
return False
def __len__(self) -> int:
return len(self.regs)
def __neg__(self) -> 'CMTKtransform':
"""Invert direction."""
x = self.copy()
# Swap directions
x.directions = [{'forward': 'inverse',
'inverse': 'forward'}[d] for d in x.directions]
# Reverse order
x.regs = x.regs[::-1]
x.directions = x.directions[::-1]
return x
def __str__(self):
return self.__repr__()
def __repr__(self):
return f'CMTKtransform with {len(self)} transform(s)'
@staticmethod
def from_file(filepath: str, **kwargs) -> 'CMTKtransform':
"""Generate CMTKtransform from file.
Parameters
----------
filepath : str
Path to CMTK transform.
**kwargs
Keyword arguments passed to CMTKtransform.__init__
Returns
-------
CMTKtransform
"""
defaults = {'directions': 'forward'}
defaults.update(kwargs)
return CMTKtransform(str(filepath), **defaults)
def make_args(self, affine_only: bool = False) -> list:
"""Generate arguments passed to subprocess."""
# Generate the arguments
# The actual command (i.e. streamxform)
args = [str(_cmtkbin / self.command)]
if affine_only:
args.append('--affine-only')
if self.threads:
args.append(f'--threads {int(self.threads)}')
# Add the regargs
args += self.regargs
return args
@property
def regargs(self) -> list:
"""Generate regargs."""
regargs = []
for i, (reg, dir) in enumerate(zip(self.regs, self.directions)):
if dir == 'inverse':
# For the first transform we need to prefix "--inverse" with
# a solitary "--"
if i == 0:
regargs.append('--')
regargs.append('--inverse')
# Note no double quotes!
regargs.append(f'{reg}')
return regargs
def append(self, transform: 'CMTKtransform', direction: str = None):
"""Add another transform.
Parameters
----------
transform : str | CMTKtransform
Either another CMTKtransform or filepath to registration.
direction : "forward" | "inverse"
Only relevant if transform is filepath.
"""
if isinstance(transform, CMTKtransform):
if self.command != transform.command:
raise ValueError('Unable to merge CMTKtransforms using '
'different commands.')
self.regs += transform.regs
self.directions += transform.directions
elif isinstance(transform, str):
if not direction:
raise ValueError('Must provide direction along with new transform')
self.regs.append(transform)
self.directions.append(direction)
else:
raise NotImplementedError(f'Unable to append {type(transform)} to {type(self)}')
def check_if_possible(self, on_error: str = 'raise'):
"""Check if this transform is possible."""
if not _cmtkbin:
msg = 'Folder with CMTK binaries not found. Make sure the ' \
'directory is in your PATH environment variable.'
if on_error == 'raise':
raise BaseException(msg)
return msg
for r in self.regs:
if not os.path.isdir(r) and not os.path.isfile(r):
msg = f'Registration {r} not found.'
if on_error == 'raise':
raise BaseException(msg)
return msg
def copy(self) -> 'CMTKtransform':
"""Return copy."""
# Attributes not to copy
no_copy = []
# Generate new empty transform
x = self.__class__(None)
# Override with this neuron's data
x.__dict__.update({k: copy.copy(v) for k, v in self.__dict__.items() if k not in no_copy})
return x
def parse_cmtk_output(self, output: str, fail_value=np.nan) -> np.ndarray:
r"""Parse CMTK output.
Briefly, CMTK output will be a byte literal like this:
b'311 63 23 \n275 54 25 \n'
In case of failed transforms we will get something like this where the
original coordinates are returned with a "FAILED" flag
b'343 72 23 \n-10 -10 -10 FAILED \n'
Parameter
---------
output : tuple of (b'', None)
Stdout of CMTK call.
fail_value
Value to use for points that failed to transform. By
default we use ``np.nan``.
Returns
-------
pointsxf : (N, 3) numpy array
The parse transformed points.
"""
# The original stout is tuple where we care only about the second one
if isinstance(output, tuple):
output = output[0]
pointsx = []
# Split string into rows - lazily using a generator
for row in (x.group(1) for x in re.finditer(r"(.*?) \n",
output.decode())):
# Split into values
values = row.split(' ')
# If this point failed
if len(values) != 3:
values = [fail_value] * 3
else:
values = [float(v) for v in values]
pointsx.append(values)
return np.asarray(pointsx)
def xform(self, points: np.ndarray,
affine_only: bool = False,
affine_fallback: bool = False) -> np.ndarray:
"""Xform data.
Parameters
----------
points : (N, 3) numpy array | pandas.DataFrame
Points to xform. DataFrame must have x/y/z columns.
affine_only : bool
Whether to apply only the non-rigid affine
transform. This is useful if points are outside
the deformation field and would therefore not
transform properly.
affine_fallback : bool
If True and some points did not transform during the
non-rigid part of the transformation, we will apply
only the affine transformation to those points.
Returns
-------
pointsxf : (N, 3) numpy array
Transformed points. Points that failed to transform will
be ``np.nan``.
"""
self.check_if_possible(on_error='raise')
if isinstance(points, pd.DataFrame):
# Make sure x/y/z columns are present
if np.any([c not in points for c in ['x', 'y', 'z']]):
raise ValueError('points DataFrame must have x/y/z columns.')
elif isinstance(points, np.ndarray) and points.ndim == 2 and points.shape[1] == 3:
points = pd.DataFrame(points, columns=['x', 'y', 'z'])
else:
raise TypeError('`points` must be numpy array of shape (N, 3) or '
'pandas DataFrame with x/y/z columns')
# Generate the result
args = self.make_args(affine_only=affine_only)
proc = subprocess.Popen(args, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
# Pipe in the points
points_str = points[['x', 'y', 'z']].to_string(index=False,
header=False)
# Do not use proc.stdin.write to avoid output buffer becoming full
# before we finish piping in stdin.
#proc.stdin.write(points_str.encode())
#output = proc.communicate()
# Read out results
# This is equivalent to e.g.:
# $ streamxform -args <<< "10, 10, 10"
output = proc.communicate(input=points_str.encode())
# If no output, something went wrong
if not output[0]:
raise utils.CMTKError('CMTK produced no output. Check points?')
# Xformed points
xf = self.parse_cmtk_output(output, fail_value=np.nan)
# Check if any points not xformed
if affine_fallback and not affine_only:
not_xf = np.any(np.isnan(xf), axis=1)
if np.any(not_xf):
xf[not_xf] = self.xform(points.loc[not_xf], affine_only=True)
return xf
def xform_image(self,
im,
target,
out=None,
interpolation="linear",
verbose=False,
):
"""Transform an image using CMTK's reformatx.
Parameters
----------
im : 3D numpy array | filepath
The floating image to transform.
target : str | TemplateBrain | (Nx, Ny, Nz, dx, dy, dz) | (Nx, Ny, Nz, dx, dy, dz, Ox, Oy, Oz)
Defines the target image: dimensions in voxels (N), the voxel size (d) and optionally
an origin (0) for the target image. Can be provided as a string (name of a template),
a TemplateBrain object, a tuple/list/array with the target specs.
out : str, optional
The filepath to save the transformed image. If None (default), will return the
transformed image as np.ndarray.
interpolation : "linear" | "nn" | "cubic" | "pv" | "sinc-cosine" | "sinc-hamming"
The interpolation method to use.
verbose : bool
Whether to print CMTK output.
Returns
-------
np.ndarray | None
If out is None, returns the transformed image as np.ndarray. Otherwise, None.
"""
assert interpolation in ("linear", "nn", "cubic", "pv", "sinc-cosine", "sinc-hamming")
# `reformatx` expects this format:
# ./reformatx --floating {INPUT_FILE} -o {OUTPUT_FILE} {REFERENCE_SPECS} {TRANSFORMS}
# where:
# - {INPUT_FILE} is the image to transform
# - {OUTPUT_FILE} is where the output will be saved
# - {REFERENCE_SPECS} defines the target space; this needs to be eitheran NRRD
# file from which CMTK can extract the target grid or the actual specs:
# "--target-grid Nx,Ny,Nz:dX,dY,dZ:[Ox,Oy,Oz]" where N is the number of
# voxels in each dimension and d is the voxel size. The optional O is the
# origin of the image. If not provided, it is assumed to be (0, 0, 0).
# - {TRANSFORMS} are the CMTK transform(s) to apply; prefix with "--inverse" to invert
# Below command works to convert JFRC2 to FCWB:
# /opt/local/lib/cmtk/bin/reformatx --verbose --floating JFRC2.nrrd -o JFRC2_xf.nrrd FCWB.nrrd --inverse /Users/philipps/flybrain-data/BridgingRegistrations/JFRC2_FCWB.list
# This took XX minutes - should check if that is actually faster than the look-up approach we
# use in `images.py`
target_specs = parse_target_specs(target)
to_remove = []
if isinstance(im, (str, pathlib.Path)):
floating = pathlib.Path(im)
if not im.is_file():
raise ValueError(f"Image file not found: {im}")
elif isinstance(im, np.ndarray):
assert im.ndim == 3
# Save to temporary file
with tempfile.NamedTemporaryFile(suffix=".nrrd", delete=False, delete_on_close=False) as tf:
nrrd.write(tf.name, im)
floating = tf.name
to_remove.append(tf.name)
else:
raise ValueError(f"Invalid image type: {type(im)}")
if out is None:
outfile = tempfile.NamedTemporaryFile(suffix=".nrrd", delete=False, delete_on_close=False).name
to_remove.append(outfile)
elif isinstance(out, (str, pathlib.Path)):
outfile = pathlib.Path(out).resolve()
else:
raise ValueError(f"Invalid output type: {type(out)}")
# Compile the command
args = [str(_cmtkbin / 'reformatx')]
args += [f'-o {outfile}']
args += [f'--floating {floating}']
args += [f'--{interpolation}']
args += [target_specs]
# Add the regargs
args += self.regargs
try:
# run the binary
# avoid resourcewarnings with null
with open(os.devnull, "w") as devnull:
startupinfo = None
if platform.system() == "Windows":
startupinfo = subprocess.STARTUPINFO()
startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW
if verbose:
# in debug mode print the output
stdout = None
else:
stdout = devnull
if verbose:
config.logger.info("executing: {}".format(" ".join(args)))
check_call(
args,
stdout=stdout,
stderr=subprocess.STDOUT,
startupinfo=startupinfo,
)
if out is None:
# Return transformed image
return nrrd.read(outfile)
elif verbose:
config.logger.info(f"Transformed image saved to {outfile}")
except BaseException:
raise
finally:
# Clean up temporary files
for f in to_remove:
os.remove(f)
def parse_target_specs(target):
"""Parse target specs into argument that can be passed to CMTK."""
# Note to self: this function should also deal with VoxelNeurons and NRRD filepaths
# For NRRD filepaths: we need to add an empty "--" before the filepath (I think)
from .templates import TemplateBrain
assert isinstance(target, (str, TemplateBrain, np.ndarray, list, tuple))
if isinstance(target, str):
from . import registry
target = registry.find_template(target)
if isinstance(target, TemplateBrain):
specs = list(target.dims) + list(target.voxdims)
# Note to self: need to check TemplateBrain (and flybrains) consistent definition of
# dims, voxdims and origin (maybe even add origin)
# At this point we expect specs to be an iterable
specs = np.asarray(target)
assert len(specs) in (6, 9), f"Target specs must be of length 6 or 9, got {len(specs)}"
target = "--target-grid "
target += ",".join(map(str, specs[:3].astype(int))) # Number of voxels (must be integer)
target += ":"
target += ",".join(map(str, specs[3:].astype(float))) # Voxel size (can be float)
if len(specs) == 9:
target += ":"
target += ",".join(map(str, specs[6:].astype(float))) # Origin (can be float)
return target