Source code for miblab_dl.fatwatermap

"""
Compute water-dominance masks from data that have fat and water maps
"""

import os
import subprocess
import shutil
import tempfile

import numpy as np
import nibabel as nib
from miblab_data.zenodo import fetch as zenodo_fetch
from platformdirs import user_cache_dir
# from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
# from nnunetv2.postprocessing.remove_connected_components import apply_postprocessing_to_folder



def _cache_dir(cache=None):

    # 1. User override via environment variable
    if cache:
        try:
            os.makedirs(cache, exist_ok=True)
        except Exception:
            # If user has set an invalid/unwritable path, raise an error
            raise ValueError(
                f"{cache} is not a valid cache directory for miblab-dl."
            )
        else:
            return cache

    # 2. Fallback to platform-specific user cache (~/.cache/miblab-dl)
    cache_dir = user_cache_dir("miblab-dl")
    os.makedirs(cache_dir, exist_ok=True)
    return cache_dir


[docs] def clear_cache(cache=None): cachedir = _cache_dir(cache) shutil.rmtree(cachedir)
[docs] def fatwater( op_phase, in_phase, te_o=None, te_i=None, t2s_w=30, t2s_f=15, tr=None, fa=None, t1_w=1400, t1_f=350, cache=None, ): """Compute fat and water maps from opposed-phase and in-phase arrays Args: op_phase (np.ndarray): opposed phase data in_phase (np.ndarray): in-phase data model (str): path to the model files cache (str, optional): directory to use for storing model weights and temp files This defaults to the standard cache dir location of the operating system. Returns: fat, water: numpy arrays of the same shape and type as the input arrays. """ print('Downloading model..') # Persistent cache memory for storing model weights avoids the # need to download every time. cachedir = _cache_dir(cache) model = zenodo_fetch("FatWaterPredictor.zip", cachedir, "17791059", extract=True) print('Predicting fat and water images..') # Making temporary folders in persistent cache is safer on HPC tmp = tempfile.mkdtemp(prefix="tmp_", dir=cachedir) # Compute waterdom = _predict_mask_numpy(model, op_phase, in_phase, tmp) fat, water = _compute_fatwater(waterdom, op_phase, in_phase, te_o, te_i, t2s_w, t2s_f, tr, fa, t1_w, t1_f) fat[fat < 0] = 0 water[water < 0] = 0 # Clean up temp dirs shutil.rmtree(tmp) return fat, water
def _predict_mask_numpy(model, op_phase, in_phase, tmp): input_folder = os.path.join(tmp, 'input_folder') predictions = os.path.join(tmp, 'predictions') output_folder = os.path.join(tmp, 'output_folder') os.makedirs(input_folder) os.makedirs(predictions) os.makedirs(output_folder) # Save numpy arrays as nifti case_id = "dixon" file_op = os.path.join(input_folder, f"{case_id}_0000.nii.gz") file_ip = os.path.join(input_folder, f"{case_id}_0001.nii.gz") nifti_op = nib.Nifti1Image(op_phase, np.eye(4)) nifti_ip = nib.Nifti1Image(in_phase, np.eye(4)) nib.save(nifti_op, file_op) nib.save(nifti_ip, file_ip) # Create predictions in a temporary output_folder _predict_mask_folder(model, input_folder, output_folder, predictions) #__predict_mask_folder(model, input_folder, output_folder) # Return result as binary numpy array mask_file = os.path.join(output_folder, f"{case_id}.nii.gz") waterdom = nib.load(mask_file).get_fdata().astype(np.int8) return waterdom # def __predict_mask_folder(model, input_folder, predictions): # # TOD: consider API for numpy arrays to avoid read and write to temp folders # # Initialize predictor # plans_dir = os.path.join( # model, "Dataset001_FatWaterPredictor", # "nnUNetTrainer__nnUNetPlans__3d_fullres" # ) # predictor = nnUNetPredictor() # predictor.initialize_from_trained_model_folder(plans_dir, None) # predictor.predict_from_files(input_folder, predictions) # # Skip the postprocessing - not managed to get this to work # # yet in the python API but should be fixable. def _predict_mask_folder(model, input_folder, output_folder, predictions): # These two variables are not used but we are setting to a # dummy value to silence the warnings os.environ["nnUNet_raw"] = input_folder os.environ["nnUNet_preprocessed"] = input_folder # Folder containing the model weights os.environ["nnUNet_results"] = model # Predict and save results in the temporary folder cmd = [ "nnUNetv2_predict", "-d", "Dataset001_FatWaterPredictor", "-i", input_folder, "-o", predictions, "-f", "0", "1", "2", "3", "4", "-tr", "nnUNetTrainer", "-c", "3d_fullres", "-p", "nnUNetPlans", ] process = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, encoding="utf-8", # <-- force UTF-8 decoding errors="replace" # <-- avoids crash if weird bytes appear ) # Stream logs in real-time for line in process.stdout: print(line, end="") retcode = process.wait() if retcode != 0: raise RuntimeError(f"Prediction failed with exit code {retcode}") # Run post-processing # os.makedirs(output_folder, exist_ok=True) source = os.path.join(model, 'Dataset001_FatWaterPredictor', 'nnUNetTrainer__nnUNetPlans__3d_fullres', "crossval_results_folds_0_1_2_3_4") pproc = os.path.join(source, 'postprocessing.pkl') plans = os.path.join(source, 'plans.json') cmd = [ "nnUNetv2_apply_postprocessing", "-i", predictions, "-o", output_folder, "-pp_pkl_file", pproc, "-np", "8", "-plans_json", plans, ] process = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, encoding="utf-8", # <-- force UTF-8 decoding errors="replace" # <-- avoids crash if weird bytes appear ) # Stream logs in real-time for line in process.stdout: print(line, end="") retcode = process.wait() if retcode != 0: raise RuntimeError(f"Postprocessing failed with exit code {retcode}") def _compute_fatwater(waterdom, op_phase, in_phase, te_o, te_i, t2s_w, t2s_f, tr, fa, t1_w, t1_f): Eof, Eif, Eow, Eiw = 1, 1, 1, 1 if te_o is not None: # Add T2* correction Eof = np.exp(-te_o/t2s_f) Eif = np.exp(-te_i/t2s_f) Eow = np.exp(-te_o/t2s_w) Eiw = np.exp(-te_i/t2s_w) if tr is not None: # Add T1 correction ef = np.exp(-tr/t1_f) ew = np.exp(-tr/t1_w) cfa = np.cos(np.deg2rad(fa)) Af = (1 - ef) / (1 - cfa * ef) Aw = (1 - ew) / (1 - cfa * ew) Eof *= Af Eif *= Af Eow *= Aw Eiw *= Aw Efatdom = np.array([[Eof, -Eow], [Eif, Eiw]]) Ewatdom = np.array([[-Eof, Eow], [Eif, Eiw]]) Efatdom_inv = np.linalg.inv(Efatdom) Ewatdom_inv = np.linalg.inv(Ewatdom) fat, water = _apply_pixelwise_matrix(op_phase, in_phase, waterdom, Efatdom_inv, Ewatdom_inv) return fat, water def _apply_pixelwise_matrix(a, b, mask, M0, M1) -> np.ndarray: """ For each pixel/voxel combine [a, b] as a 2-vector v and compute: result = M0 @ v if mask == 0/False result = M1 @ v if mask == 1/True Returns arrays c, d of same shape and type Parameters ---------- a, b : np.ndarray Input 3D arrays of the same shape (spatial). mask : np.ndarray Boolean/0-1 array same shape as `a`/`b`. True selects M1. M0, M1 : array-like (2x2) Two 2x2 matrices. Returns ------- np.ndarray Output 3D arrays of the same shape (spatial). """ # stack components into last axis: shape (..., 2) v = np.stack((a, b), axis=-1).astype(float) # shape (Z, Y, X, 2) # compute results for both matrices: result = v @ M.T (vector @ M.T => M @ v per-voxel) res0 = v @ M0.T # shape (..., 2) res1 = v @ M1.T # Select based on mask; expand mask to last axis mask_bool = np.asarray(mask, dtype=bool) mask_expanded = mask_bool[..., None] # shape (..., 1) result = np.where(mask_expanded, res1, res0) # shape (..., 2) return result[...,0], result[...,1]