76 lines
2.2 KiB
Python
Raw Normal View History

2025-01-27 22:38:23 +01:00
import os
import cv2
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
from torch import nn, optim
from BN import CatDogClassifier
import time
2025-01-31 23:21:49 +01:00
from yaspin import yaspin
2025-01-27 22:38:23 +01:00
IMG_SIZE = 128
2025-01-31 23:21:49 +01:00
if torch.cuda.is_available():
DEVICE = torch.device("cuda")
elif torch.mps.is_available():
DEVICE = torch.device("mps")
else:
2025-01-31 23:23:10 +01:00
DEVICE = torch.device("cpu")
2025-01-27 22:38:23 +01:00
def load_images_from_folder(folder, label):
data = []
for filename in os.listdir(folder):
img_path = os.path.join(folder, filename)
img = cv2.imread(img_path)
if img is not None:
img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
img = img / 255.0
data.append((img, label))
return data
2025-01-31 23:21:49 +01:00
def train_model(model_name, spinner):
spinner.write(f'Using the following device : {DEVICE}')
2025-01-27 23:20:47 +01:00
2025-01-27 22:38:23 +01:00
# Loading the dataset
cat_data = load_images_from_folder("dataset/training_set/cats", label=0)
dog_data = load_images_from_folder("dataset/training_set/dogs", label=1)
dataset = cat_data + dog_data
np.random.shuffle(dataset)
X = np.array([item[0] for item in dataset], dtype=np.float32)
Y = np.array([item[1] for item in dataset], dtype=np.int64)
X = np.transpose(X, (0, 3, 1, 2))
X_tensor = torch.tensor(X).to(DEVICE)
Y_tensor = torch.tensor(Y).to(DEVICE)
dataset = TensorDataset(X_tensor, Y_tensor)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
2025-01-31 23:43:25 +01:00
model = CatDogClassifier(img_size=IMG_SIZE)
2025-01-27 22:38:23 +01:00
model = model.to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
2025-01-27 23:20:47 +01:00
num_epochs = 20
2025-01-27 22:38:23 +01:00
start_time = time.time()
2025-01-27 23:20:47 +01:00
model.train()
2025-01-27 22:38:23 +01:00
for epoch in range(num_epochs):
total_loss = 0
for images, labels in dataloader:
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
2025-01-31 23:21:49 +01:00
spinner.write(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(dataloader):.4f}")
2025-01-27 22:38:23 +01:00
2025-01-31 23:21:49 +01:00
spinner.write(f"Time taken: {(time.time() - start_time):.2f} seconds")
2025-01-27 23:20:47 +01:00
2025-01-31 23:21:49 +01:00
torch.save(model, f"models/{model_name}.pth")