# -*- coding: utf-8 -*-
"""
Created on Mon Mar  9 19:15:35 2020

@author: lenovo
"""

#
# Copyright (c) 2011 Christopher Felton
# Modified by Akourgli


import numpy as np
import matplotlib.pyplot as plt
from  matplotlib import patches

    
def zplane(b,a,filename=None):
    """Plot the complex z-plane given a transfer function.
    """

    # get a figure/plot
    ax = plt.subplot(111)

    # create the unit circle
    uc = patches.Circle((0,0), radius=1, fill=False,
                        color='black', ls='dashed')
    ax.add_patch(uc)

        
    # Get the poles and zeros
    if len(a)<len(b):
        c=np.zeros(len(b))
        c[:len(a)]=a[:]
        a=c
        
    p = np.roots(a)
    z = np.roots(b)
        
    # Plot the zeros and set marker properties    
    t1 = plt.plot(z.real, z.imag, 'go', ms=10)
    plt.setp( t1, markersize=10.0, markeredgewidth=1.0,
              markeredgecolor='k', markerfacecolor='b')

    # Plot the poles and set marker properties
    t2 = plt.plot(p.real, p.imag, 'rx', ms=10)
    plt.setp( t2, markersize=12.0, markeredgewidth=3.0,
              markeredgecolor='r', markerfacecolor='r')

    mark_overlapping(z)
    mark_overlapping(p)
    ax.spines['left'].set_position('zero')
    ax.spines['bottom'].set_position('zero')
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)

    # set the ticks
    pp=np.max(np.abs(p));
    zz=np.max(np.abs(z));
    r = max([pp,zz,1.5]); 
    if r> 10: r=1.5
    plt.axis([-r-0.1, r+0.1, -r-0.1, r+0.1])
    plt.xlabel ('Partie réelle',loc='right')
    plt.ylabel ('Partie imaginaire',loc='top')
    plt.title ('Tracé des pôles et zéros')
    ticks = np.arange (-r, r+0.1, 0.5); plt.xticks(ticks); plt.yticks(ticks)

    plt.show()
   
    return z, p

from collections import defaultdict
def mark_overlapping(items):
    """
    Given `items` as a list of complex coordinates, make a tally of identical 
    values, and, if there is more than one, plot a superscript on the graph.
    """
    d = defaultdict(int)
    for i in items:
        d[i] += 1
    for item, count in d.items():
        if count > 1:
            plt.text(item.real, item.imag, r' ${}^{' + str(count) + '}$', fontsize=24)