Path: blob/master/site/es-419/probability/examples/Distributed_Inference_with_JAX.ipynb
25118 views
Copyright 2020 The TensorFlow Probability Authors.
Licensed under the Apache License, Version 2.0 (the "License");
Inferencia distribuida con JAX
TensorFlow Probability (TFP) en JAX ahora tiene herramientas para computación numérica distribuida. Para escalar a una gran cantidad de aceleradores, las herramientas se basan en la escritura de código utilizando el paradigma de "un programa, múltiples datos", o SPMD para abreviar.
En este bloc de notas, veremos cómo "pensar en SPMD" e introduciremos las nuevas abstracciones de TFP para escalar a configuraciones como módulos de TPU o grupos de GPU. Si ejecuta este código por sí mismo, asegúrese de seleccionar un tiempo de ejecución de TPU.
Primero, instalaremos las últimas versiones de TFP, JAX y TF.
ERROR: tensorflow 2.4.1 has requirement gast==0.3.3, but you'll have gast 0.4.0 which is incompatible.
ERROR: tensorflow 2.4.1 has requirement grpcio~=1.32.0, but you'll have grpcio 1.34.1 which is incompatible.
ERROR: tensorflow 2.4.1 has requirement h5py~=2.10.0, but you'll have h5py 3.1.0 which is incompatible.
ERROR: google-colab 1.0.0 has requirement requests~=2.23.0, but you'll have requests 2.25.1 which is incompatible.
ERROR: datascience 0.10.6 has requirement folium==0.2.1, but you'll have folium 0.8.3 which is incompatible.
ERROR: albumentations 0.1.12 has requirement imgaug<0.2.7,>=0.2.5, but you'll have imgaug 0.2.9 which is incompatible.
ERROR: tf-nightly-cpu 2.6.0.dev20210401 has requirement numpy~=1.19.2, but you'll have numpy 1.20.2 which is incompatible.
ERROR: tensorflow 2.4.1 has requirement gast==0.3.3, but you'll have gast 0.4.0 which is incompatible.
ERROR: tensorflow 2.4.1 has requirement grpcio~=1.32.0, but you'll have grpcio 1.34.1 which is incompatible.
ERROR: tensorflow 2.4.1 has requirement h5py~=2.10.0, but you'll have h5py 3.1.0 which is incompatible.
ERROR: tensorflow 2.4.1 has requirement numpy~=1.19.2, but you'll have numpy 1.20.2 which is incompatible.
ERROR: google-colab 1.0.0 has requirement requests~=2.23.0, but you'll have requests 2.25.1 which is incompatible.
ERROR: datascience 0.10.6 has requirement folium==0.2.1, but you'll have folium 0.8.3 which is incompatible.
ERROR: albumentations 0.1.12 has requirement imgaug<0.2.7,>=0.2.5, but you'll have imgaug 0.2.9 which is incompatible.
Luego, importaremos algunas bibliotecas generales, junto con algunas utilidades de JAX.
También configuraremos algunos alias útiles de TFP. Las nuevas abstracciones se proporcionan actualmente en tfp.experimental.distribute
y tfp.experimental.mcmc
.
Para conectar el bloc de notas a una TPU, utilizamos el siguiente ayudante de JAX. Para confirmar que estamos conectados imprimimos el número de dispositivos, que debería ser ocho.
Una introducción rápida a jax.pmap
Tras conectarnos a una TPU, tenemos acceso a ocho dispositivos. Sin embargo, cuando ejecutamos código JAX en modo eager, JAX por defecto ejecuta cálculos en solo uno.
La forma más sencilla de ejecutar un cálculo en muchos dispositivos es asignar una función, haciendo que cada dispositivo ejecute un índice de la asignación. JAX proporciona la transformación jax.pmap
("mapa paralelo") que convierte una función en una que asigna la función a varios dispositivos.
En el siguiente ejemplo, creamos un arreglo de tamaño 8 (para que coincida con la cantidad de dispositivos disponibles) y asignamos una función que suma 5 en ella.
Tenga en cuenta que recibimos un tipo ShardedDeviceArray
, lo que indica que el arreglo de salida está dividido físicamente entre dispositivos.
jax.pmap
actúa semánticamente como un mapa, pero tiene algunas opciones importantes que modifican su comportamiento. De forma predeterminada, pmap
asume que todas las entradas a la función se están asignando, pero podemos modificar este comportamiento con ayuda del argumento in_axes
.
De manera análoga, el argumento out_axes
de pmap
determina si se devuelven o no los valores en cada dispositivo. Establecer out_axes
en None
devuelve automáticamente el valor en el primer dispositivo y solo debe usarse si estamos seguros de que los valores son los mismos en todos los dispositivos.
¿Qué sucede cuando lo que nos gustaría hacer no se puede expresar fácilmente como una función pura asignada? Por ejemplo, ¿qué pasaría si quisiéramos hacer una suma en el eje sobre el que estamos asignando? JAX ofrece "colectivos", funciones que se comunican entre dispositivos, para que podamos escribir programas distribuidos más interesantes y complejos. Para comprender cómo funcionan exactamente, presentaremos SPMD.
¿Qué es SPMD?
Un programa y múltiples datos (SPMD) es un modelo de programación concurrente en el que un único programa (es decir, el mismo código) se ejecuta simultáneamente en todos los dispositivos, pero las entradas a cada uno de los programas en ejecución pueden diferir.
Si nuestro programa es una función simple de sus entradas (es decir, algo así como x + 5
), ejecutar un programa en SPMD es simplemente asignarlo sobre diferentes datos, como hicimos anteriormente con jax.pmap
. Sin embargo, podemos hacer más que simplemente "asignar" una función. JAX ofrece "colectivos", que son funciones que se comunican entre dispositivos.
Por ejemplo, tal vez nos gustaría tomar la suma de una cantidad en todos nuestros dispositivos. Antes de hacer eso, debemos asignar un nombre al eje que estamos mapeando en pmap
. Luego usamos la función lax.psum
("suma paralela") para realizar una suma entre dispositivos, asegurándonos de identificar el eje con nombre que estamos sumando.
El colectivo psum
agrega el valor de x
en cada dispositivo y sincroniza su valor en todo el mapa, es out
es 28.
en cada dispositivo. Ya no estamos ejecutando un simple "mapa", sino que estamos ejecutando un programa SPMD donde el cálculo de cada dispositivo ahora puede interactuar con el mismo cálculo en otros dispositivos, aunque de forma limitada mediante el uso de colectivos. En este escenario, podemos usar out_axes = None
, porque psum
sincronizará el valor.
SPMD nos permite escribir un programa que se ejecuta en cada dispositivo en cualquier configuración de TPU simultáneamente. ¡El mismo código que se usa para ejecutar el aprendizaje automático en 8 núcleos de TPU se puede usar en un módulo de TPU que puede tener de cientos a miles de núcleos! Para obtener un tutorial más detallado sobre jax.pmap
y SPMD, puede consultar el tutorial introductorio de JAX.
MCMC a escala
En este bloc de notas, nos centramos en el uso de los métodos de Monte Carlo basados en cadenas de Markov (MCMC) para la inferencia bayesiana. Hay muchas formas de usar muchos dispositivos para el MCMC, pero en este bloc de notas, nos centraremos en dos:
Ejecutar cadenas de Markov independientes en diferentes dispositivos. Este caso es bastante simple y es posible hacerlo con TFP básico.
Fragmentar un conjunto de datos entre dispositivos. Este caso es un poco más complejo y requiere maquinaria de TFP recientemente agregada.
Cadenas independientes
Digamos que nos gustaría hacer una inferencia bayesiana sobre un problema con MCMC y nos gustaría ejecutar varias cadenas en paralelo en varios dispositivos (por ejemplo, 2 en cada dispositivo). Se trata de un programa que podemos "mapear" sin problemas entre dispositivos, es decir, uno que no necesita colectivos. Para asegurarse de que cada programa ejecute una cadena de Markov diferente (en lugar de ejecutar la misma), pasamos un valor diferente para la semilla aleatoria a cada dispositivo.
Probémoslo con un problema de juguete de muestreo a partir de una distribución gaussiana bidimensional. Podemos usar la funcionalidad MCMC existente de TFP lista para usar. En general, intentamos poner la mayor parte de la lógica dentro de nuestra función asignada para distinguir más explícitamente entre lo que se ejecuta en todos los dispositivos y solo en el primero.
Por sí sola, la función run
toma una semilla aleatoria sin estado (para ver cómo funciona la aleatoriedad sin estado, puede leer el bloc de notas TFP en AX o ver el tutorial introductorio de JAX). Asignar run
en diferentes semillas dará como resultado la ejecución de varias cadenas de Markov independientes.
Observe que ahora tenemos un eje adicional correspondiente a cada dispositivo. Podemos reorganizar las dimensiones y aplanarlas para obtener un eje para las 16 cadenas.
Cuando se ejecutan cadenas independientes en muchos dispositivos, es tan fácil como ejecutar pmap
en una función que usa tfp.mcmc
, lo que nos garantiza que pasamos diferentes valores para la semilla aleatoria a cada dispositivo.
Fragmentación de datos
Cuando aplicamos el método MCMC, la distribución objetivo es a menudo una distribución posterior obtenida mediante el condicionamiento de un conjunto de datos, y el cálculo de una densidad logarítmica no normalizada implica la suma de probabilidades para cada dato observado.
Con conjuntos de datos muy grandes, puede resultar prohibitivamente costoso incluso ejecutar una cadena en un solo dispositivo. Sin embargo, cuando tenemos acceso a varios dispositivos, podemos dividir el conjunto de datos entre los dispositivos para aprovechar mejor la computación que tenemos disponible.
Si queremos aplicar el método MCMC con un conjunto de datos fragmentados, debemos asegurarnos de que la densidad logarítmica no normalizada que calculamos en cada dispositivo represente el total, es decir, la densidad de todos los datos; de lo contrario, cada dispositivo aplicará MCMC con su propia distribución objetivo incorrecta. Con este fin, TFP ahora cuenta con nuevas herramientas (es decir, tfp.experimental.distribute
y tfp.experimental.mcmc
) que permiten calcular probabilidades logarítmicas "fragmentadas" y aplicar MCMC con ellas.
Distribuciones fragmentadas
La abstracción central que TFP ahora proporciona para calcular las probabilidades logarítmicas fragmentadas es la metadistribución Sharded
, que toma una distribución como entrada y devuelve una nueva distribución que tiene propiedades específicas cuando se ejecuta en un contexto SPMD. Sharded
vive en tfp.experimental.distribute
.
Intuitivamente, una distribución Sharded
corresponde a un conjunto de variables aleatorias que se han "dividido" entre dispositivos. En cada dispositivo, producirán diferentes muestras y pueden tener individualmente diferentes densidades logarítmicas. Alternativamente, una distribución Sharded
corresponde a una "placa" en el lenguaje del modelo gráfico, donde el tamaño de la placa es la cantidad de dispositivos.
Muestreo de una distribución Sharded
Si tomamos muestras de una distribución Normal
en un programa que ejecuta pmap
con la misma semilla en cada dispositivo, obtendremos la misma muestra en cada dispositivo. Podemos pensar en la siguiente función como un muestreo de una única variable aleatoria que está sincronizada entre dispositivos.
Si envolvemos tfd.Normal(0., 1.)
con tfed.Sharded
, lógicamente ahora tenemos ocho variables aleatorias diferentes (una en cada dispositivo) y, por lo tanto, produciremos una muestra diferente para cada una, a pesar de pasar la misma semilla.
Una representación equivalente de esta distribución en un solo dispositivo son solo 8 muestras normales independientes. Aunque el valor de la muestra será diferente (tfed.Sharded
genera números pseudoaleatorios de manera ligeramente diferente), ambos representan la misma distribución.
Cómo tomar la densidad logarítmica de una distribución Sharded
Veamos qué sucede cuando calculamos la densidad logarítmica de una muestra de una distribución regular en un contexto SPMD.
Cada muestra es la misma en cada dispositivo, por lo que también calculamos la misma densidad en cada dispositivo. Intuitivamente, aquí solo tenemos una distribución sobre una única variable distribuida normalmente.
Con una distribución Sharded
, tenemos una distribución de 8 variables aleatorias, por lo que cuando calculamos el log_prob
de una muestra, sumamos, entre dispositivos, cada una de las densidades logarítmicas individuales. (Es posible que observe que este valor log_prob total es mayor que el log_prob singleton calculado anteriormente).
La distribución equivalente, "no fragmentada", produce la misma densidad logarítmica.
Una distribución Sharded
produce valores diferentes de sample
en cada dispositivo, pero obtiene el mismo valor para log_prob
en cada dispositivo. ¿Que está pasando aquí? Una distribución Sharded
ejecuta una psum
internamente para garantizar que los valores de log_prob
estén sincronizados en todos los dispositivos. ¿Por qué querríamos este comportamiento? Si ejecutamos la misma cadena MCMC en cada dispositivo, nos gustaría que target_log_prob
sea el mismo en todos los dispositivos, incluso si algunas variables aleatorias en el cálculo están divididas entre dispositivos.
Además, una distribución Sharded
garantiza que los gradientes entre dispositivos sean correctos, para garantizar que algoritmos como el HMC, que toman gradientes de la función de densidad logarítmica como parte de la función de transición, produzcan muestras adecuadas.
JointDistribution
fragmentadas
Podemos crear modelos con múltiples variables aleatorias Sharded
si usamos JointDistribution
(JD). Desafortunadamente, las distribuciones Sharded
no se pueden usar de forma segura con tfd.JointDistribution
básicas, pero tfp.experimental.distribute
exporta JD "parcheadas" que se comportarán como distribuciones Sharded
.
Estas JD fragmentadas pueden tener como componentes distribuciones TFP Sharded
y estándar. Para las distribuciones no fragmentadas, obtenemos la misma muestra en cada dispositivo y para las distribuciones fragmentadas, obtenemos muestras diferentes. La log_prob
de cada dispositivo también está sincronizada.
MCMC con distribuciones Sharded
¿Cómo pensamos en las distribuciones Sharded
en el contexto de MCMC? Si tenemos un modelo generativo que se puede expresar como JointDistribution
, podemos elegir algún eje de ese modelo para "fragmentar". Normalmente, una variable aleatoria en el modelo corresponderá a los datos observados y, si tenemos un conjunto de datos grande que nos gustaría fragmentar entre dispositivos, queremos que las variables asociadas a los puntos de datos también se fragmenten. También podemos tener variables aleatorias "locales" que se correspondan con las observaciones que estamos fragmentando, por lo que tendremos que fragmentar adicionalmente esas variables aleatorias.
En esta sección, repasaremos ejemplos del uso de distribuciones Sharded
con MCMC de TFP. Comenzaremos con un ejemplo de regresión logística bayesiana más simple y concluiremos con un ejemplo de factorización matricial, con el objetivo de demostrar algunos casos de uso para la biblioteca distribute
.
Ejemplo: regresión logística bayesiana para MNIST
Nos gustaría realizar una regresión logística bayesiana en un conjunto de datos grande; el modelo tiene un previo sobre los pesos de regresión, y una probabilidad que se suma a todos los datos para obtener la densidad logarítmica conjunta total. Si fragmentamos nuestros datos, fragmentaremos las variables aleatorias observadas y en nuestro modelo.
Usamos el siguiente modelo de regresión logística bayesiana para la clasificación MNIST:
Carguemos MNIST a partir de conjuntos de datos de TensorFlow.
Downloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /root/tensorflow_datasets/mnist/3.0.1...
Dataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.
Tenemos 60 000 imágenes de entrenamiento, pero aprovechemos nuestros 8 núcleos disponibles y dividámoslos en 8 formas. Usaremos la práctica función de utilidad shard
.
Antes de continuar, analicemos rápidamente la precisión de las TPU y su impacto en HMC. Las TPU ejecutan multiplicaciones matriciales mediante el uso de una precisión bfloat16
baja para mayor velocidad. Las multiplicaciones matriciales bfloat16
suelen ser suficientes para muchas aplicaciones de aprendizaje profundo, pero cuando se usan con el HMC, hemos descubierto empíricamente que la menor precisión puede generar trayectorias divergentes y provocar rechazos. Podemos usar multiplicaciones matriciales de mayor precisión, pero a cambio de algunos cálculos adicionales.
Para aumentar nuestra precisión matmul, podemos usar el decorador jax.default_matmul_precision
con precisión "tensorfloat32"
(para una precisión aún mayor, podríamos usar precisión "float32"
).
Ahora definamos nuestra función run
, que aceptará una semilla aleatoria (que será la misma en cada dispositivo) y un fragmento de MNIST. La función implementará el modelo antes mencionado y luego usaremos la funcionalidad MCMC básica de TFP para ejecutar una sola cadena. Nos aseguraremos de decorar run
con el decorador jax.default_matmul_precision
para asegurarnos de que la multiplicación matricial se ejecute con mayor precisión, aunque, en el ejemplo particular a continuación, también podríamos usar jnp.dot(images, w, precision=lax.Precision.HIGH)
.
jax.pmap
incluye una compilación JIT pero la función compilada se almacena en caché después de la primera llamada. Llamaremos a run
e ignoraremos el resultado para almacenar en caché la compilación.
Ahora llamaremos a run
nuevamente para ver cuánto tiempo lleva la ejecución real.
Estamos ejecutando 200 000 pasos de salto, cada uno de los cuales calcula un gradiente en todo el conjunto de datos. Dividir el cálculo en 8 núcleos nos permite calcular el equivalente a 200 000 épocas de entrenamiento en aproximadamente 95 segundos, ¡aproximadamente 2100 épocas por segundo!
Tracemos la densidad logarítmica de cada muestra y la precisión de cada muestra:
Si agrupamos las muestras, podemos calcular el promedio del modelo bayesiano para mejorar nuestro rendimiento.
¡Un promedio de modelo bayesiano aumenta nuestra precisión en casi un 1 %!
Ejemplo: sistema de recomendación MovieLens
Ahora intentemos hacer inferencias con el conjunto de datos de recomendaciones MovieLens, que es una colección de usuarios y sus calificaciones de varias películas. Específicamente, podemos representar MovieLens como una matriz de visualización donde es el número de usuarios y es el número de películas; esperamos ParseError: KaTeX parse error: Expected 'EOF', got '&' at position 3: N &̲gt; M. Las entradas de son un booleano que indican si el usuario vio o no la película . Tenga en cuenta que MovieLens proporciona calificaciones de los usuarios, pero las ignoramos para simplificar el problema.
Primero, cargaremos el conjunto de datos. Usaremos la versión con 1 millón de calificaciones.
Downloading and preparing dataset movielens/1m-ratings/0.1.0 (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/movielens/1m-ratings/0.1.0...
Shuffling and writing examples to /root/tensorflow_datasets/movielens/1m-ratings/0.1.0.incompleteYKA3TG/movielens-train.tfrecord
Dataset movielens downloaded and prepared to /root/tensorflow_datasets/movielens/1m-ratings/0.1.0. Subsequent calls will reuse this data.
Haremos un preprocesamiento del conjunto de datos para obtener la matriz de vigilancia .
Podemos definir un modelo generativo para , utilizando un modelo de factorización matricial probabilística simple. Se asume que se trata de una matriz de usuario latente y una matriz de película latente , que cuando se multiplican producen los logits de un Bernoulli para la matriz de vigilancia . También incluiremos vectores de sesgo para usuarios y películas, y .
Esta es una matriz bastante grande; 6040 usuarios y 3706 películas conducen a una matriz con más de 22 millones de entradas. ¿Cómo abordamos la fragmentación de este modelo? Bueno, si asumimos que ParseError: KaTeX parse error: Expected 'EOF', got '&' at position 3: N &̲gt; M (es decir, hay más usuarios que películas), entonces tendría sentido fragmentar la matriz de visualización a través del eje del usuario, de modo que cada dispositivo tenga un fragmento de matriz de visualización correspondiente a un subconjunto de usuarios. Sin embargo, a diferencia del ejemplo anterior, también tendremos que fragmentar la matriz , ya que tiene una incrustación para cada usuario, por lo que cada dispositivo será responsable de un fragmento de y un fragmento de . Por otro lado, se desfragmentará y se sincronizará entre dispositivos.
Antes de escribir nuestra run
, analicemos rápidamente los desafíos adicionales que implica fragmentar la variable aleatoria local . Al ejecutar el HMC, el núcleo básico tfp.mcmc.HamiltonianMonteCarlo
tomará muestras de los momentos para cada elemento del estado de la cadena. Anteriormente, solo las variables aleatorias sin fragmentar formaban parte de ese estado, y los momentos eran los mismos en cada dispositivo. Ahora que tenemos un fragmentado, necesitamos muestrear diferentes momentos en cada dispositivo para , mientras muestreamos los mismos momentos para . Para lograr esto, podemos usar tfp.experimental.mcmc.PreconditionedHamiltonianMonteCarlo
con una distribución de momento Sharded
. A medida que continuamos haciendo que el cálculo paralelo sea de primera clase, podemos simplificarlo, por ejemplo, llevando un indicador de fragmentación al núcleo del HMC.
Lo ejecutaremos nuevamente una vez para almacenar en caché la run
compilada.
Ahora lo ejecutaremos nuevamente sin la sobrecarga de compilación.
Parece que completamos alrededor de 150 000 pasos de salto en aproximadamente 3 minutos, ¡es decir, alrededor de 83 pasos de salto por segundo! Tracemos la relación de aceptación y la densidad logarítmica de nuestras muestras.
Ahora que tenemos algunas muestras de nuestra cadena de Markov, usémoslas para hacer algunas predicciones. Primero, extraigamos cada uno de los componentes. Recuerde que user_embeddings
y user_bias
se dividen entre dispositivos, por lo que debemos concatenar nuestro ShardedArray
para obtenerlos todos. Por otro lado, movie_embeddings
y movie_bias
son iguales en todos los dispositivos, por lo que podemos elegir el valor del primer fragmento. Usaremos numpy
normal para copiar los valores de las TPU a la CPU.
Tratemos de construir un sistema de recomendación simple que use la incertidumbre capturada en estas muestras. Primero, escribamos una función que clasifique las películas según la probabilidad de visualización.
Ahora podemos escribir una función que recorra todas las muestras y, para cada una, elija la película mejor clasificada que el usuario aún no haya visto. Luego podemos ver los recuentos de todas las películas recomendadas en las muestras.
Comparemos el usuario que ha visto más películas con el que ha visto menos.
Esperamos que nuestro sistema tenga más certeza sobre user_most
que user_least
, dado que tenemos más información sobre qué tipo de películas es más probable que vea user_most
.
Vemos que hay más variación en nuestras recomendaciones para user_least
, lo que refleja nuestra incertidumbre adicional en sus preferencias de visualización.
También podemos ver los géneros de las películas recomendadas.
user_most
ha visto muchas películas y se le han recomendado géneros más especializados, como misterio y crimen, mientras que user_least
no ha visto muchas películas y se le han recomendado películas más convencionales, que se inclinan por la comedia y la acción.