2

We are currently trying to replicate the results of the following paper: https://openreview.net/forum?id=H1S8UE-Rb

To do so, we need to run backpropagation on a neural network which contains complex valued weights.

When we try to do so (with code [0]), we get an error (at [1]). We cannot find the source code for any project that trains a neural network containing complex valued weights.

We were wondering if we would need to implement the paper's backpropagation adjustments ourselves or if this is already part of some neural network libraries. If it needs to be implemented in Tensorflow, what would be the proper steps to achieve that?

[0]:

def define_neuron(x):
    """
    x is input tensor
    """

    x = tf.cast(x, tf.complex64)

    mnist_x = mnist_y = 28
    n = mnist_x * mnist_y
    c = 10
    m = 10  # m needs to be calculated

    with tf.name_scope("linear_combination"):
        complex_weight = weight_complex_variable([n,m])
        complex_bias = bias_complex_variable([m])
        h_1 = x @ complex_weight + complex_bias

    return h_1

def main(_):
    mnist = input_data.read_data_sets(
        FLAGS.data_dir,
        one_hot=True,
    )

    # `None` for the first dimension in this shape means that it is variable.
    x_shape = [None, 784]
    x = tf.placeholder(tf.float32, x_shape)
    y_ = tf.placeholder(tf.float32, [None, 10])

    yz = h_1 = define_neuron(x)

    y = tf.nn.softmax(tf.abs(yz))

    with tf.name_scope('loss'):
        cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
            labels=y_,
            logits=y,
        )

    cross_entropy = tf.reduce_mean(cross_entropy)

    with tf.name_scope('adam_optimizer'):
        optimizer = tf.train.AdamOptimizer(1e-4)
        optimizer = tf.train.GradientDescentOptimizer(1e-4)
        train_step = optimizer.minimize(cross_entropy)

[1]:

Extracting /tmp/tensorflow/mnist/input_data/train-images-idx3-ubyte.gz
Extracting /tmp/tensorflow/mnist/input_data/train-labels-idx1-ubyte.gz
Extracting /tmp/tensorflow/mnist/input_data/t10k-images-idx3-ubyte.gz
Extracting /tmp/tensorflow/mnist/input_data/t10k-labels-idx1-ubyte.gz
Traceback (most recent call last):
  File "complex.py", line 156, in <module>
    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
  File "/Users/kevin/wdev/learn_tensor/env/lib/python3.6/site-packages/tensorflow/python/platform/app.py", line 48, in run
    _sys.exit(main(_sys.argv[:1] + flags_passthrough))
  File "complex.py", line 58, in main
    train_step = optimizer.minimize(cross_entropy)
  File "/Users/kevin/wdev/learn_tensor/env/lib/python3.6/site-packages/tensorflow/python/training/optimizer.py", line 343, in minimize
    grad_loss=grad_loss)
  File "/Users/kevin/wdev/learn_tensor/env/lib/python3.6/site-packages/tensorflow/python/training/optimizer.py", line 419, in compute_gradients
    [v for g, v in grads_and_vars
  File "/Users/kevin/wdev/learn_tensor/env/lib/python3.6/site-packages/tensorflow/python/training/optimizer.py", line 547, in _assert_valid_dtypes
    dtype, t.name, [v for v in valid_dtypes]))
ValueError: Invalid type tf.complex64 for linear_combination/Variable:0, expected: [tf.float32, tf.float64, tf.float16].
Slackware
  • 960
  • 1
  • 13
  • 29

2 Answers2

2

I have also tried to implement a similar network in tensorflow and saw that the optimizer cannot do backpropagation using complex valued tensors. The work around is to have separate real tensors for the real and imaginary parts. You will have to do write a function that will get the amplitude of the "complex" output of the network which is simply Re^2 - Im^2. This output value is what you will use to compute the loss.

0

Using the optimizer won't work it is a reported issue and I don't think tf 2 support it yet. You can however make it by hand, for example:

[...]
gradients = tf.gradients(mse, [weights])[0]
training_op = tf.assign(weights, weights - learning_rate * gradients)
init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    sess.run(training_op)

Gradients here do as expected and compute the gradient as it should. Here is the discussion on what the gradient compute for complex variables.

J Agustin Barrachina
  • 3,501
  • 1
  • 32
  • 52