Path: blob/master/site/ko/lite/examples/jax_conversion/overview.ipynb
25118 views
Kernel: Python 3
Copyright 2021 The TensorFlow Authors.
In [ ]:
TFLite용 Jax 모델 변환
개요
참고: 이 API는 새로운 것이며 pip install tf-nightly를 통해서만 사용할 수 있습니다. TensorFlow 버전 2.7에서 사용할 수 있습니다. 또한 API는 아직 실험적이며 변경될 수 있습니다.
이 CodeLab은 Jax를 사용하여 MNIST 인식을 위한 모델을 구축하는 방법과 이를 TensorFlow Lite로 변환하는 방법을 보여줍니다. 이 코드랩은 또한 훈련 후 양자화를 사용하여 Jax 변환 TFLite 모델을 최적화하는 방법을 보여줍니다.
전제 조건
최신 TensorFlow 야간 pip 빌드에서 이 기능을 사용하는 것이 좋습니다.
In [ ]:
데이터 준비
Keras 데이터셋으로 MNIST 데이터를 다운로드하고 전처리합니다.
In [ ]:
In [ ]:
Jax로 MNIST 모델 빌드
In [ ]:
모델 학습 및 평가
In [ ]:
TFLite 모델로 변환합니다.
참고로 우리는
params를
functools.partial
predict
func에 인라인합니다.jnp.zeros
빌드합니다. 이것은 Jax가 모델을 추적하는 데 사용되는 "자리 표시자" 텐서입니다.experimental_from_jax
호출합니다.
serving_func
는 목록으로 래핑됩니다.입력은 지정된 이름과 연결되고 목록에 래핑된 배열로 전달됩니다.
In [ ]:
변환된 TFLite 모델 확인
변환된 모델의 결과를 Jax 모델과 비교하십시오.
In [ ]:
모델 최적화
우리는 제공 할 것입니다 representative_dataset
모델을 최적화하기 위해 훈련 후 quantiztion을 할 수 있습니다.
In [ ]:
최적화된 모델 평가
In [ ]:
양자화된 모델 크기 비교
양자화된 모델이 원래 모델보다 4배 더 작은 것을 볼 수 있어야 합니다.
In [ ]: