4

I am training a model with Tensorflow Estimator, and my data is not balanced. I want to correct for this by weighting each training example.

In raw Tensorflow one might do it like this. Is there an easy way to do this in Estimator? Perhaps building a custom input_fn?

Maxim
  • 52,561
  • 27
  • 155
  • 209
opus111
  • 2,744
  • 4
  • 25
  • 41
  • 1
    Possible duplicate of [Upweight a Category in Tensorflow](https://stackoverflow.com/questions/48098951/upweight-a-category-in-tensorflow) – Maxim Jan 04 '18 at 16:28

2 Answers2

1

I assume you're doing classification. If so, use tf.estimator.DNNClassifier:

weight_column: A string or a _NumericColumn created by tf.feature_column.numeric_column defining feature column representing weights. It is used to down weight or boost examples during training. It will be multiplied by the loss of the example. If it is a string, it is used as a key to fetch weight tensor from the features. If it is a _NumericColumn, raw tensor is fetched by key weight_column.key, then weight_column.normalizer_fn is applied on it to get weight tensor.

Maxim
  • 52,561
  • 27
  • 155
  • 209
1

If you are building a custom estimator model, you should forward the class weight for each sample of your dataset as a feature to your model_fn and when defining the loss function op you can pass the class weight on weight parameter.

Example:

 tf.losses.softmax_cross_entropy(target, logits,weights=features['weight'])
Rodrigo Pereira
  • 233
  • 1
  • 11
  • Sadly the weights parameter is no longer available. However, the mathematics of cross entropy means that you get the same effect by changing the target value from 1.0, to a lower number to downweight an example or to a higher number to upweight it. – chrishmorris Jan 19 '20 at 12:08
  • If you are using Keras and tf.Dataset API, you can specify the sample weight if your input_fn outputs a 3-value tuple -> (Input,Target,Weight) – Rodrigo Pereira Jan 20 '20 at 14:09
  • Hi, Do you have any code example for tf2 to implent custom weights in a model_fn? – David Bacelj Sep 12 '20 at 13:27