# -*- coding: utf-8 -*-
"""
Created on Sun Apr 13 18:03:19 2025

@author: AKourgli
"""

import numpy as np
import matplotlib.pyplot as plt
from scipy.special import erfc

# Parameters
M = 4  # Modulation order (2, 4, 8, ...)
n_bits = 100000  # Number of bits
EbN0dB_range = np.arange(0, 30, 2)  # Eb/N0 range in dB
k = int(np.log2(M))  # Bits per symbol
n_symbols = n_bits // k  # Number of symbols

# Generate symbols (unipolar M-ASK)
original_symbols = np.arange(M)
Es_original = np.mean(original_symbols**2)  # Original average symbol energy
scaling_factor = np.sqrt(1 / Es_original)  # Scale to have Es = 1
scaled_symbols = original_symbols * scaling_factor

# Pre-calculate symbol indices for mapping
bits = np.random.randint(0, 2, n_bits)
symbols_idx = np.reshape(bits, (n_symbols, k)).dot(2**np.arange(k)[::-1])

# Modulate
transmitted_symbols = scaled_symbols[symbols_idx]

# Phase inversions (h = ±1)
h = np.random.choice([-1, 1], size=n_symbols)

# BER arrays
ber_coherent = np.zeros_like(EbN0dB_range, dtype=float)
ber_noncoherent = np.zeros_like(EbN0dB_range, dtype=float)

for i, EbN0dB in enumerate(EbN0dB_range):
    # Convert Eb/N0 to linear
    EbN0_linear = 10**(EbN0dB / 10)
    # Calculate noise variance (sigma^2)
    Es = 1  # Average symbol energy is 1
    Eb = Es / k
    N0 = Eb / EbN0_linear
    sigma = np.sqrt(N0 / 2)  # Noise standard deviation
    
    # Add noise
    noise = np.random.normal(0, sigma, n_symbols)
    received_coherent = h * transmitted_symbols + noise
    
    # Coherent detection: multiply by h to correct phase
    received_coherent_corrected = received_coherent * h
    # Demodulate
    demodulated_idx_coherent = np.argmin(np.abs(received_coherent_corrected[:, np.newaxis] - scaled_symbols), axis=1)
    # Convert symbols to bits
    demodulated_bits_coherent = np.unpackbits(demodulated_idx_coherent.astype(np.uint8), bitorder='little').reshape(-1, 8)[:, -k:].flatten()
    # Calculate BER
    ber_coherent[i] = np.mean(bits != demodulated_bits_coherent[:n_bits])
    
    # Non-coherent detection: absolute value
    received_noncoherent = np.abs(received_coherent)
    # Demodulate
    demodulated_idx_noncoherent = np.argmin(np.abs(received_noncoherent[:, np.newaxis] - scaled_symbols), axis=1)
    # Convert symbols to bits
    demodulated_bits_noncoherent = np.unpackbits(demodulated_idx_noncoherent.astype(np.uint8),bitorder='little').reshape(-1, 8)[:, -k:].flatten()
       # Calculate BER
    ber_noncoherent[i] = np.mean(bits != demodulated_bits_noncoherent[:n_bits])
    
# Theoretical BER for coherent M-ASK (approximation)
# Using formula for coherent detection of M-ASK: BER ≈ (2(M-1)/M) * Q(sqrt(6*EbN0/(M^2-1)))
# Where Q(x) = 0.5 * erfc(x / sqrt(2))
EbN0_linear_theory = 10**(EbN0dB_range / 10)
ber_coherent_theory = (2*(M-1)/M) * 0.5 * erfc(np.sqrt(3*EbN0_linear_theory/(M**2-1)))
if M == 2:
    # Cas particulier OOK (formule exacte)
    ber_noncoherent_theory = 0.5 * np.exp(-EbN0_linear_theory/4)
else:
    # Approximation pour M>2 (à titre indicatif)
    ber_noncoherent_theory = (M-1)/M * 0.5 * np.exp(-3*EbN0_linear_theory/(2*(M**2-1)))

# Plot
plt.figure(figsize=(10, 6))
#plt.semilogy(EbN0dB_range, ber_coherent, 'bo-', label='Coherent (Sim.)')
plt.semilogy(EbN0dB_range, ber_coherent_theory, 'b--', label='Coherent (Theory)')
#plt.semilogy(EbN0dB_range, ber_noncoherent, 'ro-', label='Non-coherent (Sim.)')
if M == 2:
    plt.semilogy(EbN0dB_range, ber_noncoherent_theory, 'r--', label='Non-coherent (Theory)')
else:
    plt.semilogy(EbN0dB_range, ber_noncoherent_theory, 'r:', label='Non-coherent (Approx.)')
plt.xlabel('Eb/N0 (dB)')
plt.ylabel('BER')
plt.grid(True, which="both", ls="--")
plt.legend()
plt.title(f'BER vs Eb/N0 for {M}-ASK')
plt.ylim(1e-6, 1)
plt.show()

