# -*- coding: utf-8 -*-
"""
Created on Tue Feb 25 22:33:01 2025

@author: AKourgli
"""

import numpy as np
import matplotlib.pyplot as plt
from scipy.special import erfc

# Paramètres de la simulation
N_bits = 1000               # Nombre de bits (symboles) à transmettre
samples_per_symbol = 10     # Nombre d'échantillons par symbole
SNR_dB = 10
0                # Rapport signal/bruit en dB
E_b = 1                     # Énergie par bit (normalisée)

# Calcul des paramètres liés au bruit
SNR_lin = 10**(SNR_dB/10)     # SNR en linéaire
N0 = E_b / SNR_lin          # Puisque SNR = E_b/N0
sigma = np.sqrt(N0/2)       # Écart-type du bruit par dimension

# Génération des bits et mapping BPSK
bits = np.random.randint(0, 2, N_bits)
symbols = 2 * bits - 1      # Mapping : 0 -> -1, 1 -> +1

# Définition de la forme d'onde (pulse shaping)
pulse = np.ones(samples_per_symbol) / np.sqrt(samples_per_symbol)

# Génération du signal transmis
tx_signal = np.kron(symbols, pulse)

# Passage dans le canal AWGN
noise = sigma * np.random.randn(len(tx_signal))
rx_signal = tx_signal + noise

# Filtrage adapté (Matched Filter) avec mode 'same'
matched_filter = pulse[::-1]
filtered_signal = np.convolve(rx_signal, matched_filter, mode='same')

# Pour un filtre de longueur paire, utiliser delay = samples_per_symbol//2
delay = (samples_per_symbol) // 2

# Échantillonnage en tenant compte du délai corrigé
sampled = filtered_signal[delay + np.arange(N_bits) * samples_per_symbol]


# Détection par seuillage optimal (seuil = 0)
detected_symbols = np.where(sampled > 0, 1, -1)
detected_bits = (detected_symbols + 1) // 2  # Remapping : -1 -> 0 et 1 -> 1

# Calcul du taux d'erreur binaire (BER)
n_errors = np.sum(bits != detected_bits)
BER_sim = n_errors / N_bits

print("SNR (dB) =", SNR_dB)
print("Nombre de symboles transmis :", N_bits)
print("Nombre de symboles détectés :", len(detected_symbols))
print("BER simulé =", BER_sim)
# BER théorique pour BPSK sur canal AWGN
BER_theo = 0.5 * erfc(np.sqrt(SNR_lin))
print("BER théorique =", BER_theo)

# Affichage du signal émis et du signal reçu
n_symbols_plot = 100  # Affichage de 20 symboles
samples_to_plot = n_symbols_plot * samples_per_symbol
time_axis = np.arange(samples_to_plot)

plt.figure(figsize=(12, 6))
plt.plot(time_axis, tx_signal[:samples_to_plot], label="Signal émis", linewidth=2)
plt.plot(time_axis, rx_signal[:samples_to_plot], label="Signal reçu (AWGN)", linestyle='--')
plt.xlabel("Échantillons")
plt.ylabel("Amplitude")
plt.title("Signal émis et Signal reçu (pour {} symboles)".format(n_symbols_plot))
plt.legend()
plt.grid(True)
plt.show()

# Affichage du signal filtré et des points d'échantillonnage
n_symbols_plot = 25  # Affichage de 100 symboles pour une meilleure lisibilité
samples_to_plot = n_symbols_plot * samples_per_symbol
time_axis = np.arange(samples_to_plot) / samples_per_symbol

plt.figure(figsize=(12, 6))
plt.plot(time_axis, filtered_signal[:samples_to_plot], label="Signal filtré")
# Calcul des indices d'échantillonnage corrigés
sampled_indices = delay + np.arange(n_symbols_plot) * samples_per_symbol
plt.plot(time_axis[sampled_indices], filtered_signal[sampled_indices], 'ro', label="Échantillons")
plt.axhline(0, color='k', linestyle='--', label="Seuil")
plt.xlabel("Temps (en symboles)")
plt.ylabel("Amplitude")
plt.title("Signal filtré avec échantillonnage optimal")
plt.legend()
plt.grid(True)
plt.show()

# Affichage des symboles transmis vs détectés
plt.figure(figsize=(10, 4))
plt.stem(np.arange(n_symbols_plot), symbols[:n_symbols_plot], linefmt='b-', markerfmt='bo', basefmt=" ", label="Symboles transmis")
plt.stem(np.arange(n_symbols_plot), detected_symbols[:n_symbols_plot], linefmt='r--', markerfmt='rx', basefmt=" ", label="Symboles détectés")
plt.xlabel("Index du symbole")
plt.ylabel("Amplitude")
plt.title("Symboles transmis vs détectés (pour les premiers {} symboles)".format(n_symbols_plot))
plt.legend()
plt.grid(True)
plt.show()
