fix: device issue in main
This commit is contained in:
parent
73c9b6bead
commit
cb6c695712
6
main.py
6
main.py
@ -3,9 +3,11 @@ import cv2
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
model = torch.load("bayes_cat_dog_classifier.pth")
|
||||
model.eval()
|
||||
model.to("cuda")
|
||||
model.to(DEVICE)
|
||||
|
||||
IMG_SIZE = 128
|
||||
|
||||
@ -13,7 +15,7 @@ def predict_image(image_path):
|
||||
img = cv2.imread(image_path)
|
||||
img = cv2.resize(img, (IMG_SIZE, IMG_SIZE)) / 255.0
|
||||
img = np.transpose(img, (2, 0, 1)) # Convert to (C, H, W)
|
||||
img_tensor = torch.tensor(img, dtype=torch.float32).unsqueeze(0).to("cuda") # Add batch dimension
|
||||
img_tensor = torch.tensor(img, dtype=torch.float32).unsqueeze(0).to(DEVICE) # Add batch dimension
|
||||
|
||||
model.eval() # Set model to evaluation mode
|
||||
with torch.no_grad():
|
||||
|
Loading…
x
Reference in New Issue
Block a user