# -*- coding: utf-8 -*-
"""
Created on Sun Apr 27 18:41:14 2025

@author: AKourgli
"""

import numpy as np
import matplotlib.pyplot as plt
from scipy.special import erfc

# Paramètres de simulation
N_symbols = 100000       # Nombre de symboles
Fs = 1000               # Fréquence d'échantillonnage (Hz)
T = 1                   # Durée symbole (s)
f0 = 100                # Fréquence de base (Hz)
freq_sep = 1/(2*T)     # Séparation fréquentielle
SNR_dB = np.arange(0, 30, 5)  # Plage de SNR à tester
M = 2                   # 2-FSK

# Calculs dérivés
sps = int(Fs * T)      # Samples per symbol
t = np.linspace(0, T, sps, endpoint=False)
frequencies = [f0, f0 + freq_sep]

# Génération des symboles
symbols = np.random.randint(0, M, N_symbols)

# =================================================================
# Modulation FSK
# =================================================================
def mod_fsk(symbols, frequencies, t, sps):
    modulated = np.zeros(len(symbols)*sps)
    for i, s in enumerate(symbols):
        modulated[i*sps:(i+1)*sps] = np.cos(2*np.pi*frequencies[s]*t)
    return modulated

modulated = mod_fsk(symbols, frequencies, t, sps)

# =================================================================
# Canal AWGN
# =================================================================
def add_awgn(signal, snr_dB):
    signal_power = np.mean(signal**2)
    snr_linear = 10**(snr_dB/10)
    noise_power = signal_power / snr_linear
    noise = np.sqrt(noise_power) * np.random.randn(len(signal))
    return signal + noise

# =================================================================
# Démodulateurs
# =================================================================
def demod_coherent(received, frequencies, t, sps):
    N = len(received)//sps
    detected = np.zeros(N, dtype=int)
    
    for i in range(N):
        segment = received[i*sps:(i+1)*sps]
        corr = [np.sum(segment * np.cos(2*np.pi*f*t)) for f in frequencies]
        detected[i] = np.argmax(np.abs(corr))
    
    return detected

def demod_noncoherent(received, frequencies, sps, Fs):
    N = len(received)//sps
    detected = np.zeros(N, dtype=int)
    fft_size = 1024
    freqs = np.fft.fftfreq(fft_size, 1/Fs)[:fft_size//2]
    
    for i in range(N):
        segment = received[i*sps:(i+1)*sps]
        fft = np.abs(np.fft.fft(segment, fft_size)[:fft_size//2])
        idx = np.argmax(fft)
        detected[i] = 0 if abs(freqs[idx] - frequencies[0]) < abs(freqs[idx] - frequencies[1]) else 1
    
    return detected

# =================================================================
# Simulation BER
# =================================================================
ber_coherent = np.zeros(len(SNR_dB))
ber_noncoherent = np.zeros(len(SNR_dB))

for idx, snr in enumerate(SNR_dB):
    # Ajout du bruit
    received = add_awgn(modulated, snr)
    
    # Démodulation
    detected_coherent = demod_coherent(received, frequencies, t, sps)
    detected_noncoherent = demod_noncoherent(received, frequencies, sps, Fs)
    
    # Calcul BER
    ber_coherent[idx] = np.mean(symbols != detected_coherent)
    ber_noncoherent[idx] = np.mean(symbols != detected_noncoherent)

# =================================================================
# Théorique
# =================================================================
SNR_linear = 10**(SNR_dB/10)
ber_coherent_theo = 0.5 * erfc(np.sqrt(SNR_linear/2)) 
ber_noncoherent_theo = 0.5 * np.exp(-SNR_linear/2)

# =================================================================
# Affichage
# =================================================================
plt.figure(figsize=(12, 8))
plt.semilogy(SNR_dB, ber_coherent, 'bo-', label='Cohérent (simulé)')
plt.semilogy(SNR_dB, ber_noncoherent, 'ro-', label='Non-cohérent (simulé)')
plt.semilogy(SNR_dB, ber_coherent_theo, 'b--', label='Cohérent (théorique)')
plt.semilogy(SNR_dB, ber_noncoherent_theo, 'r--', label='Non-cohérent (théorique)')

plt.title('Comparaison des performances BER pour la FSK')
plt.xlabel('SNR (dB)')
plt.ylabel('Bit Error Rate')
plt.grid(True, which='both', alpha=0.3)
plt.legend()
plt.ylim(1e-5, 1)
plt.xlim(min(SNR_dB), max(SNR_dB))
plt.show()