I am trying to use the few shot learning example provided by Keras using the Reptile algorithm (https://keras.io/examples/vision/reptile/) with my own database. But after loading the database I get an error when i want to visualize some examples from the database.
the line which I have a problem is:
sample_keys = list(train_dataset())
it produces the following: TypeError: 'Dataset' object is not callable
can somebody help me fix this as I am new in the ML space and still learning.
my database 3 classes all separated in folders with names. each class has 20 RGB pictures so a total of 60.
The Code with my changes:
learning_rate = 0.003
meta_step_size = 0.25
inner_batch_size = 15
eval_batch_size = 15
meta_iters = 2000
eval_iters = 5
inner_iters = 4
eval_interval = 1
train_shots = 20
shots = 5
classes = 3
class Dataset:
def __init__(self, training):
# Download the tfrecord files containing the omniglot data and convert to a
# dataset.
split = "train" if training else "test"
images_ds = tf.data.Dataset.list_files('packages/*/*', shuffle=False)
image_count = len (images_ds)
#label=["Box","Cardboard","Plastic_bag"]
ds = images_ds.take(image_count)
# Iterate over the dataset to get each individual image and its class,
# and put that data into a dictionary.
self.data = {}
def get_label(file_path):
return tf.strings.split(file_path, os.path.sep)[-2]
def extraction(file_path):
# This function will shrink the Omniglot images to the desired size,
# scale pixel values and convert the RGB image to grayscale
label = get_label(file_path)
img = tf.io.read_file(file_path)
image = tf.image.decode_jpeg(img)
image = tf.image.convert_image_dtype(image, tf.float32)
#image = tf.image.rgb_to_grayscale(image)
image = tf.image.resize(image, [800, 800])
return image, label
for image, label in images_ds.map(extraction):
image = image.numpy()
label = str(label.numpy())
if label not in self.data:
self.data[label] = []
self.data[label].append(image)
self.labels = list(self.data.keys())
def get_mini_dataset(
self, batch_size, repetitions, shots, num_classes, split=False
):
temp_labels = np.zeros(shape=(num_classes * shots))
temp_images = np.zeros(shape=(num_classes * shots, 800, 800, 1))
if split:
test_labels = np.zeros(shape=(num_classes))
test_images = np.zeros(shape=(num_classes, 800, 800, 1))
# Get a random subset of labels from the entire label set.
label_subset = random.choices(self.labels, k=num_classes)
for class_idx, class_obj in enumerate(label_subset):
# Use enumerated index value as a temporary label for mini-batch in
# few shot learning.
temp_labels[class_idx * shots : (class_idx + 1) * shots] = class_idx
# If creating a split dataset for testing, select an extra sample from each
# label to create the test dataset.
if split:
test_labels[class_idx] = class_idx
images_to_split = random.choices(
self.data[label_subset[class_idx]], k=shots + 1
)
test_images[class_idx] = images_to_split[-1]
temp_images[
class_idx * shots : (class_idx + 1) * shots
] = images_to_split[:-1]
else:
# For each index in the randomly selected label_subset, sample the
# necessary number of images.
temp_images[
class_idx * shots : (class_idx + 1) * shots
] = random.choices(self.data[label_subset[class_idx]], k=shots)
dataset = tf.data.Dataset.from_tensor_slices(
(temp_images.astype(np.float32), temp_labels.astype(np.int32))
)
dataset = dataset.shuffle(60).batch(batch_size).repeat(repetitions)
if split:
return dataset, test_images, test_labels
return dataset
import urllib3
urllib3.disable_warnings() # Disable SSL warnings that may happen during download.
train_dataset = Dataset(training=True)
test_dataset = Dataset(training=False)
_, axarr = plt.subplots(nrows=5, ncols=5, figsize=(20, 20))
sample_keys = list(train_dataset()) #Here i get the error
for a in range(5):
for b in range(5):
temp_image = train_dataset[sample_keys[a]][b]
temp_image = np.stack((temp_image[:, :, 0],) * 3, axis=2)
temp_image *= 255
temp_image = np.clip(temp_image, 0, 255).astype("uint8")
if b == 2:
axarr[a, b].set_title("Class : " + sample_keys[a])
axarr[a, b].imshow(temp_image)#, cmap="rgb")
axarr[a, b].xaxis.set_visible(False)
axarr[a, b].yaxis.set_visible(False)
plt.show()
The error:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[23], line 5
1 ###################### Visualize some examples from the dataset ####################################
3 _, axarr = plt.subplots(nrows=5, ncols=5, figsize=(20, 20))
----> 5 sample_keys = list(train_dataset())
7 for a in range(5):
8 for b in range(5):
TypeError: 'Dataset' object is not callable
i tried replacing also
label = str(label.numpy())
with
label = label as it already is in an string format but then I get an error as well which i am still trying to figure out why