0

I am trying to get DCGAN ( Deep Convolutional Generative Adversarial Networks) to work with tensorflow for Java.

I have added the necessary code to DCGAN’s model.py as below to output a graph to be later used in tensorflow for Java.

//at the beginning to define where the model will be saved
    #
    self.load_dir = load_dir
    self.models_dir = models_dir

    graph = tf.Graph()
    self.graph = graph

    self.graph.as_default()
    #
//near the end where the session is ran in order to build and save the model to be used in tensorflow for java. A model is saved every 200 samples as defined by DCGAN’s default settings.
    #
    steps = "training_steps-" + "{:08d}".format(step)
    set_models_dir = os.path.join(self.models_dir, steps)
    builder = tf.saved_model.builder.SavedModelBuilder(set_models_dir)
    self.builder = builder
    self.builder.add_meta_graph_and_variables(self.sess, [tf.saved_model.tag_constants.SERVING])
    self.builder.save()
    #

The above codes output a graph that is loaded by the following Java code

package Main;

import java.awt.image.BufferedImage;
import java.io.File;
import java.util.Random;

import javax.imageio.ImageIO;

import org.tensorflow.Tensor;

public class DCGAN {
    public static void main(String[] args) throws Exception {
        String model_dir = "E:\\AgentWeb\\mnist-steps\\training_steps-00050000";
        //SavedModelBundle model = SavedModelBundle.load(model_dir , "serve");
        //Session sess = model.session();

        Random rand = new Random();
        int sample_num = 64;
        int z_dim = 100;
        float [][] gen_random = new float [64][100];
        for(int i = 0 ; i < sample_num ; i++) {
            for(int j = 0 ; j < z_dim ; j++) {
                gen_random[i][j] = (float)rand.nextGaussian();
            }
        }
        Tensor <Float> sample_z = Tensor.<Float>create(gen_random, Float.class);


        Tensor <Float> sample_inputs = Tensor.<Float>create(placeholder, Float.class);
// placeholder is the tensor which I want to create after solving the problem below.

        //Tensor result = sess.runner().fetch("t_vars").feed("z", sample_z).feed("inputs", sample_inputs).run().get(3);
    }
}

(I have left some comments as I used them for debugging)

With this method I am stuck at a certain portion of translating the python code to Java for use in tensorflow for Java. In DCGAN’s model.py where the images are processed there’s the following code.

          get_image(sample_file,
                    input_height=self.input_height,
                    input_width=self.input_width,
                    resize_height=self.output_height,
                    resize_width=self.output_width,
                    crop=self.crop,
                    grayscale=self.grayscale) for sample_file in sample_files]

which calls get_iamge in saved_utils.py as follows

def get_image(image_path, input_height, input_width,
              resize_height=64, resize_width=64,
              crop=True, grayscale=False):
  image = imread(image_path, grayscale)
  return transform(image, input_height, input_width,
                   resize_height, resize_width, crop)

which then calls a method called imread as follows

def imread(path, grayscale = False):
  if (grayscale):
    return scipy.misc.imread(path, flatten = True).astype(np.float)
  else:
    # Reference: https://github.com/carpedm20/DCGAN-tensorflow/issues/162#issuecomment-315519747
    img_bgr = cv2.imread(path)
    # Reference: https://stackoverflow.com/a/15074748/
    img_rgb = img_bgr[..., ::-1]
    return img_rgb.astype(np.float)

My question is that I am unsure what the img_rgb = img_bgr[..., ::-1] part does and how do I translate it for use in my Java file in tensorflow.java.

I am familiar with the way python slices arrays but I am unfamiliar with the three dots used in there. I did read about the reference to the stackoverflow questions there and it mentions that it is similar to img[:, :, ::-1]. But I am not really sure about what it is exactly doing.

Any help is appreciated and thank you for taking your time to read this long post.

1 Answers1

0

What basically do the imread and get_image is 1) reads an image 2) convert it from BGR to RGB 3) convert it to floats 4) rescale the image

You can do this in Java either by using an imaging library, such as JMagick or AWT, or by using TensorFlow.

If you use TensorFlow, it is possible to run this preprocessing in eager mode or by building and running a small graph. For example, given tf an instance of org.tensorflow.op.Ops:

  • tf.image.decode* can read content of an image (you know to know the type of your image though to pick the right operation).
  • tf.reverse can reverse the value in your channel dimension (RGB to BGR)
  • tf.dtypes.cast can convert the image to floats
  • tf.image.resizeBilinear can rescale your image
Karl Lessard
  • 461
  • 2
  • 6