Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
tensorflow
GitHub Repository: tensorflow/docs-l10n
Path: blob/master/site/pt-br/probability/examples/Distributed_Inference_with_JAX.ipynb
25118 views
Kernel: Python 3

Licensed under the Apache License, Version 2.0 (the "License");

#@title Licensed under the Apache License, Version 2.0 (the "License"); { display-mode: "form" } # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License.

Agora, o TensorFlow Probability (TFP) no JAX conta com ferramentas para computação numérica distribuída. Para aumentar a escala para um grande número de aceleradores, as ferramentas foram programadas segundo o paradigma "único programa, múltiplos dados" ou SPMD, na sigla em inglês.

Neste notebook, veremos como "raciocinar em SPMD" e apresentaremos as novas abstrações do TFP para aumentar a escala para diversas configurações, como pods de TPU ou clusters de GPU. Se você for executar este código por contra própria, lembre-se de selecionar um runtime de GPU.

Primeiro, vamos instalar a última versão do TFP, JAX e TF.

#@title Installs !pip install jaxlib --upgrade -q 2>&1 1> /dev/null !pip install tfp-nightly[jax] --upgrade -q 2>&1 1> /dev/null !pip install tf-nightly-cpu -q -I 2>&1 1> /dev/null !pip install jax -I -q --upgrade 2>&1 1>/dev/null
ERROR: tensorflow 2.4.1 has requirement gast==0.3.3, but you'll have gast 0.4.0 which is incompatible. ERROR: tensorflow 2.4.1 has requirement grpcio~=1.32.0, but you'll have grpcio 1.34.1 which is incompatible. ERROR: tensorflow 2.4.1 has requirement h5py~=2.10.0, but you'll have h5py 3.1.0 which is incompatible. ERROR: google-colab 1.0.0 has requirement requests~=2.23.0, but you'll have requests 2.25.1 which is incompatible. ERROR: datascience 0.10.6 has requirement folium==0.2.1, but you'll have folium 0.8.3 which is incompatible. ERROR: albumentations 0.1.12 has requirement imgaug<0.2.7,>=0.2.5, but you'll have imgaug 0.2.9 which is incompatible. ERROR: tf-nightly-cpu 2.6.0.dev20210401 has requirement numpy~=1.19.2, but you'll have numpy 1.20.2 which is incompatible. ERROR: tensorflow 2.4.1 has requirement gast==0.3.3, but you'll have gast 0.4.0 which is incompatible. ERROR: tensorflow 2.4.1 has requirement grpcio~=1.32.0, but you'll have grpcio 1.34.1 which is incompatible. ERROR: tensorflow 2.4.1 has requirement h5py~=2.10.0, but you'll have h5py 3.1.0 which is incompatible. ERROR: tensorflow 2.4.1 has requirement numpy~=1.19.2, but you'll have numpy 1.20.2 which is incompatible. ERROR: google-colab 1.0.0 has requirement requests~=2.23.0, but you'll have requests 2.25.1 which is incompatible. ERROR: datascience 0.10.6 has requirement folium==0.2.1, but you'll have folium 0.8.3 which is incompatible. ERROR: albumentations 0.1.12 has requirement imgaug<0.2.7,>=0.2.5, but you'll have imgaug 0.2.9 which is incompatible.

Vamos importar algumas bibliotecas gerais e alguns utilitários do JAX.

#@title Setup and Imports import functools import collections import contextlib import jax import jax.numpy as jnp from jax import lax from jax import random import jax.numpy as jnp import numpy as np import matplotlib.pyplot as plt import seaborn as sns import pandas as pd import tensorflow_datasets as tfds from tensorflow_probability.substrates import jax as tfp sns.set(style='white')
INFO:tensorflow:Enabling eager execution INFO:tensorflow:Enabling v2 tensorshape INFO:tensorflow:Enabling resource variables INFO:tensorflow:Enabling tensor equality INFO:tensorflow:Enabling control flow v2

Também vamos configurar alguns alias do TFP muito úteis. As novas abstrações são fornecidas por tfp.experimental.distribute e tfp.experimental.mcmc.

tfd = tfp.distributions tfb = tfp.bijectors tfm = tfp.mcmc tfed = tfp.experimental.distribute tfde = tfp.experimental.distributions tfem = tfp.experimental.mcmc Root = tfed.JointDistributionCoroutine.Root

Para conectar o notebook a uma TPU, usamos o seguinte helper do JAX. Para confirmar que a conexão foi estabelecida, exibimos o número de dispositivos via print, que deve ser igual a oito.

from jax.tools import colab_tpu colab_tpu.setup_tpu() print(f'Found {jax.device_count()} devices')
Found 8 devices

Introdução rápida a jax.pmap

Após estabelecermos a conexão a uma TPU, temos acesso a oito dispositivos. Porém, quando executamos o código do JAX de forma eager, o JAX executa as computações em apenas um, por padrão.

A maneira mais simples de executar uma computação em diversos dispositivos é mapear uma função, fazendo cada dispositivo executar um índice do mapa. O JAX conta com a transformação jax.pmap ("mapa paralelo"), que transforma uma função em outra que mapeia a função em diversos dispositivos.

No exemplo abaixo, criamos um array de tamanho 8 (para coincidir com o número de dispositivos disponíveis) e mapeamos uma função que adiciona 5.

xs = jnp.arange(8.) out = jax.pmap(lambda x: x + 5.)(xs) print(type(out), out)
<class 'jax.interpreters.pxla.ShardedDeviceArray'> [ 5. 6. 7. 8. 9. 10. 11. 12.]

Observe que recebemos de volta um tipo ShardedDeviceArray, indicando que o array de saída está dividido fisicamente entre os dispositivos.

jax.pmap funciona semanticamente como um mapa, mas conta com algumas opções importantes que modificam seu comportamento. Por padrão, pmap pressupõe que todas as entradas da função estão sendo mapeadas, mas podemos modificar esse comportamento com o argumento in_axes.

xs = jnp.arange(8.) y = 5. # Map over the 0-axis of `xs` and don't map over `y` out = jax.pmap(lambda x, y: x + y, in_axes=(0, None))(xs, y) print(out)
[ 5. 6. 7. 8. 9. 10. 11. 12.]

De maneira análoga, o argumento out_axes de pmap determina se os valores devem ser retornados em cada dispositivo ou não. Ao definir out_axes como None, o valor é retornado automaticamente para o primeiro dispositivo e somente deve ser usado se estivermos confiantes de que os valores são os mesmos em todos os dispositivos.

xs = jnp.ones(8) # Value is the same on each device out = jax.pmap(lambda x: x + 1, out_axes=None)(xs) print(out)
2.0

O que acontece quando o que quisermos fazer não puder ser facilmente expresso como uma função mapeada pura? Por exemplo: e se quisermos fazer uma soma ao longo do eixo que estamos mapeando? O JAX conta com funções "coletivas" que se comunicam entre os dispositivos para permitir a criação de programas distribuídos mais interessantes e complexos. Para entender como funcionam exatamente, vamos apresentar o SPMD.

O que é o SPMD?

O modelo programa único, múltiplos dados (SPMD) é um modelo de programação concorrente, em que um único programa (isto é, o mesmo código) é executado simultaneamente em diversos dispositivos, mas as entradas de cada um dos programas em execução podem diferir.

Se o nosso programa for uma função simples de sua entrada (por exemplo, x + 5), executar um programa em SPMD significa simplesmente mapeá-lo em diferentes dados, como fizemos com jax.pmap anteriormente. Porém, podemos fazer muito mais do que somente "mapear" uma função. O JAX conta com "coletivos", que são funções que se comunicam entre os dispositivos.

Por exemplo: talvez a gente deseje fazer a soma de uma quantidade em todos os dispositivos. Antes de fazermos isso, precisamos atribuir um nome ao eixo que estamos mapeando no pmap. Em seguida, usamos a função lax.psum ("soma paralela") para fazer uma soma entre os dispositivos, assegurando a identificação do eixo nomeado que está sendo somado.

def f(x): out = lax.psum(x, axis_name='i') return out xs = jnp.arange(8.) # Length of array matches number of devices jax.pmap(f, axis_name='i')(xs)
ShardedDeviceArray([28., 28., 28., 28., 28., 28., 28., 28.], dtype=float32)

O coletivo psum agrega o valor de x em cada dispositivo e sincroniza seu valor em todo o mapa, ou seja, out é 28 em todos os dispositivos. Não estamos mais realizando um simples "mapeamento", mas sim executando um programa SPMD, em que agora a computação de cada dispositivo pode interagir com a mesma computação nos outros, embora de forma limitada, usando coletivos. Neste cenário, podemos usar out_axes = None, pois psum sincronizará o valor.

def f(x): out = lax.psum(x, axis_name='i') return out jax.pmap(f, axis_name='i', out_axes=None)(jnp.arange(8.))
ShardedDeviceArray(28., dtype=float32)

Com o SPMD, podemos escrever um único programa que é executado em cada dispositivo em qualquer configuração de TPUs simultaneamente. O mesmo código usado para fazer aprendizado de máquina em 8 núcleos de TPU pode ser usado em um pod de TPUs com centenas de milhares de núcleos! Confira mais detalhes sobre jax.pmap e SPMD no tutorial de conceitos básicos do JAX.

MCMC em grande escala

Neste notebook, vamos nos concentrar em usar os métodos de Monte Carlo via Cadeias de Markov (MCMC) para inferência bayesiana. Há diversas formas de usar vários dispositivos para MCMC, mas, neste notebook, nos concentraremos em duas:

  1. Executar cadeias de Markov independentes em dispositivos diferentes. Este caso é bem simples e possível de se fazer com o TFP padrão.

  2. Fragmentar um dataset em vários dispositivos. Este caso é um pouco mais complexo e requer funcionalidades do TFP adicionadas recentemente.

Cadeias independentes

Vamos supor que queiramos fazer inferência bayesiana para um problema usando MCMC e queiramos executar diversas cadeias em paralelo em diversos dispositivos (digamos, duas em cada dispositivo). Isso acaba sendo um programa que podemos simplesmente "mapear" nos dispositivos, ou seja, um programa que não precisa de coletivos. Para assegurar que cada programa execute uma cadeia de Markov diferente (em vez de executar a mesma), passamos um valor de semente aleatória diferente para cada dispositivo.

Vamos fazer um teste em um problema simples de amostragem de uma distribuição gaussiana bidimensional. Podemos usar a funcionalidade de MCMC existente no TFP padrão, Em geral, tentamos colocar a maior parte da lógica dentro da função mapeada para distinguir explicitamente o que é executado em todos os dispositivos e o que é executado somente no primeiro.

def run(seed): target_log_prob = tfd.Sample(tfd.Normal(0., 1.), 2).log_prob initial_state = jnp.zeros([2, 2]) # 2 chains kernel = tfm.HamiltonianMonteCarlo(target_log_prob, 1e-1, 10) def trace_fn(state, pkr): return target_log_prob(state) states, log_prob = tfm.sample_chain( num_results=1000, num_burnin_steps=1000, kernel=kernel, current_state=initial_state, trace_fn=trace_fn, seed=seed ) return states, log_prob

Isoladamente, a função run recebe uma semente aleatória stateless (para ver como a aleatoriedade stateless funciona, leia o notebook TFP no JAX ou confira o tutorial de noções básicas do JAX). O mapeamento de run em diferentes sementes resultará na execução de várias cadeias de Markov independentes.

states, log_probs = jax.pmap(run)(random.split(random.PRNGKey(0), 8)) print(states.shape, log_probs.shape) # states is (8 devices, 1000 samples, 2 chains, 2 dimensions) # log_prob is (8 devices, 1000 samples, 2 chains)
(8, 1000, 2, 2) (8, 1000, 2)

Observe como agora temos um eixo extra correspondente a cada dispositivo. Podemos reorganizar as dimensões e reduzi-las para obter um eixo para as 16 cadeias.

states = states.transpose([0, 2, 1, 3]).reshape([-1, 1000, 2]) log_probs = log_probs.transpose([0, 2, 1]).reshape([-1, 1000])
fig, ax = plt.subplots(1, 2, figsize=(10, 5)) ax[0].plot(log_probs.T, alpha=0.4) ax[1].scatter(*states.reshape([-1, 2]).T, alpha=0.1) plt.show()
Image in a Jupyter notebook

Ao executar cadeias independentes em diversos dispositivos, basta usar pmap para criar uma função que use tfp.mcmc, assegurando a passagem de valores diferentes para a semente aleatória de cada dispositivo.

Fragmentação de dados (sharding)

Ao fazermos o MCMC, a distribuição alvo geralmente é uma distribuição posterior obtida pelo condicionamento de um dataset, e a computação de uma densidade logarítmica não normalizada envolve fazer a soma da verossimilhança para cada dado observado.

Com datasets muito grandes, pode ser proibitivamente caro em termos de recursos executar uma cadeia em um único dispositivo. Entretanto, quando temos acesso a diversos dispositivos, podemos dividir o dataset entre os dispositivos para aproveitar melhor a computação disponível.

Se quisermos fazer o MCMC com um dataset fragmentado, precisamos assegurar que a densidade logarítmica não normalizada que computarmos em cada dispositivo represente o total, ou seja, a densidade para todos os dados. Caso contrário, cada dispositivo fará o MCMC com sua própria distribuição alvo incorreta. Para isso, agora o TFP conta com novas ferramentas (tfp.experimental.distribute e tfp.experimental.mcmc), que permitem computar probabilidades logarítmicas "fragmentadas" e fazer MCMC com elas.

Distribuições fragmentadas

Agora, o TFP de abstração core que fornece a computação de probabilidades logarítmicas fragmentadas é a metadistribuição Sharded, que recebe uma distribuição como entrada e retorna uma nova distribuição com propriedades específicas quando a execução é feita em um contexto de SPMD. Sharded reside em tfp.experimental.distribute.

Intuitivamente, uma distribuição Sharded corresponde a um conjunto de variáveis aleatórias que foram "divididas" entre dispositivos. Em cada dispositivo, elas geram amostras diferentes e podem ter individualmente densidades logarítmicas diferentes. Alternativamente, uma distribuição Sharded corresponde a uma "placa" no jargão de modelo gráfico, em que o tamanho da placa é o número de dispositivos.

Amostragem de uma distribuição Sharded

Se fizermos a amostragem de uma distribuição Normal que virou uma mapa paralelo pmap usando a mesma semente em cada dispositivo, teremos a mesma amostra em todos os dispositivos. Podemos pensar na seguinte função como a amostragem de uma única variável aleatória que é sincronizada entre os dispositivos.

# `pmap` expects at least one value to be mapped over, so we provide a dummy one def f(seed, _): return tfd.Normal(0., 1.).sample(seed=seed) jax.pmap(f, in_axes=(None, 0))(random.PRNGKey(0), jnp.arange(8.))
ShardedDeviceArray([-0.20584236, -0.20584236, -0.20584236, -0.20584236, -0.20584236, -0.20584236, -0.20584236, -0.20584236], dtype=float32)

Se encapsularmos tfd.Normal(0., 1.) com tfed.Sharded, agora temos logicamente oito variáveis aleatórias diferentes (uma em cada dispositivo) e, portanto, será gerada uma amostra diferente para cada uma, apesar de passarmos a mesma semente.

def f(seed, _): return tfed.Sharded(tfd.Normal(0., 1.), shard_axis_name='i').sample(seed=seed) jax.pmap(f, in_axes=(None, 0), axis_name='i')(random.PRNGKey(0), jnp.arange(8.))
ShardedDeviceArray([ 1.2152631 , 0.7818249 , 0.32549605, 0.6828047 , 1.3973192 , -0.57830244, 0.37862757, 2.7706041 ], dtype=float32)

Uma representação equivalente dessa distribuição em um único dispositivo é somente 8 amostras de uma distribuição normal independente. Embora o valor da amostra seja diferente (tfed.Sharded faz uma geração de número pseudoaleatório ligeiramente diferente), ambas representam a mesma distribuição.

dist = tfd.Sample(tfd.Normal(0., 1.), jax.device_count()) dist.sample(seed=random.PRNGKey(0))
DeviceArray([ 0.08086783, -0.38624594, -0.3756545 , 1.668957 , -1.2758069 , 2.1192007 , -0.85821325, 1.1305912 ], dtype=float32)

Recebendo a densidade logarítmica de uma distribuição Sharded

Vamos ver o que acontece quando computamos a densidade logarítmica de uma amostra a partir de uma distribuição regular em um contexto de SPMD.

def f(seed, _): dist = tfd.Normal(0., 1.) x = dist.sample(seed=seed) return x, dist.log_prob(x) jax.pmap(f, in_axes=(None, 0))(random.PRNGKey(0), jnp.arange(8.))
(ShardedDeviceArray([-0.20584236, -0.20584236, -0.20584236, -0.20584236, -0.20584236, -0.20584236, -0.20584236, -0.20584236], dtype=float32), ShardedDeviceArray([-0.94012403, -0.94012403, -0.94012403, -0.94012403, -0.94012403, -0.94012403, -0.94012403, -0.94012403], dtype=float32))

Cada amostra está no mesmo dispositivo, então também computamos a mesma densidade em cada dispositivo. Intuitivamente, temos aqui somente uma distribuição ao longo de uma única variável com distribuição normal.

Com uma distribuição Sharded, temos uma distribuição ao longo de 8 variáveis aleatórias, portanto, quanto computamos a log_prob de uma amostra, fazemos a soma entre os dispositivos de cada uma das densidades logarítmicas individuais (talvez você note que o valor log_prob total seja maior do que o valor log_prob isolado computado acima).

def f(seed, _): dist = tfed.Sharded(tfd.Normal(0., 1.), shard_axis_name='i') x = dist.sample(seed=seed) return x, dist.log_prob(x) sample, log_prob = jax.pmap(f, in_axes=(None, 0), axis_name='i')( random.PRNGKey(0), jnp.arange(8.)) print('Sample:', sample) print('Log Prob:', log_prob)
Sample: [ 1.2152631 0.7818249 0.32549605 0.6828047 1.3973192 -0.57830244 0.37862757 2.7706041 ] Log Prob: [-13.7349205 -13.7349205 -13.7349205 -13.7349205 -13.7349205 -13.7349205 -13.7349205 -13.7349205]

A distribuição equivalente "não fragmentada" produz a mesma densidade logarítmica.

dist = tfd.Sample(tfd.Normal(0., 1.), jax.device_count()) dist.log_prob(sample)
DeviceArray(-13.7349205, dtype=float32)

Uma distribuição Sharded gera valores diferentes de sample em cada dispositivo, mas obtém o mesmo valor para log_prob em cada dispositivo. O que acontece aqui? Uma distribuição Sharded faz um psum internamente para assegurar que os valores de log_prob estejam sincronizados entre os dispositivos. Por que desejamos esse comportamento? Se estivermos executando a mesma cadeia do MCMC em cada dispositivo, queremos que a target_log_prob seja a mesma em cada dispositivo, mesmo que algumas variáveis aleatórias na computação sejam fragmentadas entre os dispositivos.

Além disso, uma distribuição Sharded assegura que os gradientes entre os dispositivos estejam corretos para garantir que os algoritmos, como o HMC, que recebem gradientes da função de densidade logarítmica funcionem como parte da função de transição e gerem amostras adequadas.

JointDistributions fragmentadas

Podemos criar modelos com diversas variáveis aleatórias Sharded usando JointDistributions (JDs – distribuições conjuntas). Infelizmente, distribuições Sharded não podem ser usadas com segurança com as tfd.JointDistributions padrão, mas tfp.experimental.distribute exporta JDs "modificadas" que se comportarão como distribuições Sharded.

def f(seed, _): dist = tfed.JointDistributionSequential([ tfd.Normal(0., 1.), tfed.Sharded(tfd.Normal(0., 1.), shard_axis_name='i'), ]) x = dist.sample(seed=seed) return x, dist.log_prob(x) jax.pmap(f, in_axes=(None, 0), axis_name='i')(random.PRNGKey(0), jnp.arange(8.))
([ShardedDeviceArray([1.6121525, 1.6121525, 1.6121525, 1.6121525, 1.6121525, 1.6121525, 1.6121525, 1.6121525], dtype=float32), ShardedDeviceArray([ 0.8690128 , -0.83167845, 1.2209264 , 0.88412696, 0.76478404, -0.66208494, -0.0129658 , 0.7391483 ], dtype=float32)], ShardedDeviceArray([-12.214451, -12.214451, -12.214451, -12.214451, -12.214451, -12.214451, -12.214451, -12.214451], dtype=float32))

Essas JDs padrão podem ter tanto distribuições Sharded quanto distribuições padrão do TFP como componentes. Para as distribuições não fragmentadas, obtemos a mesma amostra em cada dispositivo. Já para distribuições fragmentadas, obtemos amostras diferentes. A log_prob em cada dispositivo também é sincronizada.

MCMC com distribuições Sharded

Como podemos pensar nas distribuições Sharded no contexto do MCMC? Se tivermos um modelo generativo que possa ser expressado como uma JointDistribution, podemos escolher algum eixo desse modelo no qual "fragmentar". Tipicamente, uma variável aleatória do modelo corresponderá aos dados observados e, se tivermos um dataset grande que desejamos fragmentar entre dispositivos, queremos que as variáveis que estejam associadas aos pontos de dados também sejam fragmentadas. Além disso, também podemos ter variáveis aleatórias "locais" com correspondência um-para-um com as observações que estamos fragmentando, então também teremos que fragmentar essas variáveis aleatórias.

Nesta seção, veremos exemplos de uso das distribuições Sharded com o MCMC do TFP. Começaremos com um exemplo mais simples de regressão logística bayesiana e concluiremos com um exemplo de fatoração de matriz, com o objetivo de demonstrar alguns casos de uso da biblioteca distribute.

Exemplo: regressão logística bayesiana para MNIST

Gostaríamos de fazer uma regressão logística bayesiana em um dataset grande. O modelo tem um prior p(θ)p(\theta) ao longo dos pesos de regressão e uma verossimilhança p(yiθ,xi)p(y_i | \theta, x_i) que é somada para todos os dados xi,yii=1N{x_i, y_i}_{i = 1}^N para obter a densidade logarítmica conjunta total. Se fragmentarmos os dados, fragmentaríamos as variáveis aleatórias observadas xix_i e yiy_i em nosso modelo.

Usamos o seguinte modelo de regressão logística bayesiana para classificação MNIST: wN(0,1)bN(0,1)yiw,b,xiCategorical(wTxi+b) \begin{align*} w &\sim \mathcal{N}(0, 1) \\ b &\sim \mathcal{N}(0, 1) \\ y_i | w, b, x_i &\sim \textrm{Categorical}(w^T x_i + b) \end{align*}

Carregaremos o MNIST usando os TensorFlow Datasets.

mnist = tfds.as_numpy(tfds.load('mnist', batch_size=-1)) raw_train_images, train_labels = mnist['train']['image'], mnist['train']['label'] train_images = raw_train_images.reshape([raw_train_images.shape[0], -1]) / 255. raw_test_images, test_labels = mnist['test']['image'], mnist['test']['label'] test_images = raw_test_images.reshape([raw_test_images.shape[0], -1]) / 255.
Downloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /root/tensorflow_datasets/mnist/3.0.1...
WARNING:absl:Dataset mnist is hosted on GCS. It will automatically be downloaded to your local data directory. If you'd instead prefer to read directly from our public GCS bucket (recommended if you're running on GCP), you can instead pass `try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`.
HBox(children=(FloatProgress(value=0.0, description='Dl Completed...', max=4.0, style=ProgressStyle(descriptio…
Dataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.

Temos 60 mil imagens de treinamento, mas usaremos os 8 núcleos disponíveis e as dividiremos em 8 conjuntos. Usaremos a função utilitária shard, que é muito útil.

def shard_value(x): x = x.reshape((jax.device_count(), -1, *x.shape[1:])) return jax.pmap(lambda x: x)(x) # pmap will physically place values on devices shard = functools.partial(jax.tree_map, shard_value)
sharded_train_images, sharded_train_labels = shard((train_images, train_labels)) print(sharded_train_images.shape, sharded_train_labels.shape)
(8, 7500, 784) (8, 7500)

Antes de continuarmos, vamos falar brevemente sobre a precisão em TPUs e seu impacto no HMC. As TPUs executam as multiplicações de matrizes usando bfloat16 de baixa precisão por questões de velocidade. Geralmente, as multiplicações de matrizes bfloat16 são suficientes para diversas aplicações de aprendizado profundo, mas, quando usadas com o HMC, descobrimos empiricamente que a precisão baixa pode levar a trajetórias divergentes, causando rejeições. Podemos usar multiplicações de matrizes de precisão maior, com um certo aumento do custo computacional.

Para aumentar a precisão de matmul, podemos usar o decorador jax.default_matmul_precision com precisão "tensorfloat32" (para uma precisão ainda maior, poderíamos usar "float32").

Agora, vamos definir a função run, que receberá uma semente aleatória (que será a mesma em cada dispositivo) e um fragmento de MNIST. A função implementará o modelo mencionado acima, e depois usarmos a funcionalidade de MCMC padrão do TFP para executar uma única cadeia. Vamos decorar run com o decorador jax.default_matmul_precision para assegurar que a multiplicação de matrizes seja executada com precisão maior, embora, no exemplo específico abaixo, poderíamos muito bem usar jnp.dot(images, w, precision=lax.Precision.HIGH).

# We can use `out_axes=None` in the `pmap` because the results will be the same # on every device. @functools.partial(jax.pmap, axis_name='data', in_axes=(None, 0), out_axes=None) @jax.default_matmul_precision('tensorfloat32') def run(seed, data): images, labels = data # a sharded dataset num_examples, dim = images.shape num_classes = 10 def model_fn(): w = yield Root(tfd.Sample(tfd.Normal(0., 1.), [dim, num_classes])) b = yield Root(tfd.Sample(tfd.Normal(0., 1.), [num_classes])) logits = jnp.dot(images, w) + b yield tfed.Sharded(tfd.Independent(tfd.Categorical(logits=logits), 1), shard_axis_name='data') model = tfed.JointDistributionCoroutine(model_fn) init_seed, sample_seed = random.split(seed) initial_state = model.sample(seed=init_seed)[:-1] # throw away `y` def target_log_prob(*state): return model.log_prob((*state, labels)) def accuracy(w, b): logits = images.dot(w) + b preds = logits.argmax(axis=-1) # We take the average accuracy across devices by using `lax.pmean` return lax.pmean((preds == labels).mean(), 'data') kernel = tfm.HamiltonianMonteCarlo(target_log_prob, 1e-2, 100) kernel = tfm.DualAveragingStepSizeAdaptation(kernel, 500) def trace_fn(state, pkr): return ( target_log_prob(*state), accuracy(*state), pkr.new_step_size) states, trace = tfm.sample_chain( num_results=1000, num_burnin_steps=1000, current_state=initial_state, kernel=kernel, trace_fn=trace_fn, seed=sample_seed ) return states, trace

jax.pmap inclui uma compilação JIT, mas é feito cache da função compilada após a primeira chamada. Vamos chamar run e ignorar a saída para fazer cache da compilação.

%%time output = run(random.PRNGKey(0), (sharded_train_images, sharded_train_labels)) jax.tree_map(lambda x: x.block_until_ready(), output)
CPU times: user 24.5 s, sys: 48.2 s, total: 1min 12s Wall time: 1min 54s

Agora vamos chamar run novamente para ver quanto tempo a execução leva.

%%time states, trace = run(random.PRNGKey(0), (sharded_train_images, sharded_train_labels)) jax.tree_map(lambda x: x.block_until_ready(), trace)
CPU times: user 13.1 s, sys: 45.2 s, total: 58.3 s Wall time: 1min 43s

Estamos executando 200 mil passos do método de leapfrog, em que cada um computa um gradiente para o dataset inteiro. Ao dividir a computação em 8 núcleos, podemos computar o equivalente a 200 mil épocas de treinamento em cerca de 95 segundos, aproximadamente 2.100 épocas por segundo!

Vamos plotar a densidade logarítmica e a exatidão de cada amostra.

fig, ax = plt.subplots(1, 3, figsize=(15, 5)) ax[0].plot(trace[0]) ax[0].set_title('Log Prob') ax[1].plot(trace[1]) ax[1].set_title('Accuracy') ax[2].plot(trace[2]) ax[2].set_title('Step Size') plt.show()
Image in a Jupyter notebook

Se agruparmos a amostras, poderemos computar uma média de modelo bayesiano para aumentar o desempenho.

@functools.partial(jax.pmap, axis_name='data', in_axes=(0, None), out_axes=None) def bayesian_model_average(data, states): images, labels = data logits = jax.vmap(lambda w, b: images.dot(w) + b)(*states) probs = jax.nn.softmax(logits, axis=-1) bma_accuracy = (probs.mean(axis=0).argmax(axis=-1) == labels).mean() avg_accuracy = (probs.argmax(axis=-1) == labels).mean() return lax.pmean(bma_accuracy, axis_name='data'), lax.pmean(avg_accuracy, axis_name='data') sharded_test_images, sharded_test_labels = shard((test_images, test_labels)) bma_acc, avg_acc = bayesian_model_average((sharded_test_images, sharded_test_labels), states) print(f'Average Accuracy: {avg_acc}') print(f'BMA Accuracy: {bma_acc}') print(f'Accuracy Improvement: {bma_acc - avg_acc}')
Average Accuracy: 0.9188529253005981 BMA Accuracy: 0.9264000058174133 Accuracy Improvement: 0.0075470805168151855

Uma média de modelo bayesiano aumenta a exatidão em quase 1 ponto percentual!

Exemplo: sistema de recomendações MovieLens

Agora, vamos tentar fazer a inferência com o dataset de recomendações MovieLens, que é um conjunto de usuários e suas classificações de diversos filmes. Especificamente, podemos representar o MovieLens como uma matriz de observação WW N×MN \times M, em que NN é o número de usuários e MM é o número de filmes; esperamos que N>MN > M. As entradas de WijW_{ij} são um booleano que indica se o usuário ii viu ou não o filme jj. Observe que o MovieLens fornece as classificações dos usuários, que são ignoradas para simplificar o problema.

Primeiro, vamos carregar o dataset. Usaremos a versão com 1 milhão de classificações.

movielens = tfds.as_numpy(tfds.load('movielens/1m-ratings', batch_size=-1)) GENRES = ['Action', 'Adventure', 'Animation', 'Children', 'Comedy', 'Crime', 'Documentary', 'Drama', 'Fantasy', 'Film-Noir', 'Horror', 'IMAX', 'Musical', 'Mystery', 'Romance', 'Sci-Fi', 'Thriller', 'Unknown', 'War', 'Western', '(no genres listed)']
Downloading and preparing dataset movielens/1m-ratings/0.1.0 (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/movielens/1m-ratings/0.1.0...
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Completed...', max=1.0, style=Progre…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Size...', max=1.0, style=ProgressSty…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Extraction completed...', max=1.0, styl…
HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))
Shuffling and writing examples to /root/tensorflow_datasets/movielens/1m-ratings/0.1.0.incompleteYKA3TG/movielens-train.tfrecord
HBox(children=(FloatProgress(value=0.0, max=1000209.0), HTML(value='')))
Dataset movielens downloaded and prepared to /root/tensorflow_datasets/movielens/1m-ratings/0.1.0. Subsequent calls will reuse this data.

Faremos um pré-processamento do dataset para obter a matriz de observação WW.

raw_movie_ids = movielens['train']['movie_id'] raw_user_ids = movielens['train']['user_id'] genres = movielens['train']['movie_genres'] movie_ids, movie_labels = pd.factorize(movielens['train']['movie_id']) user_ids, user_labels = pd.factorize(movielens['train']['user_id']) num_movies = movie_ids.max() + 1 num_users = user_ids.max() + 1 movie_titles = dict(zip(movielens['train']['movie_id'], movielens['train']['movie_title'])) movie_genres = dict(zip(movielens['train']['movie_id'], genres)) movie_id_to_title = [movie_titles[movie_labels[id]].decode('utf-8') for id in range(num_movies)] movie_id_to_genre = [GENRES[movie_genres[movie_labels[id]][0]] for id in range(num_movies)] watch_matrix = np.zeros((num_users, num_movies), bool) watch_matrix[user_ids, movie_ids] = True print(watch_matrix.shape)
(6040, 3706)

Podemos definir um modelo generativo para WW usando um modelo simples de fatoração de matriz probabilística. Vamos supor uma matriz de usuário latente UU N×DN \times D e uma matriz de filmes latente VV M×DM \times D que, quando multiplicadas, produzem os logits de uma distribuição de Bernoulli para a matriz de observação WW. Também incluiremos um vetor de bias para usuários e filmes, uu e vv, respectivamente. UN(0,1)uN(0,1)VN(0,1)vN(0,1)WijBernoulli(σ((UVT)ij+ui+vj)) \begin{align*} U &\sim \mathcal{N}(0, 1) \quad u \sim \mathcal{N}(0, 1)\\ V &\sim \mathcal{N}(0, 1) \quad v \sim \mathcal{N}(0, 1)\\ W_{ij} &\sim \textrm{Bernoulli}\left(\sigma\left(\left(UV^T\right)_{ij} + u_i + v_j\right)\right) \end{align*}

Essa matriz é muito grande: 6.040 usuários e 3.706 filmes levam a uma matriz com mais de 22 milhões de entradas. Qual seria a estratégia de fragmentação desse modelo? Se pressupormos que N>MN > M (isto é, há mais usuários do que filmes), então faria sentido fragmentar a matriz de observação ao longo do eixo de usuários para que cada dispositivo tenha uma parte da matriz de observação correspondente a um subconjunto de usuários. Porém, diferentemente do modelo anterior, também teremos que fragmentar a matriz UU, já que ela tem um embedding para cada usuário, então cada dispositivo será responsável por um fragmento de UU e um fragmento de WW. Por outro lado, VV não será fragmentado e será sincronizado entre os dispositivos.

sharded_watch_matrix = shard(watch_matrix)

Antes de escrevemos run, vamos falar brevemente sobre os desafios adicionais ao fragmentar a variável aleatória local UU. Ao executar o HMC, o kernel tfp.mcmc.HamiltonianMonteCarlo padrão fará a amostragem dos momentos para cada elemento do estado da cadeia. Anteriormente, somente as variáveis aleatórias não fragmentadas faziam parte desse estado, e os momentos eram os mesmos em cada dispositivo. Como agora temos UU fragmentado, precisamos fazer a amostragem de momentos diferentes em cada dispositivo para UU, e também precisamos fazer a amostragem dos mesmos momentos para VV. Para fazer isso, podemos usar tfp.experimental.mcmc.PreconditionedHamiltonianMonteCarlo com uma distribuição de momento Sharded. À medida que aprimorarmos a computação paralela, poderemos simplificar isso, levando um indicador de nível de fragmentação ao kernel do HMC, por exemplo.

def make_run(*, axis_name, dim=20, num_chains=2, prior_variance=1., step_size=1e-2, num_leapfrog_steps=100, num_burnin_steps=1000, num_results=500, ): @functools.partial(jax.pmap, in_axes=(None, 0), axis_name=axis_name) @jax.default_matmul_precision('tensorfloat32') def run(key, watch_matrix): num_users, num_movies = watch_matrix.shape Sharded = functools.partial(tfed.Sharded, shard_axis_name=axis_name) def prior_fn(): user_embeddings = yield Root(Sharded(tfd.Sample(tfd.Normal(0., 1.), [num_users, dim]), name='user_embeddings')) user_bias = yield Root(Sharded(tfd.Sample(tfd.Normal(0., 1.), [num_users]), name='user_bias')) movie_embeddings = yield Root(tfd.Sample(tfd.Normal(0., 1.), [num_movies, dim], name='movie_embeddings')) movie_bias = yield Root(tfd.Sample(tfd.Normal(0., 1.), [num_movies], name='movie_bias')) return (user_embeddings, user_bias, movie_embeddings, movie_bias) prior = tfed.JointDistributionCoroutine(prior_fn) def model_fn(): user_embeddings, user_bias, movie_embeddings, movie_bias = yield from prior_fn() logits = (jnp.einsum('...nd,...md->...nm', user_embeddings, movie_embeddings) + user_bias[..., :, None] + movie_bias[..., None, :]) yield Sharded(tfd.Independent(tfd.Bernoulli(logits=logits), 2), name='watch') model = tfed.JointDistributionCoroutine(model_fn) init_key, sample_key = random.split(key) initial_state = prior.sample(seed=init_key, sample_shape=num_chains) def target_log_prob(*state): return model.log_prob((*state, watch_matrix)) momentum_distribution = tfed.JointDistributionSequential([ Sharded(tfd.Independent(tfd.Normal(jnp.zeros([num_chains, num_users, dim]), 1.), 2)), Sharded(tfd.Independent(tfd.Normal(jnp.zeros([num_chains, num_users]), 1.), 1)), tfd.Independent(tfd.Normal(jnp.zeros([num_chains, num_movies, dim]), 1.), 2), tfd.Independent(tfd.Normal(jnp.zeros([num_chains, num_movies]), 1.), 1), ]) # We pass in momentum_distribution here to ensure that the momenta for # user_embeddings and user_bias are also sharded kernel = tfem.PreconditionedHamiltonianMonteCarlo(target_log_prob, step_size, num_leapfrog_steps, momentum_distribution=momentum_distribution) num_adaptation_steps = int(0.8 * num_burnin_steps) kernel = tfm.DualAveragingStepSizeAdaptation(kernel, num_adaptation_steps) def trace_fn(state, pkr): return { 'log_prob': target_log_prob(*state), 'log_accept_ratio': pkr.inner_results.log_accept_ratio, } return tfm.sample_chain( num_results, initial_state, kernel=kernel, num_burnin_steps=num_burnin_steps, trace_fn=trace_fn, seed=sample_key) return run

Vamos executar uma vez para fazer cache da função run compilada.

%%time run = make_run(axis_name='data') output = run(random.PRNGKey(0), sharded_watch_matrix) jax.tree_map(lambda x: x.block_until_ready(), output)
CPU times: user 56 s, sys: 1min 24s, total: 2min 20s Wall time: 3min 35s

Agora vamos executar novamente sem a sobrecarga de compilação.

%%time states, trace = run(random.PRNGKey(0), sharded_watch_matrix) jax.tree_map(lambda x: x.block_until_ready(), trace)
CPU times: user 28.8 s, sys: 1min 16s, total: 1min 44s Wall time: 3min 1s

Podemos ver que concluímos cerca de 150 mil passos do método de leapfrog em aproximadamente 3 minutos, portanto, cerca de 83 passos do método de leapfrog por segundo! Vamos plotar a proporção de aceitação e a densidade logarítmica das amostras.

fig, axs = plt.subplots(1, len(trace), figsize=(5 * len(trace), 5)) for ax, (key, val) in zip(axs, trace.items()): ax.plot(val[0]) # Indexing into a sharded array, each element is the same ax.set_title(key);
Image in a Jupyter notebook

Agora que temos algumas amostras da cadeia de Markov, vamos usá-las para fazer algumas previsões. Primeiro, vamos extrair cada um dos componentes. Lembre-se de que user_embeddings e user_bias são divididos entre os dispositivos, então precisamos concatenar ShardedArray para obter todos eles. Por outro lado, movie_embeddings e movie_bias estão no mesmo dispositivo, então podemos simplesmente pegar o valor no primeiro fragmento. Vamos usar o numpy comum para copiar os valores das TPUs para a CPU.

user_embeddings = np.concatenate(np.array(states.user_embeddings, np.float32), axis=2) user_bias = np.concatenate(np.array(states.user_bias, np.float32), axis=2) movie_embeddings = np.array(states.movie_embeddings[0], dtype=np.float32) movie_bias = np.array(states.movie_bias[0], dtype=np.float32) samples = (user_embeddings, user_bias, movie_embeddings, movie_bias) print(f'User embeddings: {user_embeddings.shape}') print(f'User bias: {user_bias.shape}') print(f'Movie embeddings: {movie_embeddings.shape}') print(f'Movie bias: {movie_bias.shape}')
User embeddings: (500, 2, 6040, 20) User bias: (500, 2, 6040) Movie embeddings: (500, 2, 3706, 20) Movie bias: (500, 2, 3706)

Vamos tentar criar um sistema de recomendação simples que utilize a incerteza capturada nessas amostras. Primeiro, vamos escrever uma função que classifique os filmes de acordo com a probabilidade de serem assistidos.

@jax.jit def recommend(sample, user_id): user_embeddings, user_bias, movie_embeddings, movie_bias = sample movie_logits = ( jnp.einsum('d,md->m', user_embeddings[user_id], movie_embeddings) + user_bias[user_id] + movie_bias) return movie_logits.argsort()[::-1]

Agora, podemos escrever uma função que percorre todas as amostras em um loop e, para cada uma, escolhe o filme com maior classificação ao qual o usuário ainda não assistiu. Em seguida, podemos ver a contagem de todos os filmes recomendados entre as amostras.

def get_recommendations(user_id): movie_ids = [] already_watched = set(jnp.arange(num_movies)[watch_matrix[user_id] == 1]) for i in range(500): for j in range(2): sample = jax.tree_map(lambda x: x[i, j], samples) ranking = recommend(sample, user_id) for movie_id in ranking: if int(movie_id) not in already_watched: movie_ids.append(movie_id) break return movie_ids def plot_recommendations(movie_ids, ax=None): titles = collections.Counter([movie_id_to_title[i] for i in movie_ids]) ax = ax or plt.gca() names, counts = zip(*sorted(titles.items(), key=lambda x: -x[1])) ax.bar(names, counts) ax.set_xticklabels(names, rotation=90)

Vamos obter o usuário que viu a maior quantidade de filmes e o que viu a menor quantidade.

user_watch_counts = watch_matrix.sum(axis=1) user_most = user_watch_counts.argmax() user_least = user_watch_counts.argmin() print(user_watch_counts[user_most], user_watch_counts[user_least])
2314 20

Esperamos que o sistema tenha uma certeza maior sobre user_most do que sobre user_least, já que temos mais informações sobre quais tipos de filmes user_most está mais propenso a assistir.

fig, ax = plt.subplots(1, 2, figsize=(20, 10)) most_recommendations = get_recommendations(user_most) plot_recommendations(most_recommendations, ax=ax[0]) ax[0].set_title('Recommendation for user_most') least_recommendations = get_recommendations(user_least) plot_recommendations(least_recommendations, ax=ax[1]) ax[1].set_title('Recommendation for user_least');
Image in a Jupyter notebook

Podemos ver que há uma maior variância nas recomendações para user_least, o que reflete a incerteza adicional sobre as preferências de filmes.

Também podemos conferir os gêneros dos filmes recomendados.

most_genres = collections.Counter([movie_id_to_genre[i] for i in most_recommendations]) least_genres = collections.Counter([movie_id_to_genre[i] for i in least_recommendations]) fig, ax = plt.subplots(1, 2, figsize=(20, 10)) ax[0].bar(most_genres.keys(), most_genres.values()) ax[0].set_title('Genres recommended for user_most') ax[1].bar(least_genres.keys(), least_genres.values()) ax[1].set_title('Genres recommended for user_least');
Image in a Jupyter notebook

user_most viu muitos filmes e recebeu como recomendação gêneros de nicho, como mistério e crime, enquanto user_least não assistiu a muitos filmes e recebeu como recomendação filmes de gênero mais comum, como comédia e ação.