12

I trained a model in keras and I'm thinking of pruning my fully connected network. I'm little bit lost on how to prune the layers.

Author of 'Learning both Weights and Connections for Efficient Neural Networks', say that they add a mask to threshold weights of a layer. I can try to do the same and fine tune the trained model. But, how does it reduce model size and # of computations?

Illuminati0x5B
  • 602
  • 7
  • 24
  • To be specific, you want to know how to prune specific weights in the neural network? For example, given a `W` matrix, you want to set some of the elements to 0? – gorjan Jun 03 '19 at 22:51
  • @gorjan My goal is to reduce the final model size and speed up inference. I'm not sure if setting some of the values of `W` would reduce the model size. I need to way to remove the connections. As far I understand, TensorRT and TensorFlow Lite do this? – Illuminati0x5B Jun 04 '19 at 00:01
  • 2
    You can't essentially "delete" weights. What you can do, is set certain weights to 0s and then threat the matrices as sparse matrices. Then, TF has some minor support for dense-sparse/sparse-sparse matrix multiplication that can be used to accelerate inference. Here is a related stackoverflow thread: https://stackoverflow.com/questions/44859321/how-to-perform-efficient-sparse-matrix-multiplication-by-using-tf-matmul – gorjan Jun 04 '19 at 00:11
  • @gorjan Makes sense. I thought there was more to it than this. Let me try implementing something similar to this. – Illuminati0x5B Jun 04 '19 at 02:07
  • Sure! As an answer I will post a method that given a weight matrix `w: tf.Variable`, and `k: int`, it will delete the `k%` smallest weights (elements in the matrix) based on their norm. – gorjan Jun 04 '19 at 21:40

2 Answers2

10

Based on the discussion in the comments, here is a way to prune a layer (a weight matrix) of your neural network. What the method essentially does is selects the k% smallest weights (elements of the matrix) based on their norm, and sets them to zero. That way, the corresponding matrix can be treated as a sparse matrix, and we can perform dense-sparse matrix multiplication which can be faster if enough weights are pruned.

def weight_pruning(w: tf.Variable, k: float) -> tf.Variable:
    """Performs pruning on a weight matrix w in the following way:

    - The absolute value of all elements in the weight matrix are computed.
    - The indices of the smallest k% elements based on their absolute values are selected.
    - All elements with the matching indices are set to 0.

    Args:
        w: The weight matrix.
        k: The percentage of values (units) that should be pruned from the matrix.

    Returns:
        The unit pruned weight matrix.

    """
    k = tf.cast(tf.round(tf.size(w, out_type=tf.float32) * tf.constant(k)), dtype=tf.int32)
    w_reshaped = tf.reshape(w, [-1])
    _, indices = tf.nn.top_k(tf.negative(tf.abs(w_reshaped)), k, sorted=True, name=None)
    mask = tf.scatter_nd_update(tf.Variable(tf.ones_like(w_reshaped, dtype=tf.float32), name="mask", trainable=False), tf.reshape(indices, [-1, 1]), tf.zeros([k], tf.float32))

    return w.assign(tf.reshape(w_reshaped * mask, tf.shape(w)))

While the method above prunes a single connection (weight), the method below prunes a whole neuron from a weight matrix. Namely, the method select the k% smallest neurons (columns of the weight matrix) based on the Euclidean norm, and sets them to zero.

def unit_pruning(w: tf.Variable, k: float) -> tf.Variable:
    """Performs pruning on a weight matrix w in the following way:

    - The euclidean norm of each column is computed.
    - The indices of smallest k% columns based on their euclidean norms are selected.
    - All elements in the columns that have the matching indices are set to 0.

    Args:
        w: The weight matrix.
        k: The percentage of columns that should be pruned from the matrix.

    Returns:
        The weight pruned weight matrix.

    """
    k = tf.cast(
        tf.round(tf.cast(tf.shape(w)[1], tf.float32) * tf.constant(k)), dtype=tf.int32
    )
    norm = tf.norm(w, axis=0)
    row_indices = tf.tile(tf.range(tf.shape(w)[0]), [k])
    _, col_indices = tf.nn.top_k(tf.negative(norm), k, sorted=True, name=None)
    col_indices = tf.reshape(
        tf.tile(tf.reshape(col_indices, [-1, 1]), [1, tf.shape(w)[0]]), [-1]
    )
    indices = tf.stack([row_indices, col_indices], axis=1)

    return w.assign(
        tf.scatter_nd_update(w, indices, tf.zeros(tf.shape(w)[0] * k, tf.float32))
    )

Finally, this Github repository goes through the pruning methods explained here and performs experiments on the MNIST dataset.

gorjan
  • 5,405
  • 2
  • 20
  • 40
4

If you add a mask, then only a subset of your weights will contribute to the computation, hence your model will be pruned. For instance, autoregressive models use a mask to mask out the weights that refer to future data so that the output at time step t only depends on time steps 0, 1, ..., t-1.

In your case, since you have a simple fully connected layer, it is better to use dropout. It randomly turns off some neurons at each iteration step so it reduces the computation complexity. However, the main reason dropout was invented is to tackle overfitting: by having some neurons turned off randomly, you reduce neurons' co-dependencies, i.e. you avoid that some neurons rely on others. Moreover, at each iteration, your model will be different (different number of active neurons and different connections between them), hence your final model can be interpreted as an ensamble (collection) of several diifferent models, each specialized (we hope) in the understanding of a specific subset of the input space.

Neb
  • 2,270
  • 1
  • 12
  • 22
  • Yes. But, my goal is to speed up my inference and reduce model size. If I do use a mask, I still will be storing all the layer's weights and I still will be performing the entire W.X + b (with some of W_ij set to 0.) – Illuminati0x5B May 24 '19 at 20:30
  • If your task is to reduce model size, then there's no way you can achieve this with a dynamic mask. If the mask is static, then simply remove the weights you are not interested to learn. Your network will become sparser. – Neb May 24 '19 at 20:40
  • Using a mask, does speed up computation. Consider a mask that filters out the first 3 columns of a matrix `W`. Then, you can implement it as `W[:, 3:]`. In this way, the computation will be done only on the remaining part of the matrix. For more complex masks (not continuous ecc), you still get some advantage because the gradients will not be computed for weights equal to 0 – Neb May 24 '19 at 20:44
  • But, again, the reason behind mask is in general not for speed up training. – Neb May 24 '19 at 20:44