71 lines
2.0 KiB
Python
71 lines
2.0 KiB
Python
import os
|
|
import cv2
|
|
import numpy as np
|
|
import torch
|
|
from torch.utils.data import DataLoader, TensorDataset
|
|
from torch import nn, optim
|
|
from BN import CatDogClassifier
|
|
import time
|
|
|
|
IMG_SIZE = 128
|
|
|
|
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
|
"""
|
|
This function loads all the images from the folder and labels them
|
|
"""
|
|
def load_images_from_folder(folder, label):
|
|
data = []
|
|
for filename in os.listdir(folder):
|
|
img_path = os.path.join(folder, filename)
|
|
img = cv2.imread(img_path)
|
|
if img is not None:
|
|
img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
|
|
img = img / 255.0
|
|
data.append((img, label))
|
|
return data
|
|
|
|
if __name__ == "__main__":
|
|
|
|
# Loading the dataset
|
|
cat_data = load_images_from_folder("dataset/training_set/cats", label=0)
|
|
dog_data = load_images_from_folder("dataset/training_set/dogs", label=1)
|
|
|
|
dataset = cat_data + dog_data
|
|
np.random.shuffle(dataset)
|
|
|
|
X = np.array([item[0] for item in dataset], dtype=np.float32)
|
|
Y = np.array([item[1] for item in dataset], dtype=np.int64)
|
|
|
|
X = np.transpose(X, (0, 3, 1, 2))
|
|
|
|
X_tensor = torch.tensor(X).to(DEVICE)
|
|
Y_tensor = torch.tensor(Y).to(DEVICE)
|
|
|
|
dataset = TensorDataset(X_tensor, Y_tensor)
|
|
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
|
|
|
|
model = CatDogClassifier()
|
|
model = model.to(DEVICE)
|
|
criterion = nn.CrossEntropyLoss()
|
|
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
|
|
|
num_epochs = 20
|
|
start_time = time.time()
|
|
|
|
model.train()
|
|
for epoch in range(num_epochs):
|
|
total_loss = 0
|
|
for images, labels in dataloader:
|
|
optimizer.zero_grad()
|
|
outputs = model(images)
|
|
loss = criterion(outputs, labels)
|
|
loss.backward()
|
|
optimizer.step()
|
|
total_loss += loss.item()
|
|
|
|
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(dataloader):.4f}")
|
|
|
|
print(f"Time taken: {(time.time() - start_time):.2f} seconds")
|
|
|
|
torch.save(model, f"bayes_cat_dog_classifier.pth") |