728x90
keras를 이용한 전이학습(transfer learning)
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import ResNet50
from tensorflow.keras import layers, models
from tensorflow.keras.optimizers import Adam
train_dir = "./data/catanddog/train"
test_dir = "./data/catanddog/test"
01. 데이터 증강
train_datagen = ImageDataGenerator(
rescale = 1./255,
horizontal_flip = True,
)
test_datagen = ImageDataGenerator(rescale = 1./255) #테스트 데이터에 대해서는 증강하지 않고 스케일링만 진행
train_generator = train_datagen.flow_from_directory(
train_dir,
target_size = (224, 224),
batch_size = 32,
class_mode = "binary"
)
test_generator = test_datagen.flow_from_directory(
test_dir,
target_size = (224, 224),
batch_size = 32,
class_mode = "binary"
)
Found 385 images belonging to 2 classes.
Found 98 images belonging to 2 classes.
02. ResNet 모델 불러오기
base_model = ResNet50(weights = "imagenet", include_top = False, input_shape = (224, 224, 3))
base_model.trainable = False
model = models.Sequential([
base_model,
layers.GlobalAveragePooling2D(), # output shape를 맞춰주기 위해 추가
layers.Dense(1, activation = "sigmoid")
])
model.summary()
model.compile(optimizer = Adam(), loss = "binary_crossentropy", metrics = ["accuracy"])
# 학습
model.fit(train_generator, validation_data = test_generator, epochs = 10)
728x90
'09_DL(Deep_Learning)' 카테고리의 다른 글
26_임베딩 (0) | 2025.05.08 |
---|---|
25_토크나이징 (1) | 2025.05.07 |
23_전이학습(transfer learning)_pytorch (0) | 2025.05.07 |
22_이미지 분류 신경망 (0) | 2025.05.02 |
21_Fashion_MNIST(파이토치) (2) | 2025.05.02 |