167 lines
6.2 KiB
Python

import os
import cv2
import numpy as np
import torch
from collections import defaultdict
import matplotlib.pyplot as plt
class BayesianClassifier:
def __init__(self):
self.feature_means = {}
self.feature_variances = {}
self.class_priors = {}
self.classes = []
# Initialize HOG descriptor with standard parameters
self.hog = cv2.HOGDescriptor(
_winSize=(28, 28),
_blockSize=(8, 8),
_blockStride=(4, 4),
_cellSize=(8, 8),
_nbins=9
)
def extract_features(self, image):
try:
# Convert image to grayscale
if len(image.shape) == 3 and image.shape[2] == 3:
gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
else:
gray_image = image
# Apply adaptive thresholding
binary_image = cv2.adaptiveThreshold(
gray_image, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY_INV, 11, 2
)
# Find contours
contours, _ = cv2.findContours(binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
print("No contours found.")
return np.array([])
features = []
for contour in contours:
if cv2.contourArea(contour) < 22:
continue
x, y, w, h = cv2.boundingRect(contour)
letter_image = gray_image[y:y + h, x:x + w]
letter_image = cv2.resize(letter_image, (28, 28))
# Compute HOG features
hog_features = self.hog.compute(letter_image)
features.append(hog_features.flatten())
features = np.array(features)
if features.size == 0:
print("No features extracted.")
return np.array([])
norms = np.linalg.norm(features, axis=1, keepdims=True)
features = features / np.where(norms > 1e-6, norms, 1)
return features
except Exception as e:
print(f"Error in extract_features: {e}")
return np.array([])
def train(self, dataset_path):
class_features = defaultdict(list)
total_images = 0
for class_name in os.listdir(dataset_path):
class_folder_path = os.path.join(dataset_path, class_name)
if os.path.isdir(class_folder_path):
if class_name not in self.classes:
self.classes.append(class_name)
for img_name in os.listdir(class_folder_path):
img_path = os.path.join(class_folder_path, img_name)
if os.path.isfile(img_path):
try:
image = cv2.imread(img_path)
if image is not None:
features = self.extract_features(image)
if features.size > 0:
for feature in features:
class_features[class_name].append(feature)
total_images += 1
else:
print(f"No features extracted for {img_path}")
else:
print(f"Failed to load image: {img_path}")
except Exception as e:
print(f"Error processing {img_path}: {e}")
for class_name in self.classes:
if class_name in class_features:
features = np.array(class_features[class_name])
self.feature_means[class_name] = np.mean(features, axis=0)
self.feature_variances[class_name] = np.var(features, axis=0) + 1e-6
self.class_priors[class_name] = len(features) / total_images
print("Training completed for classes:", self.classes)
def save_model(self, model_path):
model_data = {
"feature_means": self.feature_means,
"feature_variances": self.feature_variances,
"class_priors": self.class_priors,
"classes": self.classes
}
if not os.path.exists(os.path.dirname(model_path)):
os.makedirs(os.path.dirname(model_path))
torch.save(model_data, model_path)
print(f"Model saved to {model_path}")
def load_model(self, model_path):
if os.path.exists(model_path):
model_data = torch.load(model_path, weights_only=False)
self.feature_means = model_data["feature_means"]
self.feature_variances = model_data["feature_variances"]
self.class_priors = model_data["class_priors"]
self.classes = model_data["classes"]
print(f"Model loaded from {model_path}")
else:
print(f"No model found at {model_path}.")
def predict(self, image):
try:
features = self.extract_features(image)
if features.size == 0:
print("Empty features, skipping prediction.")
return None
posteriors = {}
for class_name in self.classes:
mean = self.feature_means[class_name]
variance = self.feature_variances[class_name]
prior = self.class_priors[class_name]
likelihood = -0.5 * np.sum(((features - mean) ** 2) / variance + np.log(2 * np.pi * variance))
posterior = likelihood + np.log(prior)
posteriors[class_name] = posterior
return max(posteriors, key=posteriors.get)
except Exception as e:
print(f"Error in prediction: {e}")
return None
def visualize(self):
if not self.classes:
print("No classes to visualize.")
return
for class_name in self.classes:
mean_features = self.feature_means[class_name]
plt.figure(figsize=(10, 4))
plt.title(f"Mean features for class: {class_name}")
plt.plot(mean_features)
plt.xlabel("Feature Index")
plt.ylabel("Mean Value")
plt.grid(True)
plt.show()