0

I'm devising a neural network model with point cloud. With n points input, I use an NN model to get a score for each point. I use the score to decide which k points I need to reserve and which I need to throw away. Because of gradient backprop, I could not just throw those low-score points away, I need to use a mask to set their corresponding positions 0. So, how could I define that mask in TensorFlow?

I tried to define the mask as a tensor with mask = tf.ones((batch_size, num_point)), but I could not alter its value as 'Tensor' object does not support item assignment.

Here's partial of my code,

score_index = tf.argsort(score, axis=1, direction='DESCENDING') # score.shape(batch_size, n)
topn_index = sort_index[:, :16] #keep 16 high score points
mask = tf.zeros((batch_size, num_point))
mask[topn_index]=1

I would appreciate if you could provide suggestions.

bob wong
  • 21
  • 4

1 Answers1

0

First solution might be tf.boolean_mask Maybe you should check assign operator(should be tf.Variable).

Or another way is create a numpy array and each time you can use Tensordot for multiply numpy array and tensor or maybe cast numpy array to tensor.

Min Bui
  • 1
  • 1
  • Thanks for your reply, but the problem is that I do not know how to create a mask. I could easily get index `topn_index` in tensor format, but I can't set the corresponding position in the mask to 0/1. I tried to define the mask as an numpy array first, but failed as I could not index it with a tensor format index `topn_index`. – bob wong Jul 25 '19 at 03:42
  • @bobwong as I mentioned above,you can use tf.Assign to change the value of a tensor. Maybe i quite not get the idea about **topn_index** ? can you explain a little more about that?If that is your own function so you need to modify it to suitable with each solution.Sorry for my bad english. Maybe this post will help you : https://stackoverflow.com/questions/33769041/tensorflow-indexing-with-boolean-tensor – Min Bui Jul 25 '19 at 03:57
  • `topn_index` is index of the position I need to keep. I could get it just from the sorted index `sort_index`. – bob wong Jul 25 '19 at 04:35