Bayesian final version + choix du fichier analysé
This commit is contained in:
parent
7abdb91d06
commit
922b9acf18
4 changed files with 83 additions and 60 deletions
22
main.py
22
main.py
|
@ -1,16 +1,22 @@
|
||||||
import os
|
import os
|
||||||
import cv2
|
|
||||||
from src.pipeline import ObjectDetectionPipeline
|
from src.pipeline import ObjectDetectionPipeline
|
||||||
from src.classifiers.bayesian import BayesianClassifier
|
from src.classifiers.bayesian import BayesianClassifier
|
||||||
from collections import defaultdict
|
|
||||||
|
# Définissez le mode d'analyse ici : "plan" ou "page"
|
||||||
|
analysis_mode = "plan"
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Chemin vers le modèle entraîné
|
# Configuration basée sur le mode
|
||||||
|
if analysis_mode == "plan":
|
||||||
|
model_path = "models/bayesian_modelPLAN.pth"
|
||||||
|
image_path = "data/plan.png"
|
||||||
|
else:
|
||||||
model_path = "models/bayesian_modelPAGE.pth"
|
model_path = "models/bayesian_modelPAGE.pth"
|
||||||
|
image_path = "data/page.png"
|
||||||
|
|
||||||
# Chargement du modèle bayésien
|
# Chargement du modèle bayésien
|
||||||
print(f"Chargement du modèle bayésien depuis {model_path}")
|
print(f"Chargement du modèle bayésien depuis {model_path}")
|
||||||
bayesian_model = BayesianClassifier()
|
bayesian_model = BayesianClassifier(mode=analysis_mode)
|
||||||
try:
|
try:
|
||||||
bayesian_model.load_model(model_path)
|
bayesian_model.load_model(model_path)
|
||||||
print(f"Modèle bayésien chargé depuis {model_path}")
|
print(f"Modèle bayésien chargé depuis {model_path}")
|
||||||
|
@ -18,8 +24,7 @@ if __name__ == "__main__":
|
||||||
print(f"Erreur lors du chargement du modèle : {e}")
|
print(f"Erreur lors du chargement du modèle : {e}")
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
# Chemin de l'image de test
|
# Vérification de l'existence de l'image
|
||||||
image_path = "data/page.png"
|
|
||||||
if not os.path.exists(image_path):
|
if not os.path.exists(image_path):
|
||||||
print(f"L'image de test {image_path} n'existe pas.")
|
print(f"L'image de test {image_path} n'existe pas.")
|
||||||
exit(1)
|
exit(1)
|
||||||
|
@ -33,6 +38,9 @@ if __name__ == "__main__":
|
||||||
print("Initialisation de la pipeline...")
|
print("Initialisation de la pipeline...")
|
||||||
pipeline = ObjectDetectionPipeline(image_path=image_path, model=bayesian_model, output_dir=output_dir)
|
pipeline = ObjectDetectionPipeline(image_path=image_path, model=bayesian_model, output_dir=output_dir)
|
||||||
|
|
||||||
|
# Définition du mode (plan ou page)
|
||||||
|
pipeline.set_mode(analysis_mode)
|
||||||
|
|
||||||
# Chargement de l'image
|
# Chargement de l'image
|
||||||
print("Chargement de l'image...")
|
print("Chargement de l'image...")
|
||||||
try:
|
try:
|
||||||
|
@ -45,7 +53,7 @@ if __name__ == "__main__":
|
||||||
print("Détection et classification des objets...")
|
print("Détection et classification des objets...")
|
||||||
try:
|
try:
|
||||||
class_counts, detected_objects = pipeline.detect_and_classify_objects()
|
class_counts, detected_objects = pipeline.detect_and_classify_objects()
|
||||||
print("Classes détectées :", class_counts) # Added debug info
|
print("Classes détectées :", class_counts)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Erreur lors de la détection/classification : {e}")
|
print(f"Erreur lors de la détection/classification : {e}")
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
|
@ -7,43 +7,51 @@ import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
class BayesianClassifier:
|
class BayesianClassifier:
|
||||||
def __init__(self):
|
def __init__(self, mode="page"):
|
||||||
self.feature_means = {}
|
self.feature_means = {}
|
||||||
self.feature_variances = {}
|
self.feature_variances = {}
|
||||||
self.class_priors = {}
|
self.class_priors = {}
|
||||||
self.classes = []
|
self.classes = []
|
||||||
|
|
||||||
|
self.allowed_classes = (
|
||||||
|
['Figure1', 'Figure2', 'Figure3', 'Figure4', 'Figure5', 'Figure6']
|
||||||
|
if mode == "plan"
|
||||||
|
else ['2', 'd', 'I', 'n', 'o', 'u']
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize HOG descriptor with standard parameters
|
# Initialize HOG descriptor with standard parameters
|
||||||
self.hog = cv2.HOGDescriptor(
|
self.hog = cv2.HOGDescriptor(
|
||||||
_winSize=(28, 28),
|
_winSize=(28, 28),
|
||||||
_blockSize=(8, 8),
|
_blockSize=(8, 8),
|
||||||
_blockStride=(4, 4),
|
_blockStride=(4, 4),
|
||||||
_cellSize=(8, 8),
|
_cellSize=(8, 8),
|
||||||
_nbins=9
|
_nbins=9,
|
||||||
)
|
)
|
||||||
|
|
||||||
def extract_features(self, image):
|
def extract_features(self, image):
|
||||||
try:
|
try:
|
||||||
# Convert image to grayscale
|
# Convert to grayscale if image is RGB
|
||||||
if len(image.shape) == 3 and image.shape[2] == 3:
|
if len(image.shape) == 3 and image.shape[2] == 3:
|
||||||
gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||||
else:
|
else:
|
||||||
gray_image = image
|
gray_image = image
|
||||||
|
|
||||||
# Apply adaptive thresholding for better segmentation
|
# Apply adaptive thresholding for segmentation
|
||||||
binary_image = cv2.adaptiveThreshold(
|
binary_image = cv2.adaptiveThreshold(
|
||||||
gray_image, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY_INV, 11, 2
|
gray_image, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY_INV, 11, 2
|
||||||
)
|
)
|
||||||
|
|
||||||
# Find contours
|
# Find contours
|
||||||
contours, _ = cv2.findContours(binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
contours, _ = cv2.findContours(
|
||||||
|
binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
|
||||||
|
)
|
||||||
if not contours:
|
if not contours:
|
||||||
print("No contours found.")
|
print("No contours found.")
|
||||||
return np.array([])
|
return np.array([])
|
||||||
|
|
||||||
features = []
|
features = []
|
||||||
for contour in contours:
|
for contour in contours:
|
||||||
if cv2.contourArea(contour) < 20: # Lowered area threshold
|
if cv2.contourArea(contour) < 20: # Filter small areas
|
||||||
continue
|
continue
|
||||||
|
|
||||||
x, y, w, h = cv2.boundingRect(contour)
|
x, y, w, h = cv2.boundingRect(contour)
|
||||||
|
@ -59,7 +67,7 @@ class BayesianClassifier:
|
||||||
print("No features extracted.")
|
print("No features extracted.")
|
||||||
return np.array([])
|
return np.array([])
|
||||||
|
|
||||||
# Normalize features for better consistency
|
# Normalize features
|
||||||
norms = np.linalg.norm(features, axis=1, keepdims=True)
|
norms = np.linalg.norm(features, axis=1, keepdims=True)
|
||||||
features = features / np.where(norms > 1e-6, norms, 1)
|
features = features / np.where(norms > 1e-6, norms, 1)
|
||||||
|
|
||||||
|
@ -70,11 +78,10 @@ class BayesianClassifier:
|
||||||
|
|
||||||
def train(self, dataset_path):
|
def train(self, dataset_path):
|
||||||
class_features = defaultdict(list)
|
class_features = defaultdict(list)
|
||||||
total_images = 0
|
total_samples = 0
|
||||||
|
|
||||||
allowed_classes = ['2', 'd', 'I', 'n', 'o', 'u'] # Modifiez selon vos besoins
|
|
||||||
for class_name in os.listdir(dataset_path):
|
for class_name in os.listdir(dataset_path):
|
||||||
if class_name not in allowed_classes:
|
if class_name not in self.allowed_classes:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
class_folder_path = os.path.join(dataset_path, class_name)
|
class_folder_path = os.path.join(dataset_path, class_name)
|
||||||
|
@ -85,27 +92,25 @@ class BayesianClassifier:
|
||||||
for img_name in os.listdir(class_folder_path):
|
for img_name in os.listdir(class_folder_path):
|
||||||
img_path = os.path.join(class_folder_path, img_name)
|
img_path = os.path.join(class_folder_path, img_name)
|
||||||
if os.path.isfile(img_path):
|
if os.path.isfile(img_path):
|
||||||
try:
|
|
||||||
image = cv2.imread(img_path)
|
image = cv2.imread(img_path)
|
||||||
if image is not None:
|
if image is not None:
|
||||||
features = self.extract_features(image)
|
features = self.extract_features(image)
|
||||||
if features.size > 0:
|
if features.size > 0:
|
||||||
for feature in features:
|
for feature in features:
|
||||||
class_features[class_name].append(feature)
|
class_features[class_name].append(feature)
|
||||||
total_images += 1
|
total_samples += len(features)
|
||||||
else:
|
else:
|
||||||
print(f"No features extracted for {img_path}")
|
print(f"No features extracted for {img_path}")
|
||||||
else:
|
else:
|
||||||
print(f"Failed to load image: {img_path}")
|
print(f"Failed to load image: {img_path}")
|
||||||
except Exception as e:
|
|
||||||
print(f"Error processing {img_path}: {e}")
|
|
||||||
|
|
||||||
|
# Compute means, variances, and priors
|
||||||
for class_name in self.classes:
|
for class_name in self.classes:
|
||||||
if class_name in class_features:
|
if class_name in class_features:
|
||||||
features = np.array(class_features[class_name])
|
features = np.array(class_features[class_name])
|
||||||
self.feature_means[class_name] = np.mean(features, axis=0)
|
self.feature_means[class_name] = np.mean(features, axis=0)
|
||||||
self.feature_variances[class_name] = np.var(features, axis=0) + 1e-6
|
self.feature_variances[class_name] = np.var(features, axis=0) + 1e-6 # Avoid zero variance
|
||||||
self.class_priors[class_name] = len(features) / total_images
|
self.class_priors[class_name] = len(features) / total_samples
|
||||||
|
|
||||||
print("Training completed for classes:", self.classes)
|
print("Training completed for classes:", self.classes)
|
||||||
|
|
||||||
|
@ -114,16 +119,15 @@ class BayesianClassifier:
|
||||||
"feature_means": self.feature_means,
|
"feature_means": self.feature_means,
|
||||||
"feature_variances": self.feature_variances,
|
"feature_variances": self.feature_variances,
|
||||||
"class_priors": self.class_priors,
|
"class_priors": self.class_priors,
|
||||||
"classes": self.classes
|
"classes": self.classes,
|
||||||
}
|
}
|
||||||
if not os.path.exists(os.path.dirname(model_path)):
|
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
||||||
os.makedirs(os.path.dirname(model_path))
|
|
||||||
torch.save(model_data, model_path)
|
torch.save(model_data, model_path)
|
||||||
print(f"Model saved to {model_path}")
|
print(f"Model saved to {model_path}")
|
||||||
|
|
||||||
def load_model(self, model_path):
|
def load_model(self, model_path):
|
||||||
if os.path.exists(model_path):
|
if os.path.exists(model_path):
|
||||||
model_data = torch.load(model_path, weights_only=False)
|
model_data = torch.load(model_path)
|
||||||
self.feature_means = model_data["feature_means"]
|
self.feature_means = model_data["feature_means"]
|
||||||
self.feature_variances = model_data["feature_variances"]
|
self.feature_variances = model_data["feature_variances"]
|
||||||
self.class_priors = model_data["class_priors"]
|
self.class_priors = model_data["class_priors"]
|
||||||
|
@ -132,7 +136,7 @@ class BayesianClassifier:
|
||||||
else:
|
else:
|
||||||
print(f"No model found at {model_path}.")
|
print(f"No model found at {model_path}.")
|
||||||
|
|
||||||
def predict(self, image, threshold=0.3): # Lowered threshold
|
def predict(self, image, threshold=0.3):
|
||||||
try:
|
try:
|
||||||
features = self.extract_features(image)
|
features = self.extract_features(image)
|
||||||
if features.size == 0:
|
if features.size == 0:
|
||||||
|
@ -145,14 +149,17 @@ class BayesianClassifier:
|
||||||
variance = self.feature_variances[class_name]
|
variance = self.feature_variances[class_name]
|
||||||
prior = self.class_priors[class_name]
|
prior = self.class_priors[class_name]
|
||||||
|
|
||||||
likelihood = -0.5 * np.sum(((features - mean) ** 2) / variance + np.log(2 * np.pi * variance))
|
# Compute log-likelihood
|
||||||
posterior = likelihood + np.log(prior)
|
log_likelihood = -0.5 * np.sum(
|
||||||
posteriors[class_name] = posterior
|
((features - mean) ** 2) / variance + np.log(2 * np.pi * variance),
|
||||||
|
axis=1,
|
||||||
|
)
|
||||||
|
posterior = log_likelihood + np.log(prior)
|
||||||
|
posteriors[class_name] = np.sum(posterior)
|
||||||
|
|
||||||
max_class = max(posteriors, key=posteriors.get)
|
max_class = max(posteriors, key=posteriors.get)
|
||||||
max_posterior = posteriors[max_class]
|
max_posterior = posteriors[max_class]
|
||||||
|
|
||||||
print(f"Class: {max_class}, Posterior: {max_posterior}") # Added debug info
|
|
||||||
if max_posterior < threshold:
|
if max_posterior < threshold:
|
||||||
return None
|
return None
|
||||||
return max_class
|
return max_class
|
||||||
|
|
|
@ -5,8 +5,7 @@ from collections import defaultdict
|
||||||
|
|
||||||
|
|
||||||
class ObjectDetectionPipeline:
|
class ObjectDetectionPipeline:
|
||||||
def __init__(self, image_path, model=None, output_dir="output", min_contour_area=20, binary_threshold=None):
|
def __init__(self, image_path, model=None, output_dir="output", mode="page", min_contour_area=20, binary_threshold=None):
|
||||||
# Initialize the object detection pipeline
|
|
||||||
self.image_path = image_path
|
self.image_path = image_path
|
||||||
self.image = None
|
self.image = None
|
||||||
self.binary_image = None
|
self.binary_image = None
|
||||||
|
@ -14,19 +13,28 @@ class ObjectDetectionPipeline:
|
||||||
self.output_dir = output_dir
|
self.output_dir = output_dir
|
||||||
self.min_contour_area = min_contour_area
|
self.min_contour_area = min_contour_area
|
||||||
self.binary_threshold = binary_threshold
|
self.binary_threshold = binary_threshold
|
||||||
|
self.mode = mode # Default mode is "page"
|
||||||
|
self.annotated_output_path = os.path.join(self.output_dir, f"annotated_{os.path.basename(image_path)}")
|
||||||
|
self.threshold = -395000 if mode == "plan" else -65000
|
||||||
|
|
||||||
if not os.path.exists(self.output_dir):
|
if not os.path.exists(self.output_dir):
|
||||||
os.makedirs(self.output_dir)
|
os.makedirs(self.output_dir)
|
||||||
|
|
||||||
|
def set_mode(self, mode):
|
||||||
|
"""Set the detection mode (page or plan)."""
|
||||||
|
if mode not in ["page", "plan"]:
|
||||||
|
raise ValueError("Mode must be 'page' or 'plan'.")
|
||||||
|
self.mode = mode
|
||||||
|
self.threshold = -395000 if mode == "plan" else -65000
|
||||||
|
print(f"Mode set to: {self.mode}, Threshold set to: {self.threshold}")
|
||||||
|
|
||||||
def load_image(self):
|
def load_image(self):
|
||||||
# Load the specified image
|
|
||||||
self.image = cv2.imread(self.image_path)
|
self.image = cv2.imread(self.image_path)
|
||||||
if self.image is None:
|
if self.image is None:
|
||||||
raise FileNotFoundError(f"Image {self.image_path} not found.")
|
raise FileNotFoundError(f"Image {self.image_path} not found.")
|
||||||
return self.image
|
return self.image
|
||||||
|
|
||||||
def preprocess_image(self):
|
def preprocess_image(self):
|
||||||
# Preprocess the image for inference
|
|
||||||
channels = cv2.split(self.image)
|
channels = cv2.split(self.image)
|
||||||
binary_images = []
|
binary_images = []
|
||||||
|
|
||||||
|
@ -43,7 +51,6 @@ class ObjectDetectionPipeline:
|
||||||
return binary_image
|
return binary_image
|
||||||
|
|
||||||
def detect_and_classify_objects(self):
|
def detect_and_classify_objects(self):
|
||||||
# Detect and classify objects in the image
|
|
||||||
if self.model is None:
|
if self.model is None:
|
||||||
raise ValueError("No classification model provided.")
|
raise ValueError("No classification model provided.")
|
||||||
|
|
||||||
|
@ -60,7 +67,7 @@ class ObjectDetectionPipeline:
|
||||||
x, y, w, h = cv2.boundingRect(contour)
|
x, y, w, h = cv2.boundingRect(contour)
|
||||||
letter_image = self.image[y:y + h, x:x + w]
|
letter_image = self.image[y:y + h, x:x + w]
|
||||||
|
|
||||||
predicted_class = self.model.predict(letter_image, threshold=-65000) # Adjusted threshold
|
predicted_class = self.model.predict(letter_image, threshold=self.threshold)
|
||||||
if predicted_class is None:
|
if predicted_class is None:
|
||||||
print("Object ignored due to low resemblance.")
|
print("Object ignored due to low resemblance.")
|
||||||
continue
|
continue
|
||||||
|
@ -71,7 +78,6 @@ class ObjectDetectionPipeline:
|
||||||
return dict(sorted(class_counts.items())), detected_objects
|
return dict(sorted(class_counts.items())), detected_objects
|
||||||
|
|
||||||
def save_results(self, class_counts, detected_objects):
|
def save_results(self, class_counts, detected_objects):
|
||||||
# Save detection and classification results
|
|
||||||
binary_output_path = os.path.join(self.output_dir, "binary_image.jpg")
|
binary_output_path = os.path.join(self.output_dir, "binary_image.jpg")
|
||||||
cv2.imwrite(binary_output_path, self.binary_image)
|
cv2.imwrite(binary_output_path, self.binary_image)
|
||||||
|
|
||||||
|
@ -80,8 +86,7 @@ class ObjectDetectionPipeline:
|
||||||
cv2.rectangle(annotated_image, (x, y), (x + w, y + h), (0, 255, 0), 2)
|
cv2.rectangle(annotated_image, (x, y), (x + w, y + h), (0, 255, 0), 2)
|
||||||
cv2.putText(annotated_image, str(predicted_class), (x, y - 10),
|
cv2.putText(annotated_image, str(predicted_class), (x, y - 10),
|
||||||
cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
|
cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
|
||||||
annotated_output_path = os.path.join(self.output_dir, "annotated_page.jpg")
|
cv2.imwrite(self.annotated_output_path, annotated_image)
|
||||||
cv2.imwrite(annotated_output_path, annotated_image)
|
|
||||||
|
|
||||||
results_text_path = os.path.join(self.output_dir, "results.txt")
|
results_text_path = os.path.join(self.output_dir, "results.txt")
|
||||||
with open(results_text_path, "w") as f:
|
with open(results_text_path, "w") as f:
|
||||||
|
@ -89,7 +94,6 @@ class ObjectDetectionPipeline:
|
||||||
f.write(f"{class_name}: {count}\n")
|
f.write(f"{class_name}: {count}\n")
|
||||||
|
|
||||||
def display_results(self, class_counts, detected_objects):
|
def display_results(self, class_counts, detected_objects):
|
||||||
# Display and save the results
|
|
||||||
self.save_results(class_counts, detected_objects)
|
self.save_results(class_counts, detected_objects)
|
||||||
|
|
||||||
plt.figure(figsize=(10, 5))
|
plt.figure(figsize=(10, 5))
|
||||||
|
|
20
train.py
20
train.py
|
@ -1,16 +1,24 @@
|
||||||
import os
|
import os
|
||||||
|
import argparse # Ajouté pour les arguments
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
from src.classifiers.bayesian import BayesianClassifier
|
from src.classifiers.bayesian import BayesianClassifier
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Chemin vers le dataset d'entraînement
|
# Analyse des arguments
|
||||||
dataset_path = "data/catalogue"
|
parser = argparse.ArgumentParser(description="Train Bayesian model.")
|
||||||
|
parser.add_argument("--mode", type=str, choices=["page", "plan"], default="page", help="Mode de fonctionnement : 'page' ou 'plan'.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Configuration en fonction du mode
|
||||||
|
mode = args.mode
|
||||||
|
dataset_path = f"data/catalogue{'' if mode == 'page' else 'Symbol'}"
|
||||||
|
allowed_classes = ['Figure1', 'Figure2', 'Figure3', 'Figure4', 'Figure5', 'Figure6'] if mode == "plan" else ['2', 'd', 'I', 'n', 'o', 'u']
|
||||||
|
model_path = f"models/bayesian_model{mode.upper()}.pth"
|
||||||
|
|
||||||
# Initialisation du classifieur Bayésien
|
# Initialisation du classifieur Bayésien
|
||||||
bayesian_model = BayesianClassifier()
|
bayesian_model = BayesianClassifier(mode=mode)
|
||||||
|
|
||||||
print("Début de l'entraînement...")
|
print("Début de l'entraînement...")
|
||||||
|
|
||||||
|
@ -18,9 +26,6 @@ if __name__ == "__main__":
|
||||||
class_features = defaultdict(list)
|
class_features = defaultdict(list)
|
||||||
total_images = 0
|
total_images = 0
|
||||||
|
|
||||||
# Liste des classes autorisées
|
|
||||||
allowed_classes = ['2', 'd', 'I', 'n', 'o', 'u'] # Classes spécifiques au projet
|
|
||||||
|
|
||||||
# Parcours des classes dans le dataset
|
# Parcours des classes dans le dataset
|
||||||
for class_name in os.listdir(dataset_path):
|
for class_name in os.listdir(dataset_path):
|
||||||
if class_name not in allowed_classes:
|
if class_name not in allowed_classes:
|
||||||
|
@ -57,6 +62,5 @@ if __name__ == "__main__":
|
||||||
print("Entraînement terminé.")
|
print("Entraînement terminé.")
|
||||||
|
|
||||||
# Sauvegarde du modèle entraîné
|
# Sauvegarde du modèle entraîné
|
||||||
model_path = "models/bayesian_modelPAGE.pth"
|
|
||||||
bayesian_model.save_model(model_path)
|
bayesian_model.save_model(model_path)
|
||||||
print(f"Modèle sauvegardé dans : {model_path}")
|
print(f"Modèle sauvegardé dans : {model_path}")
|
||||||
|
|
Loading…
Add table
Reference in a new issue