Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
tensorflow
GitHub Repository: tensorflow/docs-l10n
Path: blob/master/site/ja/tfx/tutorials/transform/census.ipynb
25118 views
Kernel: Python 3
#@title Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License.

TensorFlow Transform を使用したデータの前処理

TensorFlow Extended(TFX)の特徴量エンジニアリングコンポーネント

このコラボノートブックの例は、TensorFlow Transformtf.Transform)を使用して、データを前処理する方法のやや高度な例を提供します。モデルのトレーニングと本番環境での推論のサービングの両方に同じコードを使用します。

TensorFlow Transform は、トレーニングデータセットのフルパスを必要とする機能の作成など、TensorFlow の入力データを前処理するためのライブラリです。たとえば、TensorFlow Transform を使用すると、次のことができます。

  • 平均と標準偏差を使用して入力値を正規化する

  • すべての入力値に対して語彙を生成することにより、文字列を整数に変換する

  • 観測されたデータ分布に基づいて、浮動小数点数をバケットに割り当てることにより、浮動小数点数を整数に変換する

TensorFlow には、単一のサンプルまたはサンプルのバッチに対する操作のサポートが組み込まれています。tf.Transform は、これらの機能を拡張して、トレーニングデータセット全体のフルパスをサポートします。

tf.Transform の出力は、トレーニングとサービングの両方に使用できる TensorFlow グラフとしてエクスポートされます。トレーニングとサービングの両方に同じグラフを使用すると、両方の段階で同じ変換が適用されるため、スキューを防ぐことができます。

重要なポイント: tf.Transform とそれが Apache Beam でどのように機能するかを理解するには、Apache Beam についての知識が少し必要です。Apache Beam の基本的なコンセプトについては Beam プログラミングガイドを参照してください。

##この例で何が行われているのか

この例では、国勢調査データを含む広く使用されているデータセットを処理し、分類を行うためのモデルをトレーニングします。また、tf.Transform を使用してデータを変換します。

重要なポイント: モデラーおよび開発者の皆さんは、このデータがどのように使用されるか、モデルの予測が引き起こす可能性のある潜在的メリット・デメリットについて考えてください。このようなモデルは、社会的バイアスと格差を拡大する可能性があります。特徴量は解決しようとする問題に関連していますか、それともバイアスを導入しますか?詳細については、機械学習における公平性についてご一読ください。

注意: TensorFlow Model Analysis は、モデルが社会的バイアスや格差をどのように強化するかを理解するなど、モデルがデータのさまざまなセグメントをどの程度適切に予測するかを理解するための強力なツールです。

TensorFlow Transform のインストール

!pip install tensorflow-transform
# This cell is only necessary because packages were installed while python was # running. It avoids the need to restart the runtime when running in Colab. import pkg_resources import importlib importlib.reload(pkg_resources)

インポートとグローバル

まず、必要なものをインポートします。

import math import os import pprint import pandas as pd import matplotlib.pyplot as plt import tensorflow as tf print('TF: {}'.format(tf.__version__)) import apache_beam as beam print('Beam: {}'.format(beam.__version__)) import tensorflow_transform as tft import tensorflow_transform.beam as tft_beam print('Transform: {}'.format(tft.__version__)) from tfx_bsl.public import tfxio from tfx_bsl.coders.example_coder import RecordBatchToExamplesEncoder

次に、データファイルをダウンロードします。

!wget https://storage.googleapis.com/artifacts.tfx-oss-public.appspot.com/datasets/census/adult.data !wget https://storage.googleapis.com/artifacts.tfx-oss-public.appspot.com/datasets/census/adult.test train_path = './adult.data' test_path = './adult.test'

列に名前を付ける

データセットの列を参照するための便利なリストをいくつか作成します。

CATEGORICAL_FEATURE_KEYS = [ 'workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'native-country', ] NUMERIC_FEATURE_KEYS = [ 'age', 'capital-gain', 'capital-loss', 'hours-per-week', 'education-num' ] ORDERED_CSV_COLUMNS = [ 'age', 'workclass', 'fnlwgt', 'education', 'education-num', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'capital-gain', 'capital-loss', 'hours-per-week', 'native-country', 'label' ] LABEL_KEY = 'label'

以下は、データの簡易プレビューです。

pandas_train = pd.read_csv(train_path, header=None, names=ORDERED_CSV_COLUMNS) pandas_train.head(5)
one_row = dict(pandas_train.loc[0])
COLUMN_DEFAULTS = [ '' if isinstance(v, str) else 0.0 for v in dict(pandas_train.loc[1]).values()]

テストデータには、スキップする必要のあるヘッダー行が 1 行と、各行の末尾に "." があります。

pandas_test = pd.read_csv(test_path, header=1, names=ORDERED_CSV_COLUMNS) pandas_test.head(5)
testing = os.getenv("WEB_TEST_BROWSER", False) if testing: pandas_train = pandas_train.loc[:1] pandas_test = pandas_test.loc[:1]

###特徴量とスキーマを定義します。 入力の列の型に基づいてスキーマを定義します。 これはそれらを正しくインポートするのに役立ちます。

RAW_DATA_FEATURE_SPEC = dict( [(name, tf.io.FixedLenFeature([], tf.string)) for name in CATEGORICAL_FEATURE_KEYS] + [(name, tf.io.FixedLenFeature([], tf.float32)) for name in NUMERIC_FEATURE_KEYS] + [(LABEL_KEY, tf.io.FixedLenFeature([], tf.string))] ) SCHEMA = tft.DatasetMetadata.from_feature_spec(RAW_DATA_FEATURE_SPEC).schema

[オプション]tf.train.Example proto のエンコードとデコード

このチュートリアルでは、いくつかの場所で、データセットの Example を tf.train.Example との間で変換する必要があります。

以下の非表示の encode_example 関数は、データセットの特徴量のディクショナリを tf.train.Example に変換します。

#@title def encode_example(input_features): input_features = dict(input_features) output_features = {} for key in CATEGORICAL_FEATURE_KEYS: value = input_features[key] feature = tf.train.Feature( bytes_list=tf.train.BytesList(value=[value.strip().encode()])) output_features[key] = feature for key in NUMERIC_FEATURE_KEYS: value = input_features[key] feature = tf.train.Feature( float_list=tf.train.FloatList(value=[value])) output_features[key] = feature label_value = input_features.get(LABEL_KEY, None) if label_value is not None: output_features[LABEL_KEY] = tf.train.Feature( bytes_list = tf.train.BytesList(value=[label_value.strip().encode()])) example = tf.train.Example( features = tf.train.Features(feature=output_features) ) return example

次に、データセットの Example を Example proto に変換できます。

tf_example = encode_example(pandas_train.loc[0]) tf_example.features.feature['age']
serialized_example_batch = tf.constant([ encode_example(pandas_train.loc[i]).SerializeToString() for i in range(3) ]) serialized_example_batch

また、シリアル化された Example proto のバッチをテンソルのディクショナリに変換することもできます。

decoded_tensors = tf.io.parse_example( serialized_example_batch, features=RAW_DATA_FEATURE_SPEC )

一部のケースでは、ラベルが渡されないことがあるため、ラベルがオプションとなるようにエンコード関数を記述します。

features_dict = dict(pandas_train.loc[0]) features_dict.pop(LABEL_KEY) LABEL_KEY in features_dict

Example proto を作成する際には、単にラベルキーが含まれません。

no_label_example = encode_example(features_dict) LABEL_KEY in no_label_example.features.feature.keys()

###ハイパーパラメータの設定と基本的なハウスキーピング

以下は、トレーニングに使用される定数とハイパーパラメータです。

NUM_OOV_BUCKETS = 1 EPOCH_SPLITS = 10 TRAIN_NUM_EPOCHS = 2*EPOCH_SPLITS NUM_TRAIN_INSTANCES = len(pandas_train) NUM_TEST_INSTANCES = len(pandas_test) BATCH_SIZE = 128 STEPS_PER_TRAIN_EPOCH = tf.math.ceil(NUM_TRAIN_INSTANCES/BATCH_SIZE/EPOCH_SPLITS) EVALUATION_STEPS = tf.math.ceil(NUM_TEST_INSTANCES/BATCH_SIZE) # Names of temp files TRANSFORMED_TRAIN_DATA_FILEBASE = 'train_transformed' TRANSFORMED_TEST_DATA_FILEBASE = 'test_transformed' EXPORTED_MODEL_DIR = 'exported_model_dir'
if testing: TRAIN_NUM_EPOCHS = 1

##tf.Transform による前処理

###tf.Transform preprocessing_fnを作成します。前処理関数は、tf.Transform の最も重要な概念です。前処理関数では、データセットの変換が実際に行われます。テンソルのディクショナリーを受け入れて返します。ここで、テンソルは Tensor または SparseTensor を意味します。通常、前処理関数の中心となる API 呼び出しには 2 つの主要なグループがあります。

  1. TensorFlow 演算子: テンソルを受け入れて返す関数。通常は TensorFlow 演算子を意味します。これらは、生データを一度に 1 つの特徴量ベクトルで変換されたデータに変換するグラフに TensorFlow 演算子を追加します。これらは、トレーニングとサービングの両方で、すべての例で実行されます。

  2. Tensorflow Transform アナライザー/マッパー: tf.Transform によって提供されるアナライザー/マッパーのいずれか。これらもテンソルを受け入れて返し、通常は Tensorflow 演算子と Beam 計算の組み合わせを含みますが、TensorFlow 演算子とは異なり、分析中はビームパイプラインでのみ実行され、トレーニングデータセット全体を通じた処理が必要になります。Beam 計算は(トレーニング前、分析中に)1 回だけ実行され、通常はトレーニングデータセット全体を処理します。tf.constant テンソルが作成され、グラフに追加されます。たとえば、 tft.min は、トレーニングデータセットのテンソルの最小値を計算します。

注意: 前処理関数をサービング推論に適用する場合、トレーニング中にアナライザーにより作成された定数は変更されません。データに傾向または季節性の要素がある場合は、それに応じて計画します。

以下は、このデータセットの preprocessing_fn です。以下のことを実行します。

  1. tft.scale_to_0_1 を使用して、数値特徴量を [0,1] の範囲にスケーリングします。

  2. tft.compute_and_apply_vocabulary を使って、カテゴリカル特徴量ごとの語彙を計算し、各入力の整数 ID を tf.int64 として返します。これは、文字列と整数のどちらのカテゴリカル入力にも適用されます。

  3. 標準の TensorFlow 演算を使ってデータに手動変換を適用します。ここでは、ラベルに対して演算を適用しますが、特徴量も変換することが可能です。TensorFlow 演算は以下を実行します。

    • ラベルのルックアップテーブルをビルドします(tf.init_scope は、関数が初めて呼び出された時にのみテーブルを作成します)。

    • ラベルのテキストを正規化します。

    • ラベルをワンホットに変換します。

def preprocessing_fn(inputs): """Preprocess input columns into transformed columns.""" # Since we are modifying some features and leaving others unchanged, we # start by setting `outputs` to a copy of `inputs. outputs = inputs.copy() # Scale numeric columns to have range [0, 1]. for key in NUMERIC_FEATURE_KEYS: outputs[key] = tft.scale_to_0_1(inputs[key]) # For all categorical columns except the label column, we generate a # vocabulary but do not modify the feature. This vocabulary is instead # used in the trainer, by means of a feature column, to convert the feature # from a string to an integer id. for key in CATEGORICAL_FEATURE_KEYS: outputs[key] = tft.compute_and_apply_vocabulary( tf.strings.strip(inputs[key]), num_oov_buckets=NUM_OOV_BUCKETS, vocab_filename=key) # For the label column we provide the mapping from string to index. table_keys = ['>50K', '<=50K'] with tf.init_scope(): initializer = tf.lookup.KeyValueTensorInitializer( keys=table_keys, values=tf.cast(tf.range(len(table_keys)), tf.int64), key_dtype=tf.string, value_dtype=tf.int64) table = tf.lookup.StaticHashTable(initializer, default_value=-1) # Remove trailing periods for test data when the data is read with tf.data. # label_str = tf.sparse.to_dense(inputs[LABEL_KEY]) label_str = inputs[LABEL_KEY] label_str = tf.strings.regex_replace(label_str, r'\.$', '') label_str = tf.strings.strip(label_str) data_labels = table.lookup(label_str) transformed_label = tf.one_hot( indices=data_labels, depth=len(table_keys), on_value=1.0, off_value=0.0) outputs[LABEL_KEY] = tf.reshape(transformed_label, [-1, len(table_keys)]) return outputs

構文

これで、すべてをまとめて Apache Beam を使用して実行する準備がほぼ整いました。

Apache Beam は、特別な構文を使用して変換を定義および呼び出します。たとえば、次の行をご覧ください。

result = pass_this | 'name this step' >> to_this_call

メソッド to_this_call が呼び出され、pass_this というオブジェクトが渡されます。この演算は、スタックトレースで name this step と呼ばれますto_this_call の呼び出しの結果は、result に返されます。 頻繁にパイプラインのステージは次のようにチェーンされます。

result = apache_beam.Pipeline() | 'first step' >> do_this_first() | 'second step' >> do_this_last()

そして、新しいパイプラインで始まったので、以下のように続行できます。

next_result = result | 'doing more stuff' >> another_function()

データを変換する

Apache Beam パイプラインでデータを変換し始める準備が整いました。

  1. tfxio.CsvTFXIO CSV リーダーを使用してデータを読み取ります(パイプラインでテキストの行を処理するには、代わりに tfxio.BeamRecordCsvTFXIO を使用します)。

  2. 上記で定義した preprocessing_fn を使ってデータの分析と変換を行います。

  3. 結果を Example プロトの TFRecord として書き出します。これは、後でモデルのトレーニングに使用します。

def transform_data(train_data_file, test_data_file, working_dir): """Transform the data and write out as a TFRecord of Example protos. Read in the data using the CSV reader, and transform it using a preprocessing pipeline that scales numeric data and converts categorical data from strings to int64 values indices, by creating a vocabulary for each category. Args: train_data_file: File containing training data test_data_file: File containing test data working_dir: Directory to write transformed data and metadata to """ # The "with" block will create a pipeline, and run that pipeline at the exit # of the block. with beam.Pipeline() as pipeline: with tft_beam.Context(temp_dir=tempfile.mkdtemp()): # Create a TFXIO to read the census data with the schema. To do this we # need to list all columns in order since the schema doesn't specify the # order of columns in the csv. # We first read CSV files and use BeamRecordCsvTFXIO whose .BeamSource() # accepts a PCollection[bytes] because we need to patch the records first # (see "FixCommasTrainData" below). Otherwise, tfxio.CsvTFXIO can be used # to both read the CSV files and parse them to TFT inputs: # csv_tfxio = tfxio.CsvTFXIO(...) # raw_data = (pipeline | 'ToRecordBatches' >> csv_tfxio.BeamSource()) train_csv_tfxio = tfxio.CsvTFXIO( file_pattern=train_data_file, telemetry_descriptors=[], column_names=ORDERED_CSV_COLUMNS, schema=SCHEMA) # Read in raw data and convert using CSV TFXIO. raw_data = ( pipeline | 'ReadTrainCsv' >> train_csv_tfxio.BeamSource()) # Combine data and schema into a dataset tuple. Note that we already used # the schema to read the CSV data, but we also need it to interpret # raw_data. cfg = train_csv_tfxio.TensorAdapterConfig() raw_dataset = (raw_data, cfg) # The TFXIO output format is chosen for improved performance. transformed_dataset, transform_fn = ( raw_dataset | tft_beam.AnalyzeAndTransformDataset( preprocessing_fn, output_record_batches=True)) # Transformed metadata is not necessary for encoding. transformed_data, _ = transformed_dataset # Extract transformed RecordBatches, encode and write them to the given # directory. coder = RecordBatchToExamplesEncoder() _ = ( transformed_data | 'EncodeTrainData' >> beam.FlatMapTuple(lambda batch, _: coder.encode(batch)) | 'WriteTrainData' >> beam.io.WriteToTFRecord( os.path.join(working_dir, TRANSFORMED_TRAIN_DATA_FILEBASE))) # Now apply transform function to test data. In this case we remove the # trailing period at the end of each line, and also ignore the header line # that is present in the test data file. test_csv_tfxio = tfxio.CsvTFXIO( file_pattern=test_data_file, skip_header_lines=1, telemetry_descriptors=[], column_names=ORDERED_CSV_COLUMNS, schema=SCHEMA) raw_test_data = ( pipeline | 'ReadTestCsv' >> test_csv_tfxio.BeamSource()) raw_test_dataset = (raw_test_data, test_csv_tfxio.TensorAdapterConfig()) # The TFXIO output format is chosen for improved performance. transformed_test_dataset = ( (raw_test_dataset, transform_fn) | tft_beam.TransformDataset(output_record_batches=True)) # Transformed metadata is not necessary for encoding. transformed_test_data, _ = transformed_test_dataset # Extract transformed RecordBatches, encode and write them to the given # directory. _ = ( transformed_test_data | 'EncodeTestData' >> beam.FlatMapTuple(lambda batch, _: coder.encode(batch)) | 'WriteTestData' >> beam.io.WriteToTFRecord( os.path.join(working_dir, TRANSFORMED_TEST_DATA_FILEBASE))) # Will write a SavedModel and metadata to working_dir, which can then # be read by the tft.TFTransformOutput class. _ = ( transform_fn | 'WriteTransformFn' >> tft_beam.WriteTransformFn(working_dir))

パイプラインを実行します。

import tempfile import pathlib output_dir = os.path.join(tempfile.mkdtemp(), 'keras') transform_data(train_path, test_path, output_dir)

出力ディレクトリを tft.TFTransformOutput としてラップします。

tf_transform_output = tft.TFTransformOutput(output_dir)
tf_transform_output.transformed_feature_spec()

ディレクトリを確認すると、以下の 3 つの項目があります。

  1. train_transformedtest_transformed データファイル

  2. transform_fn ディレクトリ(tf.saved_model

  3. transformed_metadata

次のセクションでは、これらのアーティファクトを使ってモデルをトレーニングする方法を説明します。

!ls -l {output_dir}

##前処理されたデータを使用して、tf.keras を使用してモデルをトレーニングします

tf.Transform を使用して、トレーニングとサービングの両方に同じコードを使用し、スキューを防ぐ方法を示すためにモデルをトレーニングします。モデルをトレーニングし、トレーニングしたモデルを本番用に準備するには、入力関数を作成する必要があります。トレーニング入力関数とサービング入力関数の主な違いは、トレーニングデータにはラベルが含まれ、本番環境のデータには含まれないことです。引数と戻り値も多少異なります。

###トレーニング用の入力関数を作成します

前のセクションのパイプラインを実行すると、変換済みのデータを含む TFRecord ファイルが作成されました。

次のコードは、tf.data.experimental.make_batched_features_datasettft.TFTransformOutput.transformed_feature_spec を使用して、データファイルを tf.data.Dataset として読み取ります。

def _make_training_input_fn(tf_transform_output, train_file_pattern, batch_size): """An input function reading from transformed data, converting to model input. Args: tf_transform_output: Wrapper around output of tf.Transform. transformed_examples: Base filename of examples. batch_size: Batch size. Returns: The input data for training or eval, in the form of k. """ def input_fn(): return tf.data.experimental.make_batched_features_dataset( file_pattern=train_file_pattern, batch_size=batch_size, features=tf_transform_output.transformed_feature_spec(), reader=tf.data.TFRecordDataset, label_key=LABEL_KEY, shuffle=True) return input_fn
train_file_pattern = pathlib.Path(output_dir)/f'{TRANSFORMED_TRAIN_DATA_FILEBASE}*' input_fn = _make_training_input_fn( tf_transform_output=tf_transform_output, train_file_pattern = str(train_file_pattern), batch_size = 10 )

以下では、変換済みのデータサンプルを確認できます。education-numhourd-per-week などの数値カラムは範囲 [0,1] の浮動小数点数に、文字列カラムは ID に変換されていることに注目してください。

for example, label in input_fn().take(1): break pd.DataFrame(example)
label

モデルのトレーニングと評価

モデルを構築する

def build_keras_model(working_dir): inputs = build_keras_inputs(working_dir) encoded_inputs = encode_inputs(inputs) stacked_inputs = tf.concat(tf.nest.flatten(encoded_inputs), axis=1) output = tf.keras.layers.Dense(100, activation='relu')(stacked_inputs) output = tf.keras.layers.Dense(50, activation='relu')(output) output = tf.keras.layers.Dense(2)(output) model = tf.keras.Model(inputs=inputs, outputs=output) return model
def build_keras_inputs(working_dir): tf_transform_output = tft.TFTransformOutput(working_dir) feature_spec = tf_transform_output.transformed_feature_spec().copy() feature_spec.pop(LABEL_KEY) # Build the `keras.Input` objects. inputs = {} for key, spec in feature_spec.items(): if isinstance(spec, tf.io.VarLenFeature): inputs[key] = tf.keras.layers.Input( shape=[None], name=key, dtype=spec.dtype, sparse=True) elif isinstance(spec, tf.io.FixedLenFeature): inputs[key] = tf.keras.layers.Input( shape=spec.shape, name=key, dtype=spec.dtype) else: raise ValueError('Spec type is not supported: ', key, spec) return inputs
def encode_inputs(inputs): encoded_inputs = {} for key in inputs: feature = tf.expand_dims(inputs[key], -1) if key in CATEGORICAL_FEATURE_KEYS: num_buckets = tf_transform_output.num_buckets_for_transformed_feature(key) encoding_layer = ( tf.keras.layers.CategoryEncoding( num_tokens=num_buckets, output_mode='binary', sparse=False)) encoded_inputs[key] = encoding_layer(feature) else: encoded_inputs[key] = feature return encoded_inputs
model = build_keras_model(output_dir) tf.keras.utils.plot_model(model,rankdir='LR', show_shapes=True)

データセットを構築します。

def get_dataset(working_dir, filebase): tf_transform_output = tft.TFTransformOutput(working_dir) data_path_pattern = os.path.join( working_dir, filebase + '*') input_fn = _make_training_input_fn( tf_transform_output, data_path_pattern, batch_size=BATCH_SIZE) dataset = input_fn() return dataset

モデルをトレーニングして評価します。

def train_and_evaluate( model, working_dir): """Train the model on training data and evaluate on test data. Args: working_dir: The location of the Transform output. num_train_instances: Number of instances in train set num_test_instances: Number of instances in test set Returns: The results from the estimator's 'evaluate' method """ train_dataset = get_dataset(working_dir, TRANSFORMED_TRAIN_DATA_FILEBASE) validation_dataset = get_dataset(working_dir, TRANSFORMED_TEST_DATA_FILEBASE) model = build_keras_model(working_dir) history = train_model(model, train_dataset, validation_dataset) metric_values = model.evaluate(validation_dataset, steps=EVALUATION_STEPS, return_dict=True) return model, history, metric_values
def train_model(model, train_dataset, validation_dataset): model.compile(optimizer='adam', loss=tf.losses.CategoricalCrossentropy(from_logits=True), metrics=['accuracy']) history = model.fit(train_dataset, validation_data=validation_dataset, epochs=TRAIN_NUM_EPOCHS, steps_per_epoch=STEPS_PER_TRAIN_EPOCH, validation_steps=EVALUATION_STEPS) return history
model, history, metric_values = train_and_evaluate(model, output_dir)
plt.plot(history.history['loss'], label='Train') plt.plot(history.history['val_loss'], label='Eval') plt.ylim(0,max(plt.ylim())) plt.legend() plt.title('Loss');

新しいデータを変換する

前のセクションのトレーニングプロセスでは、transform_dataset 関数の tft_beam.AnalyzeAndTransformDataset によって生成された変換済みデータのハードコピーを使用しました。

新しいデータを操作するには、tft_beam.WriteTransformFn が保存した最終バージョンの preprocessing_fn を読み込む必要があります。

TFTransformOutput.transform_features_layer メソッドは、出力ディレクトリから preprocessing_fn SavedModel を読み込みます。

以下は、新しい未加工のバッチをソースファイルから読み込む関数です。

def read_csv(file_name, batch_size): return tf.data.experimental.make_csv_dataset( file_pattern=file_name, batch_size=batch_size, column_names=ORDERED_CSV_COLUMNS, column_defaults=COLUMN_DEFAULTS, prefetch_buffer_size=0, ignore_errors=True)
for ex in read_csv(test_path, batch_size=5): break pd.DataFrame(ex)

tft.TransformFeaturesLayer を読み込んで、このデータを preprocessing_fn で変換します。

ex2 = ex.copy() ex2.pop('fnlwgt') tft_layer = tf_transform_output.transform_features_layer() t_ex = tft_layer(ex2) label = t_ex.pop(LABEL_KEY) pd.DataFrame(t_ex)

tft_layer は、特徴量のサブセットのみが渡された場合でも変換を実行できるほどスマートな関数です。たとえば、2 つの特徴量のみを渡しても、変換済みの特徴量を得ることができます。

ex2 = pd.DataFrame(ex)[['education', 'hours-per-week']] ex2
pd.DataFrame(tft_layer(dict(ex2)))

以下はより堅牢なバージョンで、特徴量の仕様に含まれない特徴量をドロップし、提供された特徴量にラベルが存在する場合に (features, label) ペアを返します。

class Transform(tf.Module): def __init__(self, working_dir): self.working_dir = working_dir self.tf_transform_output = tft.TFTransformOutput(working_dir) self.tft_layer = tf_transform_output.transform_features_layer() @tf.function def __call__(self, features): raw_features = {} for key, val in features.items(): # Skip unused keys if key not in RAW_DATA_FEATURE_SPEC: continue raw_features[key] = val # Apply the `preprocessing_fn`. transformed_features = tft_layer(raw_features) if LABEL_KEY in transformed_features: # Pop the label and return a (features, labels) pair. data_labels = transformed_features.pop(LABEL_KEY) return (transformed_features, data_labels) else: return transformed_features
transform = Transform(output_dir)
t_ex, t_label = transform(ex)
pd.DataFrame(t_ex)

次に、Dataset.map を使用して、その変換をオンザフライで新しいデータに適用できます。

model.evaluate( read_csv(test_path, batch_size=5).map(transform), steps=EVALUATION_STEPS, return_dict=True )

モデルをエクスポートする

トレーニング済みのモデルと、新しいデータに preporcessing_fn を適用するメソッドの準備ができました。これらを、シリアル化された tf.train.Example proto を入力として受け取る新しいモデルにまとめます。

class ServingModel(tf.Module): def __init__(self, model, working_dir): self.model = model self.working_dir = working_dir self.transform = Transform(working_dir) @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)]) def __call__(self, serialized_tf_examples): # parse the tf.train.Example feature_spec = RAW_DATA_FEATURE_SPEC.copy() feature_spec.pop(LABEL_KEY) parsed_features = tf.io.parse_example(serialized_tf_examples, feature_spec) # Apply the `preprocessing_fn` transformed_features = self.transform(parsed_features) # Run the model outputs = self.model(transformed_features) # Format the output classes_names = tf.constant([['0', '1']]) classes = tf.tile(classes_names, [tf.shape(outputs)[0], 1]) return {'classes': classes, 'scores': outputs} def export(self, output_dir): # Increment the directory number. This is required in order to make this # model servable with model_server. save_model_dir = pathlib.Path(output_dir)/'model' number_dirs = [int(p.name) for p in save_model_dir.glob('*') if p.name.isdigit()] id = max([0] + number_dirs)+1 save_model_dir = save_model_dir/str(id) # Set the signature to make it visible for serving. concrete_serving_fn = self.__call__.get_concrete_function() signatures = {'serving_default': concrete_serving_fn} # Export the model. tf.saved_model.save( self, str(save_model_dir), signatures=signatures) return save_model_dir

モデルをビルドし、シリアル化した Example のバッチでテストランを実行します。

serving_model = ServingModel(model, output_dir) serving_model(serialized_example_batch)

モデルを SavedModel としてエクスポートします。

saved_model_dir = serving_model.export(output_dir) saved_model_dir

モデルを再読み込みし、同じ Example のバッチでテストします。

reloaded = tf.saved_model.load(str(saved_model_dir)) run_model = reloaded.signatures['serving_default']
run_model(serialized_example_batch)

##この例では、tf.Transform を使用して国勢調査データのデータセットを前処理し、クリーンアップおよび変換されたデータを使用してモデルをトレーニングしました。また、トレーニング済みモデルを本番環境にデプロイして推論を実行する際に使用する入力関数も作成しました。トレーニングと推論の両方に同じコードを使用することで、データのスキューに関する問題を回避します。その過程で、データのクリーンアップに必要な変換を実行するための Apache Beam 変換の作成について学習しました。また、この変換されたデータを使用して、tf.keras を使用してモデルをトレーニングする方法も確認しました。これは、TensorFlow Transform でできることのほんの一部です。tf.Transform についての知識を深めることをお勧めします。

[オプション]前処理されたデータを使用して tf.estimator でモデルをトレーニングする

警告: 新しいコードには Estimators は推奨されません。Estimators は v1.Session スタイルのコードを実行しますが、これは正しく記述するのはより難しく、特に TF 2 コードと組み合わせると予期しない動作をする可能性があります。Estimators は、互換性保証の対象となりますが、セキュリティの脆弱性以外の修正は行われません。詳細については、移行ガイドを参照してください。

###トレーニング用の入力関数を作成します

def _make_training_input_fn(tf_transform_output, transformed_examples, batch_size): """Creates an input function reading from transformed data. Args: tf_transform_output: Wrapper around output of tf.Transform. transformed_examples: Base filename of examples. batch_size: Batch size. Returns: The input function for training or eval. """ def input_fn(): """Input function for training and eval.""" dataset = tf.data.experimental.make_batched_features_dataset( file_pattern=transformed_examples, batch_size=batch_size, features=tf_transform_output.transformed_feature_spec(), reader=tf.data.TFRecordDataset, shuffle=True) transformed_features = tf.compat.v1.data.make_one_shot_iterator( dataset).get_next() # Extract features and label from the transformed tensors. transformed_labels = tf.where( tf.equal(transformed_features.pop(LABEL_KEY), 1)) return transformed_features, transformed_labels[:,1] return input_fn

###サービングするための入力関数を作成します

本番環境で使用できる入力関数を作成し、トレーニング済みのモデルをサービングできるように準備します。

def _make_serving_input_fn(tf_transform_output): """Creates an input function reading from raw data. Args: tf_transform_output: Wrapper around output of tf.Transform. Returns: The serving input function. """ raw_feature_spec = RAW_DATA_FEATURE_SPEC.copy() # Remove label since it is not available during serving. raw_feature_spec.pop(LABEL_KEY) def serving_input_fn(): """Input function for serving.""" # Get raw features by generating the basic serving input_fn and calling it. # Here we generate an input_fn that expects a parsed Example proto to be fed # to the model at serving time. See also # tf.estimator.export.build_raw_serving_input_receiver_fn. raw_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn( raw_feature_spec, default_batch_size=None) serving_input_receiver = raw_input_fn() # Apply the transform function that was used to generate the materialized # data. raw_features = serving_input_receiver.features transformed_features = tf_transform_output.transform_raw_features( raw_features) return tf.estimator.export.ServingInputReceiver( transformed_features, serving_input_receiver.receiver_tensors) return serving_input_fn

###入力データを FeatureColumns でラップします。モデルは TensorFlow FeatureColumns でデータを期待します。

def get_feature_columns(tf_transform_output): """Returns the FeatureColumns for the model. Args: tf_transform_output: A `TFTransformOutput` object. Returns: A list of FeatureColumns. """ # Wrap scalars as real valued columns. real_valued_columns = [tf.feature_column.numeric_column(key, shape=()) for key in NUMERIC_FEATURE_KEYS] # Wrap categorical columns. one_hot_columns = [ tf.feature_column.indicator_column( tf.feature_column.categorical_column_with_identity( key=key, num_buckets=(NUM_OOV_BUCKETS + tf_transform_output.vocabulary_size_by_name( vocab_filename=key)))) for key in CATEGORICAL_FEATURE_KEYS] return real_valued_columns + one_hot_columns

###モデルをトレーニング、評価、エクスポートします

def train_and_evaluate(working_dir, num_train_instances=NUM_TRAIN_INSTANCES, num_test_instances=NUM_TEST_INSTANCES): """Train the model on training data and evaluate on test data. Args: working_dir: Directory to read transformed data and metadata from and to write exported model to. num_train_instances: Number of instances in train set num_test_instances: Number of instances in test set Returns: The results from the estimator's 'evaluate' method """ tf_transform_output = tft.TFTransformOutput(working_dir) run_config = tf.estimator.RunConfig() estimator = tf.estimator.LinearClassifier( feature_columns=get_feature_columns(tf_transform_output), config=run_config, loss_reduction=tf.losses.Reduction.SUM) # Fit the model using the default optimizer. train_input_fn = _make_training_input_fn( tf_transform_output, os.path.join(working_dir, TRANSFORMED_TRAIN_DATA_FILEBASE + '*'), batch_size=BATCH_SIZE) estimator.train( input_fn=train_input_fn, max_steps=TRAIN_NUM_EPOCHS * num_train_instances / BATCH_SIZE) # Evaluate model on test dataset. eval_input_fn = _make_training_input_fn( tf_transform_output, os.path.join(working_dir, TRANSFORMED_TEST_DATA_FILEBASE + '*'), batch_size=1) # Export the model. serving_input_fn = _make_serving_input_fn(tf_transform_output) exported_model_dir = os.path.join(working_dir, EXPORTED_MODEL_DIR) estimator.export_saved_model(exported_model_dir, serving_input_fn) return estimator.evaluate(input_fn=eval_input_fn, steps=num_test_instances)

###すべてをまとめます。以上で国勢調査データを前処理し、モデルをトレーニングして、サービングする準備が完了しました。次に実行します。

注意: このセルからの出力をスクロールして、プロセス全体を表示します。結果は一番下に表示されます。

import tempfile temp = temp = os.path.join(tempfile.mkdtemp(),'estimator') transform_data(train_path, test_path, temp) results = train_and_evaluate(temp)
pprint.pprint(results)