Source code for OCT_GUI.Algorithms.ManualCorrection.GraphCutBM

""" 
GraphCutBM: Manual Refinement Graph-Cut Pipeline for Bruch's Membrane
-------------------------------------------------------------------------
PRLEC Framework for OCT Processing and Visualization 
"""
# This framework evolved from a collaboration of:
# - Research Laboratory of Electronics, Massachusetts Institute of Technology, Cambdrige, MA, US
# - Pattern Recognition Lab, Friedrich-Alexander-Universitaet Erlangen-Nuernberg, Germany
# - Department of Biomedical Engineering, Peking University, Beijing, China
# - New England Eye Center, Tufts Medical Center, Boston, MA, US
# v1.0: Updated on Mar 20, 2019
# @author: Daniel Stromer - EMAIL:daniel.stromer@fau.de
# Copyright (C) 2018-2019 - Daniel Stromer
# PRLE is developed as an Open Source project under the GNU General Public License (GPL) v3.0.
from igraph import *
import numpy as np
import cv2
import scipy.signal as sp
from FileHandler.ParameterReader import readManRefParameters
#min value for graph weights
w_min = 1e-5

[docs]def propagateBM(volumeOriginal, segmentation, segmentation_original, saved_slices, rect_correction, mode): """ Manual refinement algorithm for Bruch's Membrane. The algorithm takes the corrected slices and connects them with lines of certain thicknesses in slow-scan direction. The thickness is based on the mode, detected by the propagation call. If the spacing between subsequent corrected lines is rather low, the thickness of the connections is 1 resulting in the final output. If the spacing is rather high, the thickness is t and in this area, a graph cut is executed. Parameter used from Parameters text: - MAN_BM_BF_BSCAN: B-scan smoothing bilateral filter values - MAN_BM_MEDIAN: Median filtering resulting lines - MAN_BM_THICKNESS: Thickness of connecting lines Parameters ---------- volumeOriginal: ndarray original volume segmentation: ndarray segmentation segmentation_original:ndarray original segmentation saved_slices: dictionary (scalar, ndarray) = (slice number, slice) slices that were manually corrected rect_correction[0]: scalar startpoint x-axis rect_correction[2]: scalar startpoint y-axis rect_correction[1]: scalar endpoint x-axis rect_correction[3]:scalar endpoint y-axis mode: string 'high' or 'low' Returns --------- cropped_segmentation: ndarray segmentation result in cropped volume """ dictParameters = readManRefParameters() #crop volume and swap axes part_volume = volumeOriginal[rect_correction[2]-1:rect_correction[3]+1,:,rect_correction[0]-1:rect_correction[1]+2] part_volume = part_volume/np.max(part_volume) part_volume = np.swapaxes(part_volume, axis1=0, axis2=2) cropped_segmentation = np.zeros((rect_correction[3] - rect_correction[2] + 2, segmentation.shape[1], rect_correction[1] - rect_correction[0] + 3)).astype('uint8') for key,value in saved_slices.items(): cropped_segmentation[key-rect_correction[2]+1, :,:] = value[:,rect_correction[0]-1:rect_correction[1]+2].astype('uint8') #set BM cropped_segmentation = np.where(cropped_segmentation != dictParameters['BM_VALUE'], 0, cropped_segmentation) cropped_segmentation[:,:, 0] = np.where(segmentation[rect_correction[2]-1:rect_correction[3]+1,: ,rect_correction[0] - 1] == dictParameters['BM_VALUE'], dictParameters['BM_VALUE'], 0) cropped_segmentation[:,:, -1] = np.where(segmentation[rect_correction[2]-1:rect_correction[3]+1,:, rect_correction[1] + 1] == dictParameters['BM_VALUE'], dictParameters['BM_VALUE'], 0) cropped_segmentation[0 ,:,:] = np.where(segmentation[rect_correction[2]-1,:,rect_correction[0]-1:rect_correction[1]+2] == dictParameters['BM_VALUE'], dictParameters['BM_VALUE'], 0) cropped_segmentation[-1,:,:] = np.where(segmentation[rect_correction[3]+1,:,rect_correction[0]-1:rect_correction[1]+2] == dictParameters['BM_VALUE'], dictParameters['BM_VALUE'], 0) cropped_segmentation = np.swapaxes(cropped_segmentation, axis1=0, axis2=2) #connect points with lines cropped_segmentation = connectpoints(cropped_segmentation,mode,dictParameters) #Graph-Cut, if too little number of lines for region if(mode is 'high'): for i in range(1,cropped_segmentation.shape[0]-1): graph, endnode = getGraph(part_volume[i], cropped_segmentation[i],dictParameters) shortest_path = np.asarray(graph.get_shortest_paths(v=0, to=endnode, weights = 'weight')) cropped_segmentation[i] = inpaint(cropped_segmentation[i],shortest_path,dictParameters) cropped_segmentation = np.swapaxes(cropped_segmentation, axis1=2, axis2=0) #median filter if(mode is 'high'): try: for i in range(1,cropped_segmentation.shape[0]-1): result = np.zeros(cropped_segmentation[i].shape) values = np.where(cropped_segmentation[i].transpose() == dictParameters['BM_VALUE']) y_values = sp.medfilt(values[1],dictParameters['MAN_BM_MEDIAN']).astype('uint16') result[y_values,values[0]] = dictParameters['BM_VALUE'] cropped_segmentation[i] = result except: return cropped_segmentation #Inpaint RPE and BM vals = np.where(np.logical_and(segmentation_original[rect_correction[2]-1:rect_correction[3]+1,:,rect_correction[0]-1:rect_correction[1]+2] == dictParameters['RPE_VALUE'], cropped_segmentation ==dictParameters['BM_VALUE'])) cropped_segmentation = np.where(segmentation_original[rect_correction[2]-1:rect_correction[3]+1,:,rect_correction[0]-1:rect_correction[1]+2] == dictParameters['RPE_VALUE'] ,dictParameters['RPE_VALUE'], cropped_segmentation) for i in range(len(vals[1])): cropped_segmentation[vals[0][i],vals[1][i]-1,vals[2][i]] = dictParameters['RPE_VALUE'] cropped_segmentation[vals[0][i],vals[1][i],vals[2][i]] = dictParameters['BM_VALUE'] #Inpaint ILM cropped_segmentation = np.where(segmentation_original[rect_correction[2]-1:rect_correction[3]+1,:,rect_correction[0]-1:rect_correction[1]+2] == dictParameters['ILM_VALUE'] ,dictParameters['ILM_VALUE'], cropped_segmentation).astype('uint8') return cropped_segmentation
[docs]def connectpoints(cropped_segmentation,mode,dictParameters): """ Connecting points by line with certain thickness. The thickness depends on the mode (high or low). Parameters ---------- cropped_segmentation: ndarray cropped segmentation volume mode: string 'high' or 'low' dictParameters: dictionary Parameters from parameter.txt Returns --------- cropped_segmentation: ndarray segmentation result in cropped volume """ for k in range (1,cropped_segmentation.shape[0]-1): points = np.asarray(np.where(cropped_segmentation[k].transpose() == dictParameters['BM_VALUE'])).astype('int32') buffer_slice = cropped_segmentation[k].copy() for i in range (points[0].shape[0]-1): if(np.abs(points[0][i] - points[0][i+1]) > 1): if(mode is 'low'): cv2.line(buffer_slice,(points[0][i],points[1][i]),(points[0][i+1],points[1][i+1]),dictParameters['BM_VALUE'],1) else: cv2.line(buffer_slice,(points[0][i],points[1][i]),(points[0][i+1],points[1][i+1]),dictParameters['BM_VALUE'], dictParameters['MAN_BM_THICKNESS']) buffer_slice[:,points[0][i]] = 0 buffer_slice[points[1][i],points[0][i]] = dictParameters['BM_VALUE'] buffer_slice[:,points[0][-1]] = 0 buffer_slice[points[1][-1],points[0][-1]] = dictParameters['BM_VALUE'] cropped_segmentation[k] = buffer_slice return cropped_segmentation
[docs]def getGraph(_slice, shortest_path,dictParameters): """ Construct Graph for Manual BM Refinement. 7-neighborhood. Parameters ---------- _slice: ndarray calcualted weights shortest_path: ndarray shortest path from predecessor dictParameters: dictionary Parameters from parameter.txt Returns --------- g: graph resulting graph endpoint: scalar endpoint of graph """ slice_copy = _slice.copy() buffer_result = np.zeros((_slice.shape[0],_slice.shape[1]+2)) buffer_result[:,0] = -2 buffer_result[:,-1] = 2 buffer_result[:,1:-1] = shortest_path # Gradient filter with kernel width 3 to enforce smoothness grad = cv2.filter2D(slice_copy,-1, np.array([[1,1,1],[-1,-1,-1]])) # allow only areas where a line was inpainted buffer_result[:,1:-1] = np.where(buffer_result[:,1:-1] == dictParameters['BM_VALUE'], grad, np.nan) buffer_result[:,0] = -2 buffer_result[:,-1] = 2 #2-D to 1-D conversion slice_1D = buffer_result.flatten() # set up a graph where every pixel is a node (edges are transitions in between) g = Graph(directed=True) g.add_vertices(slice_1D.size) sx = buffer_result.shape[1] edge_list = [] weight_list = [] for idx in range(slice_1D.size): if np.isnan(slice_1D[idx]): continue v_i = idx #7-pixel neighborhood, to increase/decrease, change here if((v_i+1) % sx != 0 ): v_j = v_i +1 if (v_j >= 0 and v_j < slice_1D.shape[0] and np.isnan(slice_1D[v_j]) == False): edge_list.append((v_i, v_j)) weight_list.append(calculateWeightStandard(slice_1D[v_i],slice_1D[v_j])) v_j = v_i - sx +1 if (v_j >= 0 and v_j < slice_1D.shape[0] and np.isnan(slice_1D[v_j]) == False): edge_list.append((v_i, v_j)) weight_list.append(calculateWeightStandard(slice_1D[v_i],slice_1D[v_j])) v_j = v_i + sx +1 if (v_j >= 0 and v_j < slice_1D.shape[0] and np.isnan(slice_1D[v_j]) == False): edge_list.append((v_i, v_j)) weight_list.append(calculateWeightStandard(slice_1D[v_i],slice_1D[v_j])) v_j = v_i - 2*sx +1 if (v_j >= 0 and v_j < slice_1D.shape[0] and np.isnan(slice_1D[v_j]) == False): edge_list.append((v_i, v_j)) weight_list.append(calculateWeightStandard(slice_1D[v_i],slice_1D[v_j])) v_j = v_i + 2*sx +1 if (v_j >= 0 and v_j < slice_1D.shape[0] and np.isnan(slice_1D[v_j]) == False): edge_list.append((v_i, v_j)) weight_list.append(calculateWeightStandard(slice_1D[v_i],slice_1D[v_j])) v_j = v_i - 3*sx +1 if (v_j >= 0 and v_j < slice_1D.shape[0] and np.isnan(slice_1D[v_j]) == False): edge_list.append((v_i, v_j)) weight_list.append(calculateWeightStandard(slice_1D[v_i],slice_1D[v_j])) v_j = v_i + 3*sx +1 if (v_j >= 0 and v_j < slice_1D.shape[0] and np.isnan(slice_1D[v_j]) == False): edge_list.append((v_i, v_j)) weight_list.append(calculateWeightStandard(slice_1D[v_i],slice_1D[v_j])) if(slice_1D[v_i] > 1.0 or slice_1D[v_i] < -1.0): slice_1D[v_i] = 0.0 v_j = v_i - sx if (v_j >= 0 and v_j < slice_1D.shape[0] and np.isnan(slice_1D[v_j]) == False): edge_list.append((v_i, v_j)) weight_list.append(w_min) v_j = v_i + sx if (v_j >= 0 and v_j < slice_1D.shape[0] and np.isnan(slice_1D[v_j]) == False): edge_list.append((v_i, v_j)) weight_list.append(w_min) else: v_j = v_i - sx if (v_j >= 0 and v_j < slice_1D.shape[0] and np.isnan(slice_1D[v_j]) == False): edge_list.append((v_i, v_j)) weight_list.append(calculateWeightStandard(slice_1D[v_i],slice_1D[v_j])) v_j = v_i + sx if (v_j >= 0 and v_j < slice_1D.shape[0] and np.isnan(slice_1D[v_j]) == False): edge_list.append((v_i, v_j)) weight_list.append(calculateWeightStandard(slice_1D[v_i],slice_1D[v_j])) g.add_edges(edge_list) g.es['weight'] = weight_list return g, slice_1D.size-1
[docs]def calculateWeightStandard(value_v_i, value_v_j): """ Calculate the linear weight between two vertices vi and vj. Parameters ---------- value_v_i: scalar, float intensity at vertices vi value_v_j: scalar, float intensity at vertices vj Returns --------- weight: scalar, float linear weight between edge vi and vj """ if(value_v_j > 1.0 or value_v_j < -1.0): return 2.0 - (value_v_i) + w_min else: return 2.0 - (value_v_i + value_v_j) + w_min
[docs]def inpaint(_slice, shortest_path,dictParameters): """ Inpaint path into segmentation. Parameters ---------- _slice: ndarray calcualted weights shortest_path: ndarray shortest path from predecessor dictParameters: dictionary Parameters from parameter.txt Returns --------- segmentation: ndarray resulting inpainted segmentation """ segmentation = np.zeros((_slice.shape[0], _slice.shape[1]+2)).astype('uint8') path_x = (shortest_path % segmentation.shape[1]).astype(np.int32) path_y = (shortest_path / segmentation.shape[1]).astype(np.int32) segmentation[path_y, path_x] = dictParameters['BM_VALUE'] return segmentation[:,1:-1]