I wanted to pass a matrix of weights of the classes of a dataset to a neural network.
from sklearn.utils import class_weight
class_weights = class_weight.compute_class_weight('balanced',
np.unique(y_train),
y_train)
However I get the following error :
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-93-9452aecf4030> in <module>
2 class_weights = class_weight.compute_class_weight('balanced',
3 np.unique(y_train),
----> 4 y_train)
~\AppData\Roaming\Python\Python36\site-packages\sklearn\utils\class_weight.py in compute_class_weight(class_weight, classes, y)
39
40 if set(y) - set(classes):
---> 41 raise ValueError("classes should include all valid labels that can "
42 "be in y")
43 if class_weight is None or len(class_weight) == 0:
ValueError: classes should include all valid labels that can be in y
I don't understand, here is part of my y_train
dataset:
grade_A grade_B grade_C grade_D grade_E grade_F grade_G
689526 0 1 0 0 0 0 0
523913 1 0 0 0 0 0 0
266122 0 0 1 0 0 0 0
362552 0 0 0 1 0 0 0
classes [A,B,C,D,E,F]
include all valid labels that can be put in y !
Update
I tried to make use .values on the dataframe:
from sklearn.utils import class_weight
class_weights = class_weight.compute_class_weight('balanced',
np.unique(y_train.values),
y_train.values)
However it returned:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-25-c2342f04abd9> in <module>
2 class_weights = class_weight.compute_class_weight('balanced',
3 np.unique(y_train.values),
----> 4 y_train.values)
~\AppData\Roaming\Python\Python36\site-packages\sklearn\utils\class_weight.py in compute_class_weight(class_weight, classes, y)
38 from ..preprocessing import LabelEncoder
39
---> 40 if set(y) - set(classes):
41 raise ValueError("classes should include all valid labels that can "
42 "be in y")
TypeError: unhashable type: 'numpy.ndarray
If I type print(type(y_train)) I get the following answer:
<class 'pandas.core.frame.DataFrame'>