I am looking for a gpu-accelerated n-dimensional sliding window operation implementation in Python using Tensorflow. You can post your implementation in Torch, Caffe or Theano, but I'll choose the Tensorflow implementation as the accepted answer. Please post working code snippet that performs a 2d median filter operation (hopefully, with no code change or minimal code change, can be applied to n-dimensional images)
With my limited knowledge on Tensorflow, I believe the 2 potential modules to start with are sliding_window_batch
or extract_image_patches
and then with some map
,apply
,reshape
magic?
My failed attempt is posted below, for entertainment. Please note I have posted a similar question 2 years ago, asking for a Theano implementation, nowadays, most people are using tf/keras or torch.
import time
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import tensorflow as tf
from tensorflow.contrib.data.python.ops import sliding
from skimage import img_as_float, data
from scipy.signal import medfilt
imgs = img_as_float(data.camera())
### SCIPY median ###
stime = time.time()
scipysmoothed = medfilt(imgs,(9,9))
etime = time.time()
print('scipy smoothed: {:1.4f} seconds'.format(etime-stime))
### Failed attempt of TF median ###
method = 'Tensorflow'
stime = time.time()
window_func = lambda x: tf.contrib.distributions.percentile(x, 50.0)
# create TensorFlow Dataset object
data = tf.data.Dataset.from_tensor_slices(imgs)
# sliding window - only 1d is allowed?
window = 3
stride = 1
data = data.apply(sliding.sliding_window_batch(window, stride)).map(lambda x: window_func(x))
# create TensorFlow Iterator object
iterator = tf.data.Iterator.from_structure(data.output_types)
next_element = iterator.get_next()
# create initialization ops
init_op = iterator.make_initializer(data)
c=0
smoothed = np.zeros(imgs.shape)
with tf.Session() as sess:
# initialize the iterator on the data
sess.run(init_op)
while True:
try:
elem = sess.run(next_element)
smoothed[c,:]=elem
# obviously WRONG.
c+=1
except tf.errors.OutOfRangeError:
#print("End of dataset.")
break
#print(c)
etime = time.time()
print('tf smoothed: {:1.4f} seconds'.format(etime-stime))
plt.figure(figsize=(20,20))
plt.subplot(131)
plt.imshow(imgs,cmap='gray',interpolation='none')
plt.title('original')
plt.subplot(132)
plt.imshow(smoothed,cmap='gray',interpolation='none')
plt.title('actual smoothed\nwith {}'.format(method))
plt.subplot(133)
plt.imshow(scipysmoothed,cmap='gray',interpolation='none')
_=plt.title('expected smoothed')
.
scipy smoothed: 1.1899 seconds
tf smoothed: 0.7485 seconds