Path: blob/master/site/ko/federated/tutorials/jax_support.ipynb
25118 views
Copyright 2021 The TensorFlow Authors.
TFF에서 JAX에 대한 실험적 지원
TensorFlow 에코시스템의 일부가 되는 것 외에도 TFF는 다른 프런트엔드 및 백엔드 ML 프레임워크와의 상호 운용성을 지원하는 데 목표를 두고 있습니다. 현재, 다른 ML 프레임워크에 대한 지원은 아직 인큐베이션 단계에 있으며 지원되는 API 및 기능은 변경될 수 있습니다(대부분 TFF 사용자의 요구에 따라 결정됨). 이 튜토리얼에서는 JAX와 함께 TFF를 대체 ML 프런트엔드로 사용하고 XLA 컴파일러를 대체 백엔드로 사용하는 방법을 설명합니다. 여기에 표시된 예는 전적으로 기본 JAX/XLA 스택을 기반으로 합니다. 프레임워크 간에 코드를 혼합하는 가능성(예: JAX와 TensorFlow)은 향후 튜토리얼에서 논의할 예정입니다.
언제나처럼 여러분의 기여를 환영합니다. JAX/XLA에 대한 지원 또는 다른 ML 프레임워크와 상호 운용하는 것이 중요한 경우 이러한 기능을 TFF의 나머지 부분과 동등하게 발전시킬 수 있도록 도움을 주시기 바랍니다.
시작하기 전에
TFF 문서 본문에서 환경을 구성하는 방법을 참조하세요. 이 튜토리얼을 실행하는 위치에 따라 아래 코드의 일부 또는 전체를 주석 해제하고 실행할 수 있습니다.
이 튜토리얼은 또한 TFF의 기본 TensorFlow 튜토리얼을 검토했으며 핵심 TFF 개념에 익숙하다고 가정합니다. 아직 이 작업을 수행하지 않았다면 이 중 하나 이상을 검토하는 것이 좋습니다.
JAX 계산
TFF에서 JAX에 대한 지원은 가져오기부터 시작하여 TFF가 TensorFlow와 상호 운용되는 방식과 대칭적이도록 설계되었습니다.
또한 TensorFlow와 마찬가지로 TFF 코드를 표현하기 위한 기반은 로컬에서 실행되는 논리입니다. @tff.jax_computation
래퍼를 사용하여 아래와 같이 JAX에서 이 논리를 표현할 수 있습니다. 이것은 지금쯤 여러분에게 친숙해 있을 @tff.tf_computation
과 유사하게 작동합니다. 두 개의 정수를 더하는 계산과 같이 간단한 내용부터 시작해 보겠습니다.
일반적으로 TFF 계산을 사용하는 것처럼 위에서 정의한 JAX 계산을 사용할 수 있습니다. 예를 들어 다음과 같이 형식 서명을 확인할 수 있습니다.
인수 유형을 정의하기 위해 np.int32
를 사용했다는 점에 주목하세요. TFF는 Numpy 형식(예: np.int32
)과 TensorFlow 형식(예: tf.int32
)을 구분하지 않습니다. TFF의 관점에서 이것들은 같은 내용을 참조하는 방법들일 뿐입니다.
이제 TFF가 Python이 아님을 상기하세요(그리고 이것이 이해되지 않으면 사용자 정의 알고리즘과 같은 이전 튜토리얼 중 일부 내용을 검토하세요). 추적하고 직렬화할 수 있는 JAX 코드 즉, 평상시 XLA로 컴파일 될 것으로 예상되는 @jax.jit
로 주석 처리했을 JAX 코드와 함께 @tff.jax_computation
래퍼를 사용할 수 있습니다(하지만 실제로 @jax.jit
주석을 사용하여 TFF에서 JAX 코드를 임베딩할 필요는 없음).
실제로 막후에서 TFF는 JAX 계산을 XLA로 즉시 컴파일합니다. 다음과 같이 add_numbers
에서 직렬화된 XLA 코드를 수동으로 추출하고 인쇄하여 이를 직접 확인할 수 있습니다.
XLA 코드로 JAX 계산을 표현하는 것이 TensorFlow로 표현된 계산에 대한 tf.GraphDef
와 기능적으로 동등하다고 생각하세요. tf.GraphDef
가 모든 TensorFlow 런타임에서 실행될 수 있는 것처럼 이것은 XLA를 지원하는 다양한 환경에서 이식 가능하고 실행 가능합니다.
TFF는 XLA 컴파일러를 기반으로 하는 런타임 스택을 백엔드로 제공합니다. 다음과 같이 활성화할 수 있습니다.
이제 위에서 정의한 계산을 실행할 수 있습니다.
전혀 어렵지 않습니다. 탄력을 받았으니 MNIST와 같은 더 복잡한 작업을 수행해 보겠습니다.
미리 준비된 API를 사용한 MNIST 훈련의 예
평소와 같이 데이터 배치와 모델에 대해 TFF 유형을 정의하는 것으로 시작합니다(TFF는 강력한 형식의 프레임워크임을 기억할 것).
이제 JAX에서 모델에 대한 손실 함수를 정의하고 모델과 단일 데이터 배치를 매개변수로 사용합니다.
이제 한 가지 방법은 미리 준비된 API를 사용하는 것입니다. 다음은 방금 정의한 손실 함수를 기반으로 하는 훈련 프로세스를 생성하기 위해 API를 사용하는 방법의 예를 보여줍니다.
TensorFlow의 tf.Keras
모델에서 트레이너 빌드를 사용하는 것처럼 위의 내용을 사용할 수 있습니다. 예를 들어 훈련을 위한 초기 모델을 만드는 방법은 다음과 같습니다.
실제 훈련을 수행하려면 몇 가지 데이터가 필요합니다. 단순하게 하기 위해 임의의 데이터를 만들어 보겠습니다. 데이터가 무작위이기 때문에 훈련 데이터에 대해 평가할 것입니다. 그렇지 않고 무작위 평가 데이터를 사용하면 모델이 제대로 작동할 것으로 기대하기 어렵기 때문입니다. 또한 이 자그마한 데모에서는 클라이언트를 무작위로 샘플링하는 문제에 대해 신경 쓰지 않을 것입니다(다른 튜토리얼의 템플릿을 따라 이러한 유형의 변경을 탐구하도록 사용자에게 연습으로 남겨둠).
이를 통해 다음과 같이 단일 단계의 훈련을 수행할 수 있습니다.
훈련 단계의 결과를 평가해 보겠습니다. 쉽게 하기 위해 중앙 집중식으로 평가할 수 있습니다.
손실이 감소하고 있습니다. 훌륭합니다! 이제 여러 번 실행해 보겠습니다.
보시다시피 TFF와 함께 JAX를 사용하는 것도 별반 다르지 않지만, 실험적 API는 아직 기능면에서 TensorFlow API와 동등하지 않습니다.
배경
미리 준비된 API를 사용하지 않으려면 TensorFlow에 대한 사용자 지정 알고리즘 튜토리얼에서 본 것과 같은 방식으로 자체 사용자 지정 계산을 구현할 수 있습니다. 단, 경사 하강법에 대한 JAX 메커니즘을 사용한다는 점만 다릅니다. 예를 들어, 다음은 단일 미니배치에서 모델을 업데이트하는 JAX 계산을 정의하는 방법입니다.
이것이 작동하는지 테스트하는 방법은 다음과 같습니다.
JAX로 작업할 때 한 가지 주의할 점은 tf.data.Dataset
에 해당하는 것을 제공하지 않는다는 것입니다. 따라서 데이터세트를 반복하려면 아래와 같이 시퀀스 작업에 TFF의 선언적 구문을 사용해야 합니다.
제대로 작동하는지 보겠습니다.
단일 훈련 라운드를 수행하는 계산은 TensorFlow 튜토리얼에서 본 것과 같습니다.
제대로 작동하는지 보겠습니다.
보다시피 미리 준비된 API를 통해, 또는 저수준 TFF 구문을 직접 사용하여 TFF에서 JAX를 사용하는 것은 TensorFlow와 함께 TFF를 사용하는 것과 비슷합니다. 향후 업데이트를 기대해 주세요. 전체 ML 프레임워크에서 보다 나은 상호 운용성 지원을 원하면 언제든지 풀 리퀘스트를 보내주세요!