Bayesian final version + choix du fichier analysé
This commit is contained in:
		
							parent
							
								
									7abdb91d06
								
							
						
					
					
						commit
						922b9acf18
					
				
							
								
								
									
										24
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										24
									
								
								main.py
									
									
									
									
									
								
							| @ -1,16 +1,22 @@ | |||||||
| import os | import os | ||||||
| import cv2 |  | ||||||
| from src.pipeline import ObjectDetectionPipeline | from src.pipeline import ObjectDetectionPipeline | ||||||
| from src.classifiers.bayesian import BayesianClassifier | from src.classifiers.bayesian import BayesianClassifier | ||||||
| from collections import defaultdict | 
 | ||||||
|  | # Définissez le mode d'analyse ici : "plan" ou "page" | ||||||
|  | analysis_mode = "plan" | ||||||
| 
 | 
 | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     # Chemin vers le modèle entraîné |     # Configuration basée sur le mode | ||||||
|     model_path = "models/bayesian_modelPAGE.pth" |     if analysis_mode == "plan": | ||||||
|  |         model_path = "models/bayesian_modelPLAN.pth" | ||||||
|  |         image_path = "data/plan.png" | ||||||
|  |     else: | ||||||
|  |         model_path = "models/bayesian_modelPAGE.pth" | ||||||
|  |         image_path = "data/page.png" | ||||||
| 
 | 
 | ||||||
|     # Chargement du modèle bayésien |     # Chargement du modèle bayésien | ||||||
|     print(f"Chargement du modèle bayésien depuis {model_path}") |     print(f"Chargement du modèle bayésien depuis {model_path}") | ||||||
|     bayesian_model = BayesianClassifier() |     bayesian_model = BayesianClassifier(mode=analysis_mode) | ||||||
|     try: |     try: | ||||||
|         bayesian_model.load_model(model_path) |         bayesian_model.load_model(model_path) | ||||||
|         print(f"Modèle bayésien chargé depuis {model_path}") |         print(f"Modèle bayésien chargé depuis {model_path}") | ||||||
| @ -18,8 +24,7 @@ if __name__ == "__main__": | |||||||
|         print(f"Erreur lors du chargement du modèle : {e}") |         print(f"Erreur lors du chargement du modèle : {e}") | ||||||
|         exit(1) |         exit(1) | ||||||
| 
 | 
 | ||||||
|     # Chemin de l'image de test |     # Vérification de l'existence de l'image | ||||||
|     image_path = "data/page.png" |  | ||||||
|     if not os.path.exists(image_path): |     if not os.path.exists(image_path): | ||||||
|         print(f"L'image de test {image_path} n'existe pas.") |         print(f"L'image de test {image_path} n'existe pas.") | ||||||
|         exit(1) |         exit(1) | ||||||
| @ -33,6 +38,9 @@ if __name__ == "__main__": | |||||||
|     print("Initialisation de la pipeline...") |     print("Initialisation de la pipeline...") | ||||||
|     pipeline = ObjectDetectionPipeline(image_path=image_path, model=bayesian_model, output_dir=output_dir) |     pipeline = ObjectDetectionPipeline(image_path=image_path, model=bayesian_model, output_dir=output_dir) | ||||||
| 
 | 
 | ||||||
|  |     # Définition du mode (plan ou page) | ||||||
|  |     pipeline.set_mode(analysis_mode) | ||||||
|  | 
 | ||||||
|     # Chargement de l'image |     # Chargement de l'image | ||||||
|     print("Chargement de l'image...") |     print("Chargement de l'image...") | ||||||
|     try: |     try: | ||||||
| @ -45,7 +53,7 @@ if __name__ == "__main__": | |||||||
|     print("Détection et classification des objets...") |     print("Détection et classification des objets...") | ||||||
|     try: |     try: | ||||||
|         class_counts, detected_objects = pipeline.detect_and_classify_objects() |         class_counts, detected_objects = pipeline.detect_and_classify_objects() | ||||||
|         print("Classes détectées :", class_counts)  # Added debug info |         print("Classes détectées :", class_counts) | ||||||
|     except Exception as e: |     except Exception as e: | ||||||
|         print(f"Erreur lors de la détection/classification : {e}") |         print(f"Erreur lors de la détection/classification : {e}") | ||||||
|         exit(1) |         exit(1) | ||||||
|  | |||||||
| @ -7,43 +7,51 @@ import matplotlib.pyplot as plt | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class BayesianClassifier: | class BayesianClassifier: | ||||||
|     def __init__(self): |     def __init__(self, mode="page"): | ||||||
|         self.feature_means = {} |         self.feature_means = {} | ||||||
|         self.feature_variances = {} |         self.feature_variances = {} | ||||||
|         self.class_priors = {} |         self.class_priors = {} | ||||||
|         self.classes = [] |         self.classes = [] | ||||||
| 
 | 
 | ||||||
|  |         self.allowed_classes = ( | ||||||
|  |             ['Figure1', 'Figure2', 'Figure3', 'Figure4', 'Figure5', 'Figure6'] | ||||||
|  |             if mode == "plan" | ||||||
|  |             else ['2', 'd', 'I', 'n', 'o', 'u'] | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|         # Initialize HOG descriptor with standard parameters |         # Initialize HOG descriptor with standard parameters | ||||||
|         self.hog = cv2.HOGDescriptor( |         self.hog = cv2.HOGDescriptor( | ||||||
|             _winSize=(28, 28), |             _winSize=(28, 28), | ||||||
|             _blockSize=(8, 8), |             _blockSize=(8, 8), | ||||||
|             _blockStride=(4, 4), |             _blockStride=(4, 4), | ||||||
|             _cellSize=(8, 8), |             _cellSize=(8, 8), | ||||||
|             _nbins=9 |             _nbins=9, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|     def extract_features(self, image): |     def extract_features(self, image): | ||||||
|         try: |         try: | ||||||
|             # Convert image to grayscale |             # Convert to grayscale if image is RGB | ||||||
|             if len(image.shape) == 3 and image.shape[2] == 3: |             if len(image.shape) == 3 and image.shape[2] == 3: | ||||||
|                 gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) |                 gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | ||||||
|             else: |             else: | ||||||
|                 gray_image = image |                 gray_image = image | ||||||
| 
 | 
 | ||||||
|             # Apply adaptive thresholding for better segmentation |             # Apply adaptive thresholding for segmentation | ||||||
|             binary_image = cv2.adaptiveThreshold( |             binary_image = cv2.adaptiveThreshold( | ||||||
|                 gray_image, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY_INV, 11, 2 |                 gray_image, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY_INV, 11, 2 | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|             # Find contours |             # Find contours | ||||||
|             contours, _ = cv2.findContours(binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |             contours, _ = cv2.findContours( | ||||||
|  |                 binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE | ||||||
|  |             ) | ||||||
|             if not contours: |             if not contours: | ||||||
|                 print("No contours found.") |                 print("No contours found.") | ||||||
|                 return np.array([]) |                 return np.array([]) | ||||||
| 
 | 
 | ||||||
|             features = [] |             features = [] | ||||||
|             for contour in contours: |             for contour in contours: | ||||||
|                 if cv2.contourArea(contour) < 20:  # Lowered area threshold |                 if cv2.contourArea(contour) < 20:  # Filter small areas | ||||||
|                     continue |                     continue | ||||||
| 
 | 
 | ||||||
|                 x, y, w, h = cv2.boundingRect(contour) |                 x, y, w, h = cv2.boundingRect(contour) | ||||||
| @ -59,7 +67,7 @@ class BayesianClassifier: | |||||||
|                 print("No features extracted.") |                 print("No features extracted.") | ||||||
|                 return np.array([]) |                 return np.array([]) | ||||||
| 
 | 
 | ||||||
|             # Normalize features for better consistency |             # Normalize features | ||||||
|             norms = np.linalg.norm(features, axis=1, keepdims=True) |             norms = np.linalg.norm(features, axis=1, keepdims=True) | ||||||
|             features = features / np.where(norms > 1e-6, norms, 1) |             features = features / np.where(norms > 1e-6, norms, 1) | ||||||
| 
 | 
 | ||||||
| @ -70,11 +78,10 @@ class BayesianClassifier: | |||||||
| 
 | 
 | ||||||
|     def train(self, dataset_path): |     def train(self, dataset_path): | ||||||
|         class_features = defaultdict(list) |         class_features = defaultdict(list) | ||||||
|         total_images = 0 |         total_samples = 0 | ||||||
| 
 | 
 | ||||||
|         allowed_classes = ['2', 'd', 'I', 'n', 'o', 'u']  # Modifiez selon vos besoins |  | ||||||
|         for class_name in os.listdir(dataset_path): |         for class_name in os.listdir(dataset_path): | ||||||
|             if class_name not in allowed_classes: |             if class_name not in self.allowed_classes: | ||||||
|                 continue |                 continue | ||||||
| 
 | 
 | ||||||
|             class_folder_path = os.path.join(dataset_path, class_name) |             class_folder_path = os.path.join(dataset_path, class_name) | ||||||
| @ -85,27 +92,25 @@ class BayesianClassifier: | |||||||
|                 for img_name in os.listdir(class_folder_path): |                 for img_name in os.listdir(class_folder_path): | ||||||
|                     img_path = os.path.join(class_folder_path, img_name) |                     img_path = os.path.join(class_folder_path, img_name) | ||||||
|                     if os.path.isfile(img_path): |                     if os.path.isfile(img_path): | ||||||
|                         try: |                         image = cv2.imread(img_path) | ||||||
|                             image = cv2.imread(img_path) |                         if image is not None: | ||||||
|                             if image is not None: |                             features = self.extract_features(image) | ||||||
|                                 features = self.extract_features(image) |                             if features.size > 0: | ||||||
|                                 if features.size > 0: |                                 for feature in features: | ||||||
|                                     for feature in features: |                                     class_features[class_name].append(feature) | ||||||
|                                         class_features[class_name].append(feature) |                                 total_samples += len(features) | ||||||
|                                     total_images += 1 |  | ||||||
|                                 else: |  | ||||||
|                                     print(f"No features extracted for {img_path}") |  | ||||||
|                             else: |                             else: | ||||||
|                                 print(f"Failed to load image: {img_path}") |                                 print(f"No features extracted for {img_path}") | ||||||
|                         except Exception as e: |                         else: | ||||||
|                             print(f"Error processing {img_path}: {e}") |                             print(f"Failed to load image: {img_path}") | ||||||
| 
 | 
 | ||||||
|  |         # Compute means, variances, and priors | ||||||
|         for class_name in self.classes: |         for class_name in self.classes: | ||||||
|             if class_name in class_features: |             if class_name in class_features: | ||||||
|                 features = np.array(class_features[class_name]) |                 features = np.array(class_features[class_name]) | ||||||
|                 self.feature_means[class_name] = np.mean(features, axis=0) |                 self.feature_means[class_name] = np.mean(features, axis=0) | ||||||
|                 self.feature_variances[class_name] = np.var(features, axis=0) + 1e-6 |                 self.feature_variances[class_name] = np.var(features, axis=0) + 1e-6  # Avoid zero variance | ||||||
|                 self.class_priors[class_name] = len(features) / total_images |                 self.class_priors[class_name] = len(features) / total_samples | ||||||
| 
 | 
 | ||||||
|         print("Training completed for classes:", self.classes) |         print("Training completed for classes:", self.classes) | ||||||
| 
 | 
 | ||||||
| @ -114,16 +119,15 @@ class BayesianClassifier: | |||||||
|             "feature_means": self.feature_means, |             "feature_means": self.feature_means, | ||||||
|             "feature_variances": self.feature_variances, |             "feature_variances": self.feature_variances, | ||||||
|             "class_priors": self.class_priors, |             "class_priors": self.class_priors, | ||||||
|             "classes": self.classes |             "classes": self.classes, | ||||||
|         } |         } | ||||||
|         if not os.path.exists(os.path.dirname(model_path)): |         os.makedirs(os.path.dirname(model_path), exist_ok=True) | ||||||
|             os.makedirs(os.path.dirname(model_path)) |  | ||||||
|         torch.save(model_data, model_path) |         torch.save(model_data, model_path) | ||||||
|         print(f"Model saved to {model_path}") |         print(f"Model saved to {model_path}") | ||||||
| 
 | 
 | ||||||
|     def load_model(self, model_path): |     def load_model(self, model_path): | ||||||
|         if os.path.exists(model_path): |         if os.path.exists(model_path): | ||||||
|             model_data = torch.load(model_path, weights_only=False) |             model_data = torch.load(model_path) | ||||||
|             self.feature_means = model_data["feature_means"] |             self.feature_means = model_data["feature_means"] | ||||||
|             self.feature_variances = model_data["feature_variances"] |             self.feature_variances = model_data["feature_variances"] | ||||||
|             self.class_priors = model_data["class_priors"] |             self.class_priors = model_data["class_priors"] | ||||||
| @ -132,7 +136,7 @@ class BayesianClassifier: | |||||||
|         else: |         else: | ||||||
|             print(f"No model found at {model_path}.") |             print(f"No model found at {model_path}.") | ||||||
| 
 | 
 | ||||||
|     def predict(self, image, threshold=0.3):  # Lowered threshold |     def predict(self, image, threshold=0.3): | ||||||
|         try: |         try: | ||||||
|             features = self.extract_features(image) |             features = self.extract_features(image) | ||||||
|             if features.size == 0: |             if features.size == 0: | ||||||
| @ -145,14 +149,17 @@ class BayesianClassifier: | |||||||
|                 variance = self.feature_variances[class_name] |                 variance = self.feature_variances[class_name] | ||||||
|                 prior = self.class_priors[class_name] |                 prior = self.class_priors[class_name] | ||||||
| 
 | 
 | ||||||
|                 likelihood = -0.5 * np.sum(((features - mean) ** 2) / variance + np.log(2 * np.pi * variance)) |                 # Compute log-likelihood | ||||||
|                 posterior = likelihood + np.log(prior) |                 log_likelihood = -0.5 * np.sum( | ||||||
|                 posteriors[class_name] = posterior |                     ((features - mean) ** 2) / variance + np.log(2 * np.pi * variance), | ||||||
|  |                     axis=1, | ||||||
|  |                 ) | ||||||
|  |                 posterior = log_likelihood + np.log(prior) | ||||||
|  |                 posteriors[class_name] = np.sum(posterior) | ||||||
| 
 | 
 | ||||||
|             max_class = max(posteriors, key=posteriors.get) |             max_class = max(posteriors, key=posteriors.get) | ||||||
|             max_posterior = posteriors[max_class] |             max_posterior = posteriors[max_class] | ||||||
| 
 | 
 | ||||||
|             print(f"Class: {max_class}, Posterior: {max_posterior}")  # Added debug info |  | ||||||
|             if max_posterior < threshold: |             if max_posterior < threshold: | ||||||
|                 return None |                 return None | ||||||
|             return max_class |             return max_class | ||||||
|  | |||||||
| @ -5,8 +5,7 @@ from collections import defaultdict | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class ObjectDetectionPipeline: | class ObjectDetectionPipeline: | ||||||
|     def __init__(self, image_path, model=None, output_dir="output", min_contour_area=20, binary_threshold=None): |     def __init__(self, image_path, model=None, output_dir="output", mode="page", min_contour_area=20, binary_threshold=None): | ||||||
|         # Initialize the object detection pipeline |  | ||||||
|         self.image_path = image_path |         self.image_path = image_path | ||||||
|         self.image = None |         self.image = None | ||||||
|         self.binary_image = None |         self.binary_image = None | ||||||
| @ -14,19 +13,28 @@ class ObjectDetectionPipeline: | |||||||
|         self.output_dir = output_dir |         self.output_dir = output_dir | ||||||
|         self.min_contour_area = min_contour_area |         self.min_contour_area = min_contour_area | ||||||
|         self.binary_threshold = binary_threshold |         self.binary_threshold = binary_threshold | ||||||
|  |         self.mode = mode  # Default mode is "page" | ||||||
|  |         self.annotated_output_path = os.path.join(self.output_dir, f"annotated_{os.path.basename(image_path)}") | ||||||
|  |         self.threshold = -395000 if mode == "plan" else -65000 | ||||||
| 
 | 
 | ||||||
|         if not os.path.exists(self.output_dir): |         if not os.path.exists(self.output_dir): | ||||||
|             os.makedirs(self.output_dir) |             os.makedirs(self.output_dir) | ||||||
| 
 | 
 | ||||||
|  |     def set_mode(self, mode): | ||||||
|  |         """Set the detection mode (page or plan).""" | ||||||
|  |         if mode not in ["page", "plan"]: | ||||||
|  |             raise ValueError("Mode must be 'page' or 'plan'.") | ||||||
|  |         self.mode = mode | ||||||
|  |         self.threshold = -395000 if mode == "plan" else -65000 | ||||||
|  |         print(f"Mode set to: {self.mode}, Threshold set to: {self.threshold}") | ||||||
|  | 
 | ||||||
|     def load_image(self): |     def load_image(self): | ||||||
|         # Load the specified image |  | ||||||
|         self.image = cv2.imread(self.image_path) |         self.image = cv2.imread(self.image_path) | ||||||
|         if self.image is None: |         if self.image is None: | ||||||
|             raise FileNotFoundError(f"Image {self.image_path} not found.") |             raise FileNotFoundError(f"Image {self.image_path} not found.") | ||||||
|         return self.image |         return self.image | ||||||
| 
 | 
 | ||||||
|     def preprocess_image(self): |     def preprocess_image(self): | ||||||
|         # Preprocess the image for inference |  | ||||||
|         channels = cv2.split(self.image) |         channels = cv2.split(self.image) | ||||||
|         binary_images = [] |         binary_images = [] | ||||||
| 
 | 
 | ||||||
| @ -43,7 +51,6 @@ class ObjectDetectionPipeline: | |||||||
|         return binary_image |         return binary_image | ||||||
| 
 | 
 | ||||||
|     def detect_and_classify_objects(self): |     def detect_and_classify_objects(self): | ||||||
|         # Detect and classify objects in the image |  | ||||||
|         if self.model is None: |         if self.model is None: | ||||||
|             raise ValueError("No classification model provided.") |             raise ValueError("No classification model provided.") | ||||||
| 
 | 
 | ||||||
| @ -60,7 +67,7 @@ class ObjectDetectionPipeline: | |||||||
|             x, y, w, h = cv2.boundingRect(contour) |             x, y, w, h = cv2.boundingRect(contour) | ||||||
|             letter_image = self.image[y:y + h, x:x + w] |             letter_image = self.image[y:y + h, x:x + w] | ||||||
| 
 | 
 | ||||||
|             predicted_class = self.model.predict(letter_image, threshold=-65000)  # Adjusted threshold |             predicted_class = self.model.predict(letter_image, threshold=self.threshold) | ||||||
|             if predicted_class is None: |             if predicted_class is None: | ||||||
|                 print("Object ignored due to low resemblance.") |                 print("Object ignored due to low resemblance.") | ||||||
|                 continue |                 continue | ||||||
| @ -71,7 +78,6 @@ class ObjectDetectionPipeline: | |||||||
|         return dict(sorted(class_counts.items())), detected_objects |         return dict(sorted(class_counts.items())), detected_objects | ||||||
| 
 | 
 | ||||||
|     def save_results(self, class_counts, detected_objects): |     def save_results(self, class_counts, detected_objects): | ||||||
|         # Save detection and classification results |  | ||||||
|         binary_output_path = os.path.join(self.output_dir, "binary_image.jpg") |         binary_output_path = os.path.join(self.output_dir, "binary_image.jpg") | ||||||
|         cv2.imwrite(binary_output_path, self.binary_image) |         cv2.imwrite(binary_output_path, self.binary_image) | ||||||
| 
 | 
 | ||||||
| @ -80,8 +86,7 @@ class ObjectDetectionPipeline: | |||||||
|             cv2.rectangle(annotated_image, (x, y), (x + w, y + h), (0, 255, 0), 2) |             cv2.rectangle(annotated_image, (x, y), (x + w, y + h), (0, 255, 0), 2) | ||||||
|             cv2.putText(annotated_image, str(predicted_class), (x, y - 10), |             cv2.putText(annotated_image, str(predicted_class), (x, y - 10), | ||||||
|                         cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2) |                         cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2) | ||||||
|         annotated_output_path = os.path.join(self.output_dir, "annotated_page.jpg") |         cv2.imwrite(self.annotated_output_path, annotated_image) | ||||||
|         cv2.imwrite(annotated_output_path, annotated_image) |  | ||||||
| 
 | 
 | ||||||
|         results_text_path = os.path.join(self.output_dir, "results.txt") |         results_text_path = os.path.join(self.output_dir, "results.txt") | ||||||
|         with open(results_text_path, "w") as f: |         with open(results_text_path, "w") as f: | ||||||
| @ -89,7 +94,6 @@ class ObjectDetectionPipeline: | |||||||
|                 f.write(f"{class_name}: {count}\n") |                 f.write(f"{class_name}: {count}\n") | ||||||
| 
 | 
 | ||||||
|     def display_results(self, class_counts, detected_objects): |     def display_results(self, class_counts, detected_objects): | ||||||
|         # Display and save the results |  | ||||||
|         self.save_results(class_counts, detected_objects) |         self.save_results(class_counts, detected_objects) | ||||||
| 
 | 
 | ||||||
|         plt.figure(figsize=(10, 5)) |         plt.figure(figsize=(10, 5)) | ||||||
|  | |||||||
							
								
								
									
										20
									
								
								train.py
									
									
									
									
									
								
							
							
						
						
									
										20
									
								
								train.py
									
									
									
									
									
								
							| @ -1,16 +1,24 @@ | |||||||
| import os | import os | ||||||
|  | import argparse  # Ajouté pour les arguments | ||||||
| from collections import defaultdict | from collections import defaultdict | ||||||
| import numpy as np | import numpy as np | ||||||
| import cv2 | import cv2 | ||||||
| 
 |  | ||||||
| from src.classifiers.bayesian import BayesianClassifier | from src.classifiers.bayesian import BayesianClassifier | ||||||
| 
 | 
 | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     # Chemin vers le dataset d'entraînement |     # Analyse des arguments | ||||||
|     dataset_path = "data/catalogue" |     parser = argparse.ArgumentParser(description="Train Bayesian model.") | ||||||
|  |     parser.add_argument("--mode", type=str, choices=["page", "plan"], default="page", help="Mode de fonctionnement : 'page' ou 'plan'.") | ||||||
|  |     args = parser.parse_args() | ||||||
|  | 
 | ||||||
|  |     # Configuration en fonction du mode | ||||||
|  |     mode = args.mode | ||||||
|  |     dataset_path = f"data/catalogue{'' if mode == 'page' else 'Symbol'}" | ||||||
|  |     allowed_classes = ['Figure1', 'Figure2', 'Figure3', 'Figure4', 'Figure5', 'Figure6'] if mode == "plan" else ['2', 'd', 'I', 'n', 'o', 'u'] | ||||||
|  |     model_path = f"models/bayesian_model{mode.upper()}.pth" | ||||||
| 
 | 
 | ||||||
|     # Initialisation du classifieur Bayésien |     # Initialisation du classifieur Bayésien | ||||||
|     bayesian_model = BayesianClassifier() |     bayesian_model = BayesianClassifier(mode=mode) | ||||||
| 
 | 
 | ||||||
|     print("Début de l'entraînement...") |     print("Début de l'entraînement...") | ||||||
| 
 | 
 | ||||||
| @ -18,9 +26,6 @@ if __name__ == "__main__": | |||||||
|     class_features = defaultdict(list) |     class_features = defaultdict(list) | ||||||
|     total_images = 0 |     total_images = 0 | ||||||
| 
 | 
 | ||||||
|     # Liste des classes autorisées |  | ||||||
|     allowed_classes = ['2', 'd', 'I', 'n', 'o', 'u']  # Classes spécifiques au projet |  | ||||||
| 
 |  | ||||||
|     # Parcours des classes dans le dataset |     # Parcours des classes dans le dataset | ||||||
|     for class_name in os.listdir(dataset_path): |     for class_name in os.listdir(dataset_path): | ||||||
|         if class_name not in allowed_classes: |         if class_name not in allowed_classes: | ||||||
| @ -57,6 +62,5 @@ if __name__ == "__main__": | |||||||
|     print("Entraînement terminé.") |     print("Entraînement terminé.") | ||||||
| 
 | 
 | ||||||
|     # Sauvegarde du modèle entraîné |     # Sauvegarde du modèle entraîné | ||||||
|     model_path = "models/bayesian_modelPAGE.pth" |  | ||||||
|     bayesian_model.save_model(model_path) |     bayesian_model.save_model(model_path) | ||||||
|     print(f"Modèle sauvegardé dans : {model_path}") |     print(f"Modèle sauvegardé dans : {model_path}") | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user