# coding: utf-8
# This is the interface API for the deepseg_gm model
# that implements the model for the Spinal Cord Gray Matter Segmentation.
#
# Reference paper:
# Perone, C. S., Calabrese, E., & Cohen-Adad, J. (2017).
# Spinal cord gray matter segmentation using deep dilated convolutions.
# URL: https://arxiv.org/abs/1710.01269
import warnings
import json
import os
import sys
import io
import nipy
import nibabel as nib
from nipy.io.nifti_ref import nipy2nifti, nifti2nipy
import numpy as np
# Avoid Keras logging
original_stderr = sys.stderr
if sys.hexversion < 0x03000000:
sys.stderr = io.BytesIO()
else:
sys.stderr = io.TextIOWrapper(io.BytesIO(), sys.stderr.encoding)
try:
from keras import backend as K
except Exception as e:
sys.stderr = original_stderr
raise
else:
sys.stderr = original_stderr
from spinalcordtoolbox.resample import nipy_resample
from . import model
# Suppress warnings and TensorFlow logging
warnings.simplefilter(action='ignore', category=FutureWarning)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
[docs]def check_backend():
"""This function will check for the current backend and
then it will warn the user if the backend is theano."""
if K.backend() != 'tensorflow':
print("\nWARNING: you're using a Keras backend different than\n"
"Tensorflow, which is not recommended. Please verify\n"
"your configuration file according to: https://keras.io/backend/\n"
"to make sure you're using Tensorflow Keras backend.\n")
return K.backend()
[docs]class DataResource(object):
"""This class is responsible for resource file
management (such as loding models)."""
def __init__(self, dirname):
"""Initialize the resource with the directory
name context.
:param dirname: the root directory name.
"""
path_sct = os.environ.get("SCT_DIR", os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
self.data_root = os.path.abspath(os.path.join(path_sct, "data", dirname))
[docs] def get_file_path(self, filename):
"""Get the absolute file path based on the
data root directory.
:param filename: the filename.
"""
return os.path.join(self.data_root, filename)
[docs]class CroppedRegion(object):
"""This class holds cropping information about the volume
center crop.
"""
def __init__(self, original_shape, starts, crops):
"""Constructor for the CroppedRegion.
:param original_shape: the original volume shape.
:param starts: crop beginning (x, y).
:param crops: the crops (x, y).
"""
self.originalx = original_shape[0]
self.originaly = original_shape[1]
self.startx = starts[0]
self.starty = starts[1]
self.cropx = crops[0]
self.cropy = crops[1]
[docs] def pad(self, image):
"""This method will pad an image using the saved
cropped region.
:param image: the image to pad.
:return: padded image.
"""
bef_x = self.startx
aft_x = self.originalx - (self.startx + self.cropx)
bef_y = self.starty
aft_y = self.originaly - (self.starty + self.cropy)
padded = np.pad(image,
((bef_y, aft_y),
(bef_x, aft_x)),
mode="constant")
return padded
[docs]def crop_center(img, cropx, cropy):
"""This function will crop the center of the volume image.
:param img: image to be cropped.
:param cropx: x-coord of the crop.
:param cropy: y-coord of the crop.
:return: (cropped image, cropped region)
"""
y, x = img.shape
startx = x // 2 - (cropx // 2)
starty = y // 2 - (cropy // 2)
if startx < 0 or starty < 0:
raise RuntimeError("Negative crop.")
cropped_region = CroppedRegion((x, y), (startx, starty),
(cropx, cropy))
return img[starty:starty + cropy,
startx:startx + cropx], cropped_region
[docs]def threshold_predictions(predictions, thr=0.999):
"""This method will threshold predictions.
:param thr: the threshold.
:return: thresholded predictions.
"""
thresholded_preds = predictions[:]
low_values_indices = thresholded_preds <= thr
thresholded_preds[low_values_indices] = 0
low_values_indices = thresholded_preds > thr
thresholded_preds[low_values_indices] = 1
return thresholded_preds
[docs]def segment_volume(ninput_volume, model_name):
"""Segment a nifti volume.
:param ninput_volume: the input volume.
:param model_name: the name of the model to use.
:return: segmented slices.
"""
gmseg_model_challenge = DataResource('deepseg_gm_models')
model_path, metadata_path = model.MODELS[model_name]
metadata_abs_path = gmseg_model_challenge.get_file_path(metadata_path)
with open(metadata_abs_path) as fp:
metadata = json.load(fp)
deepgmseg_model = model.create_model(metadata['filters'])
model_abs_path = gmseg_model_challenge.get_file_path(model_path)
deepgmseg_model.load_weights(model_abs_path)
volume_data = ninput_volume.get_data()
axial_slices = []
crops = []
for slice_num in range(volume_data.shape[2]):
data = volume_data[..., slice_num]
data, cropreg = crop_center(data, model.CROP_HEIGHT,
model.CROP_WIDTH)
axial_slices.append(data)
crops.append(cropreg)
axial_slices = np.asarray(axial_slices, dtype=np.float32)
axial_slices = np.expand_dims(axial_slices, axis=3)
axial_slices -= metadata['mean_train']
axial_slices /= metadata['std_train']
preds = deepgmseg_model.predict(axial_slices, batch_size=8,
verbose=True)
preds = threshold_predictions(preds)
pred_slices = []
# Un-cropping
for slice_num in range(preds.shape[0]):
pred_slice = preds[slice_num][..., 0]
pred_slice = crops[slice_num].pad(pred_slice)
pred_slices.append(pred_slice)
pred_slices = np.asarray(pred_slices, dtype=np.uint8)
pred_slices = np.transpose(pred_slices, (1, 2, 0))
return pred_slices
[docs]def segment_file(input_filename, output_filename,
model_name, verbosity):
"""Segment a volume file.
:param input_filename: the input filename.
:param output_filename: the output filename.
:param model_name: the name of model to use.
:param verbosity: the verbosity level.
:return: the output filename.
"""
nii_original = nipy.load_image(input_filename)
pixdim = nii_original.header["pixdim"][3]
target_resample = "0.25x0.25x{:.5f}".format(pixdim)
nii_resampled = nipy_resample.resample_image(nii_original,
target_resample,
'mm', 'linear',
verbosity)
if (nii_resampled.shape[0] < 200) \
or (nii_resampled.shape[1] < 200):
raise RuntimeError("Image too small ({}, {})".format(
nii_resampled.shape[0],
nii_resampled.shape[1]))
nii_resampled = nipy2nifti(nii_resampled)
pred_slices = segment_volume(nii_resampled, model_name)
original_res = "{:.5f}x{:.5f}x{:.5f}".format(
nii_original.header["pixdim"][1],
nii_original.header["pixdim"][2],
nii_original.header["pixdim"][3])
volume_affine = nii_resampled.affine
volume_header = nii_resampled.header
nii_segmentation = nib.Nifti1Image(pred_slices, volume_affine,
volume_header)
nii_segmentation = nifti2nipy(nii_segmentation)
nii_resampled_original = nipy_resample.resample_image(nii_segmentation,
original_res,
'mm', 'linear',
verbosity)
res_data = nii_resampled_original.get_data()
res_data = threshold_predictions(res_data, 0.5)
nipy.save_image(nii_resampled_original, output_filename)
return output_filename