167 lines
		
	
	
		
			6.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			167 lines
		
	
	
		
			6.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import os
 | |
| import cv2
 | |
| import numpy as np
 | |
| import torch
 | |
| from collections import defaultdict
 | |
| import matplotlib.pyplot as plt
 | |
| 
 | |
| 
 | |
| class BayesianClassifier:
 | |
|     def __init__(self):
 | |
|         self.feature_means = {}
 | |
|         self.feature_variances = {}
 | |
|         self.class_priors = {}
 | |
|         self.classes = []
 | |
| 
 | |
|         # Initialize HOG descriptor with standard parameters
 | |
|         self.hog = cv2.HOGDescriptor(
 | |
|             _winSize=(28, 28),
 | |
|             _blockSize=(8, 8),
 | |
|             _blockStride=(4, 4),
 | |
|             _cellSize=(8, 8),
 | |
|             _nbins=9
 | |
|         )
 | |
| 
 | |
|     def extract_features(self, image):
 | |
|         try:
 | |
|             # Convert image to grayscale
 | |
|             if len(image.shape) == 3 and image.shape[2] == 3:
 | |
|                 gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
 | |
|             else:
 | |
|                 gray_image = image
 | |
| 
 | |
|             # Apply adaptive thresholding
 | |
|             binary_image = cv2.adaptiveThreshold(
 | |
|                 gray_image, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY_INV, 11, 2
 | |
|             )
 | |
| 
 | |
|             # Find contours
 | |
|             contours, _ = cv2.findContours(binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
 | |
|             if not contours:
 | |
|                 print("No contours found.")
 | |
|                 return np.array([])
 | |
| 
 | |
|             features = []
 | |
|             for contour in contours:
 | |
|                 if cv2.contourArea(contour) < 22:
 | |
|                     continue
 | |
| 
 | |
|                 x, y, w, h = cv2.boundingRect(contour)
 | |
|                 letter_image = gray_image[y:y + h, x:x + w]
 | |
|                 letter_image = cv2.resize(letter_image, (28, 28))
 | |
| 
 | |
|                 # Compute HOG features
 | |
|                 hog_features = self.hog.compute(letter_image)
 | |
|                 features.append(hog_features.flatten())
 | |
| 
 | |
|             features = np.array(features)
 | |
|             if features.size == 0:
 | |
|                 print("No features extracted.")
 | |
|                 return np.array([])
 | |
| 
 | |
|             norms = np.linalg.norm(features, axis=1, keepdims=True)
 | |
|             features = features / np.where(norms > 1e-6, norms, 1)
 | |
| 
 | |
|             return features
 | |
|         except Exception as e:
 | |
|             print(f"Error in extract_features: {e}")
 | |
|             return np.array([])
 | |
| 
 | |
|     def train(self, dataset_path):
 | |
|         class_features = defaultdict(list)
 | |
|         total_images = 0
 | |
| 
 | |
|         for class_name in os.listdir(dataset_path):
 | |
|             class_folder_path = os.path.join(dataset_path, class_name)
 | |
|             if os.path.isdir(class_folder_path):
 | |
|                 if class_name not in self.classes:
 | |
|                     self.classes.append(class_name)
 | |
| 
 | |
|                 for img_name in os.listdir(class_folder_path):
 | |
|                     img_path = os.path.join(class_folder_path, img_name)
 | |
|                     if os.path.isfile(img_path):
 | |
|                         try:
 | |
|                             image = cv2.imread(img_path)
 | |
|                             if image is not None:
 | |
|                                 features = self.extract_features(image)
 | |
|                                 if features.size > 0:
 | |
|                                     for feature in features:
 | |
|                                         class_features[class_name].append(feature)
 | |
|                                     total_images += 1
 | |
|                                 else:
 | |
|                                     print(f"No features extracted for {img_path}")
 | |
|                             else:
 | |
|                                 print(f"Failed to load image: {img_path}")
 | |
|                         except Exception as e:
 | |
|                             print(f"Error processing {img_path}: {e}")
 | |
| 
 | |
|         for class_name in self.classes:
 | |
|             if class_name in class_features:
 | |
|                 features = np.array(class_features[class_name])
 | |
|                 self.feature_means[class_name] = np.mean(features, axis=0)
 | |
|                 self.feature_variances[class_name] = np.var(features, axis=0) + 1e-6
 | |
|                 self.class_priors[class_name] = len(features) / total_images
 | |
| 
 | |
|         print("Training completed for classes:", self.classes)
 | |
| 
 | |
|     def save_model(self, model_path):
 | |
|         model_data = {
 | |
|             "feature_means": self.feature_means,
 | |
|             "feature_variances": self.feature_variances,
 | |
|             "class_priors": self.class_priors,
 | |
|             "classes": self.classes
 | |
|         }
 | |
|         if not os.path.exists(os.path.dirname(model_path)):
 | |
|             os.makedirs(os.path.dirname(model_path))
 | |
|         torch.save(model_data, model_path)
 | |
|         print(f"Model saved to {model_path}")
 | |
| 
 | |
|     def load_model(self, model_path):
 | |
|         if os.path.exists(model_path):
 | |
|             model_data = torch.load(model_path, weights_only=False)
 | |
|             self.feature_means = model_data["feature_means"]
 | |
|             self.feature_variances = model_data["feature_variances"]
 | |
|             self.class_priors = model_data["class_priors"]
 | |
|             self.classes = model_data["classes"]
 | |
|             print(f"Model loaded from {model_path}")
 | |
|         else:
 | |
|             print(f"No model found at {model_path}.")
 | |
| 
 | |
|     def predict(self, image):
 | |
|         try:
 | |
|             features = self.extract_features(image)
 | |
|             if features.size == 0:
 | |
|                 print("Empty features, skipping prediction.")
 | |
|                 return None
 | |
| 
 | |
|             posteriors = {}
 | |
|             for class_name in self.classes:
 | |
|                 mean = self.feature_means[class_name]
 | |
|                 variance = self.feature_variances[class_name]
 | |
|                 prior = self.class_priors[class_name]
 | |
| 
 | |
|                 likelihood = -0.5 * np.sum(((features - mean) ** 2) / variance + np.log(2 * np.pi * variance))
 | |
|                 posterior = likelihood + np.log(prior)
 | |
|                 posteriors[class_name] = posterior
 | |
| 
 | |
|             return max(posteriors, key=posteriors.get)
 | |
|         except Exception as e:
 | |
|             print(f"Error in prediction: {e}")
 | |
|             return None
 | |
| 
 | |
|     def visualize(self):
 | |
|         if not self.classes:
 | |
|             print("No classes to visualize.")
 | |
|             return
 | |
| 
 | |
|         for class_name in self.classes:
 | |
|             mean_features = self.feature_means[class_name]
 | |
| 
 | |
|             plt.figure(figsize=(10, 4))
 | |
|             plt.title(f"Mean features for class: {class_name}")
 | |
|             plt.plot(mean_features)
 | |
|             plt.xlabel("Feature Index")
 | |
|             plt.ylabel("Mean Value")
 | |
|             plt.grid(True)
 | |
|             plt.show()
 |