import os
import cv2
import torch
import numpy as np

IMG_SIZE = 128

if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
elif torch.mps.is_available():
    DEVICE = torch.device("mps")
else:
    DEVICE = torch.device("cpu")

def predict_image(image_path, model):
    img = cv2.imread(image_path)
    img = cv2.resize(img, (IMG_SIZE, IMG_SIZE)) / 255.0
    img = np.transpose(img, (2, 0, 1))  # Convert to (C, H, W)
    img_tensor = torch.tensor(img, dtype=torch.float32).unsqueeze(0).to(DEVICE)  # Add batch dimension

    model.eval()
    with torch.no_grad():
        output = model(img_tensor)
        predicted = torch.argmax(output, dim=1).item()

    return "Dog" if predicted == 1 else "Cat"

def make_predictions(model_path, dataset_path, type, spinner):
    model = torch.load(model_path, map_location=DEVICE)
    model.eval()
    model.to(DEVICE)

    preds = []
    for filename in os.listdir(dataset_path):
        img_path = os.path.join(dataset_path, filename)
        prediction = predict_image(img_path, model)
        preds.append(prediction)

    spinner.write(f'Precision : {preds.count(type) / 1000 * 100}%')