Compare commits

..

3 Commits

6 changed files with 109 additions and 51 deletions

View File

@ -5,9 +5,9 @@ This project is meant as a way to gradually bring improvements on the bayesian n
## Table of Contents
* [Objectives 🎯](#objectives)
* [Requirements 📋](#requirements)
* [Running the project 🚀](#running-the-project)
* [Development 🔨](#development)
* [Requirements 📋](#requirements)
* [Citations 📝](#citations)
## Objectives 🎯
@ -16,6 +16,12 @@ This project is meant as a way to gradually bring improvements on the bayesian n
- [ ] Generate some graphs to visualize the data
- [ ] Make a CLI
## Requirements 📋
To run the projet you need the following requirements:
- Python 3.12
- venv
## Running the project 🚀
```sh
@ -29,6 +35,4 @@ python main.py
## Development 🔨
### Requirements 📋
## Citations 📝

80
main.py
View File

@ -1,46 +1,54 @@
import inquirer
import typer
import pyfiglet
from yaspin import yaspin
from train import train_model
from predict import make_predictions
import os
import cv2
import torch
import numpy as np
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
choice = ""
model = torch.load("bayes_cat_dog_classifier.pth")
model.eval()
model.to(DEVICE)
def main():
choice = inquirer.list_input("What would you like to do?", choices=["Run tests", "Train a model", "Visualize training data"])
IMG_SIZE = 128
if choice == "Run tests":
predictions()
elif choice == "Train a model":
training()
else:
visualize()
def predict_image(image_path):
img = cv2.imread(image_path)
img = cv2.resize(img, (IMG_SIZE, IMG_SIZE)) / 255.0
img = np.transpose(img, (2, 0, 1)) # Convert to (C, H, W)
img_tensor = torch.tensor(img, dtype=torch.float32).unsqueeze(0).to(DEVICE) # Add batch dimension
def predictions():
default_cats_path = "dataset/test_set/cats/"
default_dogs_path = "dataset/test_set/dogs/"
models_base_path = "models/"
model.eval() # Set model to evaluation mode
with torch.no_grad():
output = model(img_tensor)
predicted = torch.argmax(output, dim=1).item()
model_name = inquirer.list_input("Select the model to use", choices=os.listdir(models_base_path))
model_path = os.path.join(models_base_path, model_name)
return "Dog" if predicted == 1 else "Cat"
dataset = inquirer.list_input("Select the testing data (default dataset)", choices=['Cats', 'Dogs'])
if dataset == "Cats":
with yaspin(text="Making predictions...", color="cyan") as sp:
make_predictions(model_path, default_cats_path, "Cat", sp)
sp.ok("DONE")
else:
with yaspin(text="Making predictions...", color="cyan") as sp:
make_predictions(model_path, default_dogs_path, "Dog", sp)
sp.ok("DONE")
def training():
text = inquirer.text(message="Enter the name of the new model")
with yaspin(text="Training new model...", color="cyan") as sp:
train_model(text, sp)
sp.ok("DONE")
# Cats
preds = []
for filename in os.listdir("dataset/test_set/cats/"):
img_path = os.path.join("dataset/test_set/cats/", filename)
prediction = predict_image(img_path)
preds.append(prediction)
def visualize():
print("Not available yet...\n")
main()
print(preds.count("Cat"))
print(preds.count("Cat") / 1000)
# Dogs
preds = []
for filename in os.listdir("dataset/test_set/dogs/"):
img_path = os.path.join("dataset/test_set/dogs/", filename)
prediction = predict_image(img_path)
preds.append(prediction)
print(preds.count("Dog"))
print(preds.count("Dog") / 1000)
if __name__ == "__main__":
print(pyfiglet.figlet_format("Cats and Dogs"))
print(pyfiglet.figlet_format("classification"))
typer.run(main)

39
predict.py Normal file
View File

@ -0,0 +1,39 @@
import os
import cv2
import torch
import numpy as np
IMG_SIZE = 128
if torch.cuda.is_available():
DEVICE = torch.device("cuda")
elif torch.mps.is_available():
DEVICE = torch.device("mps")
else:
DEIVCE = torch.device("cpu")
def predict_image(image_path, model):
img = cv2.imread(image_path)
img = cv2.resize(img, (IMG_SIZE, IMG_SIZE)) / 255.0
img = np.transpose(img, (2, 0, 1)) # Convert to (C, H, W)
img_tensor = torch.tensor(img, dtype=torch.float32).unsqueeze(0).to(DEVICE) # Add batch dimension
model.eval()
with torch.no_grad():
output = model(img_tensor)
predicted = torch.argmax(output, dim=1).item()
return "Dog" if predicted == 1 else "Cat"
def make_predictions(model_path, dataset_path, type, spinner):
model = torch.load(model_path, map_location=DEVICE)
model.eval()
model.to(DEVICE)
preds = []
for filename in os.listdir(dataset_path):
img_path = os.path.join(dataset_path, filename)
prediction = predict_image(img_path, model)
preds.append(prediction)
spinner.write(f'Precision : {preds.count(type) / 1000 * 100}%')

View File

@ -1,5 +1,9 @@
inquirer==3.4.0
matplotlib==3.10.0
numpy==2.2.2
opencv-python==4.11.0.86
pyfiglet==1.0.2
torch==2.5.1
torchvision==0.20.1
typer==0.15.1
yaspin==3.1.0

View File

@ -6,16 +6,17 @@ from torch.utils.data import DataLoader, TensorDataset
from torch import nn, optim
from BN import CatDogClassifier
import time
from yaspin import yaspin
IMG_SIZE = 128
#DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
DEVICE = "mps"
if torch.cuda.is_available():
DEVICE = torch.device("cuda")
elif torch.mps.is_available():
DEVICE = torch.device("mps")
else:
DEIVCE = torch.device("cpu")
"""
This function loads all the images from the folder and labels them
"""
def load_images_from_folder(folder, label):
data = []
for filename in os.listdir(folder):
@ -27,7 +28,9 @@ def load_images_from_folder(folder, label):
data.append((img, label))
return data
if __name__ == "__main__":
def train_model(model_name, spinner):
spinner.write(f'Using the following device : {DEVICE}')
# Loading the dataset
cat_data = load_images_from_folder("dataset/training_set/cats", label=0)
@ -66,8 +69,8 @@ if __name__ == "__main__":
optimizer.step()
total_loss += loss.item()
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(dataloader):.4f}")
spinner.write(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(dataloader):.4f}")
print(f"Time taken: {(time.time() - start_time):.2f} seconds")
spinner.write(f"Time taken: {(time.time() - start_time):.2f} seconds")
torch.save(model, f"bayes_cat_dog_classifier.pth")
torch.save(model, f"models/{model_name}.pth")