Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
tensorflow
GitHub Repository: tensorflow/docs-l10n
Path: blob/master/site/zh-cn/io/tutorials/avro.ipynb
25118 views
Kernel: Python 3
#@title Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License.

Avro 数据集 API

文本特征向量

Avro 数据集 API 的目标是将 Avro 格式的数据作为 TensorFlow 数据集原生加载到 TensorFlow 中。Avro 是一个类似于 Protocol Buffers 的数据序列化系统。它广泛用于 Apache Hadoop,可以提供持久数据的序列化格式和 Hadoop 节点之间通信的有线格式。Avro 数据是一种面向行的压缩二进制数据格式。它依赖于存储为单独 JSON 文件的架构。有关 Avro 格式和架构声明的规范,请参阅官方手册

安装软件包

安装所需的 tensorflow-io 软件包

!pip install tensorflow-io

导入软件包

import tensorflow as tf import tensorflow_io as tfio

验证 tf 和 tfio 导入

print("tensorflow-io version: {}".format(tfio.__version__)) print("tensorflow version: {}".format(tf.__version__))

用法

探索数据集

为了实现本教程的目的,我们来下载示例 Avro 数据集。

下载示例 Avro 文件:

!curl -OL https://github.com/tensorflow/io/raw/master/docs/tutorials/avro/train.avro !ls -l train.avro

下载示例 Avro 文件的相应架构文件:

!curl -OL https://github.com/tensorflow/io/raw/master/docs/tutorials/avro/train.avsc !ls -l train.avsc

在上面的示例中,基于 MNIST 数据集创建了一个测试 Avro 数据集。TFRecord 格式的原始 MNIST 数据集从 TF 命名数据集生成。但是,作为演示数据集,MNIST 数据集过大。为简单起见,我们修剪了大部分内容,只保留前几条记录。此外,对原始 MNIST 数据集中的 image 字段进行了额外的修剪,并将其映射到 Avro 中的 features 字段。因此,Avro 文件 train.avro 有 4 条记录,每条记录有 3 个字段,分别为:features(整数的数组)、label(整数或 null 的数组)和 dataType(枚举)。要查看解码的 train.avro(请注意,原始 Avro 数据文件非人类可读,因为 Avro 是压缩格式),请执行以下操作:

安装读取 Avro 文件所需的包:

!pip install avro

要以人类可读的格式读取和打印 Avro 文件,请运行以下代码:

from avro.io import DatumReader from avro.datafile import DataFileReader import json def print_avro(avro_file, max_record_num=None): if max_record_num is not None and max_record_num <= 0: return with open(avro_file, 'rb') as avro_handler: reader = DataFileReader(avro_handler, DatumReader()) record_count = 0 for record in reader: record_count = record_count+1 print(record) if max_record_num is not None and record_count == max_record_num: break print_avro(avro_file='train.avro')

train.avsc 表示的 train.avro 的架构是一个 JSON 格式的文件。查看train.avsc

def print_schema(avro_schema_file): with open(avro_schema_file, 'r') as handle: parsed = json.load(handle) print(json.dumps(parsed, indent=4, sort_keys=True)) print_schema('train.avsc')

准备数据集

使用 Avro 数据集 API 将 train.avro 加载为 TensorFlow 数据集:

features = { 'features[*]': tfio.experimental.columnar.VarLenFeatureWithRank(dtype=tf.int32), 'label': tf.io.FixedLenFeature(shape=[], dtype=tf.int32, default_value=-100), 'dataType': tf.io.FixedLenFeature(shape=[], dtype=tf.string) } schema = tf.io.gfile.GFile('train.avsc').read() dataset = tfio.experimental.columnar.make_avro_record_dataset(file_pattern=['train.avro'], reader_schema=schema, features=features, shuffle=False, batch_size=3, num_epochs=1) for record in dataset: print(record['features[*]']) print(record['label']) print(record['dataType']) print("--------------------")

上面的示例将 train.avro 转换为 TensorFlow 数据集。数据集的每个元素都是一个字典,其关键字为特征名称,值为转换后的稀疏或密集张量。例如,它会将 featureslabeldataType 字段分别转换为 VarLenFeature(SparseTensor)、FixedLenFeature(DenseTensor) 和 FixLenFeature(DenseTensor)。由于 batch_size 为 3,它会将 train.avro 中的 3 条记录强制转换为结果数据集中的一个元素。对于 train.avro 中标签为 null 的第一条记录,Avro 读取器会将其替换为指定的默认值 (-100)。在本例中,train.avro 中总共有 4 条记录。由于批次大小为 3,结果数据集包含 3 个元素,最后一个元素的批次大小为 1。但是,如果大小小于批次大小,用户也可以通过启用 drop_final_batch 丢弃最后一个批次。例如:

dataset = tfio.experimental.columnar.make_avro_record_dataset(file_pattern=['train.avro'], reader_schema=schema, features=features, shuffle=False, batch_size=3, drop_final_batch=True, num_epochs=1) for record in dataset: print(record)

此外,还可以增加 num_parallel_reads 以通过提高 Avro 解析/读取并行性来加速 Avro 数据处理。

dataset = tfio.experimental.columnar.make_avro_record_dataset(file_pattern=['train.avro'], reader_schema=schema, features=features, shuffle=False, num_parallel_reads=16, batch_size=3, drop_final_batch=True, num_epochs=1) for record in dataset: print(record)

有关 make_avro_record_dataset 的详细用法,请参阅 API 文档

使用 Avro 数据集训练 tf.keras 模型

现在,我们来看一个端到端示例,该示例基于 MNIST 数据集使用 Avro 数据集来训练 tf.keras 模型。

使用 Avro 数据集 API 将 train.avro 加载为 TensorFlow 数据集:

features = { 'features[*]': tfio.experimental.columnar.VarLenFeatureWithRank(dtype=tf.int32), 'label': tf.io.FixedLenFeature(shape=[], dtype=tf.int32, default_value=-100), } schema = tf.io.gfile.GFile('train.avsc').read() dataset = tfio.experimental.columnar.make_avro_record_dataset(file_pattern=['train.avro'], reader_schema=schema, features=features, shuffle=False, batch_size=1, num_epochs=1)

定义一个简单的 Keras 模型:

def build_and_compile_cnn_model(): model = tf.keras.Sequential() model.compile(optimizer='sgd', loss='mse') return model model = build_and_compile_cnn_model()

使用 Avro 数据集训练 Keras 模型:

def extract_label(feature): label = feature.pop('label') return tf.sparse.to_dense(feature['features[*]']), label model.fit(x=dataset.map(extract_label), epochs=1, steps_per_epoch=1, verbose=1)

Avro 数据集可以解析任何 Avro 数据并将其强制转换为 TensorFlow 张量,包括记录、映射、数组、分支和枚举中的记录。解析信息作为映射传递到 Avro 数据集实现中,其中关键字用于编码如何解析数据,值用于编码如何将数据强制转换为 TensorFlow 张量 – 决定基元类型(例如 bool、int、long、float、double、string)以及张量类型(例如稀疏或密集)。下面提供了 TensorFlow 解析器类型(见表 1)和基元类型强制转换(表 2)的清单。

表 1 支持的 TensorFlow 解析器类型:

TensorFlow 解析器类型TensorFlow 张量解释
tf.FixedLenFeature([], tf.int32)密集张量解析固定长度的特征;也就是说,所有行都具有相同的恒定数量元素,例如,只有一个元素或每行始终具有相同数量元素的数组
tf.SparseFeature(index_key=['key_1st_index', 'key_2nd_index'], value_key='key_value', dtype=tf.int64, size=[20, 50])稀疏张量解析稀疏特征,其中每行都有一个可变长度的索引和值清单。'index_key' 标识索引。'value_key' 标识值。'dtype' 为数据类型。'size' 为每个索引条目的预期最大索引值
tfio.experimental.columnar.VarLenFeatureWithRank([],tf.int64)稀疏张量解析可变长度特征;这意味着每个数据行可以具有可变数量的元素,例如,第一行有 5 个元素,第二行有 7 个元素

表 2 支持的 Avro 类型到 TensorFlow 类型的转换:

Avro 基元类型TensorFlow 基元类型
bool:二进制值tf.bool
byte:8 位无符号字节序列tf.string
double:双精度 64 位 IEEE 浮点数tf.float64
enum:枚举类型使用符号名称的 tf.string
float:单精度 32 位 IEEE 浮点数tf.float32
int:32 位有符号整数tf.int32
long:64 位有符号整数tf.int64
null:没有值使用默认值
string:unicode 字符序列tf.string

测试中提供了一组全面的 Avro 数据集 API 示例。