Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
tensorflow
GitHub Repository: tensorflow/docs-l10n
Path: blob/master/site/pt-br/guide/checkpoint.ipynb
25115 views
Kernel: Python 3
#@title Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License.

Treinando checkpoints

A frase "Salvando um modelo do TensorFlow" normalmente significa uma das duas coisas a seguir:

  1. Checkpoints, OU

  2. SavedModel.

Os checkpoints capturam o valor exato de todos os parâmetros (objetos tf.Variable) usados ​​por um modelo. Os checkpoints não contêm nenhuma descrição da computação definida pelo modelo e, portanto, normalmente são úteis apenas quando o código-fonte que usará os valores de parâmetro salvos estiver disponível.

O formato SavedModel, por outro lado, inclui uma descrição serializada da computação definida pelo modelo, além dos valores dos parâmetros (checkpoint). Os modelos neste formato são independentes do código-fonte que criou o modelo. Eles são, portanto, adequados para implantação via TensorFlow Serving, TensorFlow Lite, TensorFlow.js ou programas em outras linguagens de programação (C, C++, Java, Go, Rust, C# etc. APIs do TensorFlow).

Este guia trata de APIs para escrever e ler checkpoints.

Configuração

import tensorflow as tf
class Net(tf.keras.Model): """A simple linear model.""" def __init__(self): super(Net, self).__init__() self.l1 = tf.keras.layers.Dense(5) def call(self, x): return self.l1(x)
net = Net()

Salvando de APIs de treinamento tf.keras

Veja o Guia tf.keras sobre como salvar e restaurar.

O tf.keras.Model.save_weights salva um checkpoint do TensorFlow.

net.save_weights('easy_checkpoint')

Escrevendo checkpoints

O estado persistente de um modelo TensorFlow é armazenado em objetos tf.Variable. Eles podem ser construídos diretamente, mas geralmente são criados via APIs de alto nível, como tf.keras.layers ou tf.keras.Model.

A maneira mais fácil de gerenciar variáveis ​​é anexá-las a objetos Python e, em seguida, fazer referência a esses objetos.

Subclasses de tf.train.Checkpoint, tf.keras.layers.Layer e tf.keras.Model rastreiam automaticamente as variáveis ​​atribuídas a seus atributos. O exemplo a seguir constrói um modelo linear simples e, em seguida, grava checkpoints que contêm valores para todas as variáveis ​​do modelo.

Você pode salvar um checkpoint de modelo facilmente com Model.save_weights.

Definição manual de checkpoints

Configuração

Para ajudar a demonstrar todos os recursos de tf.train.Checkpoint, defina um dataset de brinquedo e uma etapa de otimização:

def toy_dataset(): inputs = tf.range(10.)[:, None] labels = inputs * 5. + tf.range(5.)[None, :] return tf.data.Dataset.from_tensor_slices( dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer): """Trains `net` on `example` using `optimizer`.""" with tf.GradientTape() as tape: output = net(example['x']) loss = tf.reduce_mean(tf.abs(output - example['y'])) variables = net.trainable_variables gradients = tape.gradient(loss, variables) optimizer.apply_gradients(zip(gradients, variables)) return loss

Crie os objetos do checkpoint

Use um objeto tf.train.Checkpoint para criar um checkpoint manualmente, onde os objetos que você deseja verificar com o checkpoint são definidos como atributos no objeto.

Um tf.train.CheckpointManager também pode ser útil para gerenciar múltiplos checkpoints.

opt = tf.keras.optimizers.Adam(0.1) dataset = toy_dataset() iterator = iter(dataset) ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator) manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

Treine o modelo e aplique checkpoints

O loop de treinamento a seguir cria uma instância do modelo e de um otimizador e os reúne num objeto tf.train.Checkpoint. Ele chama a etapa de treinamento dentro de um loop em cada lote de dados e grava checkpoints periodicamente no disco.

def train_and_checkpoint(net, manager): ckpt.restore(manager.latest_checkpoint) if manager.latest_checkpoint: print("Restored from {}".format(manager.latest_checkpoint)) else: print("Initializing from scratch.") for _ in range(50): example = next(iterator) loss = train_step(net, example, opt) ckpt.step.assign_add(1) if int(ckpt.step) % 10 == 0: save_path = manager.save() print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path)) print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)

Restaure e continue treinando

Depois do primeiro ciclo de treinamento, você pode passar por um novo modelo e gerente, mas retome o treinamento exatamente de onde parou:

opt = tf.keras.optimizers.Adam(0.1) net = Net() dataset = toy_dataset() iterator = iter(dataset) ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator) manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3) train_and_checkpoint(net, manager)

O objeto tf.train.CheckpointManager exclui checkpoints antigos. Acima, ele está configurado para manter apenas os três checkpoints mais recentes.

print(manager.checkpoints) # List the three remaining checkpoints

Esses caminhos, por exemplo './tf_ckpts/ckpt-10', não são arquivos no disco. Na verdade eles são prefixos para um arquivo index e um ou mais arquivos de dados que contêm os valores das variáveis. Esses prefixos são agrupados num único arquivo checkpoint ('./tf_ckpts/checkpoint') onde o CheckpointManager salva seu estado.

!ls ./tf_ckpts

Mecânica de carregamento

O TensorFlow combina variáveis ​​com valores dos checkpoints percorrendo um grafo direcionado com arestas nomeadas, começando pelo objeto que está sendo carregado. Nomes de arestas geralmente vêm de nomes de atributos dos objetos, por exemplo, o "l1" em self.l1 = tf.keras.layers.Dense(5). tf.train.Checkpoint usa nomes de argumento de palavras-chave, como "step" em tf.train.Checkpoint(step=...).

O grafo de dependência do exemplo está mostrado a seguir:

Visualization of the dependency graph for the example training loop

O otimizador aparece em vermelho, as variáveis ​​regulares em azul e as variáveis ​​de slot do otimizador em laranja. Os outros nós — por exemplo, representando o tf.train.Checkpoint — estão em preto.

As variáveis ​​de slot fazem parte do estado do otimizador, mas são criadas para uma variável específica. Por exemplo, as arestas 'm' acima correspondem ao momento, que o otimizador Adam rastreia para cada variável. As variáveis ​​de slot só são salvas num checkpoint se a variável e também o otimizador forem salvos, por isso as bordas tracejadas.

Chamar restore num objeto tf.train.Checkpoint enfileira as restaurações solicitadas, restaurando valores das variáveis ​​assim que houver um caminho correspondente do objeto Checkpoint. Por exemplo, você pode carregar apenas o bias do modelo definido acima reconstruindo um caminho para ele através da rede e da camada.

to_restore = tf.Variable(tf.zeros([5])) print(to_restore.numpy()) # All zeros fake_layer = tf.train.Checkpoint(bias=to_restore) fake_net = tf.train.Checkpoint(l1=fake_layer) new_root = tf.train.Checkpoint(net=fake_net) status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/')) print(to_restore.numpy()) # This gets the restored value.

O grafo de dependência para esses novos objetos é um subgráfico muito menor do checkpoint maior que você escreveu acima. Ele inclui apenas o bias e um contador de salvamento que tf.train.Checkpoint usa para numerar os checkpoints.

Visualization of a subgraph for the bias variable

restore retorna um objeto de status, que contém asserções opcionais. Todos os objetos criados no novo Checkpoint foram restaurados, então status.assert_existing_objects_matched passa.

status.assert_existing_objects_matched()

Existem muitos objetos no checkpoint que não correspondem, incluindo o kernel da camada e as variáveis ​​do otimizador. status.assert_consumed passaria apenas se houvesse uma correspondência exata entre o checkpoint e o programa, e causaria o lançamento de uma exceção nesse ponto.

Restaurações adiadas

Objetos Layer no TensorFlow podem adiar a criação de variáveis ​​para sua primeira chamada, quando os formatos de entrada estiverem disponíveis. Por exemplo, o formato do kernel de uma camada Dense depende dos formatos de entrada e saída da camada e, portanto, o formato de saída necessário como um argumento do construtor não seria informação suficiente para criar a variável. Como a chamada de um Layer também lê o valor da variável, uma restauração deve acontecer entre a criação da variável e seu primeiro uso.

Para oferecer suporte a esse padrão, tf.train.Checkpoint adia restaurações que ainda não possuem uma variável correspondente.

deferred_restore = tf.Variable(tf.zeros([1, 5])) print(deferred_restore.numpy()) # Not restored; still zeros fake_layer.kernel = deferred_restore print(deferred_restore.numpy()) # Restored

Inspeção manual de checkpoints

tf.train.load_checkpoint retorna um CheckpointReader que fornece acesso de nível inferior ao conteúdo do checkpoint. Ele contém mapeamentos da chave de cada variável, para o formato e para o dtype de cada variável no checkpoint. A chave de uma variável é o caminho do objeto, como nos grafos exibidos acima.

Observação: Não existe estrutura de nível superior para o checkpoint. Ele conhece apenas os caminhos e valores das variáveis, e não tem noção de models (modelos), layers (camadas) ou como eles estão conectados.

reader = tf.train.load_checkpoint('./tf_ckpts/') shape_from_key = reader.get_variable_to_shape_map() dtype_from_key = reader.get_variable_to_dtype_map() sorted(shape_from_key.keys())

Portanto, se você tiver interesse no valor de net.l1.kernel, poderá obter o valor com o seguinte código:

key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE' print("Shape:", shape_from_key[key]) print("Dtype:", dtype_from_key[key].name)

Ele também fornece um método get_tensor que permite inspecionar o valor de uma variável:

reader.get_tensor(key)

Rastreamento de objetos

Os checkpoints salvam e restauram os valores dos objetos tf.Variable "rastreando" qualquer variável ou objeto rastreável definido em um dos seus atributos. Ao executar um salvamento, as variáveis ​​são obtidas recursivamente de todos os objetos rastreados que forem alcançáveis.

Tal como acontece com as atribuições diretas de atributos como self.l1 = tf.keras.layers.Dense(5), atribuir listas e dicionários a atributos rastreará seu conteúdo.

save = tf.train.Checkpoint() save.listed = [tf.Variable(1.)] save.listed.append(tf.Variable(2.)) save.mapped = {'one': save.listed[0]} save.mapped['two'] = save.listed[1] save_path = save.save('./tf_list_example') restore = tf.train.Checkpoint() v2 = tf.Variable(0.) assert 0. == v2.numpy() # Not restored yet restore.mapped = {'two': v2} restore.restore(save_path) assert 2. == v2.numpy()

Você poderá perceber objetos wrapper para listas e dicionários. Esses wrappers são versões passíveis de verificação com checkpoints das estruturas de dados subjacentes. Assim como o carregamento baseado em atributos, esses wrappers restauram o valor de uma variável assim que ela é adicionada ao container.

restore.listed = [] print(restore.listed) # ListWrapper([]) v1 = tf.Variable(0.) restore.listed.append(v1) # Restores v1, from restore() in the previous cell assert 1. == v1.numpy()

Objetos rastreáveis ​​incluem tf.train.Checkpoint, tf.Module e suas subclasses (por exemplo, keras.layers.Layer e keras.Model) além de containers Python reconhecidos:

  • dict (e collections.OrderedDict)

  • list

  • tuple (e collections.namedtuple, typing.NamedTuple)

Outros tipos de container não são suportados, incluindo:

  • collections.defaultdict

  • set

Todos os outros objetos Python são ignorados, incluindo:

  • int

  • string

  • float

Resumo

Os objetos TensorFlow fornecem um mecanismo automático fácil de usar para salvar e restaurar os valores das variáveis ​​que eles usam.