35 lines
909 B
Python
Raw Normal View History

2025-01-27 20:01:02 +01:00
import os
import cv2
import torch
import numpy as np
2025-01-27 22:38:23 +01:00
model = torch.load("models/bayes_cat_dog_classifier.pth")
model.eval()
2025-01-27 20:01:02 +01:00
2025-01-27 22:38:23 +01:00
model.to("cuda")
2025-01-27 20:01:02 +01:00
2025-01-27 22:38:23 +01:00
IMG_SIZE = 128
2025-01-27 20:01:02 +01:00
2025-01-27 22:38:23 +01:00
def predict_image(image_path):
img = cv2.imread(image_path)
img = cv2.resize(img, (IMG_SIZE, IMG_SIZE)) / 255.0
img = np.transpose(img, (2, 0, 1)) # Convert to (C, H, W)
img_tensor = torch.tensor(img, dtype=torch.float32).unsqueeze(0).to("cuda") # Add batch dimension
2025-01-27 20:01:02 +01:00
2025-01-27 22:38:23 +01:00
model.eval() # Set model to evaluation mode
with torch.no_grad():
output = model(img_tensor)
predicted = torch.argmax(output, dim=1).item()
2025-01-27 20:01:02 +01:00
2025-01-27 22:38:23 +01:00
return "Dog" if predicted == 1 else "Cat"
2025-01-27 20:01:02 +01:00
2025-01-27 22:38:23 +01:00
preds = []
for filename in os.listdir("dataset/test_set/XD/"):
img_path = os.path.join("dataset/test_set/XD/", filename)
prediction = predict_image(img_path)
preds.append(prediction)
2025-01-27 20:01:02 +01:00
2025-01-27 22:38:23 +01:00
print(preds.count("Cat"))
print(preds.count("Dog"))