fix: device issue in main

This commit is contained in:
Nabil Ould Hamou 2025-01-27 23:21:33 +01:00
parent 73c9b6bead
commit cb6c695712

View File

@ -3,9 +3,11 @@ import cv2
import torch
import numpy as np
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = torch.load("bayes_cat_dog_classifier.pth")
model.eval()
model.to("cuda")
model.to(DEVICE)
IMG_SIZE = 128
@ -13,7 +15,7 @@ def predict_image(image_path):
img = cv2.imread(image_path)
img = cv2.resize(img, (IMG_SIZE, IMG_SIZE)) / 255.0
img = np.transpose(img, (2, 0, 1)) # Convert to (C, H, W)
img_tensor = torch.tensor(img, dtype=torch.float32).unsqueeze(0).to("cuda") # Add batch dimension
img_tensor = torch.tensor(img, dtype=torch.float32).unsqueeze(0).to(DEVICE) # Add batch dimension
model.eval() # Set model to evaluation mode
with torch.no_grad():