Final push

This commit is contained in:
yanis.bouarfa 2025-01-08 12:20:57 +01:00
parent 9a0b597309
commit afb9de39c8
5 changed files with 196 additions and 70 deletions

52
main.py
View file

@ -7,7 +7,7 @@ from src.classifiers.bayesian import BayesianClassifier
analysis_mode = "plan" analysis_mode = "plan"
if __name__ == "__main__": if __name__ == "__main__":
# Configuration basée sur le mode # Configuration en fonction du mode sélectionné
if analysis_mode == "plan": if analysis_mode == "plan":
model_path = "models/bayesian_modelPLAN.pth" model_path = "models/bayesian_modelPLAN.pth"
image_path = "data/plan.png" image_path = "data/plan.png"
@ -15,31 +15,31 @@ if __name__ == "__main__":
model_path = "models/bayesian_modelPAGE.pth" model_path = "models/bayesian_modelPAGE.pth"
image_path = "data/page.png" image_path = "data/page.png"
# Exécuter le script train.py avec le mode sélectionné # Lancement de l'entraînement avec le mode choisi
print(f"Lancement de l'entraînement avec le mode {analysis_mode}...") print(f"Entraînement en cours avec le mode {analysis_mode}...")
try: try:
subprocess.run(["python", "train.py", "--mode", analysis_mode], check=True) subprocess.run(["python", "train.py", "--mode", analysis_mode], check=True)
print("Entraînement terminé.") print("Entraînement terminé avec succès.")
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
print(f"Erreur lors de l'entraînement : {e}") print(f"Une erreur s'est produite pendant l'entraînement : {e}")
exit(1) exit(1)
# Chargement du modèle bayésien # Chargement du modèle bayésien
print(f"Chargement du modèle bayésien depuis {model_path}") print(f"Chargement du modèle depuis {model_path}...")
bayesian_model = BayesianClassifier(mode=analysis_mode) bayesian_model = BayesianClassifier(mode=analysis_mode)
try: try:
bayesian_model.load_model(model_path) bayesian_model.load_model(model_path)
print(f"Modèle bayésien chargé depuis {model_path}") print(f"Modèle chargé depuis {model_path}.")
except Exception as e: except Exception as e:
print(f"Erreur lors du chargement du modèle : {e}") print(f"Erreur lors du chargement du modèle : {e}")
exit(1) exit(1)
# Vérification de l'existence de l'image # Vérification de l'existence de l'image de test
if not os.path.exists(image_path): if not os.path.exists(image_path):
print(f"L'image de test {image_path} n'existe pas.") print(f"L'image spécifiée ({image_path}) n'existe pas.")
exit(1) exit(1)
# Initialisation du dossier de sortie # Création du dossier de sortie si nécessaire
output_dir = "output" output_dir = "output"
if not os.path.exists(output_dir): if not os.path.exists(output_dir):
os.makedirs(output_dir) os.makedirs(output_dir)
@ -48,10 +48,10 @@ if __name__ == "__main__":
print("Initialisation de la pipeline...") print("Initialisation de la pipeline...")
pipeline = ObjectDetectionPipeline(image_path=image_path, model=bayesian_model, output_dir=output_dir) pipeline = ObjectDetectionPipeline(image_path=image_path, model=bayesian_model, output_dir=output_dir)
# Définition du mode (plan ou page) # Configuration du mode d'analyse dans la pipeline
pipeline.set_mode(analysis_mode) pipeline.set_mode(analysis_mode)
# Chargement de l'image # Chargement de l'image de test
print("Chargement de l'image...") print("Chargement de l'image...")
try: try:
pipeline.load_image() pipeline.load_image()
@ -60,20 +60,34 @@ if __name__ == "__main__":
exit(1) exit(1)
# Détection et classification des objets # Détection et classification des objets
print("Détection et classification des objets...") print("Détection et classification des objets en cours...")
try: try:
class_counts, detected_objects, total_objects, ignored_objects, identified_objects = pipeline.detect_and_classify_objects() class_counts, detected_objects, total_objects, ignored_objects, identified_objects = pipeline.detect_and_classify_objects()
print(f"Classes détectées : {class_counts}") print(f"Objets détectés par classe : {class_counts}")
print("Résumé des objets :") print("Résumé de la détection :")
print(f"- Objets totaux : {total_objects}") print(f"- Nombre total d'objets : {total_objects}")
print(f"- Objets identifiés : {identified_objects}") print(f"- Objets identifiés : {identified_objects}")
print(f"- Objets ignorés : {ignored_objects}") print(f"- Objets ignorés : {ignored_objects}")
except Exception as e: except Exception as e:
print(f"Erreur lors de la détection/classification : {e}") print(f"Erreur pendant la détection/classification : {e}")
exit(1) exit(1)
# Sauvegarde et affichage des résultats # Sauvegarde et visualisation des résultats
print("Sauvegarde et affichage des résultats...") print("Sauvegarde et affichage des résultats...")
pipeline.display_results(class_counts, detected_objects) pipeline.display_results(class_counts, detected_objects)
print(f"Les résultats ont été sauvegardés dans le dossier : {output_dir}") # Affichage de l'histogramme des classes détectées
print("Affichage de l'histogramme des résultats...")
try:
pipeline.display_histogram(class_counts)
except Exception as e:
print(f"Erreur lors de l'affichage de l'histogramme : {e}")
# Affichage du nuage de points
print("Affichage du nuage de points...")
try:
pipeline.display_scatter_plot(class_counts)
except Exception as e:
print(f"Erreur lors de l'affichage du nuage de points : {e}")
print(f"Tous les résultats sont sauvegardés dans le dossier : {output_dir}")

View file

@ -5,7 +5,6 @@ import torch
from collections import defaultdict from collections import defaultdict
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
class BayesianClassifier: class BayesianClassifier:
def __init__(self, mode="page"): def __init__(self, mode="page"):
self.feature_means = {} self.feature_means = {}
@ -13,13 +12,14 @@ class BayesianClassifier:
self.class_priors = {} self.class_priors = {}
self.classes = [] self.classes = []
# Définir les classes autorisées selon le mode choisi
self.allowed_classes = ( self.allowed_classes = (
['Figure1', 'Figure2', 'Figure3', 'Figure4', 'Figure5', 'Figure6'] ['Figure1', 'Figure2', 'Figure3', 'Figure4', 'Figure5', 'Figure6']
if mode == "plan" if mode == "plan"
else ['2', 'd', 'I', 'n', 'o', 'u'] else ['2', 'd', 'I', 'n', 'o', 'u']
) )
# Initialize HOG descriptor with standard parameters # Initialisation du descripteur HOG avec des paramètres standards
self.hog = cv2.HOGDescriptor( self.hog = cv2.HOGDescriptor(
_winSize=(28, 28), _winSize=(28, 28),
_blockSize=(8, 8), _blockSize=(8, 8),
@ -29,36 +29,37 @@ class BayesianClassifier:
) )
def extract_features(self, image): def extract_features(self, image):
"""Extraire les caractéristiques d'une image donnée."""
try: try:
# Convert to grayscale if image is RGB # Convertir en niveaux de gris si l'image est en couleurs
if len(image.shape) == 3 and image.shape[2] == 3: if len(image.shape) == 3 and image.shape[2] == 3:
gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
else: else:
gray_image = image gray_image = image
# Apply adaptive thresholding for segmentation # Appliquer un seuillage adaptatif pour la segmentation
binary_image = cv2.adaptiveThreshold( binary_image = cv2.adaptiveThreshold(
gray_image, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY_INV, 11, 2 gray_image, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY_INV, 11, 2
) )
# Find contours # Trouver les contours
contours, _ = cv2.findContours( contours, _ = cv2.findContours(
binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
) )
if not contours: if not contours:
print("No contours found.") print("Aucun contour trouvé.")
return np.array([]) return np.array([])
features = [] features = []
for contour in contours: for contour in contours:
if cv2.contourArea(contour) < 20: # Filter small areas if cv2.contourArea(contour) < 20: # Filtrer les petites zones
continue continue
x, y, w, h = cv2.boundingRect(contour) x, y, w, h = cv2.boundingRect(contour)
letter_image = gray_image[y:y + h, x:x + w] letter_image = gray_image[y:y + h, x:x + w]
letter_image = cv2.resize(letter_image, (28, 28)) letter_image = cv2.resize(letter_image, (28, 28))
# Compute HOG features # Calculer les descripteurs HOG
hog_features = self.hog.compute(letter_image) hog_features = self.hog.compute(letter_image)
features.append(hog_features.flatten()) features.append(hog_features.flatten())
@ -66,16 +67,17 @@ class BayesianClassifier:
if features.size == 0: if features.size == 0:
return np.array([]) return np.array([])
# Normalize features # Normaliser les caractéristiques
norms = np.linalg.norm(features, axis=1, keepdims=True) norms = np.linalg.norm(features, axis=1, keepdims=True)
features = features / np.where(norms > 1e-6, norms, 1) features = features / np.where(norms > 1e-6, norms, 1)
return features return features
except Exception as e: except Exception as e:
print(f"Error in extract_features: {e}") print(f"Erreur dans l'extraction des caractéristiques : {e}")
return np.array([]) return np.array([])
def train(self, dataset_path): def train(self, dataset_path):
"""Entraîner le modèle Bayésien à partir d'un ensemble de données."""
class_features = defaultdict(list) class_features = defaultdict(list)
total_samples = 0 total_samples = 0
@ -99,21 +101,22 @@ class BayesianClassifier:
class_features[class_name].append(feature) class_features[class_name].append(feature)
total_samples += len(features) total_samples += len(features)
else: else:
print(f"No features extracted for {img_path}") print(f"Aucune caractéristique extraite pour {img_path}")
else: else:
print(f"Failed to load image: {img_path}") print(f"Échec du chargement de l'image : {img_path}")
# Compute means, variances, and priors # Calculer les moyennes, variances et probabilités a priori
for class_name in self.classes: for class_name in self.classes:
if class_name in class_features: if class_name in class_features:
features = np.array(class_features[class_name]) features = np.array(class_features[class_name])
self.feature_means[class_name] = np.mean(features, axis=0) self.feature_means[class_name] = np.mean(features, axis=0)
self.feature_variances[class_name] = np.var(features, axis=0) + 1e-6 # Avoid zero variance self.feature_variances[class_name] = np.var(features, axis=0) + 1e-6 # Éviter une variance nulle
self.class_priors[class_name] = len(features) / total_samples self.class_priors[class_name] = len(features) / total_samples
print("Training completed for classes:", self.classes) print("Entraînement terminé pour les classes :", self.classes)
def save_model(self, model_path): def save_model(self, model_path):
"""Sauvegarder le modèle Bayésien sur le disque."""
model_data = { model_data = {
"feature_means": self.feature_means, "feature_means": self.feature_means,
"feature_variances": self.feature_variances, "feature_variances": self.feature_variances,
@ -122,20 +125,22 @@ class BayesianClassifier:
} }
os.makedirs(os.path.dirname(model_path), exist_ok=True) os.makedirs(os.path.dirname(model_path), exist_ok=True)
torch.save(model_data, model_path) torch.save(model_data, model_path)
print(f"Model saved to {model_path}") print(f"Modèle sauvegardé à l'emplacement {model_path}")
def load_model(self, model_path): def load_model(self, model_path):
"""Charger un modèle Bayésien sauvegardé."""
if os.path.exists(model_path): if os.path.exists(model_path):
model_data = torch.load(model_path) model_data = torch.load(model_path)
self.feature_means = model_data["feature_means"] self.feature_means = model_data["feature_means"]
self.feature_variances = model_data["feature_variances"] self.feature_variances = model_data["feature_variances"]
self.class_priors = model_data["class_priors"] self.class_priors = model_data["class_priors"]
self.classes = model_data["classes"] self.classes = model_data["classes"]
print(f"Model loaded from {model_path}") print(f"Modèle chargé depuis {model_path}")
else: else:
print(f"No model found at {model_path}.") print(f"Modèle introuvable à l'emplacement {model_path}.")
def predict(self, image, threshold=0.3): def predict(self, image, threshold=0.3):
"""Prédire la classe d'une image donnée."""
try: try:
features = self.extract_features(image) features = self.extract_features(image)
if features.size == 0: if features.size == 0:
@ -147,7 +152,7 @@ class BayesianClassifier:
variance = self.feature_variances[class_name] variance = self.feature_variances[class_name]
prior = self.class_priors[class_name] prior = self.class_priors[class_name]
# Compute log-likelihood # Calculer la log-vraisemblance
log_likelihood = -0.5 * np.sum( log_likelihood = -0.5 * np.sum(
((features - mean) ** 2) / variance + np.log(2 * np.pi * variance), ((features - mean) ** 2) / variance + np.log(2 * np.pi * variance),
axis=1, axis=1,
@ -162,21 +167,22 @@ class BayesianClassifier:
return None return None
return max_class return max_class
except Exception as e: except Exception as e:
print(f"Error in prediction: {e}") print(f"Erreur dans la prédiction : {e}")
return None return None
def visualize(self): def visualize(self):
"""Visualiser les moyennes des caractéristiques pour chaque classe."""
if not self.classes: if not self.classes:
print("No classes to visualize.") print("Aucune classe à visualiser.")
return return
for class_name in self.classes: for class_name in self.classes:
mean_features = self.feature_means[class_name] mean_features = self.feature_means[class_name]
plt.figure(figsize=(10, 4)) plt.figure(figsize=(10, 4))
plt.title(f"Mean features for class: {class_name}") plt.title(f"Moyennes des caractéristiques pour la classe : {class_name}")
plt.plot(mean_features) plt.plot(mean_features)
plt.xlabel("Feature Index") plt.xlabel("Indice des caractéristiques")
plt.ylabel("Mean Value") plt.ylabel("Valeur moyenne")
plt.grid(True) plt.grid(True)
plt.show() plt.show()

98
src/classifiers/kmeans.py Normal file
View file

@ -0,0 +1,98 @@
import numpy as np
import os
class KMeansClassifier:
def __init__(self, num_clusters=6, max_iter=100, tol=1e-4):
"""
Initialiser le classifieur KMeans.
Paramètres :
- num_clusters : Nombre de clusters (classes).
- max_iter : Nombre maximal d'itérations pour l'algorithme k-means.
- tol : Tolérance pour la convergence.
"""
self.num_clusters = num_clusters
self.max_iter = max_iter
self.tol = tol
self.cluster_centers_ = None
self.labels_ = None
def fit(self, features):
"""
Entraîner le modèle k-means sur les données fournies.
Paramètres :
- features : Un tableau numpy de forme (n_samples, n_features).
"""
if len(features) < self.num_clusters:
raise ValueError("Le nombre d'échantillons est inférieur au nombre de clusters.")
np.random.seed(42)
random_indices = np.random.choice(len(features), self.num_clusters, replace=False)
self.cluster_centers_ = features[random_indices]
for iteration in range(self.max_iter):
# Assigner des étiquettes en fonction du centre le plus proche
distances = self._compute_distances(features)
self.labels_ = np.argmin(distances, axis=1)
# Mettre à jour les centres des clusters
new_centers = np.array([features[self.labels_ == k].mean(axis=0) for k in range(self.num_clusters)])
# Vérifier la convergence
if np.all(np.abs(new_centers - self.cluster_centers_) < self.tol):
print(f"Convergence atteinte en {iteration + 1} itérations.")
break
self.cluster_centers_ = new_centers
def predict(self, features):
"""
Prédire le cluster le plus proche pour les données fournies.
Paramètres :
- features : Un tableau numpy de forme (n_samples, n_features).
Retourne :
- Les étiquettes des clusters pour chaque échantillon.
"""
distances = self._compute_distances(features)
return np.argmin(distances, axis=1)
def _compute_distances(self, features):
"""
Calculer les distances entre les données et les centres des clusters.
Paramètres :
- features : Un tableau numpy de forme (n_samples, n_features).
Retourne :
- Un tableau numpy des distances de forme (n_samples, num_clusters).
"""
return np.linalg.norm(features[:, np.newaxis] - self.cluster_centers_, axis=2)
def save_model(self, path):
"""
Sauvegarder le modèle KMeans dans un fichier .npy.
Paramètres :
- path : Chemin pour sauvegarder le modèle.
"""
if self.cluster_centers_ is None:
raise ValueError("Le modèle n'a pas encore été entraîné. Rien à sauvegarder.")
os.makedirs(os.path.dirname(path), exist_ok=True)
np.save(path, self.cluster_centers_)
print(f"Modèle KMeans sauvegardé à l'emplacement {path}")
def load_model(self, path):
"""
Charger le modèle KMeans depuis un fichier .npy.
Paramètres :
- path : Chemin vers le fichier sauvegardé.
"""
if os.path.exists(path):
self.cluster_centers_ = np.load(path)
print(f"Modèle KMeans chargé depuis {path}")
else:
raise FileNotFoundError(f"Fichier du modèle KMeans introuvable à l'emplacement {path}.")

View file

@ -1,9 +1,9 @@
import cv2 import cv2
import os import os
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from collections import defaultdict from collections import defaultdict
class ObjectDetectionPipeline: class ObjectDetectionPipeline:
def __init__(self, image_path, model=None, output_dir="output", mode="page", min_contour_area=20, binary_threshold=None): def __init__(self, image_path, model=None, output_dir="output", mode="page", min_contour_area=20, binary_threshold=None):
self.image_path = image_path self.image_path = image_path
@ -13,28 +13,31 @@ class ObjectDetectionPipeline:
self.output_dir = output_dir self.output_dir = output_dir
self.min_contour_area = min_contour_area self.min_contour_area = min_contour_area
self.binary_threshold = binary_threshold self.binary_threshold = binary_threshold
self.mode = mode # Default mode is "page" self.mode = mode # Le mode par défaut est "page"
self.annotated_output_path = os.path.join(self.output_dir, f"annotated_{os.path.basename(image_path)}") self.annotated_output_path = os.path.join(self.output_dir, f"annotated_{os.path.basename(image_path)}")
self.threshold = -395000 if mode == "plan" else -65000 self.threshold = -395000 if mode == "plan" else -65000
# Créez le dossier de sortie s'il n'existe pas
if not os.path.exists(self.output_dir): if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir) os.makedirs(self.output_dir)
def set_mode(self, mode): def set_mode(self, mode):
"""Set the detection mode (page or plan).""" """Définir le mode de détection (page ou plan)."""
if mode not in ["page", "plan"]: if mode not in ["page", "plan"]:
raise ValueError("Mode must be 'page' or 'plan'.") raise ValueError("Le mode doit être 'page' ou 'plan'.")
self.mode = mode self.mode = mode
self.threshold = -395000 if mode == "plan" else -65000 self.threshold = -395000 if mode == "plan" else -65000
print(f"Mode set to: {self.mode}, Threshold set to: {self.threshold}") print(f"Mode défini à : {self.mode}, Seuil défini à : {self.threshold}")
def load_image(self): def load_image(self):
"""Charger l'image à analyser."""
self.image = cv2.imread(self.image_path) self.image = cv2.imread(self.image_path)
if self.image is None: if self.image is None:
raise FileNotFoundError(f"Image {self.image_path} not found.") raise FileNotFoundError(f"L'image {self.image_path} est introuvable.")
return self.image return self.image
def preprocess_image(self): def preprocess_image(self):
"""Effectuer un prétraitement de l'image pour obtenir une version binaire."""
channels = cv2.split(self.image) channels = cv2.split(self.image)
binary_images = [] binary_images = []
@ -51,8 +54,9 @@ class ObjectDetectionPipeline:
return binary_image return binary_image
def detect_and_classify_objects(self): def detect_and_classify_objects(self):
"""Détecter et classer les objets dans l'image."""
if self.model is None: if self.model is None:
raise ValueError("No classification model provided.") raise ValueError("Aucun modèle de classification n'a été fourni.")
self.binary_image = self.preprocess_image() self.binary_image = self.preprocess_image()
contours, _ = cv2.findContours(self.binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) contours, _ = cv2.findContours(self.binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
@ -64,7 +68,7 @@ class ObjectDetectionPipeline:
identified_objects = 0 identified_objects = 0
for contour in contours: for contour in contours:
total_objects += 1 # Compteur total des objets total_objects += 1 # Incrémenter le compteur total des objets
if cv2.contourArea(contour) < self.min_contour_area: if cv2.contourArea(contour) < self.min_contour_area:
ignored_objects += 1 ignored_objects += 1
@ -85,6 +89,7 @@ class ObjectDetectionPipeline:
return dict(sorted(class_counts.items())), detected_objects, total_objects, ignored_objects, identified_objects return dict(sorted(class_counts.items())), detected_objects, total_objects, ignored_objects, identified_objects
def save_results(self, class_counts, detected_objects): def save_results(self, class_counts, detected_objects):
"""Sauvegarder les résultats de la détection dans des fichiers."""
binary_output_path = os.path.join(self.output_dir, "binary_image.jpg") binary_output_path = os.path.join(self.output_dir, "binary_image.jpg")
cv2.imwrite(binary_output_path, self.binary_image) cv2.imwrite(binary_output_path, self.binary_image)
@ -101,11 +106,12 @@ class ObjectDetectionPipeline:
f.write(f"{class_name}: {count}\n") f.write(f"{class_name}: {count}\n")
def display_results(self, class_counts, detected_objects): def display_results(self, class_counts, detected_objects):
"""Afficher les résultats sous forme graphique."""
self.save_results(class_counts, detected_objects) self.save_results(class_counts, detected_objects)
plt.figure(figsize=(10, 5)) plt.figure(figsize=(10, 5))
plt.bar(class_counts.keys(), class_counts.values()) plt.bar(class_counts.keys(), class_counts.values())
plt.xlabel("Classes") plt.xlabel("Classes")
plt.ylabel("Object count") plt.ylabel("Nombre d'objets")
plt.title("Detected Class Distribution") plt.title("Répartition des classes détectées")
plt.show() plt.show()

View file

@ -1,32 +1,34 @@
import os import os
import argparse # Ajouté pour les arguments import argparse
from collections import defaultdict from collections import defaultdict
import numpy as np import numpy as np
import cv2 import cv2
from src.classifiers.bayesian import BayesianClassifier from src.classifiers.bayesian import BayesianClassifier
if __name__ == "__main__": if __name__ == "__main__":
# Analyse des arguments # Analyse des arguments pour configurer le mode
parser = argparse.ArgumentParser(description="Train Bayesian model.") parser = argparse.ArgumentParser(description="Entraîner le modèle Bayésien.")
parser.add_argument("--mode", type=str, choices=["page", "plan"], default="page", help="Mode de fonctionnement : 'page' ou 'plan'.") parser.add_argument("--mode", type=str, choices=["page", "plan"], default="page",
help="Mode de fonctionnement : 'page' pour les pages ou 'plan' pour les plans.")
args = parser.parse_args() args = parser.parse_args()
# Configuration en fonction du mode # Définir les chemins en fonction du mode
mode = args.mode mode = args.mode
dataset_path = f"data/catalogue{'' if mode == 'page' else 'Symbol'}" dataset_path = f"data/catalogue{'' if mode == 'page' else 'Symbol'}"
allowed_classes = ['Figure1', 'Figure2', 'Figure3', 'Figure4', 'Figure5', 'Figure6'] if mode == "plan" else ['2', 'd', 'I', 'n', 'o', 'u'] allowed_classes = ['Figure1', 'Figure2', 'Figure3', 'Figure4', 'Figure5', 'Figure6'] \
if mode == "plan" else ['2', 'd', 'I', 'n', 'o', 'u']
model_path = f"models/bayesian_model{mode.upper()}.pth" model_path = f"models/bayesian_model{mode.upper()}.pth"
# Initialisation du classifieur Bayésien # Initialiser le classifieur Bayésien
bayesian_model = BayesianClassifier(mode=mode) bayesian_model = BayesianClassifier(mode=mode)
print("Début de l'entraînement...") print("Lancement de l'entraînement...")
# Dictionnaire pour stocker les caractéristiques par classe # Stockage des caractéristiques pour chaque classe
class_features = defaultdict(list) class_features = defaultdict(list)
total_images = 0 total_images = 0
# Parcours des classes dans le dataset # Parcourir les classes dans le dataset
for class_name in os.listdir(dataset_path): for class_name in os.listdir(dataset_path):
if class_name not in allowed_classes: if class_name not in allowed_classes:
continue # Ignorer les classes non autorisées continue # Ignorer les classes non autorisées
@ -35,11 +37,11 @@ if __name__ == "__main__":
if not os.path.isdir(class_folder_path): if not os.path.isdir(class_folder_path):
continue # Ignorer les fichiers qui ne sont pas des dossiers continue # Ignorer les fichiers qui ne sont pas des dossiers
# Ajouter la classe au modèle si elle n'existe pas déjà # Ajouter la classe au modèle si elle n'est pas encore listée
if class_name not in bayesian_model.classes: if class_name not in bayesian_model.classes:
bayesian_model.classes.append(class_name) bayesian_model.classes.append(class_name)
# Parcours des images dans le dossier de la classe # Parcourir les images dans le dossier de la classe
for image_name in os.listdir(class_folder_path): for image_name in os.listdir(class_folder_path):
image_path = os.path.join(class_folder_path, image_name) image_path = os.path.join(class_folder_path, image_name)
image = cv2.imread(image_path) image = cv2.imread(image_path)
@ -51,7 +53,7 @@ if __name__ == "__main__":
class_features[class_name].append(feature) class_features[class_name].append(feature)
total_images += 1 total_images += 1
# Calcul des statistiques pour chaque classe # Calculer les statistiques pour chaque classe
for class_name in bayesian_model.classes: for class_name in bayesian_model.classes:
if class_name in class_features: if class_name in class_features:
features = np.array(class_features[class_name]) features = np.array(class_features[class_name])
@ -59,8 +61,8 @@ if __name__ == "__main__":
bayesian_model.feature_variances[class_name] = np.var(features, axis=0) + 1e-6 # Éviter la division par zéro bayesian_model.feature_variances[class_name] = np.var(features, axis=0) + 1e-6 # Éviter la division par zéro
bayesian_model.class_priors[class_name] = len(features) / total_images bayesian_model.class_priors[class_name] = len(features) / total_images
print("Entraînement terminé.") print("Entraînement terminé avec succès.")
# Sauvegarde du modèle entraîné # Sauvegarder le modèle entraîné
bayesian_model.save_model(model_path) bayesian_model.save_model(model_path)
print(f"Modèle sauvegardé dans : {model_path}") print(f"Le modèle a été sauvegardé à l'emplacement : {model_path}")