Understanding tfds take
What does train_ds.take(2) do?
train_ds is one of the return values of
(train_ds, val_ds, test_ds), metadata = tfds.load(
'tf_flowers',
data_dir=DATA_PATH,
split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],
with_info=True,
as_supervised=True,
)
The take function returns an image and a label:
for image,label in train_ds.take(1):
print("image.shape=",image.shape," label=",int(label))
Output:
image.shape= (333, 500, 3) label= 2
If the batch function has been applied, then the take function will return (batch_size * take_parameter)
images.
Example:
batch_size = 32
train_dataset = train_dataset.cache().batch(batch_size).prefetch(buffer_size=10)
probabilities = model.predict(train_dataset.take(10))
predict = np.argmax(probabilities, axis=1)
print(predict.shape)
Output:
(320,)
Correct Use of map() function and image viewing
This will cause the RGB values of the images to be out of range:
train_ds = train_ds.map(lambda x, y: (tf.image.resize(x, (150,150)), y))
Example:
Before |
After |
|
|
Dividing the RGB values by 255.0 after resizing will allow for viewing the images, for example with
mathplotlib
train_ds = train_ds.map(lambda x, y: (tf.image.resize(x, (150,150))/ 255.0, y))
Example:
Before |
After |
|
|
Sources:
Brave search "tfds take"
and
https://tf.wiki/en/appendix/tfds.html
batch() function and image viewing
Before batch():
print(train_ds)
returns
_PrefetchDataset element_spec=(TensorSpec(shape=(32, 32, 3), dtype=tf.uint8, name=None), TensorSpec(shape=(),
dtype=tf.int64, name=None))
Plot call:
plt.figure(figsize=(10, 10))
for i, (image, label) in enumerate(train_ds.take(2)):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(image)
plt.title(int(label))
plt.axis("off")
batch() function:
batch_size = 32
train_ds = train_ds.cache().batch(batch_size).prefetch(buffer_size=10)
After batch():
Using the same plot call causes Error
TypeError: Invalid shape (32, 32, 32, 3) for image data
Notice how the shape contains 4 dimensions, one extra dimension with value 32, the same as the batch size?
The take function now returns a batch, and to view an image, we need to index the elements in the batch
plt.figure(figsize=(10, 10))
for i, (image, label) in enumerate(train_ds.take(2)):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(image[0])
plt.title(int(label[0]))
plt.axis("off")