# -*- coding: utf-8 -*-
"""
Created on Fri Apr 18 17:51:37 2025

@author: AKourgli
"""

import numpy as np
import matplotlib.pyplot as plt
from scipy import signal
from scipy.signal.windows import gaussian
from scipy.signal import convolve

# Paramètres de simulation
M = 2                # Ordre de modulation (2-FSK)
f0 = 100              # Fréquence de base (Hz)
Fs = 10*f0          # Fréquence d'échantillonnage (Hz)
T = 1                # Durée symbole (s)
EbN0_dB = 1         # Rapport signal/bruit (dB)
                    # Fréquence de base (Hz)

# Calculs dérivés
n_samples = Fs * T
t = np.linspace(0, T, n_samples, endpoint=False)
freq_separation = 1/(2*T)
frequencies = [f0 + i*freq_separation for i in range(M)]

# =================================================================
# AFFICHAGE DES FRÉQUENCES
# =================================================================
print("\nFréquences utilisées pour la 2-FSK:")
for i, f in enumerate(frequencies):
    print(f"Symbole {i} → Fréquence = {f:.2f} Hz")

# Visualisation des porteuses
plt.figure(figsize=(12, 6))
t_short = np.linspace(0, 0.1, 1000)  # 100 ms

for i, f in enumerate(frequencies):
    plt.plot(t_short, np.cos(2*np.pi*f*t_short), linewidth=2, label=f'Symbole {i} ({f:.2f} Hz)')

plt.title('Porteuses M-FSK (Premières 100 ms)', pad=15)
plt.xlabel('Temps (s)')
plt.ylabel('Amplitude')
plt.grid(True, alpha=0.3)
plt.legend(title='Légende:')
plt.xlim(0, 0.1)
plt.tight_layout()
plt.show()

# Génération des symboles et modulation
n_symbols = 100
symbols = np.random.randint(0, M, n_symbols)
modulated = np.concatenate([np.cos(2*np.pi*frequencies[s]*t) for s in symbols])

# =================================================================
# CALCUL ET TRACÉ DES DSP
# =================================================================
nperseg = 1024  # Taille des segments pour Welch
t_total = np.linspace(0, len(modulated)/Fs, len(modulated))

# DSP du signal modulé
f_mod, Pxx_mod = signal.welch(modulated, Fs, nperseg=nperseg, scaling='density')

# DSP des porteuses individuelles (théoriques)
individual_psds = []
plt.figure(figsize=(12, 6))

for i, f in enumerate(frequencies):
    carrier = np.cos(2 * np.pi * f * t_total)
    f_carrier, Pxx_carrier = signal.welch(carrier, Fs, nperseg=nperseg, scaling='density')
    individual_psds.append(Pxx_carrier)
    plt.plot(f_carrier, 10 * np.log10(Pxx_carrier),  label=f'Porteuse {i} ({f:.2f} Hz)', alpha=0.7)

# Somme des DSP théoriques
sum_psd = np.sum(individual_psds, axis=0)
plt.plot(f_carrier, 10 * np.log10(sum_psd), '--', label='Somme théorique', linewidth=2, color='black')

# DSP réelle du signal modulé
plt.plot(f_mod, 10 * np.log10(Pxx_mod), label='Signal modulé réel', linewidth=2,  color='red')

plt.title('DSP des Porteuses M-FSK', pad=15)
plt.xlabel('Fréquence (Hz)')
plt.ylabel('DSP (dB/Hz)')
plt.grid(True, alpha=0.8)
plt.legend()
#plt.xlim(frequencies[0] - 2, frequencies[-1] + 2)
plt.tight_layout()
plt.show()

# =================================================================
# AJOUT DU BRUIT ET DÉMODULATION
# =================================================================
Eb = T/2
N0 = Eb / (10**(EbN0_dB/10))
noise = np.sqrt(N0/2) * np.random.randn(len(modulated))
received = modulated + noise

# Démodulation par corrélation
received_symbols = received.reshape(n_symbols, n_samples)
corr_matrix = np.zeros((n_symbols, M))

for i in range(M):
    ref = np.cos(2*np.pi*frequencies[i]*t)
    corr_matrix[:,i] = np.sum(received_symbols * ref, axis=1)

detected = np.argmax(corr_matrix, axis=1)

# =================================================================
# CONSTELLATION
# =================================================================
plt.figure(figsize=(12, 6))

for s in range(M):
    mask = (symbols == s)
    for i in range(M):
        plt.scatter(frequencies[i]*np.ones(sum(mask)), 
                    corr_matrix[mask,i],
                    label=f'Symbole {s}' if i==0 else "",
                    alpha=0.6,
                    marker='o' if s==i else 'x',
                    color=plt.cm.tab10(s))

plt.title(f'Constellation {M}-FSK SNR={EbN0_dB} (BER={np.mean(symbols != detected):.4f})')
plt.xlabel('Fréquence (Hz)')
plt.ylabel('Amplitude de corrélation')
plt.xticks(frequencies)
plt.grid(True, alpha=0.3)

for f in frequencies:
    plt.axvline(x=f, color='gray', linestyle='--', alpha=0.3)

handles, labels = plt.gca().get_legend_handles_labels()
unique_labels = [(h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i]]
plt.legend(*zip(*unique_labels), title='Symboles émis')
plt.tight_layout()
plt.show()


"""
Ajout de la modulation GMSK
"""


# =================================================================
# PARAMÈTRES GMSK
# =================================================================
BT = 0.3               # Produit Bande-passante × Durée symbole
alpha = np.sqrt(np.log(2))/(2*np.pi*BT)  # Paramètre de lissage

# =================================================================
# GÉNÉRATION DU FILTRE GAUSSIEN
# =================================================================
span = 4               # Portée du filtre (en durées symbole)
sps = int(n_samples)   # Échantillons par symbole
t_filter = np.arange(-span*T, span*T, 1/Fs)
gaussian_filter = np.exp(-t_filter**2/(2*(alpha*T)**2))
gaussian_filter /= np.sum(gaussian_filter)  # Normalisation

# Visualisation du filtre
plt.figure(figsize=(12, 4))
plt.plot(t_filter, gaussian_filter, 'r-', label='Filtre Gaussien')
plt.title(f'Réponse impulsionnelle du filtre (BT={BT})')
plt.xlabel('Temps (s)')
plt.grid(True, alpha=0.3)
plt.legend()
plt.tight_layout()
plt.show()

# =================================================================
# MODULATION GMSK
# =================================================================
# Conversion des symboles en ±1
symbols_nrz = 2*symbols - 1

# Étalement des symboles
upsampled = np.zeros(n_symbols * sps)
upsampled[::sps] = symbols_nrz

# Filtrage Gaussien
filtered = convolve(upsampled, gaussian_filter, mode='same')

# Intégration de phase
phase = np.pi/(2*T) * np.cumsum(filtered)*1/Fs

# Génération du signal GMSK
t_gmsk = np.linspace(0, n_symbols*T, n_symbols*sps)
gmsk_signal = np.cos(2*np.pi*f0*t_gmsk + phase)

# =================================================================
# COMPARAISON DES DSP
# =================================================================
f_gmsk, Pxx_gmsk = signal.welch(gmsk_signal, Fs, nperseg=1024)

plt.figure(figsize=(12,6))
for i, f in enumerate(frequencies):
    carrier = np.cos(2 * np.pi * f * t_total)
    f_carrier, Pxx_carrier = signal.welch(carrier, Fs, nperseg=nperseg, scaling='density')
    individual_psds.append(Pxx_carrier)
    plt.plot(f_carrier, 10 * np.log10(Pxx_carrier),  label=f'Porteuse {i} ({f:.2f} Hz)', alpha=0.7)
    
plt.plot(f_mod, 10*np.log10(Pxx_mod), label='M-FSK',alpha=0.3)
plt.plot(f_gmsk, 10*np.log10(Pxx_gmsk), label='GMSK', alpha=0.7)
plt.title('Comparaison des DSP: M-FSK vs GMSK')
plt.xlabel('Fréquence (Hz)')
plt.ylabel('DSP (dB/Hz)')
plt.grid(True, alpha=0.3)
plt.legend()
plt.tight_layout()
plt.show()

# =================================================================
# VISUALISATION TEMPORELLE
# =================================================================
plt.figure(figsize=(12, 6))
plt.plot(t_gmsk[:500], gmsk_signal[:500], label='GMSK')
plt.plot(t_total[:500], modulated[:500], '--', label='M-FSK', alpha=0.7)
plt.title('Comparaison temporelle (500 premiers échantillons)')
plt.xlabel('Temps (s)')
plt.ylabel('Amplitude')
plt.grid(True, alpha=0.3)
plt.legend()
plt.tight_layout()
plt.show()