# -*- coding: utf-8 -*-
"""
Created on Sun Apr 27 18:55:00 2025

@author: AKourgli
"""

import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import butter, lfilter

# --- Paramètres généraux
fs = 10000       # fréquence d'échantillonnage (Hz)
Rb = 100         # débit binaire (bits/s)
Tb = 1/Rb        # durée bit

# Paramètres FSK
def generate_frequencies(M, f_start=1000, spacing=500):
    """Génère une liste de M fréquences porteuses espacées régulièrement."""
    return np.array([f_start + i * spacing for i in range(M)])

# Modulation M-FSK
def mfsk_mod(symbols, fs, Tb, freqs):
    Ns = int(Tb * fs)
    t = np.arange(0, Tb, 1/fs)
    signal = np.zeros(len(symbols) * Ns)
    for i, s in enumerate(symbols):
        fc = freqs[s]
        signal[i*Ns:(i+1)*Ns] = np.cos(2*np.pi*fc*t)
    return signal

# Canal AWGN
def awgn(x, EbN0_dB, k, Rb, fs):
    EbN0 = 10**(EbN0_dB/10)
    Es = np.sum(x**2) / len(x) / (Rb / k)
    N0 = Es / EbN0
    sigma = np.sqrt(N0 * fs / 2)
    noise = sigma * np.random.randn(*x.shape)
    return x + noise

# Démodulation cohérente
def demod_coherent(rx, fs, Tb, freqs):
    Ns = int(Tb * fs)
    M = len(freqs)
    bits_hat = []
    t = np.arange(0, Tb, 1/fs)
    refs = [np.cos(2*np.pi*f*t) for f in freqs]
    for i in range(0, len(rx), Ns):
        segment = rx[i:i+Ns]
        metrics = [np.dot(segment, ref) for ref in refs]
        bits_hat.append(np.argmax(metrics))
    return np.array(bits_hat)

# Démodulation non-cohérente
def bandpass_filter(data, fs, fc, bw=400):
    nyq = 0.5 * fs
    low = (fc - bw/2) / nyq
    high = (fc + bw/2) / nyq
    b, a = butter(4, [low, high], btype='band')
    return lfilter(b, a, data)

def demod_noncoherent(rx, fs, Tb, freqs):
    Ns = int(Tb * fs)
    M = len(freqs)
    bits_hat = []
    # filtrage et intégration d'énergie
    energies = [bandpass_filter(rx, fs, f) for f in freqs]
    for i in range(0, len(rx), Ns):
        e_vals = [np.sum(e[i:i+Ns]**2) for e in energies]
        bits_hat.append(np.argmax(e_vals))
    return np.array(bits_hat)

# Démodulation par FFT
def demod_fft(rx, fs, Tb, freqs):
    Ns = int(Tb * fs)
    freqs_fft = np.fft.fftfreq(Ns, 1/fs)
    idxs = [np.argmin(np.abs(freqs_fft - f)) for f in freqs]
    bits_hat = []
    for i in range(0, len(rx), Ns):
        segment = rx[i:i+Ns]
        X = np.fft.fft(segment)
        metrics = [np.abs(X[idx])**2 for idx in idxs]
        bits_hat.append(np.argmax(metrics))
    return np.array(bits_hat)

# Conversion bits <-> symboles

def bits_to_symbols(bits, M):
    k = int(np.log2(M))
    assert 2**k == M, "M doit être une puissance de 2"
    bits = bits[:len(bits) - (len(bits) % k)]
    symbols = bits.reshape((-1, k))
    return np.dot(symbols, 2**np.arange(k-1, -1, -1))


def symbols_to_bits(symbols, M):
    k = int(np.log2(M))
    bits = ((symbols[:, None] & (1 << np.arange(k-1, -1, -1))) > 0).astype(int)
    return bits.reshape(-1)

# Simulation BER

def simulate_ber(method, EbN0_dBs, M, n_bits=10000):
    k = int(np.log2(M))
    freqs = generate_frequencies(M)
    Ns = int(Tb * fs)
    bers = []
    for EbN0 in EbN0_dBs:
        # Génération bits & symboles
        bits = np.random.randint(0, 2, n_bits)
        symbols = bits_to_symbols(bits, M)
        # Modulation
        tx = mfsk_mod(symbols, fs, Tb, freqs)
        # Bruit
        rx = awgn(tx, EbN0, k, Rb, fs)
        # Démodulation
        if method == 'coherent':
            sym_hat = demod_coherent(rx, fs, Tb, freqs)
        elif method == 'non':
            sym_hat = demod_noncoherent(rx, fs, Tb, freqs)
        else:
            sym_hat = demod_fft(rx, fs, Tb, freqs)
        # Conversion en bits
        bits_hat = symbols_to_bits(sym_hat, M)
        # BER
        ber = np.mean(bits[:len(bits_hat)] != bits_hat)
        bers.append(ber)
        print(f"Méthode={method}, Eb/N0={EbN0} dB, BER={ber:.2e}")
    return np.array(bers)

if __name__ == '__main__':
    # Choix du nombre de porteuses
    M = 8             # ex: 2, 4, 8, ...
    EbN0_dBs = np.arange(0, 30, 2)

    ber_coh = simulate_ber('coherent', EbN0_dBs, M)
    ber_non = simulate_ber('non', EbN0_dBs, M)
    ber_fft = simulate_ber('fft', EbN0_dBs, M)

    # Affichage
    plt.semilogy(EbN0_dBs, ber_coh, 'o-', label='Cohérent')
    plt.semilogy(EbN0_dBs, ber_non, 's--', label='Non-cohérent')
    plt.semilogy(EbN0_dBs, ber_fft, 'd-.', label='FFT-based')
    plt.xlabel("Eb/N0 (dB)")
    plt.ylabel("BER")
    plt.ylim(1e-5, 1)
    plt.grid(True)
    plt.legend()
    plt.title(f"Comparaison BER M-FSK (M={M})")
    plt.show()
