I want to fit a dataframe to a sequential deep learning model with multiple units in the final layer of the model. I'm new to deep learning.
Code:
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from sklearn.preprocessing import LabelEncoder
# Label encode
encoder = LabelEncoder()
df["survival"] = encoder.fit_transform(df[["survival"]])
df["type"] = encoder.fit_transform(df[["type"]])
df["subtype"] = encoder.fit_transform(df[["subtype"]])
# Define Sequential model
def get_model():
model = keras.Sequential(
[
layers.Dense(10, activation="relu", name="layer1"),
layers.Dense(10, activation="relu", name="layer2"),
layers.Dense(2, name="layer3"),
]
)
model.compile(optimizer='adam',
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
metrics=['accuracy'])
return model
model = get_model()
X2 = df.iloc[:,7:9] # Retrieve 2 features
tf.convert_to_tensor(X2) # Convert to tensor
normalizer = tf.keras.layers.Normalization(axis=-1)
normalizer.adapt(X2)
model.fit(X2, df[["survival"]], epochs=15, batch_size=BATCH_SIZE)
Traceback:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/tmp/ipykernel_17/2890126290.py in <module>
3 normalizer = tf.keras.layers.Normalization(axis=-1)
4 normalizer.adapt(X2)
----> 5 model.fit(X2, df[["survival"]], epochs=15, batch_size=BATCH_SIZE)
6 model.fit(X2, df[["type"]], epochs=15, batch_size=BATCH_SIZE)
7 model.fit(X2, df[["subtype"]], epochs=15, batch_size=BATCH_SIZE)
/opt/conda/lib/python3.7/site-packages/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
1182 _r=1):
1183 callbacks.on_train_batch_begin(step)
-> 1184 tmp_logs = self.train_function(iterator)
1185 if data_handler.should_sync:
1186 context.async_wait()
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
883
884 with OptionalXlaContext(self._jit_compile):
--> 885 result = self._call(*args, **kwds)
886
887 new_tracing_count = self.experimental_get_tracing_count()
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
931 # This is the first call of __call__, so we have to initialize.
932 initializers = []
--> 933 self._initialize(args, kwds, add_initializers_to=initializers)
934 finally:
935 # At this point we know that the initialization is complete (or less
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
758 self._concrete_stateful_fn = (
759 self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access
--> 760 *args, **kwds))
761
762 def invalid_creator_scope(*unused_args, **unused_kwds):
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
3064 args, kwargs = None, None
3065 with self._lock:
-> 3066 graph_function, _ = self._maybe_define_function(args, kwargs)
3067 return graph_function
3068
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
3461
3462 self._function_cache.missed.add(call_context_key)
-> 3463 graph_function = self._create_graph_function(args, kwargs)
3464 self._function_cache.primary[cache_key] = graph_function
3465
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
3306 arg_names=arg_names,
3307 override_flat_arg_shapes=override_flat_arg_shapes,
-> 3308 capture_by_value=self._capture_by_value),
3309 self._function_attributes,
3310 function_spec=self.function_spec,
/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes, acd_record_initial_resource_uses)
1005 _, original_func = tf_decorator.unwrap(python_func)
1006
-> 1007 func_outputs = python_func(*func_args, **func_kwargs)
1008
1009 # invariant: `func_outputs` contains only Tensors, CompositeTensors,
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
666 # the function a weak reference to itself to avoid a reference cycle.
667 with OptionalXlaContext(compile_with_xla):
--> 668 out = weak_wrapped_fn().__wrapped__(*args, **kwds)
669 return out
670
/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
992 except Exception as e: # pylint:disable=broad-except
993 if hasattr(e, "ag_error_metadata"):
--> 994 raise e.ag_error_metadata.to_exception(e)
995 else:
996 raise
ValueError: in user code:
/opt/conda/lib/python3.7/site-packages/keras/engine/training.py:853 train_function *
return step_function(self, iterator)
/opt/conda/lib/python3.7/site-packages/keras/engine/training.py:842 step_function **
outputs = model.distribute_strategy.run(run_step, args=(data,))
/opt/conda/lib/python3.7/site-packages/tensorflow/python/distribute/distribute_lib.py:1286 run
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
/opt/conda/lib/python3.7/site-packages/tensorflow/python/distribute/distribute_lib.py:2849 call_for_each_replica
return self._call_for_each_replica(fn, args, kwargs)
/opt/conda/lib/python3.7/site-packages/tensorflow/python/distribute/distribute_lib.py:3632 _call_for_each_replica
return fn(*args, **kwargs)
/opt/conda/lib/python3.7/site-packages/keras/engine/training.py:835 run_step **
outputs = model.train_step(data)
/opt/conda/lib/python3.7/site-packages/keras/engine/training.py:789 train_step
y, y_pred, sample_weight, regularization_losses=self.losses)
/opt/conda/lib/python3.7/site-packages/keras/engine/compile_utils.py:201 __call__
loss_value = loss_obj(y_t, y_p, sample_weight=sw)
/opt/conda/lib/python3.7/site-packages/keras/losses.py:141 __call__
losses = call_fn(y_true, y_pred)
/opt/conda/lib/python3.7/site-packages/keras/losses.py:245 call **
return ag_fn(y_true, y_pred, **self._fn_kwargs)
/opt/conda/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py:206 wrapper
return target(*args, **kwargs)
/opt/conda/lib/python3.7/site-packages/keras/losses.py:1809 binary_crossentropy
backend.binary_crossentropy(y_true, y_pred, from_logits=from_logits),
/opt/conda/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py:206 wrapper
return target(*args, **kwargs)
/opt/conda/lib/python3.7/site-packages/keras/backend.py:5000 binary_crossentropy
return tf.nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)
/opt/conda/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py:206 wrapper
return target(*args, **kwargs)
/opt/conda/lib/python3.7/site-packages/tensorflow/python/ops/nn_impl.py:246 sigmoid_cross_entropy_with_logits_v2
logits=logits, labels=labels, name=name)
/opt/conda/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py:206 wrapper
return target(*args, **kwargs)
/opt/conda/lib/python3.7/site-packages/tensorflow/python/ops/nn_impl.py:133 sigmoid_cross_entropy_with_logits
(logits.get_shape(), labels.get_shape()))
ValueError: logits and labels must have the same shape ((None, 2) vs (None, 1))
I can only fit the model if I set the units to 1
in the final layer.
model = keras.Sequential(
[
layers.Dense(10, activation="relu", name="layer1"),
layers.Dense(10, activation="relu", name="layer2"),
layers.Dense(1, name="layer3"),
]
)
How do I use multiple units in the final layer of the model? I tried reshape(-1,1)
and still get the same error.
Data:
df.head().to_dict()
{'admin.disease_code': {'TCGA-2K-A9WE-01A': 'KIRP',
'TCGA-2Z-A9J1-01A': 'KIRP',
'TCGA-2Z-A9J2-01A': 'KIRP',
'TCGA-2Z-A9J3-01A': 'KIRP',
'TCGA-2Z-A9J5-01A': 'KIRP'},
'days_to_death': {'TCGA-2K-A9WE-01A': nan,
'TCGA-2Z-A9J1-01A': nan,
'TCGA-2Z-A9J2-01A': nan,
'TCGA-2Z-A9J3-01A': 1771.0,
'TCGA-2Z-A9J5-01A': nan},
'vital_status': {'TCGA-2K-A9WE-01A': 'alive',
'TCGA-2Z-A9J1-01A': 'alive',
'TCGA-2Z-A9J2-01A': 'alive',
'TCGA-2Z-A9J3-01A': 'dead',
'TCGA-2Z-A9J5-01A': 'alive'},
'age_at_initial_pathologic_diagnosis': {'TCGA-2K-A9WE-01A': 53.0,
'TCGA-2Z-A9J1-01A': 71.0,
'TCGA-2Z-A9J2-01A': 71.0,
'TCGA-2Z-A9J3-01A': 67.0,
'TCGA-2Z-A9J5-01A': 80.0},
'gender': {'TCGA-2K-A9WE-01A': 'male',
'TCGA-2Z-A9J1-01A': 'male',
'TCGA-2Z-A9J2-01A': 'female',
'TCGA-2Z-A9J3-01A': 'male',
'TCGA-2Z-A9J5-01A': 'male'},
'karnofsky_performance_score': {'TCGA-2K-A9WE-01A': nan,
'TCGA-2Z-A9J1-01A': nan,
'TCGA-2Z-A9J2-01A': nan,
'TCGA-2Z-A9J3-01A': nan,
'TCGA-2Z-A9J5-01A': nan},
'survival': {'TCGA-2K-A9WE-01A': 'lts',
'TCGA-2Z-A9J1-01A': 'lts',
'TCGA-2Z-A9J2-01A': 'lts',
'TCGA-2Z-A9J3-01A': 'lts',
'TCGA-2Z-A9J5-01A': 'lts'},
'cg00000029': {'TCGA-2K-A9WE-01A': 0.461440642939772,
'TCGA-2Z-A9J1-01A': 0.595894468074615,
'TCGA-2Z-A9J2-01A': 0.481304782143526,
'TCGA-2Z-A9J3-01A': 0.553849599144766,
'TCGA-2Z-A9J5-01A': 0.184349035247422},
'cg00000165': {'TCGA-2K-A9WE-01A': 0.143910373119058,
'TCGA-2Z-A9J1-01A': 0.0807243779293262,
'TCGA-2Z-A9J2-01A': 0.437447195378987,
'TCGA-2Z-A9J3-01A': 0.0642332527783939,
'TCGA-2Z-A9J5-01A': 0.126118535539944},
'cg00000236': {'TCGA-2K-A9WE-01A': 0.847164847154162,
'TCGA-2Z-A9J1-01A': 0.867305510246114,
'TCGA-2Z-A9J2-01A': 0.898927359292032,
'TCGA-2Z-A9J3-01A': 0.917290578229414,
'TCGA-2Z-A9J5-01A': 0.928017823091886},
'cg00000289': {'TCGA-2K-A9WE-01A': 0.737361955793681,
'TCGA-2Z-A9J1-01A': 0.70680600651273,
'TCGA-2Z-A9J2-01A': 0.758108726247342,
'TCGA-2Z-A9J3-01A': 0.675537604266578,
'TCGA-2Z-A9J5-01A': 0.677846427070521},
'cg00000292': {'TCGA-2K-A9WE-01A': 0.716794733144112,
'TCGA-2Z-A9J1-01A': 0.217862460492399,
'TCGA-2Z-A9J2-01A': 0.868604834806246,
'TCGA-2Z-A9J3-01A': 0.543087013952312,
'TCGA-2Z-A9J5-01A': 0.850473788130218},
'cg00000321': {'TCGA-2K-A9WE-01A': 0.351877113536983,
'TCGA-2Z-A9J1-01A': 0.169408257004071,
'TCGA-2Z-A9J2-01A': 0.577744851436078,
'TCGA-2Z-A9J3-01A': 0.85044433769089,
'TCGA-2Z-A9J5-01A': 0.44473521937132},
'cg00000363': {'TCGA-2K-A9WE-01A': 0.248986769373366,
'TCGA-2Z-A9J1-01A': 0.173115013795265,
'TCGA-2Z-A9J2-01A': 0.567241575633452,
'TCGA-2Z-A9J3-01A': 0.470810530680518,
'TCGA-2Z-A9J5-01A': 0.204529155293748},
'cg00000622': {'TCGA-2K-A9WE-01A': 0.0121360989202765,
'TCGA-2Z-A9J1-01A': 0.0108902025634162,
'TCGA-2Z-A9J2-01A': 0.0122683781097633,
'TCGA-2Z-A9J3-01A': 0.0125681212511168,
'TCGA-2Z-A9J5-01A': 0.0122330126903632},
'cg00000658': {'TCGA-2K-A9WE-01A': 0.876303885229884,
'TCGA-2Z-A9J1-01A': 0.813866558997356,
'TCGA-2Z-A9J2-01A': 0.881366097769717,
'TCGA-2Z-A9J3-01A': 0.870735609192125,
'TCGA-2Z-A9J5-01A': 0.906102120405464},
'cg00000721': {'TCGA-2K-A9WE-01A': 0.944311384947134,
'TCGA-2Z-A9J1-01A': 0.938576461648791,
'TCGA-2Z-A9J2-01A': 0.936584647488041,
'TCGA-2Z-A9J3-01A': 0.956356142020249,
'TCGA-2Z-A9J5-01A': 0.938145301973259},
'cg00000734': {'TCGA-2K-A9WE-01A': 0.0490407302658151,
'TCGA-2Z-A9J1-01A': 0.0426568318037534,
'TCGA-2Z-A9J2-01A': 0.0428379760439674,
'TCGA-2Z-A9J3-01A': 0.0577007291016598,
'TCGA-2Z-A9J5-01A': 0.0491650645308977},
'cg00000769': {'TCGA-2K-A9WE-01A': 0.0200484962577958,
'TCGA-2Z-A9J1-01A': 0.0133187057875756,
'TCGA-2Z-A9J2-01A': 0.0193220859926812,
'TCGA-2Z-A9J3-01A': 0.017072120017994,
'TCGA-2Z-A9J5-01A': 0.0184242706692516},
'cg00000905': {'TCGA-2K-A9WE-01A': 0.0623434271852525,
'TCGA-2Z-A9J1-01A': 0.0540543120983417,
'TCGA-2Z-A9J2-01A': 0.0551810635627895,
'TCGA-2Z-A9J3-01A': 0.055021036675329,
'TCGA-2Z-A9J5-01A': 0.0565152834168852},
'cg00000924': {'TCGA-2K-A9WE-01A': 0.489865398138095,
'TCGA-2Z-A9J1-01A': 0.317547629906197,
'TCGA-2Z-A9J2-01A': 0.5065017863301,
'TCGA-2Z-A9J3-01A': 0.504135768615145,
'TCGA-2Z-A9J5-01A': 0.466643054300025},
'cg00000948': {'TCGA-2K-A9WE-01A': 0.920994933496615,
'TCGA-2Z-A9J1-01A': 0.89911570032979,
'TCGA-2Z-A9J2-01A': 0.855015243009544,
'TCGA-2Z-A9J3-01A': 0.911116565506201,
'TCGA-2Z-A9J5-01A': 0.934397425301759},
'cg00000957': {'TCGA-2K-A9WE-01A': 0.92663932531651,
'TCGA-2Z-A9J1-01A': 0.525131175543627,
'TCGA-2Z-A9J2-01A': 0.86481442167794,
'TCGA-2Z-A9J3-01A': 0.855796126141919,
'TCGA-2Z-A9J5-01A': 0.907979957948096},
'cg00001245': {'TCGA-2K-A9WE-01A': 0.0149191766670711,
'TCGA-2Z-A9J1-01A': 0.0152198596492253,
'TCGA-2Z-A9J2-01A': 0.0154433022292077,
'TCGA-2Z-A9J3-01A': 0.0158006072782886,
'TCGA-2Z-A9J5-01A': 0.0149090955954903},
'type': {'TCGA-2K-A9WE-01A': 'tumor',
'TCGA-2Z-A9J1-01A': 'tumor',
'TCGA-2Z-A9J2-01A': 'tumor',
'TCGA-2Z-A9J3-01A': 'tumor',
'TCGA-2Z-A9J5-01A': 'tumor'},
'subtype': {'TCGA-2K-A9WE-01A': 'KIRP',
'TCGA-2Z-A9J1-01A': 'KIRP',
'TCGA-2Z-A9J2-01A': 'KIRP',
'TCGA-2Z-A9J3-01A': 'KIRP',
'TCGA-2Z-A9J5-01A': 'KIRP'}}