# -*- coding: utf-8 -*-
"""
Created on Sat Apr 12 18:24:38 2025

@author: AKourgli
"""

import numpy as np
import matplotlib.pyplot as plt

def generate_gray_code(n_bits):
    """Génère les codes Gray pour n bits"""
    return [i ^ (i >> 1) for i in range(2**n_bits)]

def constellation_QAM(M):
    """Génère la constellation QAM avec normalisation d'énergie"""
    sqrt_M = int(np.sqrt(M))
    if sqrt_M**2 != M:
        raise ValueError("M doit être un carré parfait (4, 16, 64, ...)")
    
    points = np.arange(-sqrt_M+1, sqrt_M, 2)
    I, Q = np.meshgrid(points, points)
    constellation = (I + 1j*Q).flatten()
    constellation /= np.sqrt(np.mean(np.abs(constellation)**2))
    return constellation, points

def get_gray_labels(M, points):
    """Génère les étiquettes binaires Gray pour la constellation"""
    n_bits = int(np.log2(np.sqrt(M)))
    gray_codes = generate_gray_code(n_bits)
    labels = []
    
    for q in reversed(points):  # Inversion pour l'axe Q
        for i in points:
            idx_i = np.where(points == i)[0][0]
            idx_q = np.where(points == q)[0][0]
            gray_i = format(gray_codes[idx_i], f'0{n_bits}b')
            gray_q = format(gray_codes[idx_q], f'0{n_bits}b')
            labels.append(gray_i + gray_q)
    return labels

def modulateur_QAM(bits, M):
    """Conversion bits -> symboles QAM"""
    k = int(np.log2(M))
    constellation, _ = constellation_QAM(M)
    
    bits_reshaped = bits.reshape(-1, k)
    powers = 2 ** np.arange(k-1, -1, -1)
    symboles_idx = np.dot(bits_reshaped, powers).astype(int)
    
    return constellation[symboles_idx]

def demodulateur_QAM(symboles, M):
    """Détection par distance minimale et conversion en bits"""
    k = int(np.log2(M))
    constellation, _ = constellation_QAM(M)
    
    distances = np.abs(symboles.reshape(-1,1) - constellation)**2
    symboles_idx = np.argmin(distances, axis=1)
    
    bits = np.zeros((len(symboles_idx), k), dtype=int)
    for i in range(k):
        bits[:, i] = (symboles_idx >> (k-1 - i)) & 1
    
    return bits.flatten()

def ajouter_bruit(signal, SNRdB):
    """Ajout de bruit AWGN"""
    puissance_signal = np.mean(np.abs(signal)**2)
    SNR = 10**(SNRdB/10)
    puissance_bruit = puissance_signal / SNR
    bruit = (np.random.randn(*signal.shape) + 1j*np.random.randn(*signal.shape)) * np.sqrt(puissance_bruit/2)
    return signal + bruit

# Configuration de la simulation
M_liste = [4, 16, 64]
SNRdB = 20
nb_symboles = 1000

# Préparation de la figure
fig, axs = plt.subplots(1, 3, figsize=(20, 6))

for idx, M in enumerate(M_liste):
    # Génération de la constellation
    constellation, points = constellation_QAM(M)
    gray_labels = get_gray_labels(M, points)
    
    # Simulation du système
    k = int(np.log2(M))
    nb_bits = nb_symboles * k
    bits = np.random.randint(0, 2, nb_bits)
    symboles = modulateur_QAM(bits, M)
    
    # Tracé de la constellation
    axs[idx].scatter(np.real(constellation), np.imag(constellation), c='red', s=100, edgecolors='black')
    
    # Ajout des labels Gray
    fontsize = 12 if M <= 16 else 8
    for (x, y), lbl in zip(zip(np.real(constellation), np.imag(constellation)), gray_labels):
        axs[idx].text(x, y, lbl, ha='right', va='bottom', color='black', fontsize=fontsize , weight='bold' if M <=16 else 'normal')
    
    # Mise en forme
    axs[idx].set_title(f'QAM-{M}')
    axs[idx].set_xlabel('Composante I')
    axs[idx].set_ylabel('Composante Q')
    axs[idx].axhline(0, color='black', linewidth=0.5)
    axs[idx].axvline(0, color='black', linewidth=0.5)
    axs[idx].grid(True, alpha=0.3)
    axs[idx].axis('equal')
    axs[idx].legend()

plt.tight_layout()
plt.show()


fig, axs = plt.subplots(1, 3, figsize=(20, 6))

for idx, M in enumerate(M_liste):
    # Génération de la constellation
    constellation, points = constellation_QAM(M)
    gray_labels = get_gray_labels(M, points)
    
    # Simulation du système
    k = int(np.log2(M))
    nb_bits = nb_symboles * k
    bits = np.random.randint(0, 2, nb_bits)
    symboles = modulateur_QAM(bits, M)
    signal_bruite = ajouter_bruit(symboles, SNRdB)
    bits_recus = demodulateur_QAM(signal_bruite, M)
    BER = np.mean(bits != bits_recus)
    
    # Calcul des seuils de décision
    sorted_points = np.sort(points)
    seuils_non_normalises = [(sorted_points[i] + sorted_points[i+1])/2 for i in range(len(sorted_points)-1)]
    sum_p_sq = np.sum(sorted_points**2)
    len_points = len(sorted_points)
    facteur_normalisation = np.sqrt((2 * sum_p_sq) / len_points)
    seuils_normalises = np.array(seuils_non_normalises) / facteur_normalisation
    
    # Tracé de la constellation avec seuils
    axs[idx].scatter(np.real(signal_bruite), np.imag(signal_bruite), c='blue', marker='.', alpha=0.3, label='Symboles bruités')
    axs[idx].scatter(np.real(constellation), np.imag(constellation), c='red', s=100, edgecolors='black', label='Constellation')
    
    # Ajout des seuils en pointillés verts
    for seuil in seuils_normalises:
        axs[idx].axvline(seuil, color='green', linestyle='--', linewidth=0.8, alpha=0.7)
        axs[idx].axhline(seuil, color='green', linestyle='--', linewidth=0.8, alpha=0.7)
    for seuil in -seuils_normalises:
        axs[idx].axvline(seuil, color='green', linestyle='--', linewidth=0.8, alpha=0.7)
        axs[idx].axhline(seuil, color='green', linestyle='--', linewidth=0.8, alpha=0.7)
    
    # Ajout des labels Gray
    fontsize = 10 if M <= 16 else 5
    for (x, y), lbl in zip(zip(np.real(constellation), np.imag(constellation)), gray_labels):
        axs[idx].text(x, y, lbl, ha='right', va='bottom', color='black', fontsize=fontsize)
    
    # Mise en forme
    axs[idx].set_title(f'QAM-{M} (BER = {BER:.2e})')
    axs[idx].set_xlabel('Composante I')
    axs[idx].set_ylabel('Composante Q')
    axs[idx].axhline(0, color='black', linewidth=0.5)
    axs[idx].axvline(0, color='black', linewidth=0.5)
    axs[idx].grid(True, alpha=0.3)
    axs[idx].axis('equal')
    axs[idx].legend()

plt.tight_layout()
plt.show()

plt.tight_layout()
plt.show()