Final push
This commit is contained in:
parent
9a0b597309
commit
afb9de39c8
5 changed files with 196 additions and 70 deletions
52
main.py
52
main.py
|
@ -7,7 +7,7 @@ from src.classifiers.bayesian import BayesianClassifier
|
|||
analysis_mode = "plan"
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Configuration basée sur le mode
|
||||
# Configuration en fonction du mode sélectionné
|
||||
if analysis_mode == "plan":
|
||||
model_path = "models/bayesian_modelPLAN.pth"
|
||||
image_path = "data/plan.png"
|
||||
|
@ -15,31 +15,31 @@ if __name__ == "__main__":
|
|||
model_path = "models/bayesian_modelPAGE.pth"
|
||||
image_path = "data/page.png"
|
||||
|
||||
# Exécuter le script train.py avec le mode sélectionné
|
||||
print(f"Lancement de l'entraînement avec le mode {analysis_mode}...")
|
||||
# Lancement de l'entraînement avec le mode choisi
|
||||
print(f"Entraînement en cours avec le mode {analysis_mode}...")
|
||||
try:
|
||||
subprocess.run(["python", "train.py", "--mode", analysis_mode], check=True)
|
||||
print("Entraînement terminé.")
|
||||
print("Entraînement terminé avec succès.")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Erreur lors de l'entraînement : {e}")
|
||||
print(f"Une erreur s'est produite pendant l'entraînement : {e}")
|
||||
exit(1)
|
||||
|
||||
# Chargement du modèle bayésien
|
||||
print(f"Chargement du modèle bayésien depuis {model_path}")
|
||||
print(f"Chargement du modèle depuis {model_path}...")
|
||||
bayesian_model = BayesianClassifier(mode=analysis_mode)
|
||||
try:
|
||||
bayesian_model.load_model(model_path)
|
||||
print(f"Modèle bayésien chargé depuis {model_path}")
|
||||
print(f"Modèle chargé depuis {model_path}.")
|
||||
except Exception as e:
|
||||
print(f"Erreur lors du chargement du modèle : {e}")
|
||||
exit(1)
|
||||
|
||||
# Vérification de l'existence de l'image
|
||||
# Vérification de l'existence de l'image de test
|
||||
if not os.path.exists(image_path):
|
||||
print(f"L'image de test {image_path} n'existe pas.")
|
||||
print(f"L'image spécifiée ({image_path}) n'existe pas.")
|
||||
exit(1)
|
||||
|
||||
# Initialisation du dossier de sortie
|
||||
# Création du dossier de sortie si nécessaire
|
||||
output_dir = "output"
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
@ -48,10 +48,10 @@ if __name__ == "__main__":
|
|||
print("Initialisation de la pipeline...")
|
||||
pipeline = ObjectDetectionPipeline(image_path=image_path, model=bayesian_model, output_dir=output_dir)
|
||||
|
||||
# Définition du mode (plan ou page)
|
||||
# Configuration du mode d'analyse dans la pipeline
|
||||
pipeline.set_mode(analysis_mode)
|
||||
|
||||
# Chargement de l'image
|
||||
# Chargement de l'image de test
|
||||
print("Chargement de l'image...")
|
||||
try:
|
||||
pipeline.load_image()
|
||||
|
@ -60,20 +60,34 @@ if __name__ == "__main__":
|
|||
exit(1)
|
||||
|
||||
# Détection et classification des objets
|
||||
print("Détection et classification des objets...")
|
||||
print("Détection et classification des objets en cours...")
|
||||
try:
|
||||
class_counts, detected_objects, total_objects, ignored_objects, identified_objects = pipeline.detect_and_classify_objects()
|
||||
print(f"Classes détectées : {class_counts}")
|
||||
print("Résumé des objets :")
|
||||
print(f"- Objets totaux : {total_objects}")
|
||||
print(f"Objets détectés par classe : {class_counts}")
|
||||
print("Résumé de la détection :")
|
||||
print(f"- Nombre total d'objets : {total_objects}")
|
||||
print(f"- Objets identifiés : {identified_objects}")
|
||||
print(f"- Objets ignorés : {ignored_objects}")
|
||||
except Exception as e:
|
||||
print(f"Erreur lors de la détection/classification : {e}")
|
||||
print(f"Erreur pendant la détection/classification : {e}")
|
||||
exit(1)
|
||||
|
||||
# Sauvegarde et affichage des résultats
|
||||
# Sauvegarde et visualisation des résultats
|
||||
print("Sauvegarde et affichage des résultats...")
|
||||
pipeline.display_results(class_counts, detected_objects)
|
||||
|
||||
print(f"Les résultats ont été sauvegardés dans le dossier : {output_dir}")
|
||||
# Affichage de l'histogramme des classes détectées
|
||||
print("Affichage de l'histogramme des résultats...")
|
||||
try:
|
||||
pipeline.display_histogram(class_counts)
|
||||
except Exception as e:
|
||||
print(f"Erreur lors de l'affichage de l'histogramme : {e}")
|
||||
|
||||
# Affichage du nuage de points
|
||||
print("Affichage du nuage de points...")
|
||||
try:
|
||||
pipeline.display_scatter_plot(class_counts)
|
||||
except Exception as e:
|
||||
print(f"Erreur lors de l'affichage du nuage de points : {e}")
|
||||
|
||||
print(f"Tous les résultats sont sauvegardés dans le dossier : {output_dir}")
|
||||
|
|
|
@ -5,7 +5,6 @@ import torch
|
|||
from collections import defaultdict
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
class BayesianClassifier:
|
||||
def __init__(self, mode="page"):
|
||||
self.feature_means = {}
|
||||
|
@ -13,13 +12,14 @@ class BayesianClassifier:
|
|||
self.class_priors = {}
|
||||
self.classes = []
|
||||
|
||||
# Définir les classes autorisées selon le mode choisi
|
||||
self.allowed_classes = (
|
||||
['Figure1', 'Figure2', 'Figure3', 'Figure4', 'Figure5', 'Figure6']
|
||||
if mode == "plan"
|
||||
else ['2', 'd', 'I', 'n', 'o', 'u']
|
||||
)
|
||||
|
||||
# Initialize HOG descriptor with standard parameters
|
||||
# Initialisation du descripteur HOG avec des paramètres standards
|
||||
self.hog = cv2.HOGDescriptor(
|
||||
_winSize=(28, 28),
|
||||
_blockSize=(8, 8),
|
||||
|
@ -29,36 +29,37 @@ class BayesianClassifier:
|
|||
)
|
||||
|
||||
def extract_features(self, image):
|
||||
"""Extraire les caractéristiques d'une image donnée."""
|
||||
try:
|
||||
# Convert to grayscale if image is RGB
|
||||
# Convertir en niveaux de gris si l'image est en couleurs
|
||||
if len(image.shape) == 3 and image.shape[2] == 3:
|
||||
gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
else:
|
||||
gray_image = image
|
||||
|
||||
# Apply adaptive thresholding for segmentation
|
||||
# Appliquer un seuillage adaptatif pour la segmentation
|
||||
binary_image = cv2.adaptiveThreshold(
|
||||
gray_image, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY_INV, 11, 2
|
||||
)
|
||||
|
||||
# Find contours
|
||||
# Trouver les contours
|
||||
contours, _ = cv2.findContours(
|
||||
binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
|
||||
)
|
||||
if not contours:
|
||||
print("No contours found.")
|
||||
print("Aucun contour trouvé.")
|
||||
return np.array([])
|
||||
|
||||
features = []
|
||||
for contour in contours:
|
||||
if cv2.contourArea(contour) < 20: # Filter small areas
|
||||
if cv2.contourArea(contour) < 20: # Filtrer les petites zones
|
||||
continue
|
||||
|
||||
x, y, w, h = cv2.boundingRect(contour)
|
||||
letter_image = gray_image[y:y + h, x:x + w]
|
||||
letter_image = cv2.resize(letter_image, (28, 28))
|
||||
|
||||
# Compute HOG features
|
||||
# Calculer les descripteurs HOG
|
||||
hog_features = self.hog.compute(letter_image)
|
||||
features.append(hog_features.flatten())
|
||||
|
||||
|
@ -66,16 +67,17 @@ class BayesianClassifier:
|
|||
if features.size == 0:
|
||||
return np.array([])
|
||||
|
||||
# Normalize features
|
||||
# Normaliser les caractéristiques
|
||||
norms = np.linalg.norm(features, axis=1, keepdims=True)
|
||||
features = features / np.where(norms > 1e-6, norms, 1)
|
||||
|
||||
return features
|
||||
except Exception as e:
|
||||
print(f"Error in extract_features: {e}")
|
||||
print(f"Erreur dans l'extraction des caractéristiques : {e}")
|
||||
return np.array([])
|
||||
|
||||
def train(self, dataset_path):
|
||||
"""Entraîner le modèle Bayésien à partir d'un ensemble de données."""
|
||||
class_features = defaultdict(list)
|
||||
total_samples = 0
|
||||
|
||||
|
@ -99,21 +101,22 @@ class BayesianClassifier:
|
|||
class_features[class_name].append(feature)
|
||||
total_samples += len(features)
|
||||
else:
|
||||
print(f"No features extracted for {img_path}")
|
||||
print(f"Aucune caractéristique extraite pour {img_path}")
|
||||
else:
|
||||
print(f"Failed to load image: {img_path}")
|
||||
print(f"Échec du chargement de l'image : {img_path}")
|
||||
|
||||
# Compute means, variances, and priors
|
||||
# Calculer les moyennes, variances et probabilités a priori
|
||||
for class_name in self.classes:
|
||||
if class_name in class_features:
|
||||
features = np.array(class_features[class_name])
|
||||
self.feature_means[class_name] = np.mean(features, axis=0)
|
||||
self.feature_variances[class_name] = np.var(features, axis=0) + 1e-6 # Avoid zero variance
|
||||
self.feature_variances[class_name] = np.var(features, axis=0) + 1e-6 # Éviter une variance nulle
|
||||
self.class_priors[class_name] = len(features) / total_samples
|
||||
|
||||
print("Training completed for classes:", self.classes)
|
||||
print("Entraînement terminé pour les classes :", self.classes)
|
||||
|
||||
def save_model(self, model_path):
|
||||
"""Sauvegarder le modèle Bayésien sur le disque."""
|
||||
model_data = {
|
||||
"feature_means": self.feature_means,
|
||||
"feature_variances": self.feature_variances,
|
||||
|
@ -122,20 +125,22 @@ class BayesianClassifier:
|
|||
}
|
||||
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
||||
torch.save(model_data, model_path)
|
||||
print(f"Model saved to {model_path}")
|
||||
print(f"Modèle sauvegardé à l'emplacement {model_path}")
|
||||
|
||||
def load_model(self, model_path):
|
||||
"""Charger un modèle Bayésien sauvegardé."""
|
||||
if os.path.exists(model_path):
|
||||
model_data = torch.load(model_path)
|
||||
self.feature_means = model_data["feature_means"]
|
||||
self.feature_variances = model_data["feature_variances"]
|
||||
self.class_priors = model_data["class_priors"]
|
||||
self.classes = model_data["classes"]
|
||||
print(f"Model loaded from {model_path}")
|
||||
print(f"Modèle chargé depuis {model_path}")
|
||||
else:
|
||||
print(f"No model found at {model_path}.")
|
||||
print(f"Modèle introuvable à l'emplacement {model_path}.")
|
||||
|
||||
def predict(self, image, threshold=0.3):
|
||||
"""Prédire la classe d'une image donnée."""
|
||||
try:
|
||||
features = self.extract_features(image)
|
||||
if features.size == 0:
|
||||
|
@ -147,7 +152,7 @@ class BayesianClassifier:
|
|||
variance = self.feature_variances[class_name]
|
||||
prior = self.class_priors[class_name]
|
||||
|
||||
# Compute log-likelihood
|
||||
# Calculer la log-vraisemblance
|
||||
log_likelihood = -0.5 * np.sum(
|
||||
((features - mean) ** 2) / variance + np.log(2 * np.pi * variance),
|
||||
axis=1,
|
||||
|
@ -162,21 +167,22 @@ class BayesianClassifier:
|
|||
return None
|
||||
return max_class
|
||||
except Exception as e:
|
||||
print(f"Error in prediction: {e}")
|
||||
print(f"Erreur dans la prédiction : {e}")
|
||||
return None
|
||||
|
||||
def visualize(self):
|
||||
"""Visualiser les moyennes des caractéristiques pour chaque classe."""
|
||||
if not self.classes:
|
||||
print("No classes to visualize.")
|
||||
print("Aucune classe à visualiser.")
|
||||
return
|
||||
|
||||
for class_name in self.classes:
|
||||
mean_features = self.feature_means[class_name]
|
||||
|
||||
plt.figure(figsize=(10, 4))
|
||||
plt.title(f"Mean features for class: {class_name}")
|
||||
plt.title(f"Moyennes des caractéristiques pour la classe : {class_name}")
|
||||
plt.plot(mean_features)
|
||||
plt.xlabel("Feature Index")
|
||||
plt.ylabel("Mean Value")
|
||||
plt.xlabel("Indice des caractéristiques")
|
||||
plt.ylabel("Valeur moyenne")
|
||||
plt.grid(True)
|
||||
plt.show()
|
||||
|
|
98
src/classifiers/kmeans.py
Normal file
98
src/classifiers/kmeans.py
Normal file
|
@ -0,0 +1,98 @@
|
|||
import numpy as np
|
||||
import os
|
||||
|
||||
class KMeansClassifier:
|
||||
def __init__(self, num_clusters=6, max_iter=100, tol=1e-4):
|
||||
"""
|
||||
Initialiser le classifieur KMeans.
|
||||
|
||||
Paramètres :
|
||||
- num_clusters : Nombre de clusters (classes).
|
||||
- max_iter : Nombre maximal d'itérations pour l'algorithme k-means.
|
||||
- tol : Tolérance pour la convergence.
|
||||
"""
|
||||
self.num_clusters = num_clusters
|
||||
self.max_iter = max_iter
|
||||
self.tol = tol
|
||||
self.cluster_centers_ = None
|
||||
self.labels_ = None
|
||||
|
||||
def fit(self, features):
|
||||
"""
|
||||
Entraîner le modèle k-means sur les données fournies.
|
||||
|
||||
Paramètres :
|
||||
- features : Un tableau numpy de forme (n_samples, n_features).
|
||||
"""
|
||||
if len(features) < self.num_clusters:
|
||||
raise ValueError("Le nombre d'échantillons est inférieur au nombre de clusters.")
|
||||
|
||||
np.random.seed(42)
|
||||
random_indices = np.random.choice(len(features), self.num_clusters, replace=False)
|
||||
self.cluster_centers_ = features[random_indices]
|
||||
|
||||
for iteration in range(self.max_iter):
|
||||
# Assigner des étiquettes en fonction du centre le plus proche
|
||||
distances = self._compute_distances(features)
|
||||
self.labels_ = np.argmin(distances, axis=1)
|
||||
|
||||
# Mettre à jour les centres des clusters
|
||||
new_centers = np.array([features[self.labels_ == k].mean(axis=0) for k in range(self.num_clusters)])
|
||||
|
||||
# Vérifier la convergence
|
||||
if np.all(np.abs(new_centers - self.cluster_centers_) < self.tol):
|
||||
print(f"Convergence atteinte en {iteration + 1} itérations.")
|
||||
break
|
||||
|
||||
self.cluster_centers_ = new_centers
|
||||
|
||||
def predict(self, features):
|
||||
"""
|
||||
Prédire le cluster le plus proche pour les données fournies.
|
||||
|
||||
Paramètres :
|
||||
- features : Un tableau numpy de forme (n_samples, n_features).
|
||||
|
||||
Retourne :
|
||||
- Les étiquettes des clusters pour chaque échantillon.
|
||||
"""
|
||||
distances = self._compute_distances(features)
|
||||
return np.argmin(distances, axis=1)
|
||||
|
||||
def _compute_distances(self, features):
|
||||
"""
|
||||
Calculer les distances entre les données et les centres des clusters.
|
||||
|
||||
Paramètres :
|
||||
- features : Un tableau numpy de forme (n_samples, n_features).
|
||||
|
||||
Retourne :
|
||||
- Un tableau numpy des distances de forme (n_samples, num_clusters).
|
||||
"""
|
||||
return np.linalg.norm(features[:, np.newaxis] - self.cluster_centers_, axis=2)
|
||||
|
||||
def save_model(self, path):
|
||||
"""
|
||||
Sauvegarder le modèle KMeans dans un fichier .npy.
|
||||
|
||||
Paramètres :
|
||||
- path : Chemin pour sauvegarder le modèle.
|
||||
"""
|
||||
if self.cluster_centers_ is None:
|
||||
raise ValueError("Le modèle n'a pas encore été entraîné. Rien à sauvegarder.")
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
np.save(path, self.cluster_centers_)
|
||||
print(f"Modèle KMeans sauvegardé à l'emplacement {path}")
|
||||
|
||||
def load_model(self, path):
|
||||
"""
|
||||
Charger le modèle KMeans depuis un fichier .npy.
|
||||
|
||||
Paramètres :
|
||||
- path : Chemin vers le fichier sauvegardé.
|
||||
"""
|
||||
if os.path.exists(path):
|
||||
self.cluster_centers_ = np.load(path)
|
||||
print(f"Modèle KMeans chargé depuis {path}")
|
||||
else:
|
||||
raise FileNotFoundError(f"Fichier du modèle KMeans introuvable à l'emplacement {path}.")
|
|
@ -1,9 +1,9 @@
|
|||
|
||||
import cv2
|
||||
import os
|
||||
from matplotlib import pyplot as plt
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
class ObjectDetectionPipeline:
|
||||
def __init__(self, image_path, model=None, output_dir="output", mode="page", min_contour_area=20, binary_threshold=None):
|
||||
self.image_path = image_path
|
||||
|
@ -13,28 +13,31 @@ class ObjectDetectionPipeline:
|
|||
self.output_dir = output_dir
|
||||
self.min_contour_area = min_contour_area
|
||||
self.binary_threshold = binary_threshold
|
||||
self.mode = mode # Default mode is "page"
|
||||
self.mode = mode # Le mode par défaut est "page"
|
||||
self.annotated_output_path = os.path.join(self.output_dir, f"annotated_{os.path.basename(image_path)}")
|
||||
self.threshold = -395000 if mode == "plan" else -65000
|
||||
|
||||
# Créez le dossier de sortie s'il n'existe pas
|
||||
if not os.path.exists(self.output_dir):
|
||||
os.makedirs(self.output_dir)
|
||||
|
||||
def set_mode(self, mode):
|
||||
"""Set the detection mode (page or plan)."""
|
||||
"""Définir le mode de détection (page ou plan)."""
|
||||
if mode not in ["page", "plan"]:
|
||||
raise ValueError("Mode must be 'page' or 'plan'.")
|
||||
raise ValueError("Le mode doit être 'page' ou 'plan'.")
|
||||
self.mode = mode
|
||||
self.threshold = -395000 if mode == "plan" else -65000
|
||||
print(f"Mode set to: {self.mode}, Threshold set to: {self.threshold}")
|
||||
print(f"Mode défini à : {self.mode}, Seuil défini à : {self.threshold}")
|
||||
|
||||
def load_image(self):
|
||||
"""Charger l'image à analyser."""
|
||||
self.image = cv2.imread(self.image_path)
|
||||
if self.image is None:
|
||||
raise FileNotFoundError(f"Image {self.image_path} not found.")
|
||||
raise FileNotFoundError(f"L'image {self.image_path} est introuvable.")
|
||||
return self.image
|
||||
|
||||
def preprocess_image(self):
|
||||
"""Effectuer un prétraitement de l'image pour obtenir une version binaire."""
|
||||
channels = cv2.split(self.image)
|
||||
binary_images = []
|
||||
|
||||
|
@ -51,8 +54,9 @@ class ObjectDetectionPipeline:
|
|||
return binary_image
|
||||
|
||||
def detect_and_classify_objects(self):
|
||||
"""Détecter et classer les objets dans l'image."""
|
||||
if self.model is None:
|
||||
raise ValueError("No classification model provided.")
|
||||
raise ValueError("Aucun modèle de classification n'a été fourni.")
|
||||
|
||||
self.binary_image = self.preprocess_image()
|
||||
contours, _ = cv2.findContours(self.binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
@ -64,7 +68,7 @@ class ObjectDetectionPipeline:
|
|||
identified_objects = 0
|
||||
|
||||
for contour in contours:
|
||||
total_objects += 1 # Compteur total des objets
|
||||
total_objects += 1 # Incrémenter le compteur total des objets
|
||||
|
||||
if cv2.contourArea(contour) < self.min_contour_area:
|
||||
ignored_objects += 1
|
||||
|
@ -85,6 +89,7 @@ class ObjectDetectionPipeline:
|
|||
return dict(sorted(class_counts.items())), detected_objects, total_objects, ignored_objects, identified_objects
|
||||
|
||||
def save_results(self, class_counts, detected_objects):
|
||||
"""Sauvegarder les résultats de la détection dans des fichiers."""
|
||||
binary_output_path = os.path.join(self.output_dir, "binary_image.jpg")
|
||||
cv2.imwrite(binary_output_path, self.binary_image)
|
||||
|
||||
|
@ -101,11 +106,12 @@ class ObjectDetectionPipeline:
|
|||
f.write(f"{class_name}: {count}\n")
|
||||
|
||||
def display_results(self, class_counts, detected_objects):
|
||||
"""Afficher les résultats sous forme graphique."""
|
||||
self.save_results(class_counts, detected_objects)
|
||||
|
||||
plt.figure(figsize=(10, 5))
|
||||
plt.bar(class_counts.keys(), class_counts.values())
|
||||
plt.xlabel("Classes")
|
||||
plt.ylabel("Object count")
|
||||
plt.title("Detected Class Distribution")
|
||||
plt.show()
|
||||
plt.ylabel("Nombre d'objets")
|
||||
plt.title("Répartition des classes détectées")
|
||||
plt.show()
|
34
train.py
34
train.py
|
@ -1,32 +1,34 @@
|
|||
import os
|
||||
import argparse # Ajouté pour les arguments
|
||||
import argparse
|
||||
from collections import defaultdict
|
||||
import numpy as np
|
||||
import cv2
|
||||
from src.classifiers.bayesian import BayesianClassifier
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Analyse des arguments
|
||||
parser = argparse.ArgumentParser(description="Train Bayesian model.")
|
||||
parser.add_argument("--mode", type=str, choices=["page", "plan"], default="page", help="Mode de fonctionnement : 'page' ou 'plan'.")
|
||||
# Analyse des arguments pour configurer le mode
|
||||
parser = argparse.ArgumentParser(description="Entraîner le modèle Bayésien.")
|
||||
parser.add_argument("--mode", type=str, choices=["page", "plan"], default="page",
|
||||
help="Mode de fonctionnement : 'page' pour les pages ou 'plan' pour les plans.")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Configuration en fonction du mode
|
||||
# Définir les chemins en fonction du mode
|
||||
mode = args.mode
|
||||
dataset_path = f"data/catalogue{'' if mode == 'page' else 'Symbol'}"
|
||||
allowed_classes = ['Figure1', 'Figure2', 'Figure3', 'Figure4', 'Figure5', 'Figure6'] if mode == "plan" else ['2', 'd', 'I', 'n', 'o', 'u']
|
||||
allowed_classes = ['Figure1', 'Figure2', 'Figure3', 'Figure4', 'Figure5', 'Figure6'] \
|
||||
if mode == "plan" else ['2', 'd', 'I', 'n', 'o', 'u']
|
||||
model_path = f"models/bayesian_model{mode.upper()}.pth"
|
||||
|
||||
# Initialisation du classifieur Bayésien
|
||||
# Initialiser le classifieur Bayésien
|
||||
bayesian_model = BayesianClassifier(mode=mode)
|
||||
|
||||
print("Début de l'entraînement...")
|
||||
print("Lancement de l'entraînement...")
|
||||
|
||||
# Dictionnaire pour stocker les caractéristiques par classe
|
||||
# Stockage des caractéristiques pour chaque classe
|
||||
class_features = defaultdict(list)
|
||||
total_images = 0
|
||||
|
||||
# Parcours des classes dans le dataset
|
||||
# Parcourir les classes dans le dataset
|
||||
for class_name in os.listdir(dataset_path):
|
||||
if class_name not in allowed_classes:
|
||||
continue # Ignorer les classes non autorisées
|
||||
|
@ -35,11 +37,11 @@ if __name__ == "__main__":
|
|||
if not os.path.isdir(class_folder_path):
|
||||
continue # Ignorer les fichiers qui ne sont pas des dossiers
|
||||
|
||||
# Ajouter la classe au modèle si elle n'existe pas déjà
|
||||
# Ajouter la classe au modèle si elle n'est pas encore listée
|
||||
if class_name not in bayesian_model.classes:
|
||||
bayesian_model.classes.append(class_name)
|
||||
|
||||
# Parcours des images dans le dossier de la classe
|
||||
# Parcourir les images dans le dossier de la classe
|
||||
for image_name in os.listdir(class_folder_path):
|
||||
image_path = os.path.join(class_folder_path, image_name)
|
||||
image = cv2.imread(image_path)
|
||||
|
@ -51,7 +53,7 @@ if __name__ == "__main__":
|
|||
class_features[class_name].append(feature)
|
||||
total_images += 1
|
||||
|
||||
# Calcul des statistiques pour chaque classe
|
||||
# Calculer les statistiques pour chaque classe
|
||||
for class_name in bayesian_model.classes:
|
||||
if class_name in class_features:
|
||||
features = np.array(class_features[class_name])
|
||||
|
@ -59,8 +61,8 @@ if __name__ == "__main__":
|
|||
bayesian_model.feature_variances[class_name] = np.var(features, axis=0) + 1e-6 # Éviter la division par zéro
|
||||
bayesian_model.class_priors[class_name] = len(features) / total_images
|
||||
|
||||
print("Entraînement terminé.")
|
||||
print("Entraînement terminé avec succès.")
|
||||
|
||||
# Sauvegarde du modèle entraîné
|
||||
# Sauvegarder le modèle entraîné
|
||||
bayesian_model.save_model(model_path)
|
||||
print(f"Modèle sauvegardé dans : {model_path}")
|
||||
print(f"Le modèle a été sauvegardé à l'emplacement : {model_path}")
|
Loading…
Add table
Reference in a new issue