09_DL(Deep_Learning)

24_전이학습(transfer learning)_keras

chuuvelop 2025. 5. 7. 23:02
728x90
keras를 이용한 전이학습(transfer learning)

 

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import ResNet50
from tensorflow.keras import layers, models
from tensorflow.keras.optimizers import Adam
train_dir = "./data/catanddog/train"
test_dir = "./data/catanddog/test"

 

 

01. 데이터 증강
train_datagen = ImageDataGenerator(
    rescale = 1./255,
    horizontal_flip = True,
)

test_datagen = ImageDataGenerator(rescale = 1./255) #테스트 데이터에 대해서는 증강하지 않고 스케일링만 진행

 

train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size = (224, 224),
    batch_size = 32,
    class_mode = "binary"
)

test_generator = test_datagen.flow_from_directory(
    test_dir,
    target_size = (224, 224),
    batch_size = 32,
    class_mode = "binary"
)
Found 385 images belonging to 2 classes.
Found 98 images belonging to 2 classes.

 

 

 

02. ResNet 모델 불러오기

 

base_model = ResNet50(weights = "imagenet", include_top = False, input_shape = (224, 224, 3))
base_model.trainable = False
model = models.Sequential([
    base_model,
    layers.GlobalAveragePooling2D(), # output shape를 맞춰주기 위해 추가
    layers.Dense(1, activation = "sigmoid")
])

 

model.summary()

 

model.compile(optimizer = Adam(), loss = "binary_crossentropy", metrics = ["accuracy"])

 

# 학습
model.fit(train_generator, validation_data = test_generator, epochs = 10)

728x90

'09_DL(Deep_Learning)' 카테고리의 다른 글

26_임베딩  (0) 2025.05.08
25_토크나이징  (1) 2025.05.07
23_전이학습(transfer learning)_pytorch  (0) 2025.05.07
22_이미지 분류 신경망  (0) 2025.05.02
21_Fashion_MNIST(파이토치)  (2) 2025.05.02