tensorflow datasets

为了方便大家学习, tensorflow自身提供了多种数据集,并且将这些数据集进行了封装,方便我们直接使用。它使用 tensorflow-datasets 这个module封装了我们常用的各种公共数据集,在进行模型的学习过程中可以直接用该模块加载我们需要的数据集进行训练,挺方便的。本文介绍该模块的简单使用。

安装

1
pip install tensorflow-datasets

使用

1
2
import tensorflow as tf
import tensorflow_datasets as tfds

查看可用数据集

tensorflow-datasets 提供了我们常用的数据集,可以通过下面的命令查看有哪些可用的数据集.

1
2
3
tfds.list_builders()
# 目前我这里可以的数据集如下
# ['bair_robot_pushing_small', 'cats_vs_dogs', 'celeb_a', 'celeb_a_hq', 'cifar10', 'cifar100', 'coco2014', 'diabetic_retinopathy_detection', 'dummy_dataset_shared_generator', 'dummy_mnist', 'fashion_mnist', 'image_label_folder', 'imagenet2012', 'imdb_reviews', 'lm1b', 'lsun', 'mnist', 'moving_mnist', 'nsynth', 'omniglot', 'open_images_v4', 'quickdraw_bitmap', 'squad', 'starcraft_video', 'svhn_cropped', 'tf_flowers', 'wmt_translate_ende', 'wmt_translate_enfr']

了解数据集

确定了可用数据集以后我们可以选择一个来使用, 首先需要了解一下这个数据集的基本信息。下面以 mnist为例。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
mnist_builder = tfds.builder("mnist")
mnist_builder.download_and_prepare() # 下载数据集
info = mnist_builder.info
print(info)

# tfds.core.DatasetInfo(
# name='mnist',
# version=1.0.0,
# description='The MNIST database of handwritten digits.',
# urls=['http://yann.lecun.com/exdb/mnist/'],
# features=FeaturesDict({
# 'image': Image(shape=(28, 28, 1), dtype=tf.uint8),
# 'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10)
# },
# total_num_examples=70000,
# splits={
# 'test': <tfds.core.SplitInfo num_examples=10000>,
# 'train': <tfds.core.SplitInfo num_examples=60000>
# },
# supervised_keys=('image', 'label'),
# citation='"""
# @article{lecun2010mnist,
# title={MNIST handwritten digit database},
# author={LeCun, Yann and Cortes, Corinna and Burges, CJ},
# journal={ATT Labs [Online]. Available: http://yann. lecun. com/exdb/mnist},
# volume={2},
# year={2010}
# }
#
# """',
# )

通过该信息我们可以知道 数据集共有 70000条数据,分为训练集60000和测试集10000; 没条数据包含 ‘image’, ‘label’ 两部分。 label 包含10个类别 等重要信息。

同时在代码里我们可以直接调用这些信息

1
2
3
4
5
6
7
print(info.features)
print(info.features["label"].num_classes)
print(info.features["label"].names)

# FeaturesDict({'image': Image(shape=(28, 28, 1), dtype=tf.uint8), 'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10)})
# 10
# ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']

使用数据集

使用mnist_builder 的as_dataset 函数即可获取我们的数据,其定义如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def as_dataset(self,
split=None,
batch_size=1,
shuffle_files=None,
as_supervised=False):
"""Constructs a `tf.data.Dataset`.

Callers must pass arguments as keyword arguments.

Args:
split: `tfds.core.SplitBase`, which subset(s) of the data to read. If None
(default), returns all splits in a dict
`<key: tfds.Split, value: tf.data.Dataset>`.
batch_size: `int`, batch size. Note that variable-length features will
be 0-padded if `batch_size > 1`. Users that want more custom behavior
should use `batch_size=1` and use the `tf.data` API to construct a
custom pipeline. If `batch_size == -1`, will return feature
dictionaries of the whole dataset with `tf.Tensor`s instead of a
`tf.data.Dataset`.
shuffle_files: `bool`, whether to shuffle the input files.
Defaults to `True` if `split == tfds.Split.TRAIN` and `False` otherwise.
as_supervised: `bool`, if `True`, the returned `tf.data.Dataset`
will have a 2-tuple structure `(input, label)` according to
`builder.info.supervised_keys`. If `False`, the default,
the returned `tf.data.Dataset` will have a dictionary with all the
features.

Returns:
`tf.data.Dataset`, or if `split=None`, `dict<key: tfds.Split, value:
tfds.data.Dataset>`.

If `batch_size` is -1, will return feature dictionaries containing
the entire dataset in `tf.Tensor`s instead of a `tf.data.Dataset`.
"""

举例:

1
2
3
4
5
6
7
8
# get dataset
mnist_train = mnist_builder.as_dataset(split=tfds.Split.TRAIN)
mnist_test = mnist_builder.as_dataset(split=tfds.Split.TEST)
print(mnist_train)
print(mnist_test)

# <DatasetV1Adapter shapes: {image: (28, 28, 1), label: ()}, types: {image: tf.uint8, label: tf.int64}>
# <DatasetV1Adapter shapes: {image: (28, 28, 1), label: ()}, types: {image: tf.uint8, label: tf.int64}>

获得 tf.data.Dataset 对象 (mnist_train 和 mnist_test 都是的) 以后我们就可以使用 tf.data API 提供的api来操作数据集用于模型驯良了

tensorflow-datasets 对于数据的使用提供了更高级的封装 load 函数, 如下所示

1
2
3
4
5
6
mnist_train = tfds.load(name="mnist", split=tfds.Split.TRAIN)
assert isinstance(mnist_train, tf.data.Dataset)
print(mnist_train)

# 下载数据
# <DatasetV1Adapter shapes: {image: (28, 28, 1), label: ()}, types: {image: tf.uint8, label: tf.int64}>

参考

https://www.tensorflow.org/datasets/overview

api 接口 https://www.tensorflow.org/datasets/api_docs/python/