From cb6c6957122e550975a936f98d437dc57168ecb0 Mon Sep 17 00:00:00 2001 From: Nabil Ould Hamou Date: Mon, 27 Jan 2025 23:21:33 +0100 Subject: [PATCH] fix: device issue in main --- main.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index 5517bba..2e32937 100644 --- a/main.py +++ b/main.py @@ -3,9 +3,11 @@ import cv2 import torch import numpy as np +DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + model = torch.load("bayes_cat_dog_classifier.pth") model.eval() -model.to("cuda") +model.to(DEVICE) IMG_SIZE = 128 @@ -13,7 +15,7 @@ def predict_image(image_path): img = cv2.imread(image_path) img = cv2.resize(img, (IMG_SIZE, IMG_SIZE)) / 255.0 img = np.transpose(img, (2, 0, 1)) # Convert to (C, H, W) - img_tensor = torch.tensor(img, dtype=torch.float32).unsqueeze(0).to("cuda") # Add batch dimension + img_tensor = torch.tensor(img, dtype=torch.float32).unsqueeze(0).to(DEVICE) # Add batch dimension model.eval() # Set model to evaluation mode with torch.no_grad():