-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathimage_classifier.py
More file actions
66 lines (54 loc) · 1.76 KB
/
image_classifier.py
File metadata and controls
66 lines (54 loc) · 1.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# image_classifier.py
import os
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.preprocessing import image_dataset_from_directory
# Settings
BATCH_SIZE = 32
IMG_SIZE = (224, 224)
EPOCHS = 10
DATA_DIR = "data" # Folder structure: data/train/class1, data/val/class1
# Load Data
train_ds = image_dataset_from_directory(
os.path.join(DATA_DIR, "train"),
shuffle=True,
batch_size=BATCH_SIZE,
image_size=IMG_SIZE
)
val_ds = image_dataset_from_directory(
os.path.join(DATA_DIR, "val"),
shuffle=False,
batch_size=BATCH_SIZE,
image_size=IMG_SIZE
)
# Prefetch
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.prefetch(buffer_size=AUTOTUNE)
# Data Augmentation
data_augmentation = tf.keras.Sequential([
layers.RandomFlip('horizontal'),
layers.RandomRotation(0.2),
layers.RandomZoom(0.1),
])
# Pretrained Model (MobileNetV2)
base_model = MobileNetV2(input_shape=IMG_SIZE + (3,), include_top=False, weights='imagenet')
base_model.trainable = False
# Build Model
inputs = tf.keras.Input(shape=IMG_SIZE + (3,))
x = data_augmentation(inputs)
x = tf.keras.applications.mobilenet_v2.preprocess_input(x)
x = base_model(x, training=False)
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dropout(0.2)(x)
outputs = layers.Dense(len(train_ds.class_names), activation='softmax')(x)
model = models.Model(inputs, outputs)
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# Train
model.fit(train_ds, validation_data=val_ds, epochs=EPOCHS)
# Save Model
model.save("saved_model/image_classifier")
print("✅ Model saved to 'saved_model/image_classifier'")