Path: blob/master/site/pt-br/probability/examples/Distributed_Inference_with_JAX.ipynb
25118 views
Copyright 2020 The TensorFlow Probability Authors.
Licensed under the Apache License, Version 2.0 (the "License");
Inferência distribuída com JAX
Agora, o TensorFlow Probability (TFP) no JAX conta com ferramentas para computação numérica distribuída. Para aumentar a escala para um grande número de aceleradores, as ferramentas foram programadas segundo o paradigma "único programa, múltiplos dados" ou SPMD, na sigla em inglês.
Neste notebook, veremos como "raciocinar em SPMD" e apresentaremos as novas abstrações do TFP para aumentar a escala para diversas configurações, como pods de TPU ou clusters de GPU. Se você for executar este código por contra própria, lembre-se de selecionar um runtime de GPU.
Primeiro, vamos instalar a última versão do TFP, JAX e TF.
ERROR: tensorflow 2.4.1 has requirement gast==0.3.3, but you'll have gast 0.4.0 which is incompatible.
ERROR: tensorflow 2.4.1 has requirement grpcio~=1.32.0, but you'll have grpcio 1.34.1 which is incompatible.
ERROR: tensorflow 2.4.1 has requirement h5py~=2.10.0, but you'll have h5py 3.1.0 which is incompatible.
ERROR: google-colab 1.0.0 has requirement requests~=2.23.0, but you'll have requests 2.25.1 which is incompatible.
ERROR: datascience 0.10.6 has requirement folium==0.2.1, but you'll have folium 0.8.3 which is incompatible.
ERROR: albumentations 0.1.12 has requirement imgaug<0.2.7,>=0.2.5, but you'll have imgaug 0.2.9 which is incompatible.
ERROR: tf-nightly-cpu 2.6.0.dev20210401 has requirement numpy~=1.19.2, but you'll have numpy 1.20.2 which is incompatible.
ERROR: tensorflow 2.4.1 has requirement gast==0.3.3, but you'll have gast 0.4.0 which is incompatible.
ERROR: tensorflow 2.4.1 has requirement grpcio~=1.32.0, but you'll have grpcio 1.34.1 which is incompatible.
ERROR: tensorflow 2.4.1 has requirement h5py~=2.10.0, but you'll have h5py 3.1.0 which is incompatible.
ERROR: tensorflow 2.4.1 has requirement numpy~=1.19.2, but you'll have numpy 1.20.2 which is incompatible.
ERROR: google-colab 1.0.0 has requirement requests~=2.23.0, but you'll have requests 2.25.1 which is incompatible.
ERROR: datascience 0.10.6 has requirement folium==0.2.1, but you'll have folium 0.8.3 which is incompatible.
ERROR: albumentations 0.1.12 has requirement imgaug<0.2.7,>=0.2.5, but you'll have imgaug 0.2.9 which is incompatible.
Vamos importar algumas bibliotecas gerais e alguns utilitários do JAX.
Também vamos configurar alguns alias do TFP muito úteis. As novas abstrações são fornecidas por tfp.experimental.distribute
e tfp.experimental.mcmc
.
Para conectar o notebook a uma TPU, usamos o seguinte helper do JAX. Para confirmar que a conexão foi estabelecida, exibimos o número de dispositivos via print, que deve ser igual a oito.
Introdução rápida a jax.pmap
Após estabelecermos a conexão a uma TPU, temos acesso a oito dispositivos. Porém, quando executamos o código do JAX de forma eager, o JAX executa as computações em apenas um, por padrão.
A maneira mais simples de executar uma computação em diversos dispositivos é mapear uma função, fazendo cada dispositivo executar um índice do mapa. O JAX conta com a transformação jax.pmap
("mapa paralelo"), que transforma uma função em outra que mapeia a função em diversos dispositivos.
No exemplo abaixo, criamos um array de tamanho 8 (para coincidir com o número de dispositivos disponíveis) e mapeamos uma função que adiciona 5.
Observe que recebemos de volta um tipo ShardedDeviceArray
, indicando que o array de saída está dividido fisicamente entre os dispositivos.
jax.pmap
funciona semanticamente como um mapa, mas conta com algumas opções importantes que modificam seu comportamento. Por padrão, pmap
pressupõe que todas as entradas da função estão sendo mapeadas, mas podemos modificar esse comportamento com o argumento in_axes
.
De maneira análoga, o argumento out_axes
de pmap
determina se os valores devem ser retornados em cada dispositivo ou não. Ao definir out_axes
como None
, o valor é retornado automaticamente para o primeiro dispositivo e somente deve ser usado se estivermos confiantes de que os valores são os mesmos em todos os dispositivos.
O que acontece quando o que quisermos fazer não puder ser facilmente expresso como uma função mapeada pura? Por exemplo: e se quisermos fazer uma soma ao longo do eixo que estamos mapeando? O JAX conta com funções "coletivas" que se comunicam entre os dispositivos para permitir a criação de programas distribuídos mais interessantes e complexos. Para entender como funcionam exatamente, vamos apresentar o SPMD.
O que é o SPMD?
O modelo programa único, múltiplos dados (SPMD) é um modelo de programação concorrente, em que um único programa (isto é, o mesmo código) é executado simultaneamente em diversos dispositivos, mas as entradas de cada um dos programas em execução podem diferir.
Se o nosso programa for uma função simples de sua entrada (por exemplo, x + 5
), executar um programa em SPMD significa simplesmente mapeá-lo em diferentes dados, como fizemos com jax.pmap
anteriormente. Porém, podemos fazer muito mais do que somente "mapear" uma função. O JAX conta com "coletivos", que são funções que se comunicam entre os dispositivos.
Por exemplo: talvez a gente deseje fazer a soma de uma quantidade em todos os dispositivos. Antes de fazermos isso, precisamos atribuir um nome ao eixo que estamos mapeando no pmap
. Em seguida, usamos a função lax.psum
("soma paralela") para fazer uma soma entre os dispositivos, assegurando a identificação do eixo nomeado que está sendo somado.
O coletivo psum
agrega o valor de x
em cada dispositivo e sincroniza seu valor em todo o mapa, ou seja, out
é 28
em todos os dispositivos. Não estamos mais realizando um simples "mapeamento", mas sim executando um programa SPMD, em que agora a computação de cada dispositivo pode interagir com a mesma computação nos outros, embora de forma limitada, usando coletivos. Neste cenário, podemos usar out_axes = None
, pois psum
sincronizará o valor.
Com o SPMD, podemos escrever um único programa que é executado em cada dispositivo em qualquer configuração de TPUs simultaneamente. O mesmo código usado para fazer aprendizado de máquina em 8 núcleos de TPU pode ser usado em um pod de TPUs com centenas de milhares de núcleos! Confira mais detalhes sobre jax.pmap
e SPMD no tutorial de conceitos básicos do JAX.
MCMC em grande escala
Neste notebook, vamos nos concentrar em usar os métodos de Monte Carlo via Cadeias de Markov (MCMC) para inferência bayesiana. Há diversas formas de usar vários dispositivos para MCMC, mas, neste notebook, nos concentraremos em duas:
Executar cadeias de Markov independentes em dispositivos diferentes. Este caso é bem simples e possível de se fazer com o TFP padrão.
Fragmentar um dataset em vários dispositivos. Este caso é um pouco mais complexo e requer funcionalidades do TFP adicionadas recentemente.
Cadeias independentes
Vamos supor que queiramos fazer inferência bayesiana para um problema usando MCMC e queiramos executar diversas cadeias em paralelo em diversos dispositivos (digamos, duas em cada dispositivo). Isso acaba sendo um programa que podemos simplesmente "mapear" nos dispositivos, ou seja, um programa que não precisa de coletivos. Para assegurar que cada programa execute uma cadeia de Markov diferente (em vez de executar a mesma), passamos um valor de semente aleatória diferente para cada dispositivo.
Vamos fazer um teste em um problema simples de amostragem de uma distribuição gaussiana bidimensional. Podemos usar a funcionalidade de MCMC existente no TFP padrão, Em geral, tentamos colocar a maior parte da lógica dentro da função mapeada para distinguir explicitamente o que é executado em todos os dispositivos e o que é executado somente no primeiro.
Isoladamente, a função run
recebe uma semente aleatória stateless (para ver como a aleatoriedade stateless funciona, leia o notebook TFP no JAX ou confira o tutorial de noções básicas do JAX). O mapeamento de run
em diferentes sementes resultará na execução de várias cadeias de Markov independentes.
Observe como agora temos um eixo extra correspondente a cada dispositivo. Podemos reorganizar as dimensões e reduzi-las para obter um eixo para as 16 cadeias.
Ao executar cadeias independentes em diversos dispositivos, basta usar pmap
para criar uma função que use tfp.mcmc
, assegurando a passagem de valores diferentes para a semente aleatória de cada dispositivo.
Fragmentação de dados (sharding)
Ao fazermos o MCMC, a distribuição alvo geralmente é uma distribuição posterior obtida pelo condicionamento de um dataset, e a computação de uma densidade logarítmica não normalizada envolve fazer a soma da verossimilhança para cada dado observado.
Com datasets muito grandes, pode ser proibitivamente caro em termos de recursos executar uma cadeia em um único dispositivo. Entretanto, quando temos acesso a diversos dispositivos, podemos dividir o dataset entre os dispositivos para aproveitar melhor a computação disponível.
Se quisermos fazer o MCMC com um dataset fragmentado, precisamos assegurar que a densidade logarítmica não normalizada que computarmos em cada dispositivo represente o total, ou seja, a densidade para todos os dados. Caso contrário, cada dispositivo fará o MCMC com sua própria distribuição alvo incorreta. Para isso, agora o TFP conta com novas ferramentas (tfp.experimental.distribute
e tfp.experimental.mcmc
), que permitem computar probabilidades logarítmicas "fragmentadas" e fazer MCMC com elas.
Distribuições fragmentadas
Agora, o TFP de abstração core que fornece a computação de probabilidades logarítmicas fragmentadas é a metadistribuição Sharded
, que recebe uma distribuição como entrada e retorna uma nova distribuição com propriedades específicas quando a execução é feita em um contexto de SPMD. Sharded
reside em tfp.experimental.distribute
.
Intuitivamente, uma distribuição Sharded
corresponde a um conjunto de variáveis aleatórias que foram "divididas" entre dispositivos. Em cada dispositivo, elas geram amostras diferentes e podem ter individualmente densidades logarítmicas diferentes. Alternativamente, uma distribuição Sharded
corresponde a uma "placa" no jargão de modelo gráfico, em que o tamanho da placa é o número de dispositivos.
Amostragem de uma distribuição Sharded
Se fizermos a amostragem de uma distribuição Normal
que virou uma mapa paralelo pmap
usando a mesma semente em cada dispositivo, teremos a mesma amostra em todos os dispositivos. Podemos pensar na seguinte função como a amostragem de uma única variável aleatória que é sincronizada entre os dispositivos.
Se encapsularmos tfd.Normal(0., 1.)
com tfed.Sharded
, agora temos logicamente oito variáveis aleatórias diferentes (uma em cada dispositivo) e, portanto, será gerada uma amostra diferente para cada uma, apesar de passarmos a mesma semente.
Uma representação equivalente dessa distribuição em um único dispositivo é somente 8 amostras de uma distribuição normal independente. Embora o valor da amostra seja diferente (tfed.Sharded
faz uma geração de número pseudoaleatório ligeiramente diferente), ambas representam a mesma distribuição.
Recebendo a densidade logarítmica de uma distribuição Sharded
Vamos ver o que acontece quando computamos a densidade logarítmica de uma amostra a partir de uma distribuição regular em um contexto de SPMD.
Cada amostra está no mesmo dispositivo, então também computamos a mesma densidade em cada dispositivo. Intuitivamente, temos aqui somente uma distribuição ao longo de uma única variável com distribuição normal.
Com uma distribuição Sharded
, temos uma distribuição ao longo de 8 variáveis aleatórias, portanto, quanto computamos a log_prob
de uma amostra, fazemos a soma entre os dispositivos de cada uma das densidades logarítmicas individuais (talvez você note que o valor log_prob total seja maior do que o valor log_prob isolado computado acima).
A distribuição equivalente "não fragmentada" produz a mesma densidade logarítmica.
Uma distribuição Sharded
gera valores diferentes de sample
em cada dispositivo, mas obtém o mesmo valor para log_prob
em cada dispositivo. O que acontece aqui? Uma distribuição Sharded
faz um psum
internamente para assegurar que os valores de log_prob
estejam sincronizados entre os dispositivos. Por que desejamos esse comportamento? Se estivermos executando a mesma cadeia do MCMC em cada dispositivo, queremos que a target_log_prob
seja a mesma em cada dispositivo, mesmo que algumas variáveis aleatórias na computação sejam fragmentadas entre os dispositivos.
Além disso, uma distribuição Sharded
assegura que os gradientes entre os dispositivos estejam corretos para garantir que os algoritmos, como o HMC, que recebem gradientes da função de densidade logarítmica funcionem como parte da função de transição e gerem amostras adequadas.
JointDistribution
s fragmentadas
Podemos criar modelos com diversas variáveis aleatórias Sharded
usando JointDistribution
s (JDs – distribuições conjuntas). Infelizmente, distribuições Sharded
não podem ser usadas com segurança com as tfd.JointDistribution
s padrão, mas tfp.experimental.distribute
exporta JDs "modificadas" que se comportarão como distribuições Sharded
.
Essas JDs padrão podem ter tanto distribuições Sharded
quanto distribuições padrão do TFP como componentes. Para as distribuições não fragmentadas, obtemos a mesma amostra em cada dispositivo. Já para distribuições fragmentadas, obtemos amostras diferentes. A log_prob
em cada dispositivo também é sincronizada.
MCMC com distribuições Sharded
Como podemos pensar nas distribuições Sharded
no contexto do MCMC? Se tivermos um modelo generativo que possa ser expressado como uma JointDistribution
, podemos escolher algum eixo desse modelo no qual "fragmentar". Tipicamente, uma variável aleatória do modelo corresponderá aos dados observados e, se tivermos um dataset grande que desejamos fragmentar entre dispositivos, queremos que as variáveis que estejam associadas aos pontos de dados também sejam fragmentadas. Além disso, também podemos ter variáveis aleatórias "locais" com correspondência um-para-um com as observações que estamos fragmentando, então também teremos que fragmentar essas variáveis aleatórias.
Nesta seção, veremos exemplos de uso das distribuições Sharded
com o MCMC do TFP. Começaremos com um exemplo mais simples de regressão logística bayesiana e concluiremos com um exemplo de fatoração de matriz, com o objetivo de demonstrar alguns casos de uso da biblioteca distribute
.
Exemplo: regressão logística bayesiana para MNIST
Gostaríamos de fazer uma regressão logística bayesiana em um dataset grande. O modelo tem um prior ao longo dos pesos de regressão e uma verossimilhança que é somada para todos os dados para obter a densidade logarítmica conjunta total. Se fragmentarmos os dados, fragmentaríamos as variáveis aleatórias observadas e em nosso modelo.
Usamos o seguinte modelo de regressão logística bayesiana para classificação MNIST:
Carregaremos o MNIST usando os TensorFlow Datasets.
Downloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /root/tensorflow_datasets/mnist/3.0.1...
Dataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.
Temos 60 mil imagens de treinamento, mas usaremos os 8 núcleos disponíveis e as dividiremos em 8 conjuntos. Usaremos a função utilitária shard
, que é muito útil.
Antes de continuarmos, vamos falar brevemente sobre a precisão em TPUs e seu impacto no HMC. As TPUs executam as multiplicações de matrizes usando bfloat16
de baixa precisão por questões de velocidade. Geralmente, as multiplicações de matrizes bfloat16
são suficientes para diversas aplicações de aprendizado profundo, mas, quando usadas com o HMC, descobrimos empiricamente que a precisão baixa pode levar a trajetórias divergentes, causando rejeições. Podemos usar multiplicações de matrizes de precisão maior, com um certo aumento do custo computacional.
Para aumentar a precisão de matmul, podemos usar o decorador jax.default_matmul_precision
com precisão "tensorfloat32"
(para uma precisão ainda maior, poderíamos usar "float32"
).
Agora, vamos definir a função run
, que receberá uma semente aleatória (que será a mesma em cada dispositivo) e um fragmento de MNIST. A função implementará o modelo mencionado acima, e depois usarmos a funcionalidade de MCMC padrão do TFP para executar uma única cadeia. Vamos decorar run
com o decorador jax.default_matmul_precision
para assegurar que a multiplicação de matrizes seja executada com precisão maior, embora, no exemplo específico abaixo, poderíamos muito bem usar jnp.dot(images, w, precision=lax.Precision.HIGH)
.
jax.pmap
inclui uma compilação JIT, mas é feito cache da função compilada após a primeira chamada. Vamos chamar run
e ignorar a saída para fazer cache da compilação.
Agora vamos chamar run
novamente para ver quanto tempo a execução leva.
Estamos executando 200 mil passos do método de leapfrog, em que cada um computa um gradiente para o dataset inteiro. Ao dividir a computação em 8 núcleos, podemos computar o equivalente a 200 mil épocas de treinamento em cerca de 95 segundos, aproximadamente 2.100 épocas por segundo!
Vamos plotar a densidade logarítmica e a exatidão de cada amostra.
Se agruparmos a amostras, poderemos computar uma média de modelo bayesiano para aumentar o desempenho.
Uma média de modelo bayesiano aumenta a exatidão em quase 1 ponto percentual!
Exemplo: sistema de recomendações MovieLens
Agora, vamos tentar fazer a inferência com o dataset de recomendações MovieLens, que é um conjunto de usuários e suas classificações de diversos filmes. Especificamente, podemos representar o MovieLens como uma matriz de observação , em que é o número de usuários e é o número de filmes; esperamos que . As entradas de são um booleano que indica se o usuário viu ou não o filme . Observe que o MovieLens fornece as classificações dos usuários, que são ignoradas para simplificar o problema.
Primeiro, vamos carregar o dataset. Usaremos a versão com 1 milhão de classificações.
Downloading and preparing dataset movielens/1m-ratings/0.1.0 (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/movielens/1m-ratings/0.1.0...
Shuffling and writing examples to /root/tensorflow_datasets/movielens/1m-ratings/0.1.0.incompleteYKA3TG/movielens-train.tfrecord
Dataset movielens downloaded and prepared to /root/tensorflow_datasets/movielens/1m-ratings/0.1.0. Subsequent calls will reuse this data.
Faremos um pré-processamento do dataset para obter a matriz de observação .
Podemos definir um modelo generativo para usando um modelo simples de fatoração de matriz probabilística. Vamos supor uma matriz de usuário latente e uma matriz de filmes latente que, quando multiplicadas, produzem os logits de uma distribuição de Bernoulli para a matriz de observação . Também incluiremos um vetor de bias para usuários e filmes, e , respectivamente.
Essa matriz é muito grande: 6.040 usuários e 3.706 filmes levam a uma matriz com mais de 22 milhões de entradas. Qual seria a estratégia de fragmentação desse modelo? Se pressupormos que (isto é, há mais usuários do que filmes), então faria sentido fragmentar a matriz de observação ao longo do eixo de usuários para que cada dispositivo tenha uma parte da matriz de observação correspondente a um subconjunto de usuários. Porém, diferentemente do modelo anterior, também teremos que fragmentar a matriz , já que ela tem um embedding para cada usuário, então cada dispositivo será responsável por um fragmento de e um fragmento de . Por outro lado, não será fragmentado e será sincronizado entre os dispositivos.
Antes de escrevemos run
, vamos falar brevemente sobre os desafios adicionais ao fragmentar a variável aleatória local . Ao executar o HMC, o kernel tfp.mcmc.HamiltonianMonteCarlo
padrão fará a amostragem dos momentos para cada elemento do estado da cadeia. Anteriormente, somente as variáveis aleatórias não fragmentadas faziam parte desse estado, e os momentos eram os mesmos em cada dispositivo. Como agora temos fragmentado, precisamos fazer a amostragem de momentos diferentes em cada dispositivo para , e também precisamos fazer a amostragem dos mesmos momentos para . Para fazer isso, podemos usar tfp.experimental.mcmc.PreconditionedHamiltonianMonteCarlo
com uma distribuição de momento Sharded
. À medida que aprimorarmos a computação paralela, poderemos simplificar isso, levando um indicador de nível de fragmentação ao kernel do HMC, por exemplo.
Vamos executar uma vez para fazer cache da função run
compilada.
Agora vamos executar novamente sem a sobrecarga de compilação.
Podemos ver que concluímos cerca de 150 mil passos do método de leapfrog em aproximadamente 3 minutos, portanto, cerca de 83 passos do método de leapfrog por segundo! Vamos plotar a proporção de aceitação e a densidade logarítmica das amostras.
Agora que temos algumas amostras da cadeia de Markov, vamos usá-las para fazer algumas previsões. Primeiro, vamos extrair cada um dos componentes. Lembre-se de que user_embeddings
e user_bias
são divididos entre os dispositivos, então precisamos concatenar ShardedArray
para obter todos eles. Por outro lado, movie_embeddings
e movie_bias
estão no mesmo dispositivo, então podemos simplesmente pegar o valor no primeiro fragmento. Vamos usar o numpy
comum para copiar os valores das TPUs para a CPU.
Vamos tentar criar um sistema de recomendação simples que utilize a incerteza capturada nessas amostras. Primeiro, vamos escrever uma função que classifique os filmes de acordo com a probabilidade de serem assistidos.
Agora, podemos escrever uma função que percorre todas as amostras em um loop e, para cada uma, escolhe o filme com maior classificação ao qual o usuário ainda não assistiu. Em seguida, podemos ver a contagem de todos os filmes recomendados entre as amostras.
Vamos obter o usuário que viu a maior quantidade de filmes e o que viu a menor quantidade.
Esperamos que o sistema tenha uma certeza maior sobre user_most
do que sobre user_least
, já que temos mais informações sobre quais tipos de filmes user_most
está mais propenso a assistir.
Podemos ver que há uma maior variância nas recomendações para user_least
, o que reflete a incerteza adicional sobre as preferências de filmes.
Também podemos conferir os gêneros dos filmes recomendados.
user_most
viu muitos filmes e recebeu como recomendação gêneros de nicho, como mistério e crime, enquanto user_least
não assistiu a muitos filmes e recebeu como recomendação filmes de gênero mais comum, como comédia e ação.