Layout com tiles
Atenção: o layout com tiles é um pré-lançamento este documento descreve como ele deve funcionar. Erros poderãi ser ignorados silenciosamente.
Figura 1
A Figura 1 mostra como um array F32[3,5] é disposto na memória com tiles (ladrilhos) nas dimensões 2x2. Um formato com este layout é escrito como F32[3,5]{1,0:T(2,2)}, onde 1,0 se refere à ordem física das dimensões (campo minor_to_major em Layout) enquanto (2,2) depois dos dois pontos indica o tiling das dimensões físicas por um tile de dimensões 2x2.
Intuitivamente, os tiles são dispostos para cobrir o formato e, em seguida, dentro de cada tile, os elementos são dispostos sem usar tiling, como no exemplo acima, onde a parte direita do exemplo mostra o layout na memória, incluindo os elementos de preenchimento (padding) que são adicionados para ter tiles 2x2 completos, mesmo que os limites do array original não sejam iguais.
Os elementos extras no preenchimento não precisam conter nenhum valor específico.
Fórmulas de índice linear para o tiling, dados um formato e um tile
Sem tiling, um elemento e=(e n , e n-1, ... , e 1 ) num array com limites de array d=(d n , d n-1, ... , d 1) (d1 é a menor dimensão) é disposta na ordem de maior para menor na posição:
linear_index(e, d)
= linear_index((en, en-1, ... , e1), (dn, dn-1, ... , d1))
= endn-1...d1 + en-1dn-2...d1 + ... + e1
Para simplificar a notação neste documento, assumimos que um tile tem o mesmo número de dimensões que o array. Na implementação de tiling do XLA, isto é generalizado para tiles com menos dimensões, deixando, inicialmente, as dimensões maiores inalteradas e aplicando o tiling apenas às dimensões menores, de modo que o tiling especificado mencione um sufixo das dimensões físicas do formato no qual se está aplicando o tiling.
Quando o tiling de dimensões (t n, t n-1, ... , t 1) é usado, um elemento no array com índices (en, en-1, ... , e1) é mapeado para esta posição no layout final:
linear_index_with_tile(e, d, t)
= linear_index((⌊e/t⌋, e mod t), (⌈d/t⌉, t)) (aritmética é elemento por elemento, (a,b) é concatenação)
= linear_index((⌊en/tn⌋, ... , ⌊e1/t1⌋, en mod tn, ... , e1 mod t1), (⌈dn/tn⌉, ... , ⌈d1/t1⌉, tn, tn-1, ... , t1))
= linear_index((⌊en/tn⌋, ... , ⌊e1/t1⌋), (⌈dn/tn⌉, ... , ⌈d1/t1⌉))∙tntn-1...t1 + linear_index((en mod tn, ... , e1 mod t1), (tn, tn-1, ... , t1))
Pode-se considerar o layout como tendo duas partes: (⌊e n/tn⌋, ... , ⌊e 1 /t1⌋), que corresponde a um índice de tiles num array de tiles de tamanho (⌈d n/tn ⌉, ... , ⌈d 1/t1 ⌉), e (e nmod tn , ... , e 1 mod t 1 ), que corresponde a um índice dentro do tile. A função ceil aparece em ⌈d i/ti ⌉ porque se os blocos ultrapassarem os limites do array maior, o preenchimento será inserido como na Figura 1. Tanto os blocos quanto os elementos dentro dos blocos são dispostos recursivamente sem tiling.
Para o exemplo da Figura 1, o elemento (2,3) tem o índice do tile (1,1) e índice dentro do tile (0,1), para um vetor de coordenadas combinado de (1, 1, 0, 1). Os índices dos tiles têm limites (2, 3) e o próprio tile é (2, 2) para um vetor combinado de (2, 3, 2, 2). Assim, o índice linear com tile para o elemento com índice (2, 3) na forma lógica é
linear_index_with_tile((2,3), (3,5), (2,2))
= linear_index((1,1,0,1), (2,3,2,2))
= linear_index((1,1), (2,3)) ∙ 2 ∙ 2 + linear_index((0,1), (2,2))
= (1 ∙ 3 + 1) ∙ 2 ∙ 2 + (0 ∙ 2 + 1)
= 17.
Tiling como pad-reshape-transpose
O layout baseado em tiling funciona da seguinte maneira: (dn, dn-1, ... , d1) (d1 é a dimensão menor). Quando é disposto com um tiling de dimensões (tn, tn-1, ... , t1) (t1 é a dimensão menor), esse tiling pode ser descrito em termos de pad-reshape-transpose da seguinte forma
O array é preenchido para (⌈dn/tn⌉∙tn, ... , ⌈d1/t1⌉∙t1).
Cada dimensão i é dividida em (⌈di/ti⌉, ti), ou seja, o array é reformatado para
(⌈dn/tn⌉, tn, ... , ⌈d1/t1⌉, t1).
Não há nenhuma mudança de layout físico nesta alteração por si só, então esta aleração é um bitcast. Se não estivermos pensando explicitamente em um tile, essa remodelação poderia expressar qualquer formato com o mesmo número de elementos que o formato com preenchimento - o exemplo aqui é sobre como expressar um tile dessa maneira.Uma transposição (transpose) acontece movendo t n, ... , t 1 para as dimensões menores, mantendo sua ordem relativa, de modo que a maioria das ordens de dimensões do maior para o menor se torne (⌈dn/tn⌉, ... , ⌈d1/t1⌉, tn, ... , t1).
O formato final tem o prefixo
(⌈dn/tn⌉, ... , ⌈d1/t1⌉), que descreve o número de tiles em cada dimensão. Um elemento na matriz (e n, ... , e 1) é mapeado para este elemento no formato final: (⌊en/tn⌋, ... , ⌊e0/t0⌋, en mod tn, ... , e1 mod t1). É fácil perceber que o índice linear do elemento segue a fórmula acima conforme o esperado.
Tiles repetidos
O tiling do XLA torna-se ainda mais flexível ao aplicá-lo de forma repetida.
Figura 2
A Figura 2 mostra como um array de tamanho 4x8 é dividido em dois níveis de tiling (primeiro 2x4 e depois 2x1). Representamos esse tiling repetido como (2,4)(2,1). Cada cor indica um tile 2x4 e cada caixa de borda vermelha é um tile 2x1. Os números indicam o índice linear na memória desse elemento no formato do tiling. Este formato corresponde ao formato usado para BF16 na TPU, exceto que o tile inicial é maior, ou seja, o tiling é (8,128)(2,1), onde o objetivo do segundo tiling por 2x1 é coletar dois valores de 16 bits para formar um valor de 32 bits de forma que se alinhe com a arquitetura de uma TPU.
Observe que um segundo tile ou tile posterior pode se referir às dimensões menores dentro do tile, que apenas reorganiza os dados dentro dele, como neste exemplo com (8,128)(2,1), mas também pode se referir às dimensões principais entre tiles obtidas do tiling anterior.
Combinando dimensões usando tiles
O tiling do XLA também suporta a combinação de dimensões. Por exemplo, ele pode combinar dimensões em F32[2,7,8,11,10]{4,3,2,1,0} para F32[112,110]{1,0} antes de fazer o tiling com (2,3 ). O tile usado é (∗,∗,2,∗,3). Aqui, um asterisco em um tile implica pegar essa dimensão e combiná-la com a próxima dimensão menor. Múltiplas dimensões adjacentes podem ser agrupadas numa dimensão. Uma dimensão agrupada é representada por um valor de tile de -1 naquela dimensão do tile, que de outra forma não seria válido num tile como tamanho de dimensão.
Mais precisamente, se a dimensão i do formato for eliminada por meio de um asterisco no tile, então, antes da definição anterior de tiling ser aplicada, essa dimensão será removida tanto do formato que está sendo disposto usando tiling, quanto do vetor do tile, e o que era a dimensão i-1 do formato tem seu limite de array aumentado de di-1 para didi-1. Esse passo é repetido para cada asterisco no vetor de tiles.