# -*- coding: utf-8 -*-
"""
Created on Sun Apr 20 21:33:55 2025

@author: AKourgli
"""

import numpy as np
import matplotlib.pyplot as plt
from scipy import signal

# Parameters
# symbol (bit) duration (s)
tb = 0.1
# sampling frequency (Hz)
fs = 1000
# center carrier frequency (Hz)
fc = 100
# modulation index for MSK/GMSK
h = 0.5
# GMSK bandwidth-time product
BT = 0.3
# Gaussian filter span in symbol durations
span = 4
# frequency spacing for orthogonal FSK
# for binary orthogonal FSK, df = 1/Tb
# for M-ary, spacings are multiples of df

df = 1/tb

# Utility: time axis
def time_axis(n):
    return np.arange(n) / fs

# Gaussian filter for GMSK
def gaussian_filter(BT, tb, fs, span):
    alpha = np.sqrt(np.log(2)) / (2 * np.pi * BT)
    t_max = span * tb / 2
    t = np.arange(-t_max, t_max, 1/fs)
    h = np.exp(- (np.pi * t / alpha)**2)
    return h / np.sum(h)

# Generate M-ary FSK (binary and 4-FSK)
def generate_fsk(bits, M, tb, fs, fc, df):
    # symbol mapping
    if M == 2:
        symbols = bits
    else:
        # map bit pairs to symbols 0,1,2,3
        bit_pairs = bits.reshape((-1, 2))
        symbols = bit_pairs[:,0]*2 + bit_pairs[:,1]
    ns = int(tb * fs)
    Nsym = len(symbols)
    n_total = Nsym * ns
    t = time_axis(n_total)
    # initialize branches
    branches = [np.zeros(n_total) for _ in range(M)]
    s_bb = np.zeros(n_total, dtype=complex)
    for i in range(M):
        # center frequencies: symmetric around fc
        f_i = fc + (2*i - (M-1)) * df/2
        for k, sym in enumerate(symbols):
            if sym == i:
                idx = slice(k*ns, (k+1)*ns)
                branches[i][idx] = np.cos(2*np.pi * f_i * t[idx])
                # complex baseband wrt fc
                s_bb[idx] = np.exp(1j * 2*np.pi * (f_i - fc) * t[idx])
    # summed FSK signal
    s_fsk = np.sum(branches, axis=0)
    return branches, s_fsk, t, s_bb


# PSD plotting for FSK
def plot_psd_fsk(branches, s_fsk, fs, title):
    plt.figure(figsize=(8,5))
    for i, branch in enumerate(branches):
        f, Pxx = signal.welch(branch, fs, nperseg=1024)
        plt.semilogy(f, Pxx, label=f'Porteuse {i}')
    f_sum, Pxx_sum = signal.welch(s_fsk, fs, nperseg=1024)
    plt.semilogy(f_sum, Pxx_sum, '--', label='Sum')
    plt.title(f'PSD: {title}')
    plt.xlabel('Frequency (Hz)')
    plt.ylabel('PSD (dB/Hz)')
    plt.legend()
    plt.grid(True)


# Main execution
if __name__ == '__main__':
    np.random.seed(0)
    N = 1000
    
    # Binary sequences
    bits2 = np.random.randint(0,2,N)

    # 2-FSK
    branches2f, s2f, t2f, sbb2f = generate_fsk(bits2, 2, tb, fs, fc, df)
    plot_psd_fsk(branches2f, s2f, fs, '2-GMSK')


    # 4-ary sequences (bit pairs)
    bits4 = np.random.randint(0,2,2*N)

    # 4-FSK
    branches4f, s4f, t4f, sbb4f = generate_fsk(bits4, 4, tb, fs, fc, df)
    plot_psd_fsk(branches4f, s4f, fs, '4-GMSK')

    plt.show()
