# -*- coding: utf-8 -*-
"""
Created on Fri Apr 11 20:02:42 2025

@author: AKourgli
"""

import numpy as np
import matplotlib.pyplot as plt

# Fonction de conversion Gray → Binaire
def gray_to_binary(gray):
    binary = gray
    while gray > 0:
        gray >>= 1
        binary ^= gray
    return binary

# Modulation M-PSK avec porteuse
def mod_psk(bits, M, fc, fs, t_sym):
    k = int(np.log2(M))
    symbols = bits.reshape(-1, k)
    m = np.zeros(symbols.shape[0], dtype=np.uint8)
    for i in range(symbols.shape[0]):
        for j in range(k):
            m[i] += symbols[i, j] << (k - 1 - j)  # Conversion MSB
    gray_m = m ^ (m >> 1)  # Codage Gray
    phases = 2 * np.pi * gray_m / M
    # Génération du signal modulé
    samples_per_symbol = int(fs * t_sym)
    t = np.linspace(0, t_sym, samples_per_symbol, endpoint=False)
    passband = np.cos(2 * np.pi * fc * t.reshape(1, -1) + phases.reshape(-1, 1))
    return passband.flatten(), phases

# Démodulation M-PSK avec porteuse
def demod_psk(signal, M, fc, fs, t_sym):
    samples_per_symbol = int(fs * t_sym)
    num_symbols = len(signal) // samples_per_symbol
    t = np.linspace(0, t_sym, samples_per_symbol, endpoint=False)
    received_phases = []
    for i in range(num_symbols):
        segment = signal[i*samples_per_symbol : (i+1)*samples_per_symbol]
        I = np.mean(segment * np.cos(2 * np.pi * fc * t))
        Q = np.mean(segment * np.sin(2 * np.pi * fc * t))
        received_phases.append(np.arctan2(Q, I) % (2 * np.pi))
    received_phases = np.array(received_phases)
    # Détection des symboles
    k = int(np.log2(M))
    candidate_phases = 2 * np.pi * np.arange(M) / M
    phase_diff = np.abs(received_phases[:, None] - candidate_phases)
    phase_diff = np.minimum(phase_diff, 2 * np.pi - phase_diff)
    gray_m = np.argmin(phase_diff, axis=1)
    m = np.array([gray_to_binary(gm) for gm in gray_m], dtype=np.uint8)
    # Conversion en bits
    bits_rx = np.zeros((len(m), k), dtype=int)
    for i in range(len(m)):
        for j in range(k):
            bits_rx[i, j] = (m[i] >> (k - 1 - j)) & 1
    return bits_rx.flatten()

# Paramètres
M = 4           # QPSK
num_bits = 20   # Nombre de bits
EbN0_dB = 10    # Rapport signal/bruit
fc = 10         # Fréquence porteuse (Hz)
fs = 100        # Fréquence d'échantillonnage (Hz)
t_sym = 1       # Durée d'un symbole (s)

# Génération des bits
bits_tx = np.random.randint(0, 2, num_bits)
print("Bits émis      :", bits_tx)

# Modulation
passband_signal, tx_phases = mod_psk(bits_tx, M, fc, fs, t_sym)
print("Phases émises  :", np.round(np.degrees(tx_phases), 1), "degrés")

# Ajout de bruit
k = int(np.log2(M))
samples_per_symbol = int(fs * t_sym)
EbN0_lin = 10 ** (EbN0_dB / 10)
signal_power = np.mean(passband_signal**2)
Eb = signal_power * t_sym / k
N0 = Eb / EbN0_lin
noise = np.sqrt(N0 * fs) * np.random.randn(len(passband_signal))
received_signal = passband_signal + noise

# Démodulation
bits_rx = demod_psk(received_signal, M, fc, fs, t_sym)
print("Bits reçus     :", bits_rx)

# Calcul des erreurs
error_positions = np.where(bits_tx != bits_rx)[0]
ber = len(error_positions) / num_bits

# Tracés
plt.figure(figsize=(15, 10))

# Signal modulé avec bruit
plt.subplot(3, 1, 1)
time = np.arange(len(passband_signal)) / fs
plt.plot(time, passband_signal, label="Émis")
plt.plot(time, received_signal, alpha=0.6, label="Reçu")
plt.title(f"Signal modulé (QPSK, Eb/N0 = {EbN0_dB} dB)")
plt.xlabel("Temps (s)")
plt.ylabel("Amplitude")
plt.grid(True)
plt.legend()

# Calcul des composantes I/Q reçues
received_I = []
received_Q = []
t = np.linspace(0, t_sym, samples_per_symbol)
for i in range(len(tx_phases)):
    segment = received_signal[i*samples_per_symbol : (i+1)*samples_per_symbol]
    I = np.mean(segment * np.cos(2 * np.pi * fc * t))
    Q = np.mean(segment * np.sin(2 * np.pi * fc * t))
    received_I.append(I)
    received_Q.append(Q)

# Tracé de la constellation
plt.subplot(3, 1, 2)
plt.scatter(np.cos(tx_phases), np.sin(tx_phases), c='blue', marker='x', label="Émis")
plt.scatter(received_I, received_Q, c='red', alpha=0.5, label="Reçus")
plt.title("Constellation QPSK")
plt.xlabel("I")
plt.ylabel("Q")
plt.grid(True)
plt.axis('equal')
plt.legend()

# Bits et erreurs
plt.subplot(3, 1, 3)
plt.stem(bits_tx, linefmt='b-', markerfmt='bo', basefmt="none", label="Émis")
plt.stem(bits_rx, linefmt='r--', markerfmt='rx', basefmt="none", label="Reçus")
if len(error_positions) > 0:
    plt.stem(error_positions, bits_rx[error_positions], linefmt='none', markerfmt='ro', label="Erreurs")
plt.title(f"Bits émis vs reçus (BER = {ber:.2f})")
plt.xlabel("Index du bit")
plt.ylabel("Valeur")
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()