Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
tensorflow
GitHub Repository: tensorflow/docs-l10n
Path: blob/master/site/es-419/probability/examples/Distributed_Inference_with_JAX.ipynb
25118 views
Kernel: Python 3

Licensed under the Apache License, Version 2.0 (the "License");

#@title Licensed under the Apache License, Version 2.0 (the "License"); { display-mode: "form" } # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License.

TensorFlow Probability (TFP) en JAX ahora tiene herramientas para computación numérica distribuida. Para escalar a una gran cantidad de aceleradores, las herramientas se basan en la escritura de código utilizando el paradigma de "un programa, múltiples datos", o SPMD para abreviar.

En este bloc de notas, veremos cómo "pensar en SPMD" e introduciremos las nuevas abstracciones de TFP para escalar a configuraciones como módulos de TPU o grupos de GPU. Si ejecuta este código por sí mismo, asegúrese de seleccionar un tiempo de ejecución de TPU.

Primero, instalaremos las últimas versiones de TFP, JAX y TF.

#@title Installs !pip install jaxlib --upgrade -q 2>&1 1> /dev/null !pip install tfp-nightly[jax] --upgrade -q 2>&1 1> /dev/null !pip install tf-nightly-cpu -q -I 2>&1 1> /dev/null !pip install jax -I -q --upgrade 2>&1 1>/dev/null
ERROR: tensorflow 2.4.1 has requirement gast==0.3.3, but you'll have gast 0.4.0 which is incompatible. ERROR: tensorflow 2.4.1 has requirement grpcio~=1.32.0, but you'll have grpcio 1.34.1 which is incompatible. ERROR: tensorflow 2.4.1 has requirement h5py~=2.10.0, but you'll have h5py 3.1.0 which is incompatible. ERROR: google-colab 1.0.0 has requirement requests~=2.23.0, but you'll have requests 2.25.1 which is incompatible. ERROR: datascience 0.10.6 has requirement folium==0.2.1, but you'll have folium 0.8.3 which is incompatible. ERROR: albumentations 0.1.12 has requirement imgaug<0.2.7,>=0.2.5, but you'll have imgaug 0.2.9 which is incompatible. ERROR: tf-nightly-cpu 2.6.0.dev20210401 has requirement numpy~=1.19.2, but you'll have numpy 1.20.2 which is incompatible. ERROR: tensorflow 2.4.1 has requirement gast==0.3.3, but you'll have gast 0.4.0 which is incompatible. ERROR: tensorflow 2.4.1 has requirement grpcio~=1.32.0, but you'll have grpcio 1.34.1 which is incompatible. ERROR: tensorflow 2.4.1 has requirement h5py~=2.10.0, but you'll have h5py 3.1.0 which is incompatible. ERROR: tensorflow 2.4.1 has requirement numpy~=1.19.2, but you'll have numpy 1.20.2 which is incompatible. ERROR: google-colab 1.0.0 has requirement requests~=2.23.0, but you'll have requests 2.25.1 which is incompatible. ERROR: datascience 0.10.6 has requirement folium==0.2.1, but you'll have folium 0.8.3 which is incompatible. ERROR: albumentations 0.1.12 has requirement imgaug<0.2.7,>=0.2.5, but you'll have imgaug 0.2.9 which is incompatible.

Luego, importaremos algunas bibliotecas generales, junto con algunas utilidades de JAX.

#@title Setup and Imports import functools import collections import contextlib import jax import jax.numpy as jnp from jax import lax from jax import random import jax.numpy as jnp import numpy as np import matplotlib.pyplot as plt import seaborn as sns import pandas as pd import tensorflow_datasets as tfds from tensorflow_probability.substrates import jax as tfp sns.set(style='white')
INFO:tensorflow:Enabling eager execution INFO:tensorflow:Enabling v2 tensorshape INFO:tensorflow:Enabling resource variables INFO:tensorflow:Enabling tensor equality INFO:tensorflow:Enabling control flow v2

También configuraremos algunos alias útiles de TFP. Las nuevas abstracciones se proporcionan actualmente en tfp.experimental.distribute y tfp.experimental.mcmc.

tfd = tfp.distributions tfb = tfp.bijectors tfm = tfp.mcmc tfed = tfp.experimental.distribute tfde = tfp.experimental.distributions tfem = tfp.experimental.mcmc Root = tfed.JointDistributionCoroutine.Root

Para conectar el bloc de notas a una TPU, utilizamos el siguiente ayudante de JAX. Para confirmar que estamos conectados imprimimos el número de dispositivos, que debería ser ocho.

from jax.tools import colab_tpu colab_tpu.setup_tpu() print(f'Found {jax.device_count()} devices')
Found 8 devices

Una introducción rápida a jax.pmap

Tras conectarnos a una TPU, tenemos acceso a ocho dispositivos. Sin embargo, cuando ejecutamos código JAX en modo eager, JAX por defecto ejecuta cálculos en solo uno.

La forma más sencilla de ejecutar un cálculo en muchos dispositivos es asignar una función, haciendo que cada dispositivo ejecute un índice de la asignación. JAX proporciona la transformación jax.pmap ("mapa paralelo") que convierte una función en una que asigna la función a varios dispositivos.

En el siguiente ejemplo, creamos un arreglo de tamaño 8 (para que coincida con la cantidad de dispositivos disponibles) y asignamos una función que suma 5 en ella.

xs = jnp.arange(8.) out = jax.pmap(lambda x: x + 5.)(xs) print(type(out), out)
<class 'jax.interpreters.pxla.ShardedDeviceArray'> [ 5. 6. 7. 8. 9. 10. 11. 12.]

Tenga en cuenta que recibimos un tipo ShardedDeviceArray, lo que indica que el arreglo de salida está dividido físicamente entre dispositivos.

jax.pmap actúa semánticamente como un mapa, pero tiene algunas opciones importantes que modifican su comportamiento. De forma predeterminada, pmap asume que todas las entradas a la función se están asignando, pero podemos modificar este comportamiento con ayuda del argumento in_axes.

xs = jnp.arange(8.) y = 5. # Map over the 0-axis of `xs` and don't map over `y` out = jax.pmap(lambda x, y: x + y, in_axes=(0, None))(xs, y) print(out)
[ 5. 6. 7. 8. 9. 10. 11. 12.]

De manera análoga, el argumento out_axes de pmap determina si se devuelven o no los valores en cada dispositivo. Establecer out_axes en None devuelve automáticamente el valor en el primer dispositivo y solo debe usarse si estamos seguros de que los valores son los mismos en todos los dispositivos.

xs = jnp.ones(8) # Value is the same on each device out = jax.pmap(lambda x: x + 1, out_axes=None)(xs) print(out)
2.0

¿Qué sucede cuando lo que nos gustaría hacer no se puede expresar fácilmente como una función pura asignada? Por ejemplo, ¿qué pasaría si quisiéramos hacer una suma en el eje sobre el que estamos asignando? JAX ofrece "colectivos", funciones que se comunican entre dispositivos, para que podamos escribir programas distribuidos más interesantes y complejos. Para comprender cómo funcionan exactamente, presentaremos SPMD.

¿Qué es SPMD?

Un programa y múltiples datos (SPMD) es un modelo de programación concurrente en el que un único programa (es decir, el mismo código) se ejecuta simultáneamente en todos los dispositivos, pero las entradas a cada uno de los programas en ejecución pueden diferir.

Si nuestro programa es una función simple de sus entradas (es decir, algo así como x + 5), ejecutar un programa en SPMD es simplemente asignarlo sobre diferentes datos, como hicimos anteriormente con jax.pmap. Sin embargo, podemos hacer más que simplemente "asignar" una función. JAX ofrece "colectivos", que son funciones que se comunican entre dispositivos.

Por ejemplo, tal vez nos gustaría tomar la suma de una cantidad en todos nuestros dispositivos. Antes de hacer eso, debemos asignar un nombre al eje que estamos mapeando en pmap. Luego usamos la función lax.psum ("suma paralela") para realizar una suma entre dispositivos, asegurándonos de identificar el eje con nombre que estamos sumando.

def f(x): out = lax.psum(x, axis_name='i') return out xs = jnp.arange(8.) # Length of array matches number of devices jax.pmap(f, axis_name='i')(xs)
ShardedDeviceArray([28., 28., 28., 28., 28., 28., 28., 28.], dtype=float32)

El colectivo psum agrega el valor de x en cada dispositivo y sincroniza su valor en todo el mapa, es out es 28. en cada dispositivo. Ya no estamos ejecutando un simple "mapa", sino que estamos ejecutando un programa SPMD donde el cálculo de cada dispositivo ahora puede interactuar con el mismo cálculo en otros dispositivos, aunque de forma limitada mediante el uso de colectivos. En este escenario, podemos usar out_axes = None, porque psum sincronizará el valor.

def f(x): out = lax.psum(x, axis_name='i') return out jax.pmap(f, axis_name='i', out_axes=None)(jnp.arange(8.))
ShardedDeviceArray(28., dtype=float32)

SPMD nos permite escribir un programa que se ejecuta en cada dispositivo en cualquier configuración de TPU simultáneamente. ¡El mismo código que se usa para ejecutar el aprendizaje automático en 8 núcleos de TPU se puede usar en un módulo de TPU que puede tener de cientos a miles de núcleos! Para obtener un tutorial más detallado sobre jax.pmap y SPMD, puede consultar el tutorial introductorio de JAX.

MCMC a escala

En este bloc de notas, nos centramos en el uso de los métodos de Monte Carlo basados en cadenas de Markov (MCMC) para la inferencia bayesiana. Hay muchas formas de usar muchos dispositivos para el MCMC, pero en este bloc de notas, nos centraremos en dos:

  1. Ejecutar cadenas de Markov independientes en diferentes dispositivos. Este caso es bastante simple y es posible hacerlo con TFP básico.

  2. Fragmentar un conjunto de datos entre dispositivos. Este caso es un poco más complejo y requiere maquinaria de TFP recientemente agregada.

Cadenas independientes

Digamos que nos gustaría hacer una inferencia bayesiana sobre un problema con MCMC y nos gustaría ejecutar varias cadenas en paralelo en varios dispositivos (por ejemplo, 2 en cada dispositivo). Se trata de un programa que podemos "mapear" sin problemas entre dispositivos, es decir, uno que no necesita colectivos. Para asegurarse de que cada programa ejecute una cadena de Markov diferente (en lugar de ejecutar la misma), pasamos un valor diferente para la semilla aleatoria a cada dispositivo.

Probémoslo con un problema de juguete de muestreo a partir de una distribución gaussiana bidimensional. Podemos usar la funcionalidad MCMC existente de TFP lista para usar. En general, intentamos poner la mayor parte de la lógica dentro de nuestra función asignada para distinguir más explícitamente entre lo que se ejecuta en todos los dispositivos y solo en el primero.

def run(seed): target_log_prob = tfd.Sample(tfd.Normal(0., 1.), 2).log_prob initial_state = jnp.zeros([2, 2]) # 2 chains kernel = tfm.HamiltonianMonteCarlo(target_log_prob, 1e-1, 10) def trace_fn(state, pkr): return target_log_prob(state) states, log_prob = tfm.sample_chain( num_results=1000, num_burnin_steps=1000, kernel=kernel, current_state=initial_state, trace_fn=trace_fn, seed=seed ) return states, log_prob

Por sí sola, la función run toma una semilla aleatoria sin estado (para ver cómo funciona la aleatoriedad sin estado, puede leer el bloc de notas TFP en AX o ver el tutorial introductorio de JAX). Asignar run en diferentes semillas dará como resultado la ejecución de varias cadenas de Markov independientes.

states, log_probs = jax.pmap(run)(random.split(random.PRNGKey(0), 8)) print(states.shape, log_probs.shape) # states is (8 devices, 1000 samples, 2 chains, 2 dimensions) # log_prob is (8 devices, 1000 samples, 2 chains)
(8, 1000, 2, 2) (8, 1000, 2)

Observe que ahora tenemos un eje adicional correspondiente a cada dispositivo. Podemos reorganizar las dimensiones y aplanarlas para obtener un eje para las 16 cadenas.

states = states.transpose([0, 2, 1, 3]).reshape([-1, 1000, 2]) log_probs = log_probs.transpose([0, 2, 1]).reshape([-1, 1000])
fig, ax = plt.subplots(1, 2, figsize=(10, 5)) ax[0].plot(log_probs.T, alpha=0.4) ax[1].scatter(*states.reshape([-1, 2]).T, alpha=0.1) plt.show()
Image in a Jupyter notebook

Cuando se ejecutan cadenas independientes en muchos dispositivos, es tan fácil como ejecutar pmap en una función que usa tfp.mcmc, lo que nos garantiza que pasamos diferentes valores para la semilla aleatoria a cada dispositivo.

Fragmentación de datos

Cuando aplicamos el método MCMC, la distribución objetivo es a menudo una distribución posterior obtenida mediante el condicionamiento de un conjunto de datos, y el cálculo de una densidad logarítmica no normalizada implica la suma de probabilidades para cada dato observado.

Con conjuntos de datos muy grandes, puede resultar prohibitivamente costoso incluso ejecutar una cadena en un solo dispositivo. Sin embargo, cuando tenemos acceso a varios dispositivos, podemos dividir el conjunto de datos entre los dispositivos para aprovechar mejor la computación que tenemos disponible.

Si queremos aplicar el método MCMC con un conjunto de datos fragmentados, debemos asegurarnos de que la densidad logarítmica no normalizada que calculamos en cada dispositivo represente el total, es decir, la densidad de todos los datos; de lo contrario, cada dispositivo aplicará MCMC con su propia distribución objetivo incorrecta. Con este fin, TFP ahora cuenta con nuevas herramientas (es decir, tfp.experimental.distribute y tfp.experimental.mcmc) que permiten calcular probabilidades logarítmicas "fragmentadas" y aplicar MCMC con ellas.

Distribuciones fragmentadas

La abstracción central que TFP ahora proporciona para calcular las probabilidades logarítmicas fragmentadas es la metadistribución Sharded, que toma una distribución como entrada y devuelve una nueva distribución que tiene propiedades específicas cuando se ejecuta en un contexto SPMD. Sharded vive en tfp.experimental.distribute.

Intuitivamente, una distribución Sharded corresponde a un conjunto de variables aleatorias que se han "dividido" entre dispositivos. En cada dispositivo, producirán diferentes muestras y pueden tener individualmente diferentes densidades logarítmicas. Alternativamente, una distribución Sharded corresponde a una "placa" en el lenguaje del modelo gráfico, donde el tamaño de la placa es la cantidad de dispositivos.

Muestreo de una distribución Sharded

Si tomamos muestras de una distribución Normal en un programa que ejecuta pmap con la misma semilla en cada dispositivo, obtendremos la misma muestra en cada dispositivo. Podemos pensar en la siguiente función como un muestreo de una única variable aleatoria que está sincronizada entre dispositivos.

# `pmap` expects at least one value to be mapped over, so we provide a dummy one def f(seed, _): return tfd.Normal(0., 1.).sample(seed=seed) jax.pmap(f, in_axes=(None, 0))(random.PRNGKey(0), jnp.arange(8.))
ShardedDeviceArray([-0.20584236, -0.20584236, -0.20584236, -0.20584236, -0.20584236, -0.20584236, -0.20584236, -0.20584236], dtype=float32)

Si envolvemos tfd.Normal(0., 1.) con tfed.Sharded, lógicamente ahora tenemos ocho variables aleatorias diferentes (una en cada dispositivo) y, por lo tanto, produciremos una muestra diferente para cada una, a pesar de pasar la misma semilla.

def f(seed, _): return tfed.Sharded(tfd.Normal(0., 1.), shard_axis_name='i').sample(seed=seed) jax.pmap(f, in_axes=(None, 0), axis_name='i')(random.PRNGKey(0), jnp.arange(8.))
ShardedDeviceArray([ 1.2152631 , 0.7818249 , 0.32549605, 0.6828047 , 1.3973192 , -0.57830244, 0.37862757, 2.7706041 ], dtype=float32)

Una representación equivalente de esta distribución en un solo dispositivo son solo 8 muestras normales independientes. Aunque el valor de la muestra será diferente (tfed.Sharded genera números pseudoaleatorios de manera ligeramente diferente), ambos representan la misma distribución.

dist = tfd.Sample(tfd.Normal(0., 1.), jax.device_count()) dist.sample(seed=random.PRNGKey(0))
DeviceArray([ 0.08086783, -0.38624594, -0.3756545 , 1.668957 , -1.2758069 , 2.1192007 , -0.85821325, 1.1305912 ], dtype=float32)

Cómo tomar la densidad logarítmica de una distribución Sharded

Veamos qué sucede cuando calculamos la densidad logarítmica de una muestra de una distribución regular en un contexto SPMD.

def f(seed, _): dist = tfd.Normal(0., 1.) x = dist.sample(seed=seed) return x, dist.log_prob(x) jax.pmap(f, in_axes=(None, 0))(random.PRNGKey(0), jnp.arange(8.))
(ShardedDeviceArray([-0.20584236, -0.20584236, -0.20584236, -0.20584236, -0.20584236, -0.20584236, -0.20584236, -0.20584236], dtype=float32), ShardedDeviceArray([-0.94012403, -0.94012403, -0.94012403, -0.94012403, -0.94012403, -0.94012403, -0.94012403, -0.94012403], dtype=float32))

Cada muestra es la misma en cada dispositivo, por lo que también calculamos la misma densidad en cada dispositivo. Intuitivamente, aquí solo tenemos una distribución sobre una única variable distribuida normalmente.

Con una distribución Sharded, tenemos una distribución de 8 variables aleatorias, por lo que cuando calculamos el log_prob de una muestra, sumamos, entre dispositivos, cada una de las densidades logarítmicas individuales. (Es posible que observe que este valor log_prob total es mayor que el log_prob singleton calculado anteriormente).

def f(seed, _): dist = tfed.Sharded(tfd.Normal(0., 1.), shard_axis_name='i') x = dist.sample(seed=seed) return x, dist.log_prob(x) sample, log_prob = jax.pmap(f, in_axes=(None, 0), axis_name='i')( random.PRNGKey(0), jnp.arange(8.)) print('Sample:', sample) print('Log Prob:', log_prob)
Sample: [ 1.2152631 0.7818249 0.32549605 0.6828047 1.3973192 -0.57830244 0.37862757 2.7706041 ] Log Prob: [-13.7349205 -13.7349205 -13.7349205 -13.7349205 -13.7349205 -13.7349205 -13.7349205 -13.7349205]

La distribución equivalente, "no fragmentada", produce la misma densidad logarítmica.

dist = tfd.Sample(tfd.Normal(0., 1.), jax.device_count()) dist.log_prob(sample)
DeviceArray(-13.7349205, dtype=float32)

Una distribución Sharded produce valores diferentes de sample en cada dispositivo, pero obtiene el mismo valor para log_prob en cada dispositivo. ¿Que está pasando aquí? Una distribución Sharded ejecuta una psum internamente para garantizar que los valores de log_prob estén sincronizados en todos los dispositivos. ¿Por qué querríamos este comportamiento? Si ejecutamos la misma cadena MCMC en cada dispositivo, nos gustaría que target_log_prob sea el mismo en todos los dispositivos, incluso si algunas variables aleatorias en el cálculo están divididas entre dispositivos.

Además, una distribución Sharded garantiza que los gradientes entre dispositivos sean correctos, para garantizar que algoritmos como el HMC, que toman gradientes de la función de densidad logarítmica como parte de la función de transición, produzcan muestras adecuadas.

JointDistribution fragmentadas

Podemos crear modelos con múltiples variables aleatorias Sharded si usamos JointDistribution (JD). Desafortunadamente, las distribuciones Sharded no se pueden usar de forma segura con tfd.JointDistribution básicas, pero tfp.experimental.distribute exporta JD "parcheadas" que se comportarán como distribuciones Sharded.

def f(seed, _): dist = tfed.JointDistributionSequential([ tfd.Normal(0., 1.), tfed.Sharded(tfd.Normal(0., 1.), shard_axis_name='i'), ]) x = dist.sample(seed=seed) return x, dist.log_prob(x) jax.pmap(f, in_axes=(None, 0), axis_name='i')(random.PRNGKey(0), jnp.arange(8.))
([ShardedDeviceArray([1.6121525, 1.6121525, 1.6121525, 1.6121525, 1.6121525, 1.6121525, 1.6121525, 1.6121525], dtype=float32), ShardedDeviceArray([ 0.8690128 , -0.83167845, 1.2209264 , 0.88412696, 0.76478404, -0.66208494, -0.0129658 , 0.7391483 ], dtype=float32)], ShardedDeviceArray([-12.214451, -12.214451, -12.214451, -12.214451, -12.214451, -12.214451, -12.214451, -12.214451], dtype=float32))

Estas JD fragmentadas pueden tener como componentes distribuciones TFP Sharded y estándar. Para las distribuciones no fragmentadas, obtenemos la misma muestra en cada dispositivo y para las distribuciones fragmentadas, obtenemos muestras diferentes. La log_prob de cada dispositivo también está sincronizada.

MCMC con distribuciones Sharded

¿Cómo pensamos en las distribuciones Sharded en el contexto de MCMC? Si tenemos un modelo generativo que se puede expresar como JointDistribution, podemos elegir algún eje de ese modelo para "fragmentar". Normalmente, una variable aleatoria en el modelo corresponderá a los datos observados y, si tenemos un conjunto de datos grande que nos gustaría fragmentar entre dispositivos, queremos que las variables asociadas a los puntos de datos también se fragmenten. También podemos tener variables aleatorias "locales" que se correspondan con las observaciones que estamos fragmentando, por lo que tendremos que fragmentar adicionalmente esas variables aleatorias.

En esta sección, repasaremos ejemplos del uso de distribuciones Sharded con MCMC de TFP. Comenzaremos con un ejemplo de regresión logística bayesiana más simple y concluiremos con un ejemplo de factorización matricial, con el objetivo de demostrar algunos casos de uso para la biblioteca distribute.

Ejemplo: regresión logística bayesiana para MNIST

Nos gustaría realizar una regresión logística bayesiana en un conjunto de datos grande; el modelo tiene un p(θ)p(\theta) previo sobre los pesos de regresión, y una probabilidad p(yiθ,xi)p(y_i | \theta, x_i) que se suma a todos los datos {xi,yi}i=1N\{x_i, y_i\}_{i = 1}^N para obtener la densidad logarítmica conjunta total. Si fragmentamos nuestros datos, fragmentaremos las variables aleatorias observadas xix_i y yiy_i en nuestro modelo.

Usamos el siguiente modelo de regresión logística bayesiana para la clasificación MNIST: wN(0,1)bN(0,1)yiw,b,xiCategorical(wTxi+b) \begin{align*} w &\sim \mathcal{N}(0, 1) \\ b &\sim \mathcal{N}(0, 1) \\ y_i | w, b, x_i &\sim \textrm{Categorical}(w^T x_i + b) \end{align*}

Carguemos MNIST a partir de conjuntos de datos de TensorFlow.

mnist = tfds.as_numpy(tfds.load('mnist', batch_size=-1)) raw_train_images, train_labels = mnist['train']['image'], mnist['train']['label'] train_images = raw_train_images.reshape([raw_train_images.shape[0], -1]) / 255. raw_test_images, test_labels = mnist['test']['image'], mnist['test']['label'] test_images = raw_test_images.reshape([raw_test_images.shape[0], -1]) / 255.
Downloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /root/tensorflow_datasets/mnist/3.0.1...
WARNING:absl:Dataset mnist is hosted on GCS. It will automatically be downloaded to your local data directory. If you'd instead prefer to read directly from our public GCS bucket (recommended if you're running on GCP), you can instead pass `try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`.
HBox(children=(FloatProgress(value=0.0, description='Dl Completed...', max=4.0, style=ProgressStyle(descriptio…
Dataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.

Tenemos 60 000 imágenes de entrenamiento, pero aprovechemos nuestros 8 núcleos disponibles y dividámoslos en 8 formas. Usaremos la práctica función de utilidad shard.

def shard_value(x): x = x.reshape((jax.device_count(), -1, *x.shape[1:])) return jax.pmap(lambda x: x)(x) # pmap will physically place values on devices shard = functools.partial(jax.tree_map, shard_value)
sharded_train_images, sharded_train_labels = shard((train_images, train_labels)) print(sharded_train_images.shape, sharded_train_labels.shape)
(8, 7500, 784) (8, 7500)

Antes de continuar, analicemos rápidamente la precisión de las TPU y su impacto en HMC. Las TPU ejecutan multiplicaciones matriciales mediante el uso de una precisión bfloat16 baja para mayor velocidad. Las multiplicaciones matriciales bfloat16 suelen ser suficientes para muchas aplicaciones de aprendizaje profundo, pero cuando se usan con el HMC, hemos descubierto empíricamente que la menor precisión puede generar trayectorias divergentes y provocar rechazos. Podemos usar multiplicaciones matriciales de mayor precisión, pero a cambio de algunos cálculos adicionales.

Para aumentar nuestra precisión matmul, podemos usar el decorador jax.default_matmul_precision con precisión "tensorfloat32" (para una precisión aún mayor, podríamos usar precisión "float32").

Ahora definamos nuestra función run, que aceptará una semilla aleatoria (que será la misma en cada dispositivo) y un fragmento de MNIST. La función implementará el modelo antes mencionado y luego usaremos la funcionalidad MCMC básica de TFP para ejecutar una sola cadena. Nos aseguraremos de decorar run con el decorador jax.default_matmul_precision para asegurarnos de que la multiplicación matricial se ejecute con mayor precisión, aunque, en el ejemplo particular a continuación, también podríamos usar jnp.dot(images, w, precision=lax.Precision.HIGH).

# We can use `out_axes=None` in the `pmap` because the results will be the same # on every device. @functools.partial(jax.pmap, axis_name='data', in_axes=(None, 0), out_axes=None) @jax.default_matmul_precision('tensorfloat32') def run(seed, data): images, labels = data # a sharded dataset num_examples, dim = images.shape num_classes = 10 def model_fn(): w = yield Root(tfd.Sample(tfd.Normal(0., 1.), [dim, num_classes])) b = yield Root(tfd.Sample(tfd.Normal(0., 1.), [num_classes])) logits = jnp.dot(images, w) + b yield tfed.Sharded(tfd.Independent(tfd.Categorical(logits=logits), 1), shard_axis_name='data') model = tfed.JointDistributionCoroutine(model_fn) init_seed, sample_seed = random.split(seed) initial_state = model.sample(seed=init_seed)[:-1] # throw away `y` def target_log_prob(*state): return model.log_prob((*state, labels)) def accuracy(w, b): logits = images.dot(w) + b preds = logits.argmax(axis=-1) # We take the average accuracy across devices by using `lax.pmean` return lax.pmean((preds == labels).mean(), 'data') kernel = tfm.HamiltonianMonteCarlo(target_log_prob, 1e-2, 100) kernel = tfm.DualAveragingStepSizeAdaptation(kernel, 500) def trace_fn(state, pkr): return ( target_log_prob(*state), accuracy(*state), pkr.new_step_size) states, trace = tfm.sample_chain( num_results=1000, num_burnin_steps=1000, current_state=initial_state, kernel=kernel, trace_fn=trace_fn, seed=sample_seed ) return states, trace

jax.pmap incluye una compilación JIT pero la función compilada se almacena en caché después de la primera llamada. Llamaremos a run e ignoraremos el resultado para almacenar en caché la compilación.

%%time output = run(random.PRNGKey(0), (sharded_train_images, sharded_train_labels)) jax.tree_map(lambda x: x.block_until_ready(), output)
CPU times: user 24.5 s, sys: 48.2 s, total: 1min 12s Wall time: 1min 54s

Ahora llamaremos a run nuevamente para ver cuánto tiempo lleva la ejecución real.

%%time states, trace = run(random.PRNGKey(0), (sharded_train_images, sharded_train_labels)) jax.tree_map(lambda x: x.block_until_ready(), trace)
CPU times: user 13.1 s, sys: 45.2 s, total: 58.3 s Wall time: 1min 43s

Estamos ejecutando 200 000 pasos de salto, cada uno de los cuales calcula un gradiente en todo el conjunto de datos. Dividir el cálculo en 8 núcleos nos permite calcular el equivalente a 200 000 épocas de entrenamiento en aproximadamente 95 segundos, ¡aproximadamente 2100 épocas por segundo!

Tracemos la densidad logarítmica de cada muestra y la precisión de cada muestra:

fig, ax = plt.subplots(1, 3, figsize=(15, 5)) ax[0].plot(trace[0]) ax[0].set_title('Log Prob') ax[1].plot(trace[1]) ax[1].set_title('Accuracy') ax[2].plot(trace[2]) ax[2].set_title('Step Size') plt.show()
Image in a Jupyter notebook

Si agrupamos las muestras, podemos calcular el promedio del modelo bayesiano para mejorar nuestro rendimiento.

@functools.partial(jax.pmap, axis_name='data', in_axes=(0, None), out_axes=None) def bayesian_model_average(data, states): images, labels = data logits = jax.vmap(lambda w, b: images.dot(w) + b)(*states) probs = jax.nn.softmax(logits, axis=-1) bma_accuracy = (probs.mean(axis=0).argmax(axis=-1) == labels).mean() avg_accuracy = (probs.argmax(axis=-1) == labels).mean() return lax.pmean(bma_accuracy, axis_name='data'), lax.pmean(avg_accuracy, axis_name='data') sharded_test_images, sharded_test_labels = shard((test_images, test_labels)) bma_acc, avg_acc = bayesian_model_average((sharded_test_images, sharded_test_labels), states) print(f'Average Accuracy: {avg_acc}') print(f'BMA Accuracy: {bma_acc}') print(f'Accuracy Improvement: {bma_acc - avg_acc}')
Average Accuracy: 0.9188529253005981 BMA Accuracy: 0.9264000058174133 Accuracy Improvement: 0.0075470805168151855

¡Un promedio de modelo bayesiano aumenta nuestra precisión en casi un 1 %!

Ejemplo: sistema de recomendación MovieLens

Ahora intentemos hacer inferencias con el conjunto de datos de recomendaciones MovieLens, que es una colección de usuarios y sus calificaciones de varias películas. Específicamente, podemos representar MovieLens como una matriz de visualización N×MN \times M WW donde NN es el número de usuarios y MM es el número de películas; esperamos ParseError: KaTeX parse error: Expected 'EOF', got '&' at position 3: N &̲gt; M. Las entradas de WijW_{ij} son un booleano que indican si el usuario ii vio o no la película jj. Tenga en cuenta que MovieLens proporciona calificaciones de los usuarios, pero las ignoramos para simplificar el problema.

Primero, cargaremos el conjunto de datos. Usaremos la versión con 1 millón de calificaciones.

movielens = tfds.as_numpy(tfds.load('movielens/1m-ratings', batch_size=-1)) GENRES = ['Action', 'Adventure', 'Animation', 'Children', 'Comedy', 'Crime', 'Documentary', 'Drama', 'Fantasy', 'Film-Noir', 'Horror', 'IMAX', 'Musical', 'Mystery', 'Romance', 'Sci-Fi', 'Thriller', 'Unknown', 'War', 'Western', '(no genres listed)']
Downloading and preparing dataset movielens/1m-ratings/0.1.0 (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/movielens/1m-ratings/0.1.0...
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Completed...', max=1.0, style=Progre…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Size...', max=1.0, style=ProgressSty…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Extraction completed...', max=1.0, styl…
HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))
Shuffling and writing examples to /root/tensorflow_datasets/movielens/1m-ratings/0.1.0.incompleteYKA3TG/movielens-train.tfrecord
HBox(children=(FloatProgress(value=0.0, max=1000209.0), HTML(value='')))
Dataset movielens downloaded and prepared to /root/tensorflow_datasets/movielens/1m-ratings/0.1.0. Subsequent calls will reuse this data.

Haremos un preprocesamiento del conjunto de datos para obtener la matriz de vigilancia WW.

raw_movie_ids = movielens['train']['movie_id'] raw_user_ids = movielens['train']['user_id'] genres = movielens['train']['movie_genres'] movie_ids, movie_labels = pd.factorize(movielens['train']['movie_id']) user_ids, user_labels = pd.factorize(movielens['train']['user_id']) num_movies = movie_ids.max() + 1 num_users = user_ids.max() + 1 movie_titles = dict(zip(movielens['train']['movie_id'], movielens['train']['movie_title'])) movie_genres = dict(zip(movielens['train']['movie_id'], genres)) movie_id_to_title = [movie_titles[movie_labels[id]].decode('utf-8') for id in range(num_movies)] movie_id_to_genre = [GENRES[movie_genres[movie_labels[id]][0]] for id in range(num_movies)] watch_matrix = np.zeros((num_users, num_movies), bool) watch_matrix[user_ids, movie_ids] = True print(watch_matrix.shape)
(6040, 3706)

Podemos definir un modelo generativo para WW, utilizando un modelo de factorización matricial probabilística simple. Se asume que se trata de una matriz de usuario UU latente N×DN \times D y una matriz de película latente M×DM \times D VV, que cuando se multiplican producen los logits de un Bernoulli para la matriz de vigilancia WW. También incluiremos vectores de sesgo para usuarios y películas, uu y vv. UN(0,1)uN(0,1)VN(0,1)vN(0,1)WijBernoulli(σ((UVT)ij+ui+vj)) \begin{align*} U &\sim \mathcal{N}(0, 1) \quad u \sim \mathcal{N}(0, 1)\\ V &\sim \mathcal{N}(0, 1) \quad v \sim \mathcal{N}(0, 1)\\ W_{ij} &\sim \textrm{Bernoulli}\left(\sigma\left(\left(UV^T\right)_{ij} + u_i + v_j\right)\right) \end{align*}

Esta es una matriz bastante grande; 6040 usuarios y 3706 películas conducen a una matriz con más de 22 millones de entradas. ¿Cómo abordamos la fragmentación de este modelo? Bueno, si asumimos que ParseError: KaTeX parse error: Expected 'EOF', got '&' at position 3: N &̲gt; M (es decir, hay más usuarios que películas), entonces tendría sentido fragmentar la matriz de visualización a través del eje del usuario, de modo que cada dispositivo tenga un fragmento de matriz de visualización correspondiente a un subconjunto de usuarios. Sin embargo, a diferencia del ejemplo anterior, también tendremos que fragmentar la matriz UU, ya que tiene una incrustación para cada usuario, por lo que cada dispositivo será responsable de un fragmento de UU y un fragmento de WW. Por otro lado, VV se desfragmentará y se sincronizará entre dispositivos.

sharded_watch_matrix = shard(watch_matrix)

Antes de escribir nuestra run, analicemos rápidamente los desafíos adicionales que implica fragmentar la variable aleatoria local UU. Al ejecutar el HMC, el núcleo básico tfp.mcmc.HamiltonianMonteCarlo tomará muestras de los momentos para cada elemento del estado de la cadena. Anteriormente, solo las variables aleatorias sin fragmentar formaban parte de ese estado, y los momentos eran los mismos en cada dispositivo. Ahora que tenemos un UU fragmentado, necesitamos muestrear diferentes momentos en cada dispositivo para UU, mientras muestreamos los mismos momentos para VV. Para lograr esto, podemos usar tfp.experimental.mcmc.PreconditionedHamiltonianMonteCarlo con una distribución de momento Sharded. A medida que continuamos haciendo que el cálculo paralelo sea de primera clase, podemos simplificarlo, por ejemplo, llevando un indicador de fragmentación al núcleo del HMC.

def make_run(*, axis_name, dim=20, num_chains=2, prior_variance=1., step_size=1e-2, num_leapfrog_steps=100, num_burnin_steps=1000, num_results=500, ): @functools.partial(jax.pmap, in_axes=(None, 0), axis_name=axis_name) @jax.default_matmul_precision('tensorfloat32') def run(key, watch_matrix): num_users, num_movies = watch_matrix.shape Sharded = functools.partial(tfed.Sharded, shard_axis_name=axis_name) def prior_fn(): user_embeddings = yield Root(Sharded(tfd.Sample(tfd.Normal(0., 1.), [num_users, dim]), name='user_embeddings')) user_bias = yield Root(Sharded(tfd.Sample(tfd.Normal(0., 1.), [num_users]), name='user_bias')) movie_embeddings = yield Root(tfd.Sample(tfd.Normal(0., 1.), [num_movies, dim], name='movie_embeddings')) movie_bias = yield Root(tfd.Sample(tfd.Normal(0., 1.), [num_movies], name='movie_bias')) return (user_embeddings, user_bias, movie_embeddings, movie_bias) prior = tfed.JointDistributionCoroutine(prior_fn) def model_fn(): user_embeddings, user_bias, movie_embeddings, movie_bias = yield from prior_fn() logits = (jnp.einsum('...nd,...md->...nm', user_embeddings, movie_embeddings) + user_bias[..., :, None] + movie_bias[..., None, :]) yield Sharded(tfd.Independent(tfd.Bernoulli(logits=logits), 2), name='watch') model = tfed.JointDistributionCoroutine(model_fn) init_key, sample_key = random.split(key) initial_state = prior.sample(seed=init_key, sample_shape=num_chains) def target_log_prob(*state): return model.log_prob((*state, watch_matrix)) momentum_distribution = tfed.JointDistributionSequential([ Sharded(tfd.Independent(tfd.Normal(jnp.zeros([num_chains, num_users, dim]), 1.), 2)), Sharded(tfd.Independent(tfd.Normal(jnp.zeros([num_chains, num_users]), 1.), 1)), tfd.Independent(tfd.Normal(jnp.zeros([num_chains, num_movies, dim]), 1.), 2), tfd.Independent(tfd.Normal(jnp.zeros([num_chains, num_movies]), 1.), 1), ]) # We pass in momentum_distribution here to ensure that the momenta for # user_embeddings and user_bias are also sharded kernel = tfem.PreconditionedHamiltonianMonteCarlo(target_log_prob, step_size, num_leapfrog_steps, momentum_distribution=momentum_distribution) num_adaptation_steps = int(0.8 * num_burnin_steps) kernel = tfm.DualAveragingStepSizeAdaptation(kernel, num_adaptation_steps) def trace_fn(state, pkr): return { 'log_prob': target_log_prob(*state), 'log_accept_ratio': pkr.inner_results.log_accept_ratio, } return tfm.sample_chain( num_results, initial_state, kernel=kernel, num_burnin_steps=num_burnin_steps, trace_fn=trace_fn, seed=sample_key) return run

Lo ejecutaremos nuevamente una vez para almacenar en caché la run compilada.

%%time run = make_run(axis_name='data') output = run(random.PRNGKey(0), sharded_watch_matrix) jax.tree_map(lambda x: x.block_until_ready(), output)
CPU times: user 56 s, sys: 1min 24s, total: 2min 20s Wall time: 3min 35s

Ahora lo ejecutaremos nuevamente sin la sobrecarga de compilación.

%%time states, trace = run(random.PRNGKey(0), sharded_watch_matrix) jax.tree_map(lambda x: x.block_until_ready(), trace)
CPU times: user 28.8 s, sys: 1min 16s, total: 1min 44s Wall time: 3min 1s

Parece que completamos alrededor de 150 000 pasos de salto en aproximadamente 3 minutos, ¡es decir, alrededor de 83 pasos de salto por segundo! Tracemos la relación de aceptación y la densidad logarítmica de nuestras muestras.

fig, axs = plt.subplots(1, len(trace), figsize=(5 * len(trace), 5)) for ax, (key, val) in zip(axs, trace.items()): ax.plot(val[0]) # Indexing into a sharded array, each element is the same ax.set_title(key);
Image in a Jupyter notebook

Ahora que tenemos algunas muestras de nuestra cadena de Markov, usémoslas para hacer algunas predicciones. Primero, extraigamos cada uno de los componentes. Recuerde que user_embeddings y user_bias se dividen entre dispositivos, por lo que debemos concatenar nuestro ShardedArray para obtenerlos todos. Por otro lado, movie_embeddings y movie_bias son iguales en todos los dispositivos, por lo que podemos elegir el valor del primer fragmento. Usaremos numpy normal para copiar los valores de las TPU a la CPU.

user_embeddings = np.concatenate(np.array(states.user_embeddings, np.float32), axis=2) user_bias = np.concatenate(np.array(states.user_bias, np.float32), axis=2) movie_embeddings = np.array(states.movie_embeddings[0], dtype=np.float32) movie_bias = np.array(states.movie_bias[0], dtype=np.float32) samples = (user_embeddings, user_bias, movie_embeddings, movie_bias) print(f'User embeddings: {user_embeddings.shape}') print(f'User bias: {user_bias.shape}') print(f'Movie embeddings: {movie_embeddings.shape}') print(f'Movie bias: {movie_bias.shape}')
User embeddings: (500, 2, 6040, 20) User bias: (500, 2, 6040) Movie embeddings: (500, 2, 3706, 20) Movie bias: (500, 2, 3706)

Tratemos de construir un sistema de recomendación simple que use la incertidumbre capturada en estas muestras. Primero, escribamos una función que clasifique las películas según la probabilidad de visualización.

@jax.jit def recommend(sample, user_id): user_embeddings, user_bias, movie_embeddings, movie_bias = sample movie_logits = ( jnp.einsum('d,md->m', user_embeddings[user_id], movie_embeddings) + user_bias[user_id] + movie_bias) return movie_logits.argsort()[::-1]

Ahora podemos escribir una función que recorra todas las muestras y, para cada una, elija la película mejor clasificada que el usuario aún no haya visto. Luego podemos ver los recuentos de todas las películas recomendadas en las muestras.

def get_recommendations(user_id): movie_ids = [] already_watched = set(jnp.arange(num_movies)[watch_matrix[user_id] == 1]) for i in range(500): for j in range(2): sample = jax.tree_map(lambda x: x[i, j], samples) ranking = recommend(sample, user_id) for movie_id in ranking: if int(movie_id) not in already_watched: movie_ids.append(movie_id) break return movie_ids def plot_recommendations(movie_ids, ax=None): titles = collections.Counter([movie_id_to_title[i] for i in movie_ids]) ax = ax or plt.gca() names, counts = zip(*sorted(titles.items(), key=lambda x: -x[1])) ax.bar(names, counts) ax.set_xticklabels(names, rotation=90)

Comparemos el usuario que ha visto más películas con el que ha visto menos.

user_watch_counts = watch_matrix.sum(axis=1) user_most = user_watch_counts.argmax() user_least = user_watch_counts.argmin() print(user_watch_counts[user_most], user_watch_counts[user_least])
2314 20

Esperamos que nuestro sistema tenga más certeza sobre user_most que user_least, dado que tenemos más información sobre qué tipo de películas es más probable que vea user_most.

fig, ax = plt.subplots(1, 2, figsize=(20, 10)) most_recommendations = get_recommendations(user_most) plot_recommendations(most_recommendations, ax=ax[0]) ax[0].set_title('Recommendation for user_most') least_recommendations = get_recommendations(user_least) plot_recommendations(least_recommendations, ax=ax[1]) ax[1].set_title('Recommendation for user_least');
Image in a Jupyter notebook

Vemos que hay más variación en nuestras recomendaciones para user_least, lo que refleja nuestra incertidumbre adicional en sus preferencias de visualización.

También podemos ver los géneros de las películas recomendadas.

most_genres = collections.Counter([movie_id_to_genre[i] for i in most_recommendations]) least_genres = collections.Counter([movie_id_to_genre[i] for i in least_recommendations]) fig, ax = plt.subplots(1, 2, figsize=(20, 10)) ax[0].bar(most_genres.keys(), most_genres.values()) ax[0].set_title('Genres recommended for user_most') ax[1].bar(least_genres.keys(), least_genres.values()) ax[1].set_title('Genres recommended for user_least');
Image in a Jupyter notebook

user_most ha visto muchas películas y se le han recomendado géneros más especializados, como misterio y crimen, mientras que user_least no ha visto muchas películas y se le han recomendado películas más convencionales, que se inclinan por la comedia y la acción.