Giới thiệu về Transfer learning và Fine-tuning (Phần 2)

Tiếp nối bài trước về Transfer Learning, hôm nay chúng ta cùng tìm hiểu về Fine Tuning.

Mở đầu

Fine tuning : Thuật ngữ này có thể được dịch là “Tinh chỉnh” – là một quá trình sử dụng một mô hình mạng đã được huấn luyện cho một nhiệm vụ nhất định để thực hiện một nhiệm vụ tương tự. Sở dĩ cách lý giải này có phần giống Transfer Learning – bởi Fine Tuning là một kỹ thuật Transfer Learning mà ! Hãy cùng tìm hiểu xem cụ thể nó là thế nào nhé.

Khi mô hình của bạn đã hội tụ trên dữ liệu mới, bạn có thể cố gắng giải phóng toàn bộ hoặc một phần của mô hình cơ sở và đào tạo lại toàn bộ mô hình từ đầu đến cuối với tỷ lệ học tập rất thấp.

Đây là bước cuối cùng tùy chọn có thể mang lại cho bạn những cải tiến gia tăng. Nó cũng có thể dẫn đến tình trạng overfitting – hãy cân nhắc điều đó.

Điều quan trọng là chỉ thực hiện bước này sau khi mô hình với các lớp đông lạnh đã được huấn luyện để hội tụ. Nếu bạn trộn các lớp trainable được khởi tạo ngẫu nhiên với các lớp trainable chứa các tính năng đã được huấn luyện trước, các lớp được khởi tạo ngẫu nhiên sẽ gây ra các cập nhật gradient rất lớn trong quá trình huấn luyện, điều này sẽ phá hủy các tính năng đã được huấn luyện trước của bạn.

Một vấn đề quan trọng nữa là là sử dụng tỷ lệ học tập rất thấp ở giai đoạn này, bởi vì bạn đang đào tạo một mô hình lớn hơn nhiều so với trong vòng đào tạo đầu tiên, trên một tập dữ liệu thường rất nhỏ. Do đó, bạn có nguy cơ bị overfitting rất nhanh nếu áp dụng các biện pháp cập nhật trọng lượng lớn. Ở đây, bạn chỉ muốn đọc các trọng số được huấn luyện trước theo cách tăng dần.

Đây là cách implement fine-tuning toàn bộ mô hình cơ sở:

# Hủy đóng băng mô hình cơ sở
base_model.trainable = True

# Quan trọng là phải biên dịch lại mô hình của bạn sau khi thực hiện bất kỳ thay đổi nào đối với thuộc tính `trainable` của bất kỳ lớp bên trong nào
# Để các thay đổi của bạn được tính đến
model.compile(optimizer=keras.optimizers.Adam(1e-5),  # Tỉ lệ học rất thấp
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])

# Train. Cẩn thận để dừng lại trước khi bị overfit
model.fit(new_dataset, epochs=10, callbacks=..., validation_data=...)

Lưu ý quan trọng về compile()trainable

Việc gọi compile() trên một mô hình có nghĩa là "đóng băng" hành vi của mô hình đó. Điều này ngụ ý rằng các giá trị thuộc tính trainable tại thời điểm mô hình được biên dịch nên được bảo toàn trong suốt thời gian tồn tại của mô hình đó, cho đến khi quá trình biên dịch được gọi lại. Do đó, nếu bạn thay đổi bất kỳ giá trị có thể đào tạo nào, hãy đảm bảo gọi lại compile () trên mô hình của bạn để các thay đổi của bạn được tính đến.

Lưu ý quan trọng về lớp BatchNormalization

Nhiều mô hình hình ảnh chứa các lớp BatchNormalization. Lớp đó là một trường hợp đặc biệt trên mọi số lượng có thể tưởng tượng được. Dưới đây là một số điều cần ghi nhớ.

  • BatchNormalization chứa 2 trọng lượng không thể đào tạo được cập nhật trong quá trình đào tạo. Đây là các biến theo dõi giá trị trung bình và phương sai của các yếu tố đầu vào.
  • Khi bạn đặt bn_layer.trainable = False, lớp BatchNormalization sẽ chạy ở chế độ suy luận và sẽ không cập nhật thống kê trung bình & phương sai của nó. Điều này không đúng với các lớp khác nói chung, vì khả năng tập tạ & chế độ suy luận / huấn luyện là hai khái niệm trực giao. Nhưng cả hai được gắn với nhau trong trường hợp của lớp BatchNormalization.
  • Khi bạn giải phóng một mô hình có chứa các lớp BatchNormalization để thực hiện tinh chỉnh, bạn nên giữ các lớp BatchNormalization ở chế độ suy luận bằng cách chuyển training = False khi gọi mô hình cơ sở. Nếu không, các bản cập nhật được áp dụng cho các trọng lượng không thể đào tạo sẽ đột ngột phá hủy những gì mà mô hình đã học được.

Bạn sẽ thấy mẫu này hoạt động trong ví dụ end-to-end ở cuối hướng dẫn này.

Transfer learning & fine-tuning với một vòng lặp đào tạo tùy chỉnh

Nếu thay vì fit(), bạn đang sử dụng vòng lặp đào tạo cấp thấp của riêng mình, thì quy trình làm việc về cơ bản vẫn giữ nguyên. Bạn nên cẩn thận chỉ tính đến danh sách model.trainable_weights khi áp dụng cập nhật gradient:

# Khởi tạo mô hình cơ sở
base_model = keras.applications.Xception(
    weights='imagenet',
    input_shape=(150, 150, 3),
    include_top=False)
# Đóng băng mô hình cơ sở
base_model.trainable = False

# Khởi tạo một mô hình mới on top.
inputs = keras.Input(shape=(150, 150, 3))
x = base_model(inputs, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
optimizer = keras.optimizers.Adam()

# Lặp lại các lô của tập dữ liệu.
for inputs, targets in new_dataset:
    # Mở GradientTape.
    with tf.GradientTape() as tape:
        # Chuyển tiếp
        predictions = model(inputs)
        # Tính toán giá trị tổn thất cho lô này.
        loss_value = loss_fn(targets, predictions)

    # Lấy gradients với độ giảm của trọng số *trainable*.
    gradients = tape.gradient(loss_value, model.trainable_weights)
    # Cập nhật trọng số của mô hình
    optimizer.apply_gradients(zip(gradients, model.trainable_weights))

Một ví dụ từ đầu đến cuối: tinh chỉnh mô hình phân loại hình ảnh trên tập dữ liệu mèo và chó

Để củng cố những khái niệm này, hãy hướng dẫn bạn qua một ví dụ cụ thể về học tập và tinh chỉnh chuyển giao từ đầu đến cuối. Chúng tôi sẽ tải mô hình Xception, được đào tạo trước trên ImageNet và sử dụng nó trên tập dữ liệu phân loại Kaggle "mèo so với chó".

Lấy dữ liệu

Đầu tiên, hãy tìm nạp tập dữ liệu mèo và chó bằng TFDS. Nếu bạn có tập dữ liệu của riêng mình, có thể bạn sẽ muốn sử dụng tiện ích tf.keras.preprocessing.image_dataset_from_directory để tạo các đối tượng tập dữ liệu có nhãn tương tự từ một tập hợp các hình ảnh trên đĩa được lưu trữ vào các thư mục dành riêng cho lớp.

Học chuyển giao hữu ích nhất khi làm việc với các tập dữ liệu rất nhỏ. Để giữ cho tập dữ liệu của chúng tôi nhỏ, chúng tôi sẽ sử dụng 40% dữ liệu đào tạo ban đầu (25.000 hình ảnh) để đào tạo, 10% để xác thực và 10% để kiểm tra.

import tensorflow_datasets as tfds

tfds.disable_progress_bar()

train_ds, validation_ds, test_ds = tfds.load(
    "cats_vs_dogs",
    # Reserve 10% for validation and 10% for test
    split=["train[:40%]", "train[40%:50%]", "train[50%:60%]"],
    as_supervised=True,  # Include labels
)

print("Number of training samples: %d" % tf.data.experimental.cardinality(train_ds))
print("Number of validation samples: %d" % tf.data.experimental.cardinality(validation_ds))
print("Number of test samples: %d" % tf.data.experimental.cardinality(test_ds))
Number of training samples: 9305
Number of validation samples: 2326
Number of test samples: 2326

Đây là 9 hình ảnh đầu tiên trong tập dữ liệu đào tạo – như bạn có thể thấy, chúng đều có kích thước khác nhau.

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 10))
for i, (image, label) in enumerate(train_ds.take(9)):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(image)
    plt.title(int(label))
    plt.axis("off")

Chúng ta cũng có thể thấy rằng nhãn 1 là "chó" và nhãn 0 là "mèo".

Chuẩn hóa dữ liệu

Hình ảnh thô của có nhiều kích cỡ khác nhau. Ngoài ra, mỗi pixel bao gồm 3 giá trị integer 0 đến 255 (giá trị RGB). Đây không phải là một sự phù hợp tuyệt vời cho mạng nơ-ron. Chúng ta cần làm 2 việc:

  • Chuẩn hóa thành kích thước hình ảnh cố định. Chúng ta chọn 150×150.
  • Chuẩn hóa các giá trị pixel từ -1 đến 1. Chúng tôi sẽ thực hiện việc này bằng cách sử dụng lớp Chuẩn hóa như một phần của chính mô hình.

Nói chung, bạn nên phát triển các mô hình lấy dữ liệu thô làm đầu vào, trái ngược với các mô hình lấy dữ liệu đã được xử lý trước. Lý do là, nếu mô hình của bạn yêu cầu dữ liệu được xử lý trước, bất kỳ khi nào bạn xuất mô hình của mình để sử dụng ở nơi khác (trong trình duyệt web, trong ứng dụng dành cho thiết bị di động), bạn sẽ cần phải thực hiện lại cùng một quy trình xử lý trước. Điều này trở nên rất phức tạp rất nhanh chóng. Vì vậy, chúng ta nên thực hiện ít tiền xử lý nhất có thể trước khi đưa vào mô hình.

Ở đây, chúng ta sẽ thực hiện thay đổi kích thước hình ảnh trong đường ống dữ liệu (vì mạng nơ-ron sâu chỉ có thể xử lý các lô dữ liệu liền kề) và chúng tôi sẽ thực hiện điều chỉnh tỷ lệ giá trị đầu vào như một phần của mô hình, khi chúng tôi tạo nó.

Hãy thay đổi kích thước hình ảnh thành 150×150:

size = (150, 150)

train_ds = train_ds.map(lambda x, y: (tf.image.resize(x, size), y))
validation_ds = validation_ds.map(lambda x, y: (tf.image.resize(x, size), y))
test_ds = test_ds.map(lambda x, y: (tf.image.resize(x, size), y))

Bên cạnh đó, hãy tập hợp dữ liệu và sử dụng bộ nhớ đệm & tìm nạp trước để tối ưu hóa tốc độ tải.

batch_size = 32

train_ds = train_ds.cache().batch(batch_size).prefetch(buffer_size=10)
validation_ds = validation_ds.cache().batch(batch_size).prefetch(buffer_size=10)
test_ds = test_ds.cache().batch(batch_size).prefetch(buffer_size=10)

Sử dụng tăng dữ liệu ngẫu nhiên

Khi bạn không có tập dữ liệu hình ảnh lớn, bạn nên đưa tính đa dạng mẫu vào một cách giả tạo bằng cách áp dụng các phép biến đổi ngẫu nhiên nhưng thực tế cho hình ảnh huấn luyện, chẳng hạn như lật ngang ngẫu nhiên hoặc xoay ngẫu nhiên nhỏ. Điều này giúp mô hình hiển thị các khía cạnh khác nhau của dữ liệu đào tạo trong khi làm chậm quá trình overfitting.

from tensorflow import keras
from tensorflow.keras import layers

data_augmentation = keras.Sequential(
    [layers.RandomFlip("horizontal"), layers.RandomRotation(0.1),]
)

Hãy hình dung hình ảnh đầu tiên của lô đầu tiên trông như thế nào sau nhiều lần biến đổi ngẫu nhiên:

import numpy as np

for images, labels in train_ds.take(1):
    plt.figure(figsize=(10, 10))
    first_image = images[0]
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        augmented_image = data_augmentation(
            tf.expand_dims(first_image, 0), training=True
        )
        plt.imshow(augmented_image[0].numpy().astype("int32"))
        plt.title(int(labels[0]))
        plt.axis("off")

Xây dựng một mô hình

Bây giờ chúng ta hãy xây dựng một mô hình theo kế hoạch chi tiết đã giải thích trước đó.

Lưu ý rằng:

  • Thêm một lớp Rescaling để chia tỷ lệ các giá trị đầu vào (ban đầu trong phạm vi [0, 255]) thành phạm vi [-1, 1].
  • Thêm một lớp Dropout trước lớp phân loại, để chính quy hóa.
  • Đảm bảo training = False khi gọi mô hình cơ sở, để nó chạy ở chế độ suy luận, do đó thống kê batchnorm không được cập nhật ngay cả sau khi chúng tôi giải phóng mô hình cơ sở để fine-tuning.
base_model = keras.applications.Xception(
    weights="imagenet",  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False,
)  # Do not include the ImageNet classifier at the top.

# Freeze the base_model
base_model.trainable = False

# Create new model on top
inputs = keras.Input(shape=(150, 150, 3))
x = data_augmentation(inputs)  # Apply random data augmentation

# Pre-trained Xception weights requires that input be scaled
# from (0, 255) to a range of (-1., +1.), the rescaling layer
# outputs: `(inputs * scale) + offset`
scale_layer = keras.layers.Rescaling(scale=1 / 127.5, offset=-1)
x = scale_layer(x)

# The base model contains batchnorm layers. We want to keep them in inference mode
# when we unfreeze the base model for fine-tuning, so we make sure that the
# base_model is running in inference mode here.
x = base_model(x, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(0.2)(x)  # Regularize with dropout
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

model.summary()
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_5 (InputLayer)         [(None, 150, 150, 3)]     0         
_________________________________________________________________
sequential_3 (Sequential)    (None, 150, 150, 3)       0         
_________________________________________________________________
rescaling (Rescaling)        (None, 150, 150, 3)       0         
_________________________________________________________________
xception (Functional)        (None, 5, 5, 2048)        20861480  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
_________________________________________________________________
dropout (Dropout)            (None, 2048)              0         
_________________________________________________________________
dense_7 (Dense)              (None, 1)                 2049      
=================================================================
Total params: 20,863,529
Trainable params: 2,049
Non-trainable params: 20,861,480
_________________________________________________________________

Đào tạo lớp trên cùng

model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 20
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Epoch 1/20
291/291 [==============================] - 133s 451ms/step - loss: 0.1670 - binary_accuracy: 0.9267 - val_loss: 0.0830 - val_binary_accuracy: 0.9716
Epoch 2/20
291/291 [==============================] - 135s 465ms/step - loss: 0.1208 - binary_accuracy: 0.9502 - val_loss: 0.0768 - val_binary_accuracy: 0.9716
Epoch 3/20
291/291 [==============================] - 135s 463ms/step - loss: 0.1062 - binary_accuracy: 0.9572 - val_loss: 0.0757 - val_binary_accuracy: 0.9716
Epoch 4/20
291/291 [==============================] - 137s 469ms/step - loss: 0.1024 - binary_accuracy: 0.9554 - val_loss: 0.0733 - val_binary_accuracy: 0.9725
Epoch 5/20
291/291 [==============================] - 137s 470ms/step - loss: 0.1004 - binary_accuracy: 0.9587 - val_loss: 0.0735 - val_binary_accuracy: 0.9729
Epoch 6/20
291/291 [==============================] - 136s 467ms/step - loss: 0.0979 - binary_accuracy: 0.9577 - val_loss: 0.0747 - val_binary_accuracy: 0.9708
Epoch 7/20
291/291 [==============================] - 134s 462ms/step - loss: 0.0998 - binary_accuracy: 0.9596 - val_loss: 0.0706 - val_binary_accuracy: 0.9725
Epoch 8/20
291/291 [==============================] - 133s 457ms/step - loss: 0.1029 - binary_accuracy: 0.9592 - val_loss: 0.0720 - val_binary_accuracy: 0.9733
Epoch 9/20
291/291 [==============================] - 135s 466ms/step - loss: 0.0937 - binary_accuracy: 0.9625 - val_loss: 0.0707 - val_binary_accuracy: 0.9721
Epoch 10/20
291/291 [==============================] - 137s 472ms/step - loss: 0.0967 - binary_accuracy: 0.9580 - val_loss: 0.0720 - val_binary_accuracy: 0.9712
Epoch 11/20
291/291 [==============================] - 135s 463ms/step - loss: 0.0961 - binary_accuracy: 0.9612 - val_loss: 0.0802 - val_binary_accuracy: 0.9699
Epoch 12/20
291/291 [==============================] - 134s 460ms/step - loss: 0.0963 - binary_accuracy: 0.9638 - val_loss: 0.0721 - val_binary_accuracy: 0.9716
Epoch 13/20
291/291 [==============================] - 136s 468ms/step - loss: 0.0925 - binary_accuracy: 0.9635 - val_loss: 0.0736 - val_binary_accuracy: 0.9686
Epoch 14/20
291/291 [==============================] - 138s 476ms/step - loss: 0.0909 - binary_accuracy: 0.9624 - val_loss: 0.0766 - val_binary_accuracy: 0.9703
Epoch 15/20
291/291 [==============================] - 136s 467ms/step - loss: 0.0949 - binary_accuracy: 0.9598 - val_loss: 0.0704 - val_binary_accuracy: 0.9725
Epoch 16/20
291/291 [==============================] - 133s 456ms/step - loss: 0.0969 - binary_accuracy: 0.9586 - val_loss: 0.0722 - val_binary_accuracy: 0.9708
Epoch 17/20
291/291 [==============================] - 135s 464ms/step - loss: 0.0913 - binary_accuracy: 0.9635 - val_loss: 0.0718 - val_binary_accuracy: 0.9716
Epoch 18/20
291/291 [==============================] - 137s 472ms/step - loss: 0.0915 - binary_accuracy: 0.9639 - val_loss: 0.0727 - val_binary_accuracy: 0.9725
Epoch 19/20
291/291 [==============================] - 134s 460ms/step - loss: 0.0938 - binary_accuracy: 0.9631 - val_loss: 0.0707 - val_binary_accuracy: 0.9733
Epoch 20/20
291/291 [==============================] - 134s 460ms/step - loss: 0.0971 - binary_accuracy: 0.9609 - val_loss: 0.0714 - val_binary_accuracy: 0.9716

<keras.callbacks.History at 0x7f4494e38f70>

Thực hiện một vòng tinh chỉnh toàn bộ mô hình

Cuối cùng, hãy giải phóng mô hình cơ sở và đào tạo toàn bộ mô hình từ đầu đến cuối với tỷ lệ học tập thấp.

Quan trọng là, mặc dù mô hình cơ sở trở nên có thể huấn luyện được, nhưng nó vẫn đang chạy ở chế độ suy luận vì chúng ta đã đặt training=False khi gọi nó khi chúng ta xây dựng mô hình. Điều này có nghĩa là các lớp chuẩn hóa hàng loạt bên trong sẽ không cập nhật thống kê hàng loạt của chúng. Nếu họ làm vậy, họ sẽ phá hủy các đại diện mà mô hình đã học cho đến nay.

base_model.trainable = True
model.summary()

model.compile(
    optimizer=keras.optimizers.Adam(1e-5),  # Low learning rate
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 10
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_5 (InputLayer)         [(None, 150, 150, 3)]     0         
_________________________________________________________________
sequential_3 (Sequential)    (None, 150, 150, 3)       0         
_________________________________________________________________
rescaling (Rescaling)        (None, 150, 150, 3)       0         
_________________________________________________________________
xception (Functional)        (None, 5, 5, 2048)        20861480  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
_________________________________________________________________
dropout (Dropout)            (None, 2048)              0         
_________________________________________________________________
dense_7 (Dense)              (None, 1)                 2049      
=================================================================
Total params: 20,863,529
Trainable params: 20,809,001
Non-trainable params: 54,528
_________________________________________________________________
Epoch 1/10
291/291 [==============================] - 567s 2s/step - loss: 0.0749 - binary_accuracy: 0.9689 - val_loss: 0.0605 - val_binary_accuracy: 0.9776
Epoch 2/10
291/291 [==============================] - 551s 2s/step - loss: 0.0559 - binary_accuracy: 0.9770 - val_loss: 0.0507 - val_binary_accuracy: 0.9798
Epoch 3/10
291/291 [==============================] - 545s 2s/step - loss: 0.0444 - binary_accuracy: 0.9832 - val_loss: 0.0502 - val_binary_accuracy: 0.9807
Epoch 4/10
291/291 [==============================] - 558s 2s/step - loss: 0.0365 - binary_accuracy: 0.9874 - val_loss: 0.0506 - val_binary_accuracy: 0.9807
Epoch 5/10
291/291 [==============================] - 550s 2s/step - loss: 0.0276 - binary_accuracy: 0.9890 - val_loss: 0.0477 - val_binary_accuracy: 0.9802
Epoch 6/10
291/291 [==============================] - 588s 2s/step - loss: 0.0206 - binary_accuracy: 0.9916 - val_loss: 0.0444 - val_binary_accuracy: 0.9832
Epoch 7/10
291/291 [==============================] - 542s 2s/step - loss: 0.0206 - binary_accuracy: 0.9923 - val_loss: 0.0502 - val_binary_accuracy: 0.9828
Epoch 8/10
291/291 [==============================] - 544s 2s/step - loss: 0.0153 - binary_accuracy: 0.9939 - val_loss: 0.0509 - val_binary_accuracy: 0.9819
Epoch 9/10
291/291 [==============================] - 548s 2s/step - loss: 0.0156 - binary_accuracy: 0.9934 - val_loss: 0.0610 - val_binary_accuracy: 0.9807
Epoch 10/10
291/291 [==============================] - 546s 2s/step - loss: 0.0176 - binary_accuracy: 0.9936 - val_loss: 0.0561 - val_binary_accuracy: 0.9789

<keras.callbacks.History at 0x7f4495056040> 

Sau 10 epochs, fine-tuning mang lại cho chúng ta một cải tiến tốt đẹp ở đây.

Kết bài

Trên đây chúng ta đã tìm hiểu kỹ thuật Fine-tuning. Cảm ơn các bạn đã giành thời gian theo dõi. Thân ái !

Tham khảo https://keras.io/guides/

Leave a Comment

* By using this form you agree with the storage and handling of your data by this website.

You may also like