Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
tensorflow
GitHub Repository: tensorflow/docs-l10n
Path: blob/master/site/ko/probability/examples/Distributed_Inference_with_JAX.ipynb
25118 views
Kernel: Python 3

아파치 라이선스, 버전 2.0 (the "License")에 의거함;

#@title Licensed under the Apache License, Version 2.0 (the "License"); { display-mode: "form" } # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License.

JAX의 TensorFlow Probability(TFP)에 이제 분산된 수치적 컴퓨팅을 위한 도구가 있습니다. 수많은 가속기로 확장하기 위해, "단일-프로그램 다중-데이터" 패러다임(SPMD)을 사용하여 코드를 작성하는 것을 중심으로 도구가 구축됩니다.

이 노트북에서는, "SPMD으로 사고"하는 방법과 TPU 팟 또는 GPU의 클러스터와 같은 구성으로 확장하기 위해 새로운 TFP 추상화를 도입하는 내용을 다룹니다. 이 코드를 직접 실행하고 있다면, TPU 런타임을 선택합니다.

우선 최신 버전의 TFP, JAX 및 TF를 설치하겠습니다.

#@title Installs !pip install jaxlib --upgrade -q 2>&1 1> /dev/null !pip install tfp-nightly[jax] --upgrade -q 2>&1 1> /dev/null !pip install tf-nightly-cpu -q -I 2>&1 1> /dev/null !pip install jax -I -q --upgrade 2>&1 1>/dev/null
ERROR: tensorflow 2.4.1 has requirement gast==0.3.3, but you'll have gast 0.4.0 which is incompatible. ERROR: tensorflow 2.4.1 has requirement grpcio~=1.32.0, but you'll have grpcio 1.34.1 which is incompatible. ERROR: tensorflow 2.4.1 has requirement h5py~=2.10.0, but you'll have h5py 3.1.0 which is incompatible. ERROR: google-colab 1.0.0 has requirement requests~=2.23.0, but you'll have requests 2.25.1 which is incompatible. ERROR: datascience 0.10.6 has requirement folium==0.2.1, but you'll have folium 0.8.3 which is incompatible. ERROR: albumentations 0.1.12 has requirement imgaug<0.2.7,>=0.2.5, but you'll have imgaug 0.2.9 which is incompatible. ERROR: tf-nightly-cpu 2.6.0.dev20210401 has requirement numpy~=1.19.2, but you'll have numpy 1.20.2 which is incompatible. ERROR: tensorflow 2.4.1 has requirement gast==0.3.3, but you'll have gast 0.4.0 which is incompatible. ERROR: tensorflow 2.4.1 has requirement grpcio~=1.32.0, but you'll have grpcio 1.34.1 which is incompatible. ERROR: tensorflow 2.4.1 has requirement h5py~=2.10.0, but you'll have h5py 3.1.0 which is incompatible. ERROR: tensorflow 2.4.1 has requirement numpy~=1.19.2, but you'll have numpy 1.20.2 which is incompatible. ERROR: google-colab 1.0.0 has requirement requests~=2.23.0, but you'll have requests 2.25.1 which is incompatible. ERROR: datascience 0.10.6 has requirement folium==0.2.1, but you'll have folium 0.8.3 which is incompatible. ERROR: albumentations 0.1.12 has requirement imgaug<0.2.7,>=0.2.5, but you'll have imgaug 0.2.9 which is incompatible.

몇몇 JAX 유틸리티와 함께 몇몇 일반 라이브러리를 가져오겠습니다.

#@title Setup and Imports import functools import collections import contextlib import jax import jax.numpy as jnp from jax import lax from jax import random import jax.numpy as jnp import numpy as np import matplotlib.pyplot as plt import seaborn as sns import pandas as pd import tensorflow_datasets as tfds from tensorflow_probability.substrates import jax as tfp sns.set(style='white')
INFO:tensorflow:Enabling eager execution INFO:tensorflow:Enabling v2 tensorshape INFO:tensorflow:Enabling resource variables INFO:tensorflow:Enabling tensor equality INFO:tensorflow:Enabling control flow v2

또한 몇몇 유용한 TFP 에일리어스도 설정하겠습니다. 새로운 추상화는 현재 tfp.experimental.distributetfp.experimental.mcmc에 제공됩니다.

tfd = tfp.distributions tfb = tfp.bijectors tfm = tfp.mcmc tfed = tfp.experimental.distribute tfde = tfp.experimental.distributions tfem = tfp.experimental.mcmc Root = tfed.JointDistributionCoroutine.Root

노트북을 TPU에 연결하기 위해, JAX에서 다음 헬퍼를 사용합니다. 연결되었는지 확인하기 위해, 기기의 수를 프린트합니다. 이는 8이어야 합니다.

from jax.tools import colab_tpu colab_tpu.setup_tpu() print(f'Found {jax.device_count()} devices')
Found 8 devices

jax.pmap에 대한 짧은 소개

TPU에 연결한 후, 8개 기기에 액세스했습니다. 하지만, JAX 코드를 열심히 실행할 때, JAX는 기본적으로 하나의 컴퓨팅에서만 실행됩니다.

여러 기기에 걸친 컴퓨팅을 하는 가장 간단한 방법은 각 기기가 맵의 인덱스 1개를 실행하도록 하는 함수를 매핑하는 것입니다. JAX는 jax.pmap("병렬 맵") 함수를 여러 기기에 걸쳐 함수를 매핑하는 함수로 변환하는 변환을 제공합니다.

다음 예시에서, 크기가 8(사용 가능한 기기 수 일치를 위함)인 배열을 생성하고 5를 이에 추가하는 함수를 매핑합니다.

xs = jnp.arange(8.) out = jax.pmap(lambda x: x + 5.)(xs) print(type(out), out)
<class 'jax.interpreters.pxla.ShardedDeviceArray'> [ 5. 6. 7. 8. 9. 10. 11. 12.]

출력 행렬이 물리적으로 기기에 분할된 것을 나타내는 ShardedDeviceArray 유형을 반환받았습니다.

jax.pmap는 의미론적으로 맵처럼 동작하지만, 동작을 수정하는 중요한 옵션이 몇 가지 있습니다. 기본적으로, pmap는 매핑되고 있는 함수에 대한 모든 입력을 추정하지만, 이 동작을 in_axes 인수로 수정할 수 있습니다.

xs = jnp.arange(8.) y = 5. # Map over the 0-axis of `xs` and don't map over `y` out = jax.pmap(lambda x, y: x + y, in_axes=(0, None))(xs, y) print(out)
[ 5. 6. 7. 8. 9. 10. 11. 12.]

마찬가지로, pmap에 대한 out_axes 인수가 모든 기기에 대한 값을 반환할지 여부를 결정합니다. out_axesNone로 설정하는 중 자동으로 첫 번째 기기에 대한 값을 반환하기 때문에 기기에 대한 값이 동일하다는 확신이 있을 때에만 사용해야 합니다.

xs = jnp.ones(8) # Value is the same on each device out = jax.pmap(lambda x: x + 1, out_axes=None)(xs) print(out)
2.0

하고자 하는 작업을 매핑된 순수 함수로 쉽게 표현할 수 없다면 어떻게 될까요? 예를 들어, 매핑하고 있는 축 전체의 합을 컴퓨팅하고 싶다면 어떨까요? 더욱 흥미롭고 복잡한 분산 프로그램을 작성할 수 있도록 JAX는 모든 기기와 통신하는 함수인 "collectives"를 제공합니다. 실제 작동 방법에 대한 이해를 위해, SPMD를 소개하겠습니다.

SPMD란?

단일-프로그램 다중-데이터(SPMD)는 단일 프로그램(즉 동일한 코드)이 기기 간 동시에 실행되지만 실행 프로그램의 각각에 대한 입력은 다를 수 있는 동시 프로그래밍 모델입니다.

프로그램이 입력(예: x + 5)의 단순한 함수인 경우, SPMD에서 프로그램을 실행하는 것은 그저 jax.pmap로 이전에 했던 것과 같이 다른 데이터에 매핑하는 것입니다. 하지만, 그저 함수를 "매핑"하는 것 그 이상의 작업이 가능합니다. JAX는 기기 간 통신하는 함수인 "collectives"를 제공합니다.

예를 들어, 모든 기기의 수량의 합계를 내려고 할 수 있습니다. 그러기 전에, pmap에서 매핑할 축에 이름을 할당해야 합니다. 그런 다음 lax.psum ("병렬 합계") 함수를 사용해 기기 간 합을 수행하여, 합계를 컴퓨팅할 명명된 축을 식별합니다.

def f(x): out = lax.psum(x, axis_name='i') return out xs = jnp.arange(8.) # Length of array matches number of devices jax.pmap(f, axis_name='i')(xs)
ShardedDeviceArray([28., 28., 28., 28., 28., 28., 28., 28.], dtype=float32)

psum 집합은 각 기기에 대한 x 값을 종합하며 맵 전체에 값을 동기화합니다. 각 기기의 out28입니다. 더 이상 단순한 "매핑"을 하지는 않지만, 집합을 이용하는 데 한계가 있을지라도 각 기기의 컴퓨팅이 다른 기기의 동일한 컴퓨팅과 상호 작용할 수 있는 SMPD 프로그램을 실행합니다. 이 시나리오에서, psum가 값을 동기화할 것이기 때문에 out_axes = None를 사용할 수 있습니다.

def f(x): out = lax.psum(x, axis_name='i') return out jax.pmap(f, axis_name='i', out_axes=None)(jnp.arange(8.))
ShardedDeviceArray(28., dtype=float32)

SPMD를 사용하여 TPU 구성의 모든 기기에서 동시에 실행되는 하나의 프로그램을 작성할 수 있습니다. 8개의 TPU 코어에 대해 기계 학습을 수행하는 데 사용되는 동일한 코드를 수백 개에서 수천 개의 코어를 가진 TPU 포드에서 사용할 수 있습니다! jax.pmap 및 SPMD에 관한 더욱 자세한 튜토리얼은 JAX 101 tutorial를 참조할 수 있습니다.

대규모 MCMC

이 노트북에서는, 베이지안을 위해 마르코프 연쇄 몬테카를로(MCMC) 메서드를 사용하는 데 중점을 둡니다. MCMC를 위해 많은 기기를 활용하는 방법이 있을 수 있지만, 이 노트북에서는 다음과 같은 두 가지에 집중합니다.

  1. 다른 기기에서 독립적인 마르코프 연쇄 실행. 이 경우는 상당히 단순하며 Vanilla TFP로 가능합니다.

  2. Sharding a dataset across devices. 모든 기기에 데이터세트를 샤딩. 이 경우는 좀 더 복잡하고 최근 추가된 TFP 장치가 필요합니다.

독립적인 연쇄

MCMC를 사용하여 문제에 대한 베이지안 추론을 하고자 하며 여러 기기(예: 각 기기에서 2개)에서 병렬로 여러 연쇄를 실행하고자 한다고 가정해 봅시다. 이것은 전체 기기, 즉 집합이 필요 없는 기기 전체에 "매핑"할 수 있는 프로그램인 것으로 밝혀졌습니다. 각 프로그램이 서로 다른 마르코프 연쇄를 실행하도록 하기 위해 (동일한 것을 실행하는 것과 반대로), 각 기기에 무작위 시드에 대한 서로 다른 값을 전달합니다.

2D 가우시안 분포 샘플링의 장난감 문제에 이를 적용해 봅시다. TFP의 기존 MCMC 기능을 바로 사용할 수 있습니다. 일반적으로, 모든 기기에서 실행되는 것과 첫 번째 기기에서 실행되는 것을 더욱 명확하게 구분하기 위해 매핑된 함수의 내부에 로직의 대부분을 넣으려고 합니다.

def run(seed): target_log_prob = tfd.Sample(tfd.Normal(0., 1.), 2).log_prob initial_state = jnp.zeros([2, 2]) # 2 chains kernel = tfm.HamiltonianMonteCarlo(target_log_prob, 1e-1, 10) def trace_fn(state, pkr): return target_log_prob(state) states, log_prob = tfm.sample_chain( num_results=1000, num_burnin_steps=1000, kernel=kernel, current_state=initial_state, trace_fn=trace_fn, seed=seed ) return states, log_prob

run 함수는 스스로 상태 비저장 무작위 시드를 받아들입니다(상태 비저장 무작위가 어떻게 작동하는지 확인하려면, JAX의 TFP 노트북 또는 JAX 101 튜토리얼을 참조합니다). 다른 시드에 대해 run을 매핑하면 여러 독립적인 마르코프 연쇄가 실행됩니다.

states, log_probs = jax.pmap(run)(random.split(random.PRNGKey(0), 8)) print(states.shape, log_probs.shape) # states is (8 devices, 1000 samples, 2 chains, 2 dimensions) # log_prob is (8 devices, 1000 samples, 2 chains)
(8, 1000, 2, 2) (8, 1000, 2)

이제 각 기기에 해당하는 추가 축이 어떻게 있는지 알아봅니다. 차원을 재배열할 수 있으며 평평하게 하여 16개 연쇄에 대한 축을 얻을 수 있습니다.

states = states.transpose([0, 2, 1, 3]).reshape([-1, 1000, 2]) log_probs = log_probs.transpose([0, 2, 1]).reshape([-1, 1000])
fig, ax = plt.subplots(1, 2, figsize=(10, 5)) ax[0].plot(log_probs.T, alpha=0.4) ax[1].scatter(*states.reshape([-1, 2]).T, alpha=0.1) plt.show()
Image in a Jupyter notebook

많은 기기에서 독립적인 연쇄를 실행하는 경우, 이는 tfp.mcmc를 사용하는 함수에 대해 pmap-ing 을 사용하는 것만큼 쉬워 각 기기에 무작위 시드에 대한 다른 값을 전달합니다.

데이터 공유

MCMC를 수행하는 경우, 대상 분산은 종종 데이터세트에 조건을 적용하여 얻은 사후 분산이며, 각 관찰된 데이터에 대한 가능성을 합하여 비정규화된 로그 밀도를 컴퓨팅합니다.

매우 큰 데이터세트의 경우, 하나의 기기에서 하나의 연쇄를 실행하는 것조차 엄두가 나지 않을 만큼 비용이 많이 들 수 있습니다. 하지만, 여러 기기에 액세스하는 경우, 사용 가능한 컴퓨팅을 더욱 잘 활용하기 위해 기기 전체의 데이터세트를 분할할 수 있습니다.

샤딩된 데이터세트로 MCMC를 수행하려는 경우, 각 기기에서 컴퓨팅한 비정규화된 로그-밀도가 total, 즉 모든 데이터에 대한 밀도를 나타내는지 확인해야 합니다. 그렇지 않으면 각 기기는 올바르지 않은 대상 분포로 MCMC를 수행하게 될 것입니다. 이를 위해, TFP는 이제 "샤딩된" 로그 가능성을 컴퓨팅하고 이로 MCMC를 수행할 수 있는 새로운 툴(예: tfp.experimental.distributetfp.experimental.mcmc)을 제공합니다.

샤딩된 분포

TFP가 이제 샤딩된 로그 가능성 컴퓨팅을 위해 제공하는 핵심 추상화는 Sharded이며, 이는 분포를 입력으로 취하고 SPME 컨텍스트에서 실행될 때 특정 속성을 가지는 새로운 분포를 반환합니다. Shardedtfp.experimental.distribute에 있습니다.

직관적으로 Shared 분포는 장치 간에 "분할"된 랜덤 변수 집합에 해당합니다. 각 기기에서, 다른 샘플을 생성하며 개별적으로 다른 로그-밀도를 가질 수 있습니다. 또는, Sharded 분포는 플레이트 크기가 기기의 수인 그래픽 모델 용어의 "플레이트"에 해당합니다.

Sharded 분포 샘플링

각 기기에 동일한 시드를 사용하여 pmap-ed 프로그램의 Normal 분포에서 샘플링하는 경우, 각 기기에 동일한 샘플을 얻게 됩니다. 다음 함수를 기기 간 동기화된 단일 무작위 함수 샘플링으로 간주할 수 있습니다.

# `pmap` expects at least one value to be mapped over, so we provide a dummy one def f(seed, _): return tfd.Normal(0., 1.).sample(seed=seed) jax.pmap(f, in_axes=(None, 0))(random.PRNGKey(0), jnp.arange(8.))
ShardedDeviceArray([-0.20584236, -0.20584236, -0.20584236, -0.20584236, -0.20584236, -0.20584236, -0.20584236, -0.20584236], dtype=float32)

tfed.Shardedtfd.Normal(0., 1.)를 래핑 하는 경우, 이제 논리적으로 8개의 다른 무작위 변수(각 기기에 하나)를 가기 때문에 동일한 시드에서 전달함에도 불구하고 각각의 변수에 대해 다른 샘플을 생성합니다.

def f(seed, _): return tfed.Sharded(tfd.Normal(0., 1.), shard_axis_name='i').sample(seed=seed) jax.pmap(f, in_axes=(None, 0), axis_name='i')(random.PRNGKey(0), jnp.arange(8.))
ShardedDeviceArray([ 1.2152631 , 0.7818249 , 0.32549605, 0.6828047 , 1.3973192 , -0.57830244, 0.37862757, 2.7706041 ], dtype=float32)

단일 기기에 대한 이 분포의 동등한 표현은 8개의 독립적인 정규 샘플일 뿐입니다. 샘플 값이 다르더라도(tfed.Sharded는 의사 난수 생성을 약간 다르게 함), 모두 동일한 분포를 나타냅니다.

dist = tfd.Sample(tfd.Normal(0., 1.), jax.device_count()) dist.sample(seed=random.PRNGKey(0))
DeviceArray([ 0.08086783, -0.38624594, -0.3756545 , 1.668957 , -1.2758069 , 2.1192007 , -0.85821325, 1.1305912 ], dtype=float32)

Sharded 분포의 로그-밀도 가져오기

SPMD 컨텍스트에서 정규 분포 샘플의 로그-밀집을 컴퓨팅할 때 무슨 일이 일어나는지 확인합니다.

def f(seed, _): dist = tfd.Normal(0., 1.) x = dist.sample(seed=seed) return x, dist.log_prob(x) jax.pmap(f, in_axes=(None, 0))(random.PRNGKey(0), jnp.arange(8.))
(ShardedDeviceArray([-0.20584236, -0.20584236, -0.20584236, -0.20584236, -0.20584236, -0.20584236, -0.20584236, -0.20584236], dtype=float32), ShardedDeviceArray([-0.94012403, -0.94012403, -0.94012403, -0.94012403, -0.94012403, -0.94012403, -0.94012403, -0.94012403], dtype=float32))

각 샘플은 각 장치에서 동일하므로 각 장치에서 동일한 밀도를 계산합니다. 직관적으로, 여기에는 단일 정규 분포 변수에 대한 하나의 분포만 있게 됩니다.

Sharded 분포를 사용하여, 8개의 무작위 변수에 대한 분포를 가지게 되어, 샘플 log_prob를 컴퓨팅 하는 경우 기기 간의 개별 로그 밀집 각각에 대해 합계를 냅니다. 이 총 log_prob 값이 위에서 컴퓨팅 된 싱글톤 log_prob보다 크다는 것을 알 수 있습니다.

def f(seed, _): dist = tfed.Sharded(tfd.Normal(0., 1.), shard_axis_name='i') x = dist.sample(seed=seed) return x, dist.log_prob(x) sample, log_prob = jax.pmap(f, in_axes=(None, 0), axis_name='i')( random.PRNGKey(0), jnp.arange(8.)) print('Sample:', sample) print('Log Prob:', log_prob)
Sample: [ 1.2152631 0.7818249 0.32549605 0.6828047 1.3973192 -0.57830244 0.37862757 2.7706041 ] Log Prob: [-13.7349205 -13.7349205 -13.7349205 -13.7349205 -13.7349205 -13.7349205 -13.7349205 -13.7349205]

동등한 "샤딩되지 않은" 분포는 동일한 로그 밀도를 생성합니다.

dist = tfd.Sample(tfd.Normal(0., 1.), jax.device_count()) dist.log_prob(sample)
DeviceArray(-13.7349205, dtype=float32)

Sharded 분포는 각 기기에서 sample의 다른 값을 생성하지만, 각 기기에서 log_prob에 대한 동일한 값을 얻습니다. 여기서 무슨 일이 일어나고 있습니까? Sharded 분포는 내부적으로 psum를 수행하여 기기 간 log_prob이 동기화되도록 합니다. 이러한 동작을 원하는 이유는 무엇입니까? 각 기기에 동일한 MCMC 연쇄를 실행하는 경우, 일부 무작위 변수가 기기 간 샤딩되더라도 target_log_prob는 각 기기에서 동일해야 합니다.

또한, Sharded 분포는 기기 간 그래디언트가 정확하도록 보장하고, 변환 함수의 일부로서 로그-밀도 함수의 그래디언트를 취하는 HMC와 같은 알고리즘이 적절한 샘플을 생성하도록 합니다.

샤딩된 JointDistribution

JointDistribution(JD)를 사용하여 여러 Sharded 무작위 변수로 모델을 생성할 수 있습니다. 유감스럽게도, Sharded 분포는 Vanilla tfd.JointDistribution과 안전하게 사용할 수 없지만, tfp.experimental.distributeSharded 분포와 같이 동작하는 "패치된" JD를 내보내기 합니다.

def f(seed, _): dist = tfed.JointDistributionSequential([ tfd.Normal(0., 1.), tfed.Sharded(tfd.Normal(0., 1.), shard_axis_name='i'), ]) x = dist.sample(seed=seed) return x, dist.log_prob(x) jax.pmap(f, in_axes=(None, 0), axis_name='i')(random.PRNGKey(0), jnp.arange(8.))
([ShardedDeviceArray([1.6121525, 1.6121525, 1.6121525, 1.6121525, 1.6121525, 1.6121525, 1.6121525, 1.6121525], dtype=float32), ShardedDeviceArray([ 0.8690128 , -0.83167845, 1.2209264 , 0.88412696, 0.76478404, -0.66208494, -0.0129658 , 0.7391483 ], dtype=float32)], ShardedDeviceArray([-12.214451, -12.214451, -12.214451, -12.214451, -12.214451, -12.214451, -12.214451, -12.214451], dtype=float32))

이러한 샤딩된 JD는 Sharded 및 Vanilla TFP 분산을 구성 요소로 모두 가질 수 있습니다. 샤딩되지 않은 분포의 경우, 각 기기에서 동일한 샘플을 얻고, 샤딩된 분포의 경우, 다른 샘플을 얻습니다. 각 기기의 log_prob 역시 동기화됩니다.

Sharded 분포를 사용한 MCMC

MCMC의 컨텍스트 내 Sharded 분포에 대해 어떻게 생각하십니까? JointDistribution로 표현될 수 있는 생성 모델이 있다면, 해당 모델의 일부 축을 선별하여 전체를 "샤딩"할 수 있습니다. 일반적으로, 모델의 하나의 무작위 변수는 관찰된 데이터에 해당하며, 기기 간 공유하고자 하는 큰 데이터세트가 있다면 데이터 지점과 관련된 변수도 샤딩되기를 원합니다. 또한 샤딩 중인 관찰과 일대일인 "로컬" 무작위 변수가 있을 수 있으므로 이러한 무작위 변수를 추가적으로 샤딩해야 합니다.

이 섹션에서 TFP MCMC를 사용한 Sharded 분포의 사용에 대한 예시를 다루겠습니다. 더욱 단순한 베이지안 로지스틱 회귀 예시를 사용하여 시작하고 distribute 라이브러리를 위한 몇몇 활용 사례를 시연하는 것을 목표로 행렬 인수 분해 예시를 사용하여 마무리할 것입니다.

예시: MNIST를 위한 베이지안 로지스틱 회귀

대규모 데이터세트에서 베이지안 로지스틱 회귀를 하고자 합니다. 모델에는 총 조인트 로그 밀도를 얻기 위한 회귀 가중치에 대한 이전 p(θ)p(\theta)와 전체 데이터 xi,yii=1N{x_i, y_i}_{i = 1}^N에 대해 합산되는 가능성 p(yiθ,xi)p(y_i | \theta, x_i)가 있습니다. 데이터를 샤딩하는 경우, 모델의 관측된 무작위 변수 xix_iyiy_i를 샤딩합니다.

MNIST 분류를 위해 다음 베이지안 로지스틱 회귀 모델을 사용합니다: wN(0,1) bN(0,1) yiw,b,xiCategorical(wTxi+b) \begin{align*} w &\sim \mathcal{N}(0, 1) \ b &\sim \mathcal{N}(0, 1) \ y_i | w, b, x_i &\sim \textrm{Categorical}(w^T x_i + b) \end{align*}

데이터세트를 사용하는 MNIST를 로드하겠습니다

mnist = tfds.as_numpy(tfds.load('mnist', batch_size=-1)) raw_train_images, train_labels = mnist['train']['image'], mnist['train']['label'] train_images = raw_train_images.reshape([raw_train_images.shape[0], -1]) / 255. raw_test_images, test_labels = mnist['test']['image'], mnist['test']['label'] test_images = raw_test_images.reshape([raw_test_images.shape[0], -1]) / 255.
Downloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /root/tensorflow_datasets/mnist/3.0.1...
WARNING:absl:Dataset mnist is hosted on GCS. It will automatically be downloaded to your local data directory. If you'd instead prefer to read directly from our public GCS bucket (recommended if you're running on GCP), you can instead pass `try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`.
HBox(children=(FloatProgress(value=0.0, description='Dl Completed...', max=4.0, style=ProgressStyle(descriptio…
Dataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.

60,000개의 훈련 이미지가 있지만 8개의 사용 가능한 코어를 활용해서 8가지의 방식으로 분할하겠습니다. 이 편리한 shard 활용 함수를 사용하겠습니다.

def shard_value(x): x = x.reshape((jax.device_count(), -1, *x.shape[1:])) return jax.pmap(lambda x: x)(x) # pmap will physically place values on devices shard = functools.partial(jax.tree_map, shard_value)
sharded_train_images, sharded_train_labels = shard((train_images, train_labels)) print(sharded_train_images.shape, sharded_train_labels.shape)
(8, 7500, 784) (8, 7500)

계속 진행하기 전에, TPU의 정확도와 TPU가 HMC에 미치는 영향을 신속히 논의하겠습니다. PU는 속도를 위해 낮은 bfloat16 정확도를 사용해 행렬 곱셈을 수행합니다. bfloat16 행렬 곱셈은 종종 많은 딥 러닝 애플리케이션에는 충분하지만, HMC와 사용하는 경우, 낮은 정확성으로 인해 궤적이 분산되어 거부 반응이 유발될 수 있다는 것을 실증적으로 발견했습니다. 약간의 추가적인 컴퓨팅 비용으로 더욱 높은 정확도의 행렬 곱셈을 사용힐 수 있습니다.

행렬곱 정확도를 향상하기 위해, "tensorfloat32" 정확도로 jax.default_matmul_precision 데코레이터를 사용할 수 있습니다(더 높은 정확도의 경우 "float32" 정확도를 사용할 수 있습니다).

이제 무작위 시드 및 MNIST의 샤드를 포함하는 run 함수를 정의해 봅니다. 이 함수는 앞서 언급한 모델을 구현한 다음 단일 연쇄를 실행하기 위해 TFP의 Vanilla MCMC 기능을 사용합니다. 행렬 곱셈이 더욱 높은 정확성으로 실행되도록 jax.default_matmul_precision 데코레이터로 run를 데코레이트할 것입니다. 아래의 특정 예시에서는 jnp.dot(images, w, precision=lax.Precision.HIGH)를 사용할 수도 있습니다.

# We can use `out_axes=None` in the `pmap` because the results will be the same # on every device. @functools.partial(jax.pmap, axis_name='data', in_axes=(None, 0), out_axes=None) @jax.default_matmul_precision('tensorfloat32') def run(seed, data): images, labels = data # a sharded dataset num_examples, dim = images.shape num_classes = 10 def model_fn(): w = yield Root(tfd.Sample(tfd.Normal(0., 1.), [dim, num_classes])) b = yield Root(tfd.Sample(tfd.Normal(0., 1.), [num_classes])) logits = jnp.dot(images, w) + b yield tfed.Sharded(tfd.Independent(tfd.Categorical(logits=logits), 1), shard_axis_name='data') model = tfed.JointDistributionCoroutine(model_fn) init_seed, sample_seed = random.split(seed) initial_state = model.sample(seed=init_seed)[:-1] # throw away `y` def target_log_prob(*state): return model.log_prob((*state, labels)) def accuracy(w, b): logits = images.dot(w) + b preds = logits.argmax(axis=-1) # We take the average accuracy across devices by using `lax.pmean` return lax.pmean((preds == labels).mean(), 'data') kernel = tfm.HamiltonianMonteCarlo(target_log_prob, 1e-2, 100) kernel = tfm.DualAveragingStepSizeAdaptation(kernel, 500) def trace_fn(state, pkr): return ( target_log_prob(*state), accuracy(*state), pkr.new_step_size) states, trace = tfm.sample_chain( num_results=1000, num_burnin_steps=1000, current_state=initial_state, kernel=kernel, trace_fn=trace_fn, seed=sample_seed ) return states, trace

jax.pmap는 JIT 컴파일을 포함하지만 컴파일된 함수는 첫 번째 호출 후에 캐시 됩니다. run를 호출하고 출력을 무시하여 컴파일을 캐시 하겠습니다.

%%time output = run(random.PRNGKey(0), (sharded_train_images, sharded_train_labels)) jax.tree_map(lambda x: x.block_until_ready(), output)
CPU times: user 24.5 s, sys: 48.2 s, total: 1min 12s Wall time: 1min 54s

이제 실제 실행에 소요되는 시간을 확인하기 위해 run을 다시 호출하겠습니다.

%%time states, trace = run(random.PRNGKey(0), (sharded_train_images, sharded_train_labels)) jax.tree_map(lambda x: x.block_until_ready(), trace)
CPU times: user 13.1 s, sys: 45.2 s, total: 58.3 s Wall time: 1min 43s

각각의 200,000개의 도약 단계를 수행 중이며, 각 단계는 전체 데이터세트에 대한 그래디언트를 컴퓨팅합니다. 8개의 코어에 대한 컴퓨팅을 분할하면 약 95초에 200,000개의 epoch 훈련, 초당 2,100개의 epoch 훈련과 동등하게 컴퓨팅할 수 있습니다!

각 샘플의 로그-밀도 및 각 샘플의 정확도를 표시해 봅시다.

fig, ax = plt.subplots(1, 3, figsize=(15, 5)) ax[0].plot(trace[0]) ax[0].set_title('Log Prob') ax[1].plot(trace[1]) ax[1].set_title('Accuracy') ax[2].plot(trace[2]) ax[2].set_title('Step Size') plt.show()
Image in a Jupyter notebook

샘플을 앙상블하는 경우, 베이지안 모델 평균을 컴퓨팅하여 성능을 개선할 수 있습니다.

@functools.partial(jax.pmap, axis_name='data', in_axes=(0, None), out_axes=None) def bayesian_model_average(data, states): images, labels = data logits = jax.vmap(lambda w, b: images.dot(w) + b)(*states) probs = jax.nn.softmax(logits, axis=-1) bma_accuracy = (probs.mean(axis=0).argmax(axis=-1) == labels).mean() avg_accuracy = (probs.argmax(axis=-1) == labels).mean() return lax.pmean(bma_accuracy, axis_name='data'), lax.pmean(avg_accuracy, axis_name='data') sharded_test_images, sharded_test_labels = shard((test_images, test_labels)) bma_acc, avg_acc = bayesian_model_average((sharded_test_images, sharded_test_labels), states) print(f'Average Accuracy: {avg_acc}') print(f'BMA Accuracy: {bma_acc}') print(f'Accuracy Improvement: {bma_acc - avg_acc}')
Average Accuracy: 0.9188529253005981 BMA Accuracy: 0.9264000058174133 Accuracy Improvement: 0.0075470805168151855

베이지안 모델 평균은 정확도를 거의 1% 향상합니다!

예시: MovieLens 추천 시스템

사용자와 여러 영화에 대한 평점을 모은 MovieLens 추천 데이터세트를 사용하여 추론을 시도해 보겠습니다. 구체적으로 말하면, MovieLens를 NN은 사용자의 수이고 MM은 영화의 수인 N×MN \times M watch matrix WW로 나타낼 수 있습니다. ParseError: KaTeX parse error: Expected 'EOF', got '&' at position 3: N &̲gt; M일 것으로 예상합니다. WijW_{ij}의 입력값은 사용자 ii이 영화 jj를 봤는지 표시하는 불리언입니다. MovieLens는 사용자 평점을 제공하지만 문제를 단순화하기 위해 무시하겠습니다.

우선, 데이터세트를 로드하겠습니다. 평점이 1백만이 넘는 버전을 사용하겠습니다.

movielens = tfds.as_numpy(tfds.load('movielens/1m-ratings', batch_size=-1)) GENRES = ['Action', 'Adventure', 'Animation', 'Children', 'Comedy', 'Crime', 'Documentary', 'Drama', 'Fantasy', 'Film-Noir', 'Horror', 'IMAX', 'Musical', 'Mystery', 'Romance', 'Sci-Fi', 'Thriller', 'Unknown', 'War', 'Western', '(no genres listed)']
Downloading and preparing dataset movielens/1m-ratings/0.1.0 (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/movielens/1m-ratings/0.1.0...
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Completed...', max=1.0, style=Progre…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Size...', max=1.0, style=ProgressSty…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Extraction completed...', max=1.0, styl…
HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))
Shuffling and writing examples to /root/tensorflow_datasets/movielens/1m-ratings/0.1.0.incompleteYKA3TG/movielens-train.tfrecord
HBox(children=(FloatProgress(value=0.0, max=1000209.0), HTML(value='')))
Dataset movielens downloaded and prepared to /root/tensorflow_datasets/movielens/1m-ratings/0.1.0. Subsequent calls will reuse this data.

행렬 WW를 보기 위해 데이터세트를 약간 전처리하겠습니다.

raw_movie_ids = movielens['train']['movie_id'] raw_user_ids = movielens['train']['user_id'] genres = movielens['train']['movie_genres'] movie_ids, movie_labels = pd.factorize(movielens['train']['movie_id']) user_ids, user_labels = pd.factorize(movielens['train']['user_id']) num_movies = movie_ids.max() + 1 num_users = user_ids.max() + 1 movie_titles = dict(zip(movielens['train']['movie_id'], movielens['train']['movie_title'])) movie_genres = dict(zip(movielens['train']['movie_id'], genres)) movie_id_to_title = [movie_titles[movie_labels[id]].decode('utf-8') for id in range(num_movies)] movie_id_to_genre = [GENRES[movie_genres[movie_labels[id]][0]] for id in range(num_movies)] watch_matrix = np.zeros((num_users, num_movies), bool) watch_matrix[user_ids, movie_ids] = True print(watch_matrix.shape)
(6040, 3706)

간단한 확률적 행렬 인수 분해 모델을 사용하여 WW에 대한 생성 모델을 정의할 수 있습니다. 잠재 N×DN \times D 사용자 행렬 UU 및 잠재 M×DM \times D 영화 행렬 VV를 추정합니다. 이 행렬을 곱하면 시청 행렬 WW에 대한 베르누이 로짓을 생성합니다. 또한 사용자 및 영화, uu and vv에 대한 편향 벡터를 포함하겠습니다. UN(0,1)uN(0,1) VN(0,1)vN(0,1) WijBernoulli(σ((UVT)ij+ui+vj)) \begin{align*} U &\sim \mathcal{N}(0, 1) \quad u \sim \mathcal{N}(0, 1)\ V &\sim \mathcal{N}(0, 1) \quad v \sim \mathcal{N}(0, 1)\ W_{ij} &\sim \textrm{Bernoulli}\left(\sigma\left(\left(UV^T\right)_{ij} + u_i + v_j\right)\right) \end{align*}

이것은 상당히 큰 행렬입니다. 6,040명의 사용자와 3,706개의 영화가 2천2백만 개가 넘는 입력값을 가진 행렬로 이어집니다. 이 모델을 샤딩하는 데 어떻게 접근해야 할까요? ParseError: KaTeX parse error: Expected 'EOF', got '&' at position 3: N &̲gt; M(사용자가 영화보다 많음)일 것이라 가정하면 사용자 축 전체에 걸쳐 시청 행렬을 샤딩하는 것이 의미가 있으며, 따라서 각 기기는 사용자의 하위 집합에 해당하는 시청 패트릭스 청크를 가지게 됩니다. 하지만 이전의 예시와는 다르게, UU 행렬은 각 사용자에 대한 임베딩이 있기 때문에 샤딩해야 하며, 따라서 각 기기는 UU의 샤드 및 WW의 샤드에 대한 책임이 있습니다. 반면에, VV는 기기 전체에 샤딩되지 않으며 동기화됩니다.

sharded_watch_matrix = shard(watch_matrix)

run을 작성하기 전에, 로컬 무작위 변수 UU를 샤딩하는 것과 관련된 추가적인 어려움에 대해 짧게 논의해 보겠습니다. HMC를 실행하는 경우, Vanilla tfp.mcmc.HamiltonianMonteCarlo 커널이 각 연쇄 상태의 요소에 대한 모멘텀을 샘플링할 것입니다. 이전에는, 샤딩되지 않은 무작위 변수들만이 해당 상태의 일부였으며 모멘텀은 각 기기에서 동일했습니다. 이제 샤딩된 UU가 있는 경우, VV에 대한 동일한 모멘텀을 샘플링하는 동시에 UU에 대한 각 기기의 다른 모멘텀을 샘플링해야 합니다. 이를 해내기 위해서, Sharded 모멘텀 분산과 tfp.experimental.mcmc.PreconditionedHamiltonianMonteCarlo를 사용할 수 있습니다. 예를 들어, HMC 커널에 샤딩 지표를 가져옴으로써 병렬 컴퓨팅을 계속 1등급으로 만들어 이를 단순화할 수 있습니다.

def make_run(*, axis_name, dim=20, num_chains=2, prior_variance=1., step_size=1e-2, num_leapfrog_steps=100, num_burnin_steps=1000, num_results=500, ): @functools.partial(jax.pmap, in_axes=(None, 0), axis_name=axis_name) @jax.default_matmul_precision('tensorfloat32') def run(key, watch_matrix): num_users, num_movies = watch_matrix.shape Sharded = functools.partial(tfed.Sharded, shard_axis_name=axis_name) def prior_fn(): user_embeddings = yield Root(Sharded(tfd.Sample(tfd.Normal(0., 1.), [num_users, dim]), name='user_embeddings')) user_bias = yield Root(Sharded(tfd.Sample(tfd.Normal(0., 1.), [num_users]), name='user_bias')) movie_embeddings = yield Root(tfd.Sample(tfd.Normal(0., 1.), [num_movies, dim], name='movie_embeddings')) movie_bias = yield Root(tfd.Sample(tfd.Normal(0., 1.), [num_movies], name='movie_bias')) return (user_embeddings, user_bias, movie_embeddings, movie_bias) prior = tfed.JointDistributionCoroutine(prior_fn) def model_fn(): user_embeddings, user_bias, movie_embeddings, movie_bias = yield from prior_fn() logits = (jnp.einsum('...nd,...md->...nm', user_embeddings, movie_embeddings) + user_bias[..., :, None] + movie_bias[..., None, :]) yield Sharded(tfd.Independent(tfd.Bernoulli(logits=logits), 2), name='watch') model = tfed.JointDistributionCoroutine(model_fn) init_key, sample_key = random.split(key) initial_state = prior.sample(seed=init_key, sample_shape=num_chains) def target_log_prob(*state): return model.log_prob((*state, watch_matrix)) momentum_distribution = tfed.JointDistributionSequential([ Sharded(tfd.Independent(tfd.Normal(jnp.zeros([num_chains, num_users, dim]), 1.), 2)), Sharded(tfd.Independent(tfd.Normal(jnp.zeros([num_chains, num_users]), 1.), 1)), tfd.Independent(tfd.Normal(jnp.zeros([num_chains, num_movies, dim]), 1.), 2), tfd.Independent(tfd.Normal(jnp.zeros([num_chains, num_movies]), 1.), 1), ]) # We pass in momentum_distribution here to ensure that the momenta for # user_embeddings and user_bias are also sharded kernel = tfem.PreconditionedHamiltonianMonteCarlo(target_log_prob, step_size, num_leapfrog_steps, momentum_distribution=momentum_distribution) num_adaptation_steps = int(0.8 * num_burnin_steps) kernel = tfm.DualAveragingStepSizeAdaptation(kernel, num_adaptation_steps) def trace_fn(state, pkr): return { 'log_prob': target_log_prob(*state), 'log_accept_ratio': pkr.inner_results.log_accept_ratio, } return tfm.sample_chain( num_results, initial_state, kernel=kernel, num_burnin_steps=num_burnin_steps, trace_fn=trace_fn, seed=sample_key) return run

이를 다시 한 번 실행하여 컴파일된 run을 캐시합니다.

%%time run = make_run(axis_name='data') output = run(random.PRNGKey(0), sharded_watch_matrix) jax.tree_map(lambda x: x.block_until_ready(), output)
CPU times: user 56 s, sys: 1min 24s, total: 2min 20s Wall time: 3min 35s

이를 이제 컴파일 오버헤드 없이 다시 실행합니다.

%%time states, trace = run(random.PRNGKey(0), sharded_watch_matrix) jax.tree_map(lambda x: x.block_until_ready(), trace)
CPU times: user 28.8 s, sys: 1min 16s, total: 1min 44s Wall time: 3min 1s

3분 동안 10,000개의 도약 단계를 마친 것으로 보이므로, 초당 83개의 도약 단계를 마쳤습니다! 샘플의 허용 비율 및 로그 밀도를 표시해 보겠습니다.

fig, axs = plt.subplots(1, len(trace), figsize=(5 * len(trace), 5)) for ax, (key, val) in zip(axs, trace.items()): ax.plot(val[0]) # Indexing into a sharded array, each element is the same ax.set_title(key);
Image in a Jupyter notebook

이제 마르코프 연쇄에서 샘플을 얻었습니다. 이를 사용해 예측을 해보겠습니다. 우선, 구성 요소 각각을 추출해 봅시다. user_embeddingsuser_bias는 기기 간 분할되어 있으므로, ShardedArray를 연결하여 이를 모두 획득해야 한다는 점을 잊지 마십시오. 반면, 모든 기기에서 movie_embeddingsmovie_bias는 동일하므로, 첫 번째 샤드에서 값을 선별할 수 있습니다. 정규 numpy를 사용하여 TPU에서 값을 CPU에 다시 복사하겠습니다.

user_embeddings = np.concatenate(np.array(states.user_embeddings, np.float32), axis=2) user_bias = np.concatenate(np.array(states.user_bias, np.float32), axis=2) movie_embeddings = np.array(states.movie_embeddings[0], dtype=np.float32) movie_bias = np.array(states.movie_bias[0], dtype=np.float32) samples = (user_embeddings, user_bias, movie_embeddings, movie_bias) print(f'User embeddings: {user_embeddings.shape}') print(f'User bias: {user_bias.shape}') print(f'Movie embeddings: {movie_embeddings.shape}') print(f'Movie bias: {movie_bias.shape}')
User embeddings: (500, 2, 6040, 20) User bias: (500, 2, 6040) Movie embeddings: (500, 2, 3706, 20) Movie bias: (500, 2, 3706)

이러한 샘플에서 캡처된 불확실성을 활용하는 단순한 추천 시스템을 구축해 보겠습니다. 우선 시청 확률에 따른 영화의 순위를 매기는 함수를 작성하겠습니다.

@jax.jit def recommend(sample, user_id): user_embeddings, user_bias, movie_embeddings, movie_bias = sample movie_logits = ( jnp.einsum('d,md->m', user_embeddings[user_id], movie_embeddings) + user_bias[user_id] + movie_bias) return movie_logits.argsort()[::-1]

이제 모든 샘플 및 각각을 순환하고, 사용자가 아직 시청하지 않은 상위 랭크된 영화를 선별하는 함수를 작성할 수 있습니다. 그런 다음 전체 샘플의 모든 추천된 영화 수를 볼 수 있습니다.

def get_recommendations(user_id): movie_ids = [] already_watched = set(jnp.arange(num_movies)[watch_matrix[user_id] == 1]) for i in range(500): for j in range(2): sample = jax.tree_map(lambda x: x[i, j], samples) ranking = recommend(sample, user_id) for movie_id in ranking: if int(movie_id) not in already_watched: movie_ids.append(movie_id) break return movie_ids def plot_recommendations(movie_ids, ax=None): titles = collections.Counter([movie_id_to_title[i] for i in movie_ids]) ax = ax or plt.gca() names, counts = zip(*sorted(titles.items(), key=lambda x: -x[1])) ax.bar(names, counts) ax.set_xticklabels(names, rotation=90)

가장 많은 영화를 본 사용자 대 가장 적게 영화를 본 사용자를 살펴보겠습니다.

user_watch_counts = watch_matrix.sum(axis=1) user_most = user_watch_counts.argmax() user_least = user_watch_counts.argmin() print(user_watch_counts[user_most], user_watch_counts[user_least])
2314 20

user_most가 더 많이 시청할 것 같은 영화의 종류가 무엇인지에 대한 정보가 더 많다는 점을 고려하면 시스템이 user_least보다 user_most에 대한 확실성이 더 있길 바랍니다.

fig, ax = plt.subplots(1, 2, figsize=(20, 10)) most_recommendations = get_recommendations(user_most) plot_recommendations(most_recommendations, ax=ax[0]) ax[0].set_title('Recommendation for user_most') least_recommendations = get_recommendations(user_least) plot_recommendations(least_recommendations, ax=ax[1]) ax[1].set_title('Recommendation for user_least');
Image in a Jupyter notebook

시청 선호도의 추가적인 불확실성을 반영하는 user_least에 대한 추천에 더 많은 변수가 있다는 것을 확인할 수 있습니다.

추천 영화의 장르도 확인할 수 있습니다.

most_genres = collections.Counter([movie_id_to_genre[i] for i in most_recommendations]) least_genres = collections.Counter([movie_id_to_genre[i] for i in least_recommendations]) fig, ax = plt.subplots(1, 2, figsize=(20, 10)) ax[0].bar(most_genres.keys(), most_genres.values()) ax[0].set_title('Genres recommended for user_most') ax[1].bar(least_genres.keys(), least_genres.values()) ax[1].set_title('Genres recommended for user_least');
Image in a Jupyter notebook

user_least가 많은 영화를 보지 않았으며 코미디와 액션을 왜곡하는 주류 영화를 더 많이 추천받은 반면, user_most는 영화를 더 많이 보았고 미스터리 및 범죄와 같은 틈새 장르를 추천받았습니다.