Path: blob/master/site/pt-br/guide/checkpoint.ipynb
25115 views
Copyright 2018 The TensorFlow Authors.
Treinando checkpoints
A frase "Salvando um modelo do TensorFlow" normalmente significa uma das duas coisas a seguir:
Checkpoints, OU
SavedModel.
Os checkpoints capturam o valor exato de todos os parâmetros (objetos tf.Variable
) usados por um modelo. Os checkpoints não contêm nenhuma descrição da computação definida pelo modelo e, portanto, normalmente são úteis apenas quando o código-fonte que usará os valores de parâmetro salvos estiver disponível.
O formato SavedModel, por outro lado, inclui uma descrição serializada da computação definida pelo modelo, além dos valores dos parâmetros (checkpoint). Os modelos neste formato são independentes do código-fonte que criou o modelo. Eles são, portanto, adequados para implantação via TensorFlow Serving, TensorFlow Lite, TensorFlow.js ou programas em outras linguagens de programação (C, C++, Java, Go, Rust, C# etc. APIs do TensorFlow).
Este guia trata de APIs para escrever e ler checkpoints.
Configuração
Salvando de APIs de treinamento tf.keras
Veja o Guia tf.keras
sobre como salvar e restaurar.
O tf.keras.Model.save_weights
salva um checkpoint do TensorFlow.
Escrevendo checkpoints
O estado persistente de um modelo TensorFlow é armazenado em objetos tf.Variable
. Eles podem ser construídos diretamente, mas geralmente são criados via APIs de alto nível, como tf.keras.layers
ou tf.keras.Model
.
A maneira mais fácil de gerenciar variáveis é anexá-las a objetos Python e, em seguida, fazer referência a esses objetos.
Subclasses de tf.train.Checkpoint
, tf.keras.layers.Layer
e tf.keras.Model
rastreiam automaticamente as variáveis atribuídas a seus atributos. O exemplo a seguir constrói um modelo linear simples e, em seguida, grava checkpoints que contêm valores para todas as variáveis do modelo.
Você pode salvar um checkpoint de modelo facilmente com Model.save_weights
.
Definição manual de checkpoints
Configuração
Para ajudar a demonstrar todos os recursos de tf.train.Checkpoint
, defina um dataset de brinquedo e uma etapa de otimização:
Crie os objetos do checkpoint
Use um objeto tf.train.Checkpoint
para criar um checkpoint manualmente, onde os objetos que você deseja verificar com o checkpoint são definidos como atributos no objeto.
Um tf.train.CheckpointManager
também pode ser útil para gerenciar múltiplos checkpoints.
Treine o modelo e aplique checkpoints
O loop de treinamento a seguir cria uma instância do modelo e de um otimizador e os reúne num objeto tf.train.Checkpoint
. Ele chama a etapa de treinamento dentro de um loop em cada lote de dados e grava checkpoints periodicamente no disco.
Restaure e continue treinando
Depois do primeiro ciclo de treinamento, você pode passar por um novo modelo e gerente, mas retome o treinamento exatamente de onde parou:
O objeto tf.train.CheckpointManager
exclui checkpoints antigos. Acima, ele está configurado para manter apenas os três checkpoints mais recentes.
Esses caminhos, por exemplo './tf_ckpts/ckpt-10'
, não são arquivos no disco. Na verdade eles são prefixos para um arquivo index
e um ou mais arquivos de dados que contêm os valores das variáveis. Esses prefixos são agrupados num único arquivo checkpoint
('./tf_ckpts/checkpoint'
) onde o CheckpointManager
salva seu estado.
Mecânica de carregamento
O TensorFlow combina variáveis com valores dos checkpoints percorrendo um grafo direcionado com arestas nomeadas, começando pelo objeto que está sendo carregado. Nomes de arestas geralmente vêm de nomes de atributos dos objetos, por exemplo, o "l1"
em self.l1 = tf.keras.layers.Dense(5)
. tf.train.Checkpoint
usa nomes de argumento de palavras-chave, como "step"
em tf.train.Checkpoint(step=...)
.
O grafo de dependência do exemplo está mostrado a seguir:
O otimizador aparece em vermelho, as variáveis regulares em azul e as variáveis de slot do otimizador em laranja. Os outros nós — por exemplo, representando o tf.train.Checkpoint
— estão em preto.
As variáveis de slot fazem parte do estado do otimizador, mas são criadas para uma variável específica. Por exemplo, as arestas 'm'
acima correspondem ao momento, que o otimizador Adam rastreia para cada variável. As variáveis de slot só são salvas num checkpoint se a variável e também o otimizador forem salvos, por isso as bordas tracejadas.
Chamar restore
num objeto tf.train.Checkpoint
enfileira as restaurações solicitadas, restaurando valores das variáveis assim que houver um caminho correspondente do objeto Checkpoint
. Por exemplo, você pode carregar apenas o bias do modelo definido acima reconstruindo um caminho para ele através da rede e da camada.
O grafo de dependência para esses novos objetos é um subgráfico muito menor do checkpoint maior que você escreveu acima. Ele inclui apenas o bias e um contador de salvamento que tf.train.Checkpoint
usa para numerar os checkpoints.
restore
retorna um objeto de status, que contém asserções opcionais. Todos os objetos criados no novo Checkpoint
foram restaurados, então status.assert_existing_objects_matched
passa.
Existem muitos objetos no checkpoint que não correspondem, incluindo o kernel da camada e as variáveis do otimizador. status.assert_consumed
passaria apenas se houvesse uma correspondência exata entre o checkpoint e o programa, e causaria o lançamento de uma exceção nesse ponto.
Restaurações adiadas
Objetos Layer
no TensorFlow podem adiar a criação de variáveis para sua primeira chamada, quando os formatos de entrada estiverem disponíveis. Por exemplo, o formato do kernel de uma camada Dense
depende dos formatos de entrada e saída da camada e, portanto, o formato de saída necessário como um argumento do construtor não seria informação suficiente para criar a variável. Como a chamada de um Layer
também lê o valor da variável, uma restauração deve acontecer entre a criação da variável e seu primeiro uso.
Para oferecer suporte a esse padrão, tf.train.Checkpoint
adia restaurações que ainda não possuem uma variável correspondente.
Inspeção manual de checkpoints
tf.train.load_checkpoint
retorna um CheckpointReader
que fornece acesso de nível inferior ao conteúdo do checkpoint. Ele contém mapeamentos da chave de cada variável, para o formato e para o dtype de cada variável no checkpoint. A chave de uma variável é o caminho do objeto, como nos grafos exibidos acima.
Observação: Não existe estrutura de nível superior para o checkpoint. Ele conhece apenas os caminhos e valores das variáveis, e não tem noção de models
(modelos), layers
(camadas) ou como eles estão conectados.
Portanto, se você tiver interesse no valor de net.l1.kernel
, poderá obter o valor com o seguinte código:
Ele também fornece um método get_tensor
que permite inspecionar o valor de uma variável:
Rastreamento de objetos
Os checkpoints salvam e restauram os valores dos objetos tf.Variable
"rastreando" qualquer variável ou objeto rastreável definido em um dos seus atributos. Ao executar um salvamento, as variáveis são obtidas recursivamente de todos os objetos rastreados que forem alcançáveis.
Tal como acontece com as atribuições diretas de atributos como self.l1 = tf.keras.layers.Dense(5)
, atribuir listas e dicionários a atributos rastreará seu conteúdo.
Você poderá perceber objetos wrapper para listas e dicionários. Esses wrappers são versões passíveis de verificação com checkpoints das estruturas de dados subjacentes. Assim como o carregamento baseado em atributos, esses wrappers restauram o valor de uma variável assim que ela é adicionada ao container.
Objetos rastreáveis incluem tf.train.Checkpoint
, tf.Module
e suas subclasses (por exemplo, keras.layers.Layer
e keras.Model
) além de containers Python reconhecidos:
dict
(ecollections.OrderedDict
)list
tuple
(ecollections.namedtuple
,typing.NamedTuple
)
Outros tipos de container não são suportados, incluindo:
collections.defaultdict
set
Todos os outros objetos Python são ignorados, incluindo:
int
string
float
Resumo
Os objetos TensorFlow fornecem um mecanismo automático fácil de usar para salvar e restaurar os valores das variáveis que eles usam.