为了方便大家学习, tensorflow自身提供了多种数据集,并且将这些数据集进行了封装,方便我们直接使用。它使用 tensorflow-datasets
这个module封装了我们常用的各种公共数据集,在进行模型的学习过程中可以直接用该模块加载我们需要的数据集进行训练,挺方便的。本文介绍该模块的简单使用。
安装
1 | pip install tensorflow-datasets |
使用
1 | import tensorflow as tf |
查看可用数据集
tensorflow-datasets
提供了我们常用的数据集,可以通过下面的命令查看有哪些可用的数据集.
1 | tfds.list_builders() |
了解数据集
确定了可用数据集以后我们可以选择一个来使用, 首先需要了解一下这个数据集的基本信息。下面以 mnist为例。
1 | mnist_builder = tfds.builder("mnist") |
通过该信息我们可以知道 数据集共有 70000条数据,分为训练集60000和测试集10000; 没条数据包含 ‘image’, ‘label’ 两部分。 label 包含10个类别 等重要信息。
同时在代码里我们可以直接调用这些信息
1 | print(info.features) |
使用数据集
使用mnist_builder 的as_dataset 函数即可获取我们的数据,其定义如下:
1 | def as_dataset(self, |
举例:
1 | # get dataset |
获得 tf.data.Dataset 对象 (mnist_train 和 mnist_test 都是的) 以后我们就可以使用 tf.data
API 提供的api来操作数据集用于模型驯良了
tensorflow-datasets 对于数据的使用提供了更高级的封装 load
函数, 如下所示
1 | mnist_train = tfds.load(name="mnist", split=tfds.Split.TRAIN) |