1

Hi sorry for the cringe "tell me what this means" question but I can't seem to figure it out... Here's the code

from imageai.Classification.Custom import ClassificationModelTrainer

model_trainer = ClassificationModelTrainer()
model_trainer.setModelTypeAsResNet50()
model_trainer.setDataDirectory("images")
model_trainer.trainModel(num_objects=10, num_experiments=200, enhance_data=True, batch_size=8, show_network_summary=True)

and it yields this scary error:

Traceback (most recent call last):
  File "marco.py", line 15, in <module>
    model_trainer.trainModel(num_objects=10, num_experiments=200, enhance_data=True, batch_size=8, show_network_summary=True)
  File "/home/lollo/.local/lib/python3.8/site-packages/imageai/Classification/Custom/__init__.py", line 393, in trainModel
    model.fit_generator(train_generator, steps_per_epoch=int(num_train / batch_size), epochs=self.__num_epochs,
  File "/home/lollo/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 1847, in fit_generator
    return self.fit(
  File "/home/lollo/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 1100, in fit
    tmp_logs = self.train_function(iterator)
  File "/home/lollo/.local/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 828, in __call__
    result = self._call(*args, **kwds)
  File "/home/lollo/.local/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 888, in _call
    return self._stateless_fn(*args, **kwds)
  File "/home/lollo/.local/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2942, in __call__
    return graph_function._call_flat(
  File "/home/lollo/.local/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 1918, in _call_flat
    return self._build_call_outputs(self._inference_function.call(
  File "/home/lollo/.local/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 555, in call
    outputs = execute.execute(
  File "/home/lollo/.local/lib/python3.8/site-packages/tensorflow/python/eager/execute.py", line 59, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
tensorflow.python.framework.errors_impl.InvalidArgumentError:  logits and labels must be broadcastable: logits_size=[8,10] labels_size=[8,2]
     [[node categorical_crossentropy/softmax_cross_entropy_with_logits (defined at /home/lollo/.local/lib/python3.8/site-packages/imageai/Classification/Custom/__init__.py:393) ]] [Op:__inference_train_function_11908]

Function call stack:
train_function

Thank you for any suggestion <3

  • Hi @doggianomaster, Could you please share the number of classes in your dataset, is it 10? If it is 10 please make sure that the images directory contains 10 different folders in it. Thanks! –  May 09 '22 at 03:42
  • Hi! Sorry I forgot to reply o .o You're right I only had 2 classes! Thank you! – doggianomaster Jun 25 '22 at 16:15

0 Answers0