diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..0cafc1c
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1 @@
+.venv/
\ No newline at end of file
diff --git a/README.md b/README.md
index 1a0ef71..1f0e1b5 100644
--- a/README.md
+++ b/README.md
@@ -1,2 +1,34 @@
-# SignsDetectionAI
-AI project for my classes
+
Signs Detection AI
+
+AI project for the Introduction to AI course at Université de Tours for the ISA Masters.
+
+## Requirements
+
+- Python 3.12
+
+## Installation
+
+Start by creating a virtual environment (Optional)
+```bash
+python3 -m venv .venv
+source .venv/bin/activate
+```
+
+Install the dependencies
+```bash
+pip install -r pip-dependencies.txt
+```
+## Running the project
+
+Run the project
+```bash
+python main.py
+```
+
+## Testing
+
+## Documentation
+
+## Authors
+- Nabil Ould Hamou - [@NabilOuldHamou](https://github.com/NabilOuldHamou)
+- Yanis Bouarfa - [@Yanax373](https://github.com/Yanax373)
\ No newline at end of file
diff --git a/data/page.png b/data/page.png
new file mode 100644
index 0000000..321d9b7
Binary files /dev/null and b/data/page.png differ
diff --git a/data/plan.png b/data/plan.png
new file mode 100644
index 0000000..e33cda1
Binary files /dev/null and b/data/plan.png differ
diff --git a/generate_dataset.py b/generate_dataset.py
new file mode 100644
index 0000000..79eea81
--- /dev/null
+++ b/generate_dataset.py
@@ -0,0 +1,78 @@
+from PIL import Image, ImageDraw, ImageFont
+import os
+
+# Répertoire pour sauvegarder les images générées
+output_dir = "data/catalogue"
+
+# Définir la taille de la police et de l'image
+font_size = 20 # Ajustez pour la taille souhaitée
+image_size = (28, 28) # Taille de l'image pour chaque caractère
+
+# Listes des caractères à générer
+uppercase_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
+lowercase_letters = "abcdefghijklmnopqrstuvwxyz"
+numbers = "0123456789"
+
+# Chemin vers le fichier de police (à mettre à jour avec un chemin valide sur votre système)
+font_path = "arial.ttf" # Assurez-vous que cette police est disponible
+
+# Créer le répertoire de sortie s'il n'existe pas
+os.makedirs(output_dir, exist_ok=True)
+
+# Fonction pour créer des images de caractères
+def create_character_image(character, output_path):
+ """
+ Crée une image contenant un caractère spécifique et la sauvegarde dans le chemin donné.
+
+ :param character: Caractère à dessiner
+ :param output_path: Chemin où sauvegarder l'image
+ """
+ # Créer une image vierge avec un fond blanc
+ img = Image.new("RGB", image_size, "white")
+ draw = ImageDraw.Draw(img)
+
+ # Charger la police
+ try:
+ font = ImageFont.truetype(font_path, font_size)
+ except IOError:
+ print(f"Fichier de police introuvable : {font_path}")
+ return
+
+ # Calculer la position du texte pour centrer le caractère
+ bbox = font.getbbox(character)
+ text_width = bbox[2] - bbox[0]
+ text_height = bbox[3] - bbox[1]
+ text_x = (image_size[0] - text_width) // 2
+ text_y = (image_size[1] - text_height) // 2
+
+ # Dessiner le caractère sur l'image
+ draw.text((text_x, text_y), character, font=font, fill="black")
+
+ # Sauvegarder l'image
+ img.save(output_path)
+
+# Générer des images pour les lettres majuscules et minuscules
+for upper, lower in zip(uppercase_letters, lowercase_letters):
+ upper_dir = os.path.join(output_dir, f"{upper}_") # Sous-dossier pour les majuscules
+ lower_dir = os.path.join(output_dir, upper) # Sous-dossier pour les minuscules
+
+ os.makedirs(upper_dir, exist_ok=True) # Créer le sous-dossier pour les majuscules
+ os.makedirs(lower_dir, exist_ok=True) # Créer le sous-dossier pour les minuscules
+
+ # Sauvegarder l'image de la lettre majuscule
+ upper_image_path = os.path.join(upper_dir, f"{upper}.png")
+ create_character_image(upper, upper_image_path)
+
+ # Sauvegarder l'image de la lettre minuscule
+ lower_image_path = os.path.join(lower_dir, f"{lower}.png")
+ create_character_image(lower, lower_image_path)
+
+# Générer des images pour les chiffres
+for num in numbers:
+ num_dir = os.path.join(output_dir, num) # Sous-dossier pour chaque chiffre
+ os.makedirs(num_dir, exist_ok=True) # Créer le sous-dossier
+
+ num_image_path = os.path.join(num_dir, f"{num}.png")
+ create_character_image(num, num_image_path)
+
+print(f"Les images des lettres et des chiffres ont été générées dans le répertoire : {output_dir}")
diff --git a/main.py b/main.py
new file mode 100644
index 0000000..4436bd2
--- /dev/null
+++ b/main.py
@@ -0,0 +1,93 @@
+import os
+import subprocess
+from src.pipeline import ObjectDetectionPipeline
+from src.classifiers.bayesian import BayesianClassifier
+
+# Définissez le mode d'analyse ici : "plan" ou "page"
+analysis_mode = "plan"
+
+if __name__ == "__main__":
+ # Configuration en fonction du mode sélectionné
+ if analysis_mode == "plan":
+ model_path = "models/bayesian_modelPLAN.pth"
+ image_path = "data/plan.png"
+ else:
+ model_path = "models/bayesian_modelPAGE.pth"
+ image_path = "data/page.png"
+
+ # Lancement de l'entraînement avec le mode choisi
+ print(f"Entraînement en cours avec le mode {analysis_mode}...")
+ try:
+ subprocess.run(["python", "train.py", "--mode", analysis_mode], check=True)
+ print("Entraînement terminé avec succès.")
+ except subprocess.CalledProcessError as e:
+ print(f"Une erreur s'est produite pendant l'entraînement : {e}")
+ exit(1)
+
+ # Chargement du modèle bayésien
+ print(f"Chargement du modèle depuis {model_path}...")
+ bayesian_model = BayesianClassifier(mode=analysis_mode)
+ try:
+ bayesian_model.load_model(model_path)
+ print(f"Modèle chargé depuis {model_path}.")
+ except Exception as e:
+ print(f"Erreur lors du chargement du modèle : {e}")
+ exit(1)
+
+ # Vérification de l'existence de l'image de test
+ if not os.path.exists(image_path):
+ print(f"L'image spécifiée ({image_path}) n'existe pas.")
+ exit(1)
+
+ # Création du dossier de sortie si nécessaire
+ output_dir = "output"
+ if not os.path.exists(output_dir):
+ os.makedirs(output_dir)
+
+ # Initialisation de la pipeline
+ print("Initialisation de la pipeline...")
+ pipeline = ObjectDetectionPipeline(image_path=image_path, model=bayesian_model, output_dir=output_dir)
+
+ # Configuration du mode d'analyse dans la pipeline
+ pipeline.set_mode(analysis_mode)
+
+ # Chargement de l'image de test
+ print("Chargement de l'image...")
+ try:
+ pipeline.load_image()
+ except FileNotFoundError as e:
+ print(e)
+ exit(1)
+
+ # Détection et classification des objets
+ print("Détection et classification des objets en cours...")
+ try:
+ class_counts, detected_objects, total_objects, ignored_objects, identified_objects = pipeline.detect_and_classify_objects()
+ print(f"Objets détectés par classe : {class_counts}")
+ print("Résumé de la détection :")
+ print(f"- Nombre total d'objets : {total_objects}")
+ print(f"- Objets identifiés : {identified_objects}")
+ print(f"- Objets ignorés : {ignored_objects}")
+ except Exception as e:
+ print(f"Erreur pendant la détection/classification : {e}")
+ exit(1)
+
+ # Sauvegarde et visualisation des résultats
+ print("Sauvegarde et affichage des résultats...")
+ pipeline.display_results(class_counts, detected_objects)
+
+ # Affichage de l'histogramme des classes détectées
+ print("Affichage de l'histogramme des résultats...")
+ try:
+ pipeline.display_histogram(class_counts)
+ except Exception as e:
+ print(f"Erreur lors de l'affichage de l'histogramme : {e}")
+
+ # Affichage du nuage de points
+ print("Affichage du nuage de points...")
+ try:
+ pipeline.display_scatter_plot(class_counts)
+ except Exception as e:
+ print(f"Erreur lors de l'affichage du nuage de points : {e}")
+
+ print(f"Tous les résultats sont sauvegardés dans le dossier : {output_dir}")
diff --git a/pip-dependencies.txt b/pip-dependencies.txt
new file mode 100644
index 0000000..042dc10
--- /dev/null
+++ b/pip-dependencies.txt
@@ -0,0 +1,6 @@
+opencv-python
+torch
+tensorflow
+numpy
+pandas
+matplotlib
\ No newline at end of file
diff --git a/src/classifiers/bayesian.py b/src/classifiers/bayesian.py
new file mode 100644
index 0000000..96f5938
--- /dev/null
+++ b/src/classifiers/bayesian.py
@@ -0,0 +1,188 @@
+import os
+import cv2
+import numpy as np
+import torch
+from collections import defaultdict
+import matplotlib.pyplot as plt
+
+class BayesianClassifier:
+ def __init__(self, mode="page"):
+ self.feature_means = {}
+ self.feature_variances = {}
+ self.class_priors = {}
+ self.classes = []
+
+ # Définir les classes autorisées selon le mode choisi
+ self.allowed_classes = (
+ ['Figure1', 'Figure2', 'Figure3', 'Figure4', 'Figure5', 'Figure6']
+ if mode == "plan"
+ else ['2', 'd', 'I', 'n', 'o', 'u']
+ )
+
+ # Initialisation du descripteur HOG avec des paramètres standards
+ self.hog = cv2.HOGDescriptor(
+ _winSize=(28, 28),
+ _blockSize=(8, 8),
+ _blockStride=(4, 4),
+ _cellSize=(8, 8),
+ _nbins=9,
+ )
+
+ def extract_features(self, image):
+ """Extraire les caractéristiques d'une image donnée."""
+ try:
+ # Convertir en niveaux de gris si l'image est en couleurs
+ if len(image.shape) == 3 and image.shape[2] == 3:
+ gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
+ else:
+ gray_image = image
+
+ # Appliquer un seuillage adaptatif pour la segmentation
+ binary_image = cv2.adaptiveThreshold(
+ gray_image, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY_INV, 11, 2
+ )
+
+ # Trouver les contours
+ contours, _ = cv2.findContours(
+ binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
+ )
+ if not contours:
+ print("Aucun contour trouvé.")
+ return np.array([])
+
+ features = []
+ for contour in contours:
+ if cv2.contourArea(contour) < 20: # Filtrer les petites zones
+ continue
+
+ x, y, w, h = cv2.boundingRect(contour)
+ letter_image = gray_image[y:y + h, x:x + w]
+ letter_image = cv2.resize(letter_image, (28, 28))
+
+ # Calculer les descripteurs HOG
+ hog_features = self.hog.compute(letter_image)
+ features.append(hog_features.flatten())
+
+ features = np.array(features)
+ if features.size == 0:
+ return np.array([])
+
+ # Normaliser les caractéristiques
+ norms = np.linalg.norm(features, axis=1, keepdims=True)
+ features = features / np.where(norms > 1e-6, norms, 1)
+
+ return features
+ except Exception as e:
+ print(f"Erreur dans l'extraction des caractéristiques : {e}")
+ return np.array([])
+
+ def train(self, dataset_path):
+ """Entraîner le modèle Bayésien à partir d'un ensemble de données."""
+ class_features = defaultdict(list)
+ total_samples = 0
+
+ for class_name in os.listdir(dataset_path):
+ if class_name not in self.allowed_classes:
+ continue
+
+ class_folder_path = os.path.join(dataset_path, class_name)
+ if os.path.isdir(class_folder_path):
+ if class_name not in self.classes:
+ self.classes.append(class_name)
+
+ for img_name in os.listdir(class_folder_path):
+ img_path = os.path.join(class_folder_path, img_name)
+ if os.path.isfile(img_path):
+ image = cv2.imread(img_path)
+ if image is not None:
+ features = self.extract_features(image)
+ if features.size > 0:
+ for feature in features:
+ class_features[class_name].append(feature)
+ total_samples += len(features)
+ else:
+ print(f"Aucune caractéristique extraite pour {img_path}")
+ else:
+ print(f"Échec du chargement de l'image : {img_path}")
+
+ # Calculer les moyennes, variances et probabilités a priori
+ for class_name in self.classes:
+ if class_name in class_features:
+ features = np.array(class_features[class_name])
+ self.feature_means[class_name] = np.mean(features, axis=0)
+ self.feature_variances[class_name] = np.var(features, axis=0) + 1e-6 # Éviter une variance nulle
+ self.class_priors[class_name] = len(features) / total_samples
+
+ print("Entraînement terminé pour les classes :", self.classes)
+
+ def save_model(self, model_path):
+ """Sauvegarder le modèle Bayésien sur le disque."""
+ model_data = {
+ "feature_means": self.feature_means,
+ "feature_variances": self.feature_variances,
+ "class_priors": self.class_priors,
+ "classes": self.classes,
+ }
+ os.makedirs(os.path.dirname(model_path), exist_ok=True)
+ torch.save(model_data, model_path)
+ print(f"Modèle sauvegardé à l'emplacement {model_path}")
+
+ def load_model(self, model_path):
+ """Charger un modèle Bayésien sauvegardé."""
+ if os.path.exists(model_path):
+ model_data = torch.load(model_path)
+ self.feature_means = model_data["feature_means"]
+ self.feature_variances = model_data["feature_variances"]
+ self.class_priors = model_data["class_priors"]
+ self.classes = model_data["classes"]
+ print(f"Modèle chargé depuis {model_path}")
+ else:
+ print(f"Modèle introuvable à l'emplacement {model_path}.")
+
+ def predict(self, image, threshold=0.3):
+ """Prédire la classe d'une image donnée."""
+ try:
+ features = self.extract_features(image)
+ if features.size == 0:
+ return None
+
+ posteriors = {}
+ for class_name in self.classes:
+ mean = self.feature_means[class_name]
+ variance = self.feature_variances[class_name]
+ prior = self.class_priors[class_name]
+
+ # Calculer la log-vraisemblance
+ log_likelihood = -0.5 * np.sum(
+ ((features - mean) ** 2) / variance + np.log(2 * np.pi * variance),
+ axis=1,
+ )
+ posterior = log_likelihood + np.log(prior)
+ posteriors[class_name] = np.sum(posterior)
+
+ max_class = max(posteriors, key=posteriors.get)
+ max_posterior = posteriors[max_class]
+
+ if max_posterior < threshold:
+ return None
+ return max_class
+ except Exception as e:
+ print(f"Erreur dans la prédiction : {e}")
+ return None
+
+ def visualize(self):
+ """Visualiser les moyennes des caractéristiques pour chaque classe."""
+ if not self.classes:
+ print("Aucune classe à visualiser.")
+ return
+
+ for class_name in self.classes:
+ mean_features = self.feature_means[class_name]
+
+ plt.figure(figsize=(10, 4))
+ plt.title(f"Moyennes des caractéristiques pour la classe : {class_name}")
+ plt.plot(mean_features)
+ plt.xlabel("Indice des caractéristiques")
+ plt.ylabel("Valeur moyenne")
+ plt.grid(True)
+ plt.show()
diff --git a/src/classifiers/kmeans.py b/src/classifiers/kmeans.py
new file mode 100644
index 0000000..90b3bbd
--- /dev/null
+++ b/src/classifiers/kmeans.py
@@ -0,0 +1,98 @@
+import numpy as np
+import os
+
+class KMeansClassifier:
+ def __init__(self, num_clusters=6, max_iter=100, tol=1e-4):
+ """
+ Initialiser le classifieur KMeans.
+
+ Paramètres :
+ - num_clusters : Nombre de clusters (classes).
+ - max_iter : Nombre maximal d'itérations pour l'algorithme k-means.
+ - tol : Tolérance pour la convergence.
+ """
+ self.num_clusters = num_clusters
+ self.max_iter = max_iter
+ self.tol = tol
+ self.cluster_centers_ = None
+ self.labels_ = None
+
+ def fit(self, features):
+ """
+ Entraîner le modèle k-means sur les données fournies.
+
+ Paramètres :
+ - features : Un tableau numpy de forme (n_samples, n_features).
+ """
+ if len(features) < self.num_clusters:
+ raise ValueError("Le nombre d'échantillons est inférieur au nombre de clusters.")
+
+ np.random.seed(42)
+ random_indices = np.random.choice(len(features), self.num_clusters, replace=False)
+ self.cluster_centers_ = features[random_indices]
+
+ for iteration in range(self.max_iter):
+ # Assigner des étiquettes en fonction du centre le plus proche
+ distances = self._compute_distances(features)
+ self.labels_ = np.argmin(distances, axis=1)
+
+ # Mettre à jour les centres des clusters
+ new_centers = np.array([features[self.labels_ == k].mean(axis=0) for k in range(self.num_clusters)])
+
+ # Vérifier la convergence
+ if np.all(np.abs(new_centers - self.cluster_centers_) < self.tol):
+ print(f"Convergence atteinte en {iteration + 1} itérations.")
+ break
+
+ self.cluster_centers_ = new_centers
+
+ def predict(self, features):
+ """
+ Prédire le cluster le plus proche pour les données fournies.
+
+ Paramètres :
+ - features : Un tableau numpy de forme (n_samples, n_features).
+
+ Retourne :
+ - Les étiquettes des clusters pour chaque échantillon.
+ """
+ distances = self._compute_distances(features)
+ return np.argmin(distances, axis=1)
+
+ def _compute_distances(self, features):
+ """
+ Calculer les distances entre les données et les centres des clusters.
+
+ Paramètres :
+ - features : Un tableau numpy de forme (n_samples, n_features).
+
+ Retourne :
+ - Un tableau numpy des distances de forme (n_samples, num_clusters).
+ """
+ return np.linalg.norm(features[:, np.newaxis] - self.cluster_centers_, axis=2)
+
+ def save_model(self, path):
+ """
+ Sauvegarder le modèle KMeans dans un fichier .npy.
+
+ Paramètres :
+ - path : Chemin pour sauvegarder le modèle.
+ """
+ if self.cluster_centers_ is None:
+ raise ValueError("Le modèle n'a pas encore été entraîné. Rien à sauvegarder.")
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ np.save(path, self.cluster_centers_)
+ print(f"Modèle KMeans sauvegardé à l'emplacement {path}")
+
+ def load_model(self, path):
+ """
+ Charger le modèle KMeans depuis un fichier .npy.
+
+ Paramètres :
+ - path : Chemin vers le fichier sauvegardé.
+ """
+ if os.path.exists(path):
+ self.cluster_centers_ = np.load(path)
+ print(f"Modèle KMeans chargé depuis {path}")
+ else:
+ raise FileNotFoundError(f"Fichier du modèle KMeans introuvable à l'emplacement {path}.")
diff --git a/src/pipeline.py b/src/pipeline.py
new file mode 100644
index 0000000..9b474b3
--- /dev/null
+++ b/src/pipeline.py
@@ -0,0 +1,117 @@
+
+import cv2
+import os
+from matplotlib import pyplot as plt
+from collections import defaultdict
+
+class ObjectDetectionPipeline:
+ def __init__(self, image_path, model=None, output_dir="output", mode="page", min_contour_area=20, binary_threshold=None):
+ self.image_path = image_path
+ self.image = None
+ self.binary_image = None
+ self.model = model
+ self.output_dir = output_dir
+ self.min_contour_area = min_contour_area
+ self.binary_threshold = binary_threshold
+ self.mode = mode # Le mode par défaut est "page"
+ self.annotated_output_path = os.path.join(self.output_dir, f"annotated_{os.path.basename(image_path)}")
+ self.threshold = -395000 if mode == "plan" else -65000
+
+ # Créez le dossier de sortie s'il n'existe pas
+ if not os.path.exists(self.output_dir):
+ os.makedirs(self.output_dir)
+
+ def set_mode(self, mode):
+ """Définir le mode de détection (page ou plan)."""
+ if mode not in ["page", "plan"]:
+ raise ValueError("Le mode doit être 'page' ou 'plan'.")
+ self.mode = mode
+ self.threshold = -395000 if mode == "plan" else -65000
+ print(f"Mode défini à : {self.mode}, Seuil défini à : {self.threshold}")
+
+ def load_image(self):
+ """Charger l'image à analyser."""
+ self.image = cv2.imread(self.image_path)
+ if self.image is None:
+ raise FileNotFoundError(f"L'image {self.image_path} est introuvable.")
+ return self.image
+
+ def preprocess_image(self):
+ """Effectuer un prétraitement de l'image pour obtenir une version binaire."""
+ channels = cv2.split(self.image)
+ binary_images = []
+
+ for channel in channels:
+ if self.binary_threshold is None:
+ _, binary_channel = cv2.threshold(channel, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
+ else:
+ _, binary_channel = cv2.threshold(channel, self.binary_threshold, 255, cv2.THRESH_BINARY_INV)
+ binary_images.append(binary_channel)
+
+ binary_image = cv2.bitwise_or(binary_images[0], binary_images[1])
+ binary_image = cv2.bitwise_or(binary_image, binary_images[2])
+ self.binary_image = binary_image
+ return binary_image
+
+ def detect_and_classify_objects(self):
+ """Détecter et classer les objets dans l'image."""
+ if self.model is None:
+ raise ValueError("Aucun modèle de classification n'a été fourni.")
+
+ self.binary_image = self.preprocess_image()
+ contours, _ = cv2.findContours(self.binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+
+ class_counts = defaultdict(int)
+ detected_objects = []
+ total_objects = 0
+ ignored_objects = 0
+ identified_objects = 0
+
+ for contour in contours:
+ total_objects += 1 # Incrémenter le compteur total des objets
+
+ if cv2.contourArea(contour) < self.min_contour_area:
+ ignored_objects += 1
+ continue
+
+ x, y, w, h = cv2.boundingRect(contour)
+ letter_image = self.image[y:y + h, x:x + w]
+
+ predicted_class = self.model.predict(letter_image, threshold=self.threshold)
+ if predicted_class is None:
+ ignored_objects += 1
+ continue
+
+ identified_objects += 1
+ class_counts[predicted_class] += 1
+ detected_objects.append((x, y, w, h, predicted_class))
+
+ return dict(sorted(class_counts.items())), detected_objects, total_objects, ignored_objects, identified_objects
+
+ def save_results(self, class_counts, detected_objects):
+ """Sauvegarder les résultats de la détection dans des fichiers."""
+ binary_output_path = os.path.join(self.output_dir, "binary_image.jpg")
+ cv2.imwrite(binary_output_path, self.binary_image)
+
+ annotated_image = self.image.copy()
+ for (x, y, w, h, predicted_class) in detected_objects:
+ cv2.rectangle(annotated_image, (x, y), (x + w, y + h), (0, 255, 0), 2)
+ cv2.putText(annotated_image, str(predicted_class), (x, y - 10),
+ cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
+ cv2.imwrite(self.annotated_output_path, annotated_image)
+
+ results_text_path = os.path.join(self.output_dir, "results.txt")
+ with open(results_text_path, "w") as f:
+ for class_name, count in class_counts.items():
+ f.write(f"{class_name}: {count}\n")
+
+ def display_results(self, class_counts, detected_objects):
+ """Afficher les résultats sous forme graphique."""
+ self.save_results(class_counts, detected_objects)
+
+ plt.figure(figsize=(10, 5))
+ plt.bar(class_counts.keys(), class_counts.values())
+ plt.xlabel("Classes")
+ plt.ylabel("Nombre d'objets")
+ plt.title("Répartition des classes détectées")
+ plt.show()
\ No newline at end of file
diff --git a/train.py b/train.py
new file mode 100644
index 0000000..e233175
--- /dev/null
+++ b/train.py
@@ -0,0 +1,68 @@
+import os
+import argparse
+from collections import defaultdict
+import numpy as np
+import cv2
+from src.classifiers.bayesian import BayesianClassifier
+
+if __name__ == "__main__":
+ # Analyse des arguments pour configurer le mode
+ parser = argparse.ArgumentParser(description="Entraîner le modèle Bayésien.")
+ parser.add_argument("--mode", type=str, choices=["page", "plan"], default="page",
+ help="Mode de fonctionnement : 'page' pour les pages ou 'plan' pour les plans.")
+ args = parser.parse_args()
+
+ # Définir les chemins en fonction du mode
+ mode = args.mode
+ dataset_path = f"data/catalogue{'' if mode == 'page' else 'Symbol'}"
+ allowed_classes = ['Figure1', 'Figure2', 'Figure3', 'Figure4', 'Figure5', 'Figure6'] \
+ if mode == "plan" else ['2', 'd', 'I', 'n', 'o', 'u']
+ model_path = f"models/bayesian_model{mode.upper()}.pth"
+
+ # Initialiser le classifieur Bayésien
+ bayesian_model = BayesianClassifier(mode=mode)
+
+ print("Lancement de l'entraînement...")
+
+ # Stockage des caractéristiques pour chaque classe
+ class_features = defaultdict(list)
+ total_images = 0
+
+ # Parcourir les classes dans le dataset
+ for class_name in os.listdir(dataset_path):
+ if class_name not in allowed_classes:
+ continue # Ignorer les classes non autorisées
+
+ class_folder_path = os.path.join(dataset_path, class_name)
+ if not os.path.isdir(class_folder_path):
+ continue # Ignorer les fichiers qui ne sont pas des dossiers
+
+ # Ajouter la classe au modèle si elle n'est pas encore listée
+ if class_name not in bayesian_model.classes:
+ bayesian_model.classes.append(class_name)
+
+ # Parcourir les images dans le dossier de la classe
+ for image_name in os.listdir(class_folder_path):
+ image_path = os.path.join(class_folder_path, image_name)
+ image = cv2.imread(image_path)
+
+ if image is not None:
+ # Extraire les caractéristiques de l'image
+ features = bayesian_model.extract_features(image)
+ for feature in features:
+ class_features[class_name].append(feature)
+ total_images += 1
+
+ # Calculer les statistiques pour chaque classe
+ for class_name in bayesian_model.classes:
+ if class_name in class_features:
+ features = np.array(class_features[class_name])
+ bayesian_model.feature_means[class_name] = np.mean(features, axis=0)
+ bayesian_model.feature_variances[class_name] = np.var(features, axis=0) + 1e-6 # Éviter la division par zéro
+ bayesian_model.class_priors[class_name] = len(features) / total_images
+
+ print("Entraînement terminé avec succès.")
+
+ # Sauvegarder le modèle entraîné
+ bayesian_model.save_model(model_path)
+ print(f"Le modèle a été sauvegardé à l'emplacement : {model_path}")
\ No newline at end of file