Добавил:
Опубликованный материал нарушает ваши авторские права? Сообщите нам.
Вуз: Предмет: Файл:
Шолле Ф. - Глубокое обучение на Python (Библиотека программиста) - 2023.pdf
Скачиваний:
3
Добавлен:
07.04.2024
Размер:
11.34 Mб
Скачать

256    Глава 7. Работа с Keras: глубокое погружение

7.4.4. Ускорение вычислений с помощью tf.function

Возможно,.вы.заметили,.что.реализованные.вами.циклы.работают.значительно. медленнее,.чем.встроенные.функции.fit() .и.evaluate(),.несмотря.на.то.что. фактически.реализуют.ту.же.логику..Причина.в.том,.что.по.умолчанию.код. TensorFlow.выполняется.построчно.и.немедленно,.подобно.коду.NumPy.или. обычному.коду.Python..Немедленное.выполнение.упрощает.отладку,.но.с.точки. зрения.производительности.далеко.не.оптимально.

Более.полезным.для.производительности.будет.скомпилировать .код.Tensor­ Flow .в .граф вычислений, .который .можно .оптимизировать .глобально, .что. не.получится.сделать.при.построчной.интерпретации.кода..Синтаксис.применения.такой.оптимизации.прост:.добавьте.@tf.function .к.любой.функции,. которую.нужно.скомпилировать.перед.выполнением,.как.показано.в.следу­ ющем.листинге.

Листинг 7.25. Добавление декоратора @tf.function к функции оценки

@tf.function def test_step(inputs, targets):

predictions = model(inputs, training=False) loss = loss_fn(targets, predictions)

logs = {}

for metric in metrics: metric.update_state(targets, predictions) logs["val_" + metric.name] = metric.result()

loss_tracking_metric.update_state(loss) logs["val_loss"] = loss_tracking_metric.result() return logs

Единственная новая строка

val_dataset = tf.data.Dataset.from_tensor_slices((val_images, val_labels)) val_dataset = val_dataset.batch(32)

reset_metrics()

for inputs_batch, targets_batch in val_dataset: logs = test_step(inputs_batch, targets_batch)

print("Evaluation results:") for key, value in logs.items():

print(f"...{key}: {value:.4f}")

В.Colab.время.выполнения.цикла.оценки.уменьшилось.с.1,8.до.0,8.секунды..

Теперь.он.выполняется.намного.быстрее!

Помните,.что.в.процессе.отладки.код.лучше.запускать.без.декоратора.@tf.func­- tion..Так.проще.находить.и.устранять.ошибки..Закончив.отладку,.код.можно. ускорить,.добавив.декоратор.@tf.function .перед.функциями,.реализующими. шаг.обучения.и.шаг.оценки,.или.любыми.другими.функциями,.для.которых. важна.высокая.производительность.

7.4. Разработка своего цикла обучения и оценки    257

7.4.5. Использование fit() с нестандартным циклом обучения

Ранее.мы.с.нуля.написали.полный.цикл.обучения..Этот.подход.дает.максимальную.гибкость,.но.не.только.требует.написать.много.кода,.но.и.лишает.множества. удобных.возможностей.fit(),.таких.как.обратные.вызовы.или.встроенная.поддержка.распределенного.обучения.

А.получится.ли.применить.свой.алгоритм.обучения.и.сохранить.всю.мощь.встроенной.логики.обучения.Keras?.На.самом.деле.существует.золотая.середина.между. использованием.fit() .и.реализацией.своего.цикла.обучения:.можно.написать. свою.функцию.шага.обучения,.а.все.остальные.задачи.переложить.на.фреймворк.

Для.этого.достаточно.переопределить.метод.train_step().класса.Model,.который. вызывается.функцией.fit() .для.обработки.каждого.пакета.данных,.и.использовать.fit() .как.обычно,.а.функция.будет.запускать.ваш.алгоритм.обучения.

Вот.простой.пример:

. создадим.новый.класс,.наследующий.класс.keras.Model;

.переопределим.метод.train_step(self, data),.почти.полностью.повторив. все,.что.мы.написали.выше..Теперь.метод.будет.возвращать.словарь,.отображающий.имена.метрик.(включая.метрику.потерь).в.их.текущие.значения;

.реализуем.свойство.metrics .для.отслеживания.экземпляров.класса.Metric . в.модели..Это.позволит.модели.автоматически.вызывать.reset_state() .для. метрик.в.начале.каждой.эпохи.и.в.начале.вызова.функции.evaluate(),.чтобы. не.делать.этого.вручную.

Листинг 7.26. Реализация своего шага обучения для использования с fit()

loss_fn

= keras.losses.SparseCategoricalCrossentropy()

 

 

loss_tracker = keras.metrics.Mean(name="loss")

 

 

 

 

class CustomModel(keras.Model):

Мы переопределяем

 

def

train_step(self, data):

 

 

 

 

 

метод train_step

 

inputs, targets = data

 

 

 

 

 

 

 

with tf.GradientTape() as tape:

 

 

 

 

predictions = self(inputs, training=True)

 

loss = loss_fn(targets, predictions)

Данный объект метрики будет использоваться для слежения за средним значением потерь на пакетах в ходе обучения и оценки

Здесь вместо model(inputs, training=True) используется self(inputs, training=True), потому что моделью является сам экземпляр класса

gradients = tape.gradient(loss, model.trainable_weights) optimizer.apply_gradients(zip(gradients, model.trainable_weights))

loss_tracker.update_state(loss)

return {"loss": loss_tracker.result()}

@property

 

Список всех метрик, которые

def metrics(self):

 

 

должны сбрасываться

return [loss_tracker]

 

 

в исходное состояние

 

 

 

 

в начале каждой эпохи

Обновить метрику потерь, в которой хранится среднее значение потерь

Вернуть среднее значение потерь, получившееся к данному моменту, обратившись к экземпляру метрики loss_tracker

258    Глава 7. Работа с Keras: глубокое погружение

Теперь.можно.создать.экземпляр.модели,.скомпилировать.ее.(в.данном.случае. мы.передаем.только.оптимизатор,.потому.что.потери.определены.вне.модели). и.обучить,.используя.fit() .как.обычно:

inputs = keras.Input(shape=(28 * 28,))

features = layers.Dense(512, activation="relu")(inputs) features = layers.Dropout(0.5)(features)

outputs = layers.Dense(10, activation="softmax")(features) model = CustomModel(inputs, outputs)

model.compile(optimizer=keras.optimizers.RMSprop()) model.fit(train_images, train_labels, epochs=3)

Отметим.несколько.важных.моментов:

.данный.подход.можно.использовать.также.при.построении.моделей.с.по­ мощью.функционального.API.—.он.не.зависит.от.способа.построения.модели:. с.применением.класса.Sequential,.функционального.API.или.наследованием. класса.Model;

.при.переопределении.метода.train_step .не.нужно.использовать.декоратор. @tf.function .—.фреймворк.сделает.это.автоматически.

А.что.насчет.метрик.и.функции.потерь,.которые.настраиваются.с.помощью. compile()?.После.вызова.compile() .вы.получаете.доступ.к:

. self.compiled_loss .—.функции.потерь,.переданной.в.вызов.compile();

.self.compiled_metrics .—.обертке.для.списка.метрик,.которая.позволяет.вы- звать.self.compiled_metrics.update_state().и.обновить.сразу.все.метрики;

.self.metrics.—.фактическому.списку.метрик,.переданному.в.вызов.compile().. Обратите.внимание,.что.он.также.включает.метрику,.предназначенную.для. отслеживания.потерь,.подобно.тому.как.мы.делали.это.вручную.с.помощью. нашей.метрики.loss_tracking_metric.

То.есть.мы.можем.написать.такой.класс: class CustomModel(keras.Model):

def train_step(self, data):

Вычислить величину потерь

 

 

inputs, targets = data

 

вызовом self.compiled_loss

 

with tf.GradientTape() as tape:

 

 

 

 

predictions = self(inputs, training=True)

 

loss = self.compiled_loss(targets, predictions)

 

 

 

 

gradients = tape.gradient(loss, model.trainable_weights) optimizer.apply_gradients(zip(gradients, model.trainable_weights)) self.compiled_metrics.update_state(targets, predictions)

return {m.name: m.result() for m in self.metrics}

Обновить метрики модели

Вернуть словарь,

с помощью обертки

отображающий имена метрик

self.compiled_metrics

в их текущие значения