# -*- coding: utf-8 -*-
"""
Created on Tue Apr 15 10:29:18 2025

@author: AKourgli
"""

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from itertools import cycle

# =============================================
# Fonctions de génération des constellations
# =============================================
def generate_constellation(mod_type, M):
    """ Génère les points de constellation normalisés """
    if mod_type == 'ASK':
        points = np.linspace(-(M-1), M-1, M)
        points /= np.sqrt(np.mean(points**2))  # Normalisation puissance moyenne = 1
        
    elif mod_type == 'PSK':
        angles = 2 * np.pi * np.arange(M) / M
        points = np.exp(1j * angles)
        
    elif mod_type == 'QAM':
        sqrtM = int(np.sqrt(M))
        I = np.linspace(-(sqrtM-1), sqrtM-1, sqrtM)
        Q = np.linspace(-(sqrtM-1), sqrtM-1, sqrtM)
        points = (np.repeat(I, sqrtM) + 1j * np.tile(Q, sqrtM))
        points /= np.sqrt(np.mean(np.abs(points)**2))  # Normalisation
        
    elif mod_type == 'FSK':
        points = np.eye(M)  # Signaux orthogonaux
        points /= np.sqrt(np.mean(np.linalg.norm(points, axis=1)**2))  # Normalisation
        
    return points

# =============================================
# Calcul de l'information mutuelle (Monte Carlo)
# =============================================
def compute_mutual_info(constellation, snr_db, mod_type, num_samples=10000):
    M = len(constellation)
    SNR_lin = 10 ** (snr_db / 10)
    
    # Génération des symboles transmis
    symbols_idx = np.random.randint(0, M, num_samples)
    symbols = constellation[symbols_idx]
    
    # Ajout du bruit
    if mod_type == 'FSK':
        noise = np.random.normal(0, 1/np.sqrt(2*SNR_lin), (num_samples, M))
        received = symbols + noise
    else:
        noise = (np.random.normal(0, 1/np.sqrt(2*SNR_lin), num_samples) 
                + 1j * np.random.normal(0, 1/np.sqrt(2*SNR_lin), num_samples))
        received = symbols + noise
    
    # Calcul des probabilités
    MI = 0
    for i in range(M):
        idx = np.where(symbols_idx == i)[0]
        y = received[idx]
        
        # Calcul des distances
        if mod_type == 'FSK':
            dist = np.linalg.norm(y[:, np.newaxis, :] - constellation, axis=2)**2
        else:
            dist = np.abs(y[:, np.newaxis] - constellation)**2
            
        # Terme exponentiel
        exp_term = np.exp(-SNR_lin * dist)
        sum_exp = np.sum(exp_term, axis=1)
        
        MI += np.mean(np.log2(sum_exp) - np.log2(exp_term[:, i]))
        
    return np.log2(M) - MI/M


# =============================================
# Paramètres de simulation
# =============================================
SNR_dB = np.arange(-5, 30, 2)
modulations = {
    'ASK': [2, 4, 8,16],
    'PSK': [2, 4, 8,16],
    'QAM': [4, 16, 64, 128],
    'FSK': [2, 4, 8, 16]
}

# =============================================
# Simulation et regroupement des capacités
# =============================================
results = []
for mod_type, M_list in modulations.items():
    print(f"Processing {mod_type}...")
    for M in M_list:
        constellation = generate_constellation(mod_type, M)
        capacity = []
        
        for snr in tqdm(SNR_dB, desc=f'M={M}'):
            mi = compute_mutual_info(constellation, snr, mod_type)
            capacity.append(mi)
            
        results.append({'capacity': np.array(capacity),'label': f'{mod_type} M={M}' })

# Regroupement par capacité similaire (tolérance=5%)
groups = []
tolerance = 0.05
for result in results:
    matched = False
    for group in groups:
        if np.allclose(result['capacity'], group['capacities'][0], atol=tolerance):
            group['labels'].append(result['label'])
            group['capacities'].append(result['capacity'])
            matched = True
            break
    if not matched:
        groups.append({
            'labels': [result['label']],
            'capacities': [result['capacity']]
        })

# =============================================
# Visualisation avec légendes regroupées
# =============================================
plt.figure(figsize=(12, 8))
colors = cycle(plt.cm.tab20.colors)
markers = cycle(['o', 's', '^', 'D', 'v', '*', 'p', '<', '>', 'X'])

for group, color, marker in zip(groups, colors, markers):
    # Tracer toutes les courbes du groupe avec le même style
    for i, capacity in enumerate(group['capacities']):
        label = ', '.join(group['labels']) if i == 0 else None
        plt.plot(SNR_dB, capacity, marker=marker, linestyle='-', color=color, label=label)

plt.xlabel('SNR (dB)')
plt.ylabel('Information mutuelle (bits/symbole)')
plt.title('Capacité des canaux pour les modulations numériques')
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()
