Source code for mdreg.main


import time
import numpy as np
from tqdm import tqdm
import dask.array as da

from mdreg import fit_models, elastix, skimage, ants, io

# TODO: test optional dependencies -> skimage default
# TODO: user guide introduction


[docs] def fit(moving, fit_pixels = None, fit_coreg = None, fit_image = None, tol = 1e-6, maxit = 5, verbose = 0, force_2d = False, path = None, ): """ Remove motion from a series of 2D- or 3D images. Parameters ---------- moving : numpy.ndarray | zarr.Array The series of images to be corrected, with dimensions (x,y,t) or (x,y,z,t). fit_pixels : dict, optional A dictionary defining a single-pixel signal model. The possible items in the dictionary are the keywords of the function `mdreg.fit_pixels`. For a slice-by-slice computation (4D array with force_2d=True), *fit_pixels* can be a list of dictionaries, one for each slice. The default is None. fit_coreg : dict, optional The parameters for coregistering the images. *fit_coreg* has one required item 'package' with possible values 'skimage' (default), 'elastix' and 'ants'. The other parameters are the possible keywords of the *coreg_series* function of the package specified. fit_image : dict or list, optional A dictionary defining the function to fit the signal data, and its parameter values. This argument is ignored if *fit_pixels* is already provided. *fit_image* has one required key 'func' that specifies the fit function to use. The other entries are the keyword arguments of this fit function. The fit function can be one of the functions built in to mdreg, or a custom made function. A valid fit function *must* take a signal array as argument, and return two variables: an array with the same shape containing the fit to the model, and a second variable that contains the fitted parameters. For a slice-by-slice computation (4D array with force_2d=True), *fit_image* can be a list of dictionaries, one for each slice. If *fit_image* is not provided, a constant model is used. tol : float, optional Stopping criterion for the iteration. The iteration stops if the largest difference between new and old coregistered series in any pixel at any time point is less than *tol* of the largest value. The default is 1e-6. maxit : int, optional The maximum number of iterations. The default is 0. verbose : int, optional The level of feedback to provide to the user. 0: no feedback; 1: text output only; 2: text output and progress bars. The default is 2. force_2d : bool, optional By default, a 3-dimensional moving array will be coregistered with a 3-dimensional deformation field. To perform slice-by-slice 2-dimensional registration instead, set *force_2d* to True. This keyword is ignored when the arrays are 2-dimensional. The default is False. path : str, optional Path on disk where to save the results. If no path is provided, the results are not saved to disk. Defaults to None. Returns ------- coreg : numpy.ndarray | zarr.Array The coregistered images with the same dimensions as *moving*. fit : numpy.ndarray | zarr.Array The fitted signal model with the same dimensions as *arr*. transfo : numpy.ndarray | zarr.Array | list The parameters of the transformation deforming the moving image to the coregistered image. With skimage, this is the deformation field with the same dimensions as *moving*, and one additional dimension for the components of the vector field. With elastix this is an array of parameter objects and with ants this is an array of files with transform parameters. Note when force_2d = True these are 2-dimensional arrays with one transform per slice and per time point. pars : numpy.ndarray | zarr.Array The parameters of the fitted signal model with dimensions (x,y,n) or (x,y,z,n), where n is the number of free parameters of the signal model. """ # Set defaults in fit_coreg if fit_coreg is None: fit_coreg = {'package': 'skimage'} if 'package' not in fit_coreg: fit_coreg['package'] = 'skimage' if 'progress_bar' not in fit_coreg: fit_coreg['progress_bar'] = verbose>1 if 'name' not in fit_coreg: fit_coreg['name'] = 'coreg' # 2D slice-by-slice coregistration if moving.ndim==4: if force_2d: return _fit_force_2d( moving, fit_image, fit_coreg, fit_pixels, tol, maxit, verbose, path, ) # Set defaults for fit_image if fit_image is None: fit_image = {'func': fit_models.fit_constant} # Check inputs if not isinstance(fit_image, dict): raise ValueError('The fit_image argument must be a dictionary.') # Set paths _set_path(fit_coreg, path) _set_path(fit_image, path) _set_path(fit_pixels, path) # Compute converged = False it = 1 start = time.time() if verbose > 0: print('Initializing..') coreg = io._copy(moving, path, 'coreg') while not converged: startit = time.time() # Fit signal model if verbose > 0: print(f'Iteration {it}: fitting signal model') if fit_pixels is not None: fit, pars = fit_models.fit_pixels(coreg, **fit_pixels) else: kwargs = {i:fit_image[i] for i in fit_image if i!='func'} fit, pars = fit_image['func'](coreg, **kwargs) # Fit deformation if verbose > 0: print(f'Iteration {it}: fitting deformation fields') coreg_curr = io._copy(coreg, path, 'tmp') vals = _coreg_series(moving, fit, **fit_coreg) coreg, transfo = vals[:2] # Check convergence converged = _diff(coreg, coreg_curr) < tol if verbose > 0: print(f'Calculation time for iteration {it}: ' f'{(time.time()-startit)/60} min') if it == maxit: break it += 1 if verbose > 0: print(f'Total calculation time: {(time.time()-start)/60} min') io._remove(path, 'tmp') if len(vals) > 2: # optional return value defo = vals[2] return coreg, fit, transfo, pars, defo else: return coreg, fit, transfo, pars
def _fit_force_2d( moving, fit_image, fit_coreg, fit_pixels, tol, maxit, verbose, path, ): # Required outputs coreg = io._copy(moving, path, 'coreg') if fit_coreg['package'] == 'skimage': transfo = io._defo( moving, path, force_2d=True, name=fit_coreg['name']+'_defo', ) else: transfo = np.empty(moving.shape[-2:], dtype=object) # Optional outputs defo = None if 'return_deformation' in fit_coreg: if fit_coreg['return_deformation']: defo = io._defo( moving, path, force_2d=True, name=fit_coreg['name']+'_defo', ) for k in tqdm( range(moving.shape[2]), desc='Fitting slice', disable=verbose<2, ): if verbose == 1: print(f'Fitting slice {k+1} / {moving.shape[2]}') if fit_image is None: fit_image_k = None elif isinstance(fit_image, dict): fit_image_k = fit_image else: fit_image_k = fit_image[k] if fit_pixels is None: fit_pixels_k = None elif isinstance(fit_pixels, dict): fit_pixels_k = fit_pixels else: fit_pixels_k = fit_pixels[k] vals = fit( moving[:,:,k,:], fit_pixels = fit_pixels_k, fit_image = fit_image_k, fit_coreg = fit_coreg, tol = tol, maxit = maxit, verbose = verbose, ) coreg[:,:,k,:], fit_k, transfo_k, pars_k = vals[:4] if k == 0: fit_arr, pars = io._fit_models_init(moving, path, pars_k.shape[-1]) if fit_coreg['package'] == 'skimage': transfo[:,:,k,:,:] = transfo_k else: transfo[k,:] = transfo_k fit_arr[:,:,k,:] = fit_k pars[:,:,k,:] = pars_k if defo is not None: defo[:,:,k,:,:] = vals[4] if defo is None: return coreg, fit_arr, transfo, pars else: return coreg, fit_arr, transfo, pars, defo def _set_path(dct, path): if dct is None: return if path is None: return if 'path' in dct: if path != dct['path']: raise ValueError("Two different paths are provided.") else: dct['path'] = path def _diff(coreg, coreg_curr): if isinstance(coreg, np.ndarray): corr = np.max(np.abs(coreg-coreg_curr))/np.max(np.abs(coreg_curr)) else: coreg = da.from_zarr(coreg) coreg_curr = da.from_zarr(coreg_curr) corr = da.max(da.abs(coreg-coreg_curr))/da.max(da.abs(coreg_curr)) corr.compute() return corr def _coreg_series(moving, fit, package='elastix', **fit_coreg): if package == 'elastix': fit_coreg = _set_mdreg_elastix_defaults(fit_coreg) return elastix.coreg_series(moving, fit, **fit_coreg) elif package == 'skimage': return skimage.coreg_series(moving, fit, **fit_coreg) elif package == 'ants': return ants.coreg_series(moving, fit, **fit_coreg) else: raise NotImplementedError( 'This coregistration package is not implemented') def _set_mdreg_elastix_defaults(params): if "WriteResultImage" not in params: params["WriteResultImage"] = "false" if "WriteDeformationField" not in params: params["WriteDeformationField"] = "false" if "ResultImagePixelType" not in params: params["ResultImagePixelType"] = "float" # # Removing this for v0.4.2 as results appear to be worse # if 'Metric' not in params: # params["Metric"] = "AdvancedMeanSquares" # # Settings pre v0.4.0 - unclear why - removed for now # if "FinalGridSpacingInPhysicalUnits" not in params: # params["FinalGridSpacingInPhysicalUnits"] = "50.0" # if "AutomaticParameterEstimation" not in params: # params["AutomaticParameterEstimation"] = "true" # if "ASGDParameterEstimationMethod" not in params: # params["ASGDParameterEstimationMethod"] = "Original" # if "MaximumStepLength" not in params: # params["MaximumStepLength"] = "1.0" # if "CheckNumberOfSamples" not in params: # params["CheckNumberOfSamples"] = "true" # if "RandomCoordinate" not in params: # params["ImageSampler"] = "RandomCoordinate" return params