Use of Callbacks in Keras
Modules and methods
Modules
experimental module: Public API for tf.keras.callbacks.experimental namespace.
Classes
class BaseLogger : Callback that accumulates epoch averages of metrics
class CSVLogger : Callback that streams epoch results to a CSV file
class Callback: Abstract base class used to build new callbacks.
class CallbackList: Container abstracting a list of callbacks.
class EarlyStopping: Stop training when a monitored metric has stopped improving.
class History: Callback that records events into a History object.
class LambdaCallback: Callback for creating simple, custom callbacks on-the-fly.
class LearningRateScheduler: Learning rate scheduler.
class ModelCheckpoint: Callback to save the Keras model or model weights at some frequency.
class ProgbarLogger: Callback that prints metrics to stdout.
class ReduceLROnPlateau: Reduce learning rate when a metric has stopped improving.
class RemoteMonitor: Callback used to stream events to a server.
class TensorBoard: Enable visualizations for TensorBoard.
class TerminateOnNaN: Callback that terminates training when a NaN loss is encountered.
More commonly used methods
ModelCheckpoint
Function: Callback to save the Keras model or model weights at some frequency
Callback to save the Keras model or the weight of the model at a certain frequency (epoch).
Example 1
#Initialization model model = Classifier() #Set save path checkpoint_filepath = 'E:/Python_Workspace/Saved_models/checkpoint' #Set the modelcheckpoint parameter model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( filepath = checkpoint_filepath, save_weights_only=True, monitor='val_accuracy', mode='max', save_best_only = True) #Compilation model model.compile(optimizer=keras.optimizers.Adam(learning_rate=0.001),loss=keras.losses.SparseCategoricalCrossentropy(),metrics=['accuracy']) #Training model history = model.fit(train_x,train_y,epochs=500,batch_size=256,validation_data=(val_x, val_y),callbacks=model_checkpoint_callback) # Restore model with saved weights model.load_weights(checkpoint_filepath)
Example 2 trains a new model and saves a uniquely named checkpoint every five epochs:
# Include the epoch in the file name (uses `str.format`) checkpoint_path = "training_2/cp-{epoch:04d}.ckpt" checkpoint_dir = os.path.dirname(checkpoint_path) batch_size = 32 # Create a callback that saves the model's weights every 5 epochs cp_callback = tf.keras.callbacks.ModelCheckpoint( filepath=checkpoint_path, verbose=1, save_weights_only=True, save_freq=5*batch_size) # Create a new model instance model = create_model() # Save the weights using the `checkpoint_path` format model.save_weights(checkpoint_path.format(epoch=0)) # Train the model with the new callback model.fit(train_images, train_labels, epochs=50, batch_size=batch_size, callbacks=[cp_callback], validation_data=(test_images, test_labels), verbose=0) latest = tf.train.latest_checkpoint(checkpoint_dir) # Create a new model instance model = create_model() # Load the previously saved weights model.load_weights(latest) # Re-evaluate the model loss, acc = model.evaluate(test_images, test_labels, verbose=2) print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
Parameter description
tf.keras.callbacks.ModelCheckpoint( filepath, monitor='val_loss', verbose=0, save_best_only=False, save_weights_only=False, mode='auto', save_freq='epoch', options=None, **kwargs )
filepath the path where the model file is saved
The name of the metric monitored by monitor, such as loss, val_ loss,acc,val_ ACC is generally these.
history.historyhistory = model.fit()
You can view metric What are there
verbose is used to record training logs. verbose = 0 is not used to output log information in the standard output stream
verbose = 1 is the output progress bar record
verbose = 2 outputs one line of records for each epoch
save_best_only you only need to save the best model true or false
mode {auto ','min','max '} maximum or minimum target of monitoring quantity
save_weights_only save the complete model
save_freq the default time to save the model or weight is epoch, that is, it is saved at the end of each epoch. If it is set to integer type, it is saved after N batches
Saved files
If you save the weights, the following file appears:
[the external chain image transfer fails. The source station may have an anti-theft chain mechanism. It is recommended to save the image and upload it directly (img-mfphwgnh-16352364329) (C: \ users \ Lenovo \ appdata \ roaming \ typora user images \ image-20211026160956843. PNG)]
What are these files?
The above code stores the weights in checkpoint ——In the collection of format files, these files only contain training weights in binary format. Checkpoints include:
- One or more tiles containing model weights.
- An index file indicating which weights are stored in which slice.
If you train the model on a computer, you will get a fragment with the following suffix:. data-00000-of-00001
Save the entire model
The entire model can be saved in two different file formats (SavedModel and HDF5). TensorFlow SavedModel format is the default file format in TF2.x. However, the model can be saved in HDF5 format.
Keras provides a basic save format using the HDF5 standard.
Loading models can use
new_model = tf.keras.models.load_model('my_model.h5')
odel('my_model.h5')
[Save custom model](https://tensorflow.google.cn/tutorials/keras/save_and_load?hl=zh-cn#%E4%BF%9D%E5%AD%98%E8%87%AA%E5%AE%9A%E4%B9%89%E5%AF%B9%E8%B1%A1) Follow up==