Path: blob/master/site/ko/datasets/keras_example.ipynb
25115 views
Training a neural network on MNIST with Keras
This simple example demonstrates how to plug TensorFlow Datasets (TFDS) into a Keras model.
Copyright 2020 The TensorFlow Datasets Authors, Licensed under the Apache License, Version 2.0
데이터 세트로드
다음 인수를 사용하여 MNIST 데이터세트를 로드합니다.
shuffle_files=True
: MNIST 데이터는 단일 파일에만 저장되지만 디스크에 여러 파일이 있는 더 큰 데이터세트의 경우 훈련할 때 셔플하는 것이 좋습니다.as_supervised=True
: 사전{'image': img, 'label': label}
대신 튜플(img, label)
을 반환합니다.
훈련 파이프라인 구축하기
다음 변환을 적용합니다.
tf.data.Dataset.map
: TFDS는tf.uint8
유형의 이미지를 제공하는 반면, 모델은tf.float32
를 기대합니다. 따라서 이미지를 정규화해야 합니다.tf.data.Dataset.cache
데이터세트를 메모리에 피팅할 때 성능 개선을 위해 셔플 전에 캐시하세요.
참고: 캐싱 후에 임의 변환을 적용해야 합니다.tf.data.Dataset.shuffle
: 진정한 무작위성을 위해 셔플 버퍼를 전체 데이터세트 크기로 설정합니다.
참고: 메모리에 들어갈 수 없는 큰 데이터세트의 경우 시스템에서 허용하는 경우buffer_size=1000
을 사용합니다.tf.data.Dataset.batch
: 각 epoch에서 고유한 배치를 얻기 위해 셔플한 후 데이터세트의 요소를 배치로 만듭니다.tf.data.Dataset.prefetch
: 성능을 위해 프리페치하여 파이프라인을 종료하는 것이 좋습니다.
평가 파이프라인 구축하기
테스트 파이프라인은 약간의 차이는 있지만 학습 파이프라인과 유사합니다.
tf.data.Dataset.shuffle
을 호출할 필요가 없습니다.배치는 epoch 간에 같을 수 있으므로 일괄 처리 후에 캐싱이 수행됩니다.
2 단계: 모델 생성 및 훈련하기
TFDS 입력 파이프라인을 간단한 Keras 모델에 연결한 다음, 모델을 컴파일하고 훈련합니다.