# -*- coding: utf-8 -*-
"""
Created on Sat Mar 15 09:18:41 2025

@author: AKourgli
"""

import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import convolve

# Fonction pour générer la réponse impulsionnelle d'un filtre SRRC
def srrc_filter(beta, T, span, sps):
    """
    Génère la réponse impulsionnelle d'un filtre Square Root Raised Cosine (SRRC).
    
    Paramètres:
      beta : facteur de roll-off (entre 0 et 1)
      T    : période de symbole (ici T=1)
      span : étendue du filtre en nombre de symboles (total)
      sps  : nombre d'échantillons par symbole
      
    Retourne:
      h    : coefficients du filtre
      t    : vecteur temps associé
    """
    t = np.linspace(-span/2, span/2, span * sps + 1)
    h = np.zeros_like(t)
    for i in range(len(t)):
        # Cas t = 0
        if np.isclose(t[i], 0.0, atol=1e-8):
            h[i] = 1 - beta + (4 * beta / np.pi)
        # Cas particulier lorsque |t| = T/(4β)
        elif beta != 0 and np.isclose(np.abs(t[i]), 1/(4*beta), atol=1e-8):
            h[i] = (beta/np.sqrt(2)) * (((1 + 2/np.pi) * np.sin(np.pi/(4*beta))) +
                                         ((1 - 2/np.pi) * np.cos(np.pi/(4*beta))))
        else:
            numerator = np.sin(np.pi * t[i] * (1 - beta)) + 4 * beta * t[i] * np.cos(np.pi * t[i] * (1 + beta))
            denominator = np.pi * t[i] * (1 - (4 * beta * t[i])**2)
            h[i] = numerator / denominator
    return h, t

# Fonction pour suréchantillonner la suite de symboles
def upsample(symbols, sps):
    """
    Insère (sps-1) zéros entre chaque symbole.
    """
    upsampled = np.zeros(len(symbols) * sps)
    upsampled[::sps] = symbols
    return upsampled

# Paramètres de simulation
np.random.seed(0)       # pour reproductibilité
num_symbols = 20        # nombre total de symboles BPSK générés
sps = 8                 # échantillons par symbole
T = 1                   # temps de symbole
span = 6                # le filtre couvre "span" symboles (total)

# Génération aléatoire des symboles BPSK: 0 -> -1 et 1 -> +1
bits = np.random.randint(0, 2, num_symbols)
symbols = 2 * bits - 1  # conversion en -1 et +1
    
# Suréchantillonnage (insertion de zéros)
upsym = upsample(symbols, sps)

# Pour chaque valeur de roll-off de 0 à 1 par pas de 0.25
for beta in np.arange(0, 1.01, 0.25):
    
 
    # Création du filtre SRRC (pour émission et réception)
    h, t_filter = srrc_filter(beta, T, span, sps)
    
    # Mise en forme à l'émission : convolution de la suite suréchantillonnée par le filtre SRRC
    tx_signal = np.convolve(upsym, h, mode='same')
    
    # Tracé de la suite suréchantillonnée
    plt.figure(figsize=(10, 8))
    plt.subplot(5, 1, 1)
    plt.stem(np.arange(len(upsym)), upsym, use_line_collection=True)
    plt.title(f"Symboles BPSK suréchantillonnés (beta = {beta})")
    plt.xlabel("Indice d'échantillon")
    plt.ylabel("Amplitude")
    
    # Tracé du signal mis en forme (signal transmis)
    plt.subplot(5, 1, 2)
    plt.plot(tx_signal)
    plt.title("Signal après mise en forme (filtrage SRRC)")
    plt.xlabel("Indice d'échantillon")
    plt.ylabel("Amplitude")
    
    # Ajout de bruit (AWGN)
    SNR_dB = 5  # rapport signal/bruit en dB
    signal_power = np.mean(tx_signal**2)
    noise_power = signal_power / (10**(SNR_dB/10))
    noise = np.sqrt(noise_power) * np.random.randn(len(tx_signal))
    rx_signal = tx_signal + noise
    
    plt.subplot(5, 1, 3)
    plt.plot(rx_signal)
    plt.title("Signal reçu avec bruit AWGN")
    plt.xlabel("Indice d'échantillon")
    plt.ylabel("Amplitude")
    
    # Filtrage par le filtre adapté (même SRRC)
    mf_output = np.convolve(rx_signal, h, mode='same')
    
    plt.subplot(5, 1, 4)
    plt.plot(mf_output)
    plt.title("Signal après filtrage adapté (SRRC)")
    plt.xlabel("Indice d'échantillon")
    plt.ylabel("Amplitude")
    
    # Pour la détection, le système comporte deux filtres SRRC : 
    # le délai total effectif est donc de 2 * delay
    delay = (len(h) - 1) // 2
    total_delay = 2 * delay  # délai total en échantillons
    
    # Échantillonnage à partir de l'indice total_delay
    detected_samples = mf_output[total_delay::sps]
    n_valid = len(detected_samples)
    detected_symbols = np.where(detected_samples >= 0, 1, -1)
    
    # Comparaison avec la partie valide des symboles transmis :
    # Le délai total correspond à total_delay/sps symboles à ignorer
    symbol_delay = total_delay // sps
    valid_symbols = symbols[symbol_delay : symbol_delay + n_valid]
    
    # Calcul du nombre d'erreurs et du taux d'erreur
    error_count = np.sum(valid_symbols != detected_symbols)
    error_rate = error_count / len(valid_symbols)
    
    # Tracé des symboles émis (valides) et détectés sur le même graphe
    plt.subplot(5, 1, 5)
    plt.stem(np.arange(n_valid), valid_symbols, linefmt='b-', markerfmt='bo', basefmt=" ", label='Émis')
    plt.stem(np.arange(n_valid), detected_symbols, linefmt='r--', markerfmt='rx', basefmt=" ", label='Détectés')
    plt.axhline(0, color='gray', linestyle='--', linewidth=0.5)
    plt.title("Symboles émis (valides) et détectés")
    plt.xlabel("Indice du symbole")
    plt.ylabel("Amplitude")
    plt.legend()
    
    plt.tight_layout()
    plt.show()
    
    print(f"=== Roll-off beta = {beta} ===")
    print("Symboles transmis (valides pour comparaison) :")
    print(valid_symbols)
    print("Symboles détectés:")
    print(detected_symbols)
    print(f"Erreur : {error_count} symboles erronés, taux d'erreur = {error_rate*100:.2f}%\n")
