import numpy as np
from qiskit import QuantumCircuit
from qiskit_aer import Aer
from qiskit import transpile
from qiskit import transpile
from qiskit.quantum_info import DensityMatrix, entropy, partial_trace
from qiskit.quantum_info import random_density_matrix, random_statevector, state_fidelity, Operator
import numpy as np
import matplotlib.pyplot as plt
from functools import reduce
import pandas as pd
from scipy.linalg import logm

# Define the function of quantum relative entropy
def relative_entropy(rho, sigma):
    rho_data = rho.data
    sigma_data = sigma.data
    log_rho = logm(rho_data)/np.log(2)
    log_sigma = logm(sigma_data)/np.log(2)
    relative_entropy = np.trace(rho_data@(log_rho-log_sigma)).real
    return relative_entropy

# Number of qubits in the system (A, B, Ap, Bp)
n_qubits = 4

# Backend for statevector simulation
backend = Aer.get_backend('statevector_simulator')

# Store values of causation measure and initial information of the system
S_A_init = []
S_B_init = []
S_init = []
#C_A_init = []
#C_B_init = []
Delta_discord_total = []
Delta_discord_AB = []
coherence_init = []
coherence_A_init = []
coherence_consumption_AB = []
#I_ApA_B = []
lb1 = []
lb2 = []
lb3 = []
lb4 = []
C_A2B = []
C_B2A = []
diff_C = []

### Create a quantum circuit to apply the CNOT transformations

#class ControledRotationCircuit:
#    def _init_(self, num_qubits, theta):
#        self.qc1 = QuantumCircuit(num_qubits)
#        self.theta = theta
#        self.num_qubits = num_qubits

#    def local_copy(self, control:int, target:int):
#        ## Apply a CNOT gate on a input qubit and its ancilla
#        self.qc1.cx(control,target)

#    def channel_evolution(self, control: int, traget: int):
#        ## Apply the controlled rotation of theta on the input qubits
#        self.qc1.crx(theta, control, target)

#    def get_circuit(self):
#        return self.qc1


qc1 = QuantumCircuit(n_qubits)
    
# Step 1: Apply CNOT on Ap -> A and B -> Bp
qc1.cx(2, 3)  # CNOT(A -> Ap)
#qc1.h(2) # rotate the basis of A to {+,-}
#qc1.h(1)
qc1.cx(1, 0)  # CNOT(B -> Bp)

# Step 2: Apply CNOT on A -> B
#qc1.s(2)
#qc1.h(2)
#qc1.h(1)
#qc1.y(2)
#qc1.rx(np.pi/8,1)
#qc1.cx(2, 1)  # CNOT(A -> B)
#qc1.cx(1, 2)
qc1.swap(1,2)
#theta = np.pi/640
#qc1.crx(theta, 2,1)
    
# Step 3: Apply inverse of Step 1 (CNOT Ap -> A and Bp -> B)
qc1.cx(3, 2)  # CNOT(Ap -> A)
qc1.cx(0, 1)  # CNOT(Bp -> B)

U1 = Operator(qc1)

### Simulation with random initial state of A
for _ in range(1000):

    ### Generate density matrix of A
    
    # Arbitrary density matrix
    #rho_A = random_density_matrix(2)
    
    # Classical states
    #rhoA_00 = np.random.random()
    #rhoA_11 = 1 - rhoA_00
    #rho_A = DensityMatrix(np.diag([rhoA_00, rhoA_11]))

    # Pure states
    psi_A = random_statevector(2)
    rho_A = DensityMatrix(psi_A)
    #rho_A = DensityMatrix(np.array(
    #    [[0.5, 0.5],
    #    [0.5, 0.5]]
    #))

    # Compute the initial Von Neumann entropy of A
    S_A_init.append(entropy(rho_A))
    
    # Initialize a full 4-qubit density matrix with B, Ap, Bp in |0>
    rho_Ap = DensityMatrix.from_label('0')
    #rho_B = DensityMatrix.from_label('0')
    #rho_B = DensityMatrix(np.array(
    #    [[0.5, 0.5],
    #    [0.5, 0.5]]
    #))
    #rho_B = random_density_matrix(2)
    psi_B = random_statevector(2)
    rho_B = DensityMatrix(psi_B)
    rho_Bp = DensityMatrix.from_label('0')

    S_B_init.append(entropy(rho_B))
    S_init.append(entropy(rho_A)+entropy(rho_B))

    rho_list = [rho_Ap, rho_A, rho_B, rho_Bp]
    rho_full = reduce(lambda x,y: x.tensor(y), rho_list)
    #rho_full_data = rho_full.data
    #rho_full_incoherent_data = np.diag(np.diag(rho_full_data))
    #rho_full_incoherent = DensityMatrix(rho_full_incoherent_data)
    #coherence_init.append(relative_entropy(rho_full, rho_full_incoherent))

    # Compute the initial coherence of A and B
    rho_A_data = rho_A.data
    A_init_incoherent_data = np.diag(np.diag(rho_A_data))
    A_init_incoherent = DensityMatrix(A_init_incoherent_data)
    C_A_init = relative_entropy(rho_A, A_init_incoherent)
    coherence_A_init.append(C_A_init)
    rho_B_data = rho_B.data
    B_init_incoherent_data = np.diag(np.diag(rho_B_data))
    B_init_incoherent = DensityMatrix(B_init_incoherent_data)
    C_B_init = relative_entropy(rho_B, B_init_incoherent)


    # Evolution of Circuit
    final_rho = rho_full.evolve(U1)

    # Compute entropy terms for causation measure
    S_BBp = entropy(partial_trace(final_rho, [2, 3]))
    S_ApABp = entropy(partial_trace(final_rho, [1]))
    S_Bp = entropy(partial_trace(final_rho, [1, 2, 3]))
    S_ApABBp = entropy(final_rho)
    S_AAp = entropy(partial_trace(final_rho, [0,1]))
    S_BpBAp = entropy(partial_trace(final_rho, [2]))
    S_Ap = entropy(partial_trace(final_rho, [0,1,2]))

    # Compute causation measure
    C_A_to_B = S_BBp + S_ApABp - S_Bp - S_ApABBp
    C_B_to_A = S_AAp + S_BpBAp - S_Ap - S_ApABBp
    C_A2B.append(C_A_to_B)
    C_B2A.append(C_B_to_A)
    diff_C.append(C_A_to_B-C_B_to_A)

    # Compute the mutual information between ApA and B
    #S_ApA = entropy(partial_trace(final_rho, [0,1]))
    #S_B = entropy(partial_trace(final_rho,[0,2,3]))
    #S_ApAB = entropy(partial_trace(final_rho,[0]))
    #I_ApA_B.append(S_ApA+S_B-S_ApAB)

    # Compute the lower bound of C(A->B): 1/2 * (I(A:B)+I(Ap:B))
    #S_ApB = entropy(partial_trace(final_rho, [0,2]))
    S_AB = entropy(partial_trace(final_rho, [0,3]))
    #S_A = entropy(partial_trace(final_rho,[0,1,3]))
    #S_Ap = entropy(partial_trace(final_rho,[0,1,2]))
    #S_B = entropy(partial_trace(final_rho,[0,2,3]))
    lb1.append(S_AB)
    #lb2.append(S_A+S_B-S_AB+S_Ap+S_B-S_ApB)
    #lb3.append(S_A+S_B-S_AB)
    #lb4.append(S_Ap+S_B-S_ApB)


    # Compute the closet incoherent state w.r.t. rho_A, rho_Ap, rho_B, rho_Bp
    #rho_Ap_final = partial_trace(final_rho, [0,1,2])
    #rho_Ap_final_data = rho_Ap_final.data
    #Ap_incoherent_data = np.diag(np.diag(rho_Ap_final_data))
    #Ap_incoherent = DensityMatrix(Ap_incoherent_data)
    #C_Ap_final = relative_entropy(rho_Ap_final, Ap_incoherent)

    #rho_A_final = partial_trace(final_rho, [0,1,3])
    #rho_A_final_data = rho_A_final.data
    #A_incoherent_data = np.diag(np.diag(rho_A_final_data))
    #A_incoherent = DensityMatrix(A_incoherent_data)
    #C_A_final = relative_entropy(rho_A_final, A_incoherent)

    #rho_B_final = partial_trace(final_rho, [0,2,3])
    #rho_B_final_data = rho_B_final.data
    #B_incoherent_data = np.diag(np.diag(rho_B_final_data))
    #B_incoherent = DensityMatrix(B_incoherent_data)
    #C_B_final = relative_entropy(rho_B_final, B_incoherent)

    #rho_Bp_final = partial_trace(final_rho, [1,2,3])
    #rho_Bp_final_data = rho_Bp_final.data
    #Bp_incoherent_data = np.diag(np.diag(rho_Bp_final_data))
    #Bp_incoherent = DensityMatrix(Bp_incoherent_data)
    #C_Bp_final = relative_entropy(rho_Bp_final, Bp_incoherent)

    # Compute the coherent comsumption
    #cc = C_A_final - C_A_init + C_B_final - C_B_init
    #coherence_consumption.append(cc)

    # Compute the increased quantum discord of the global system A'ABB'
    #final_rho_data = final_rho.data
    #final_rho_incoherent_data = np.diag(np.diag(final_rho_data))
    #final_rho_incoherent = DensityMatrix(final_rho_incoherent_data)
    #C_final_rho = relative_entropy(final_rho, final_rho_incoherent)
    #Delta_discord_total.append(C_final_rho)

    # Compute the coherence variation of subsystem AB
    rho_AB_final = partial_trace(final_rho, [0,1])
    rho_AB_final_data = rho_AB_final.data
    rho_AB_final_incoherent_data = np.diag(np.diag(rho_AB_final_data))
    rho_AB_final_incoherent = DensityMatrix(rho_AB_final_incoherent_data)
    C_final_rho_AB = relative_entropy(rho_AB_final, rho_AB_final_incoherent)
    cc = C_final_rho_AB - C_A_init - C_B_init
    #coherence_consumption.append(cc)
    coherence_consumption_AB.append(abs(cc))



# Print the causation measure results
#CausationMeasure_dict = {'S_init_total': S_init, 'S_init(A)': S_A_init, 
#    'S_init(B)': S_B_init, 'C(A->B)': C_A2B, 
#    'I(ApA:B)': I_ApA_B, 'coherence_consumption': coherence_consumption,
#    'discord':Delta_discord, 'lower bound': lb}

#CausationMeasure_dict = {'S_init_total': S_init, 'S_init(A)': S_A_init, 
#    'S_init(B)': S_B_init, 'coherence_A_init': coherence_A_init, 'C(A->B)': C_A2B, 
#    'C(B->A)':C_B2A, 'discord of AB':Delta_discord_AB, 'lower bound 1':lb1, 
#    'lower bound 2':lb2, 'lower bound 3':lb3, 'lower bound 4':lb4}

CausationMeasure_dict = {'S_init_total': S_init, 'S_init(A)': S_A_init, 
    'S_init(B)': S_B_init, 'coherence_A_init': coherence_A_init, 'C(A->B)': C_A2B, 
    'C(B->A)':C_B2A, 'difference of causation entropy':diff_C, 'lower bound 1':lb1,
    'coherence_consumption_AB':coherence_consumption_AB}
CausationMeasure_df = pd.DataFrame(CausationMeasure_dict)

print("Maxmum C_A2B is:", CausationMeasure_df['C(A->B)'].max())
print("Minimum C_A2B is:", CausationMeasure_df['C(A->B)'].min())
#print("Maxmum C_B2A is:", CausationMeasure_df['C(B->A)'].max())
#print("Minimum C_B2A is:", CausationMeasure_df['C(B->A)'].min())
#print("Maxmum lower bound 1 is:", CausationMeasure_df['lower bound 1'].max())
#print("Minimum lower bound 1 is:", CausationMeasure_df['lower bound 1'].min())
#print("Maxmum lower bound 3 is:", CausationMeasure_df['lower bound 3'].max())
#print("Minimum lower bound 3 is:", CausationMeasure_df['lower bound 3'].min())
#print("Initial quantum information:", S_A_init)
#print("C_A2Bs:", C_A2B)


CausationMeasure_df.sort_values(by='coherence_A_init', inplace=True, ascending=True)
CausationMeasure_df.to_csv('CR64.csv')
#ax = CausationMeasure_df.plot.scatter(x='S_init(A)', y='C(A->B)', 
#    color="Red", label="C(A->B)")
#CausationMeasure_df.plot.scatter(x='S_init(A)', y='lower bound', color="Blue", 
#    label="lower bound", ax=ax) 
plt.rc('text', usetex=True)
plt.figure(figsize=(6, 5))
plt.scatter(CausationMeasure_df['coherence_A_init'], CausationMeasure_df['C(A->B)'], label=r"$\mathcal{C}(A\to B)$", marker='o')
#plt.scatter(CausationMeasure_df['S_init(A)'], CausationMeasure_df['C(A->B)'], label=r"$\mathcal{C}(A\to B)$", marker='o')
#plt.scatter(CausationMeasure_df['S_init(A)'], CausationMeasure_df['C(B->A)'], label=r"$\mathcal{C}(B\to A)$", marker='o')
#plt.scatter(CausationMeasure_df['S_init(A)'], CausationMeasure_df['difference of causation entropy'], label=r"$\mathcal(C)(B\to A)$", marker='+')
#plt.scatter(CausationMeasure_df['S_init(A)'], CausationMeasure_df['coherence_consumption_AB'], label=r"coherence_consumption_AB", marker='x')
#plt.scatter(CausationMeasure_df['S_init(A)'], CausationMeasure_df['lower bound 1'], label=r"$I(A':B)$", marker='+')
#plt.scatter(CausationMeasure_df['S_init(A)'], CausationMeasure_df['lower bound 2'], label=r"$I(A':B)+I(A:B)$", marker='+')
#plt.scatter(CausationMeasure_df['S_init(A)'], CausationMeasure_df['lower bound 3'], label=r"$I(A:B)$", marker='s')
#plt.scatter(CausationMeasure_df['S_init(A)'], CausationMeasure_df['lower bound 4'], label=r"$I(A':B)$", marker='^')
plt.xlabel('coherence_A_init')
plt.ylabel('')  # No y-axis label
plt.legend()
plt.grid(True)
plt.title(r'Transfer Entropy v.s. Initial Coherence of A')
plt.tight_layout()

plt.figure(figsize=(6, 5))
plt.scatter(CausationMeasure_df['coherence_A_init'], CausationMeasure_df['C(B->A)'], label=r"$\mathcal{C}(B\to A)$", marker='o')
plt.xlabel('coherence_A_init')
plt.ylabel('')  # No y-axis label
plt.legend()
plt.grid(True)
plt.title(r'Transfer Entropy v.s. Initial Coherence of A')
plt.tight_layout()

#CausationMeasure_df.sort_values(by='S_init(B)', inplace=True, ascending=True)
#CausationMeasure_df.plot.scatter(x='S_init(B)', y='C(A->B)')
#plt.title('Causation Entropy v.s. S(B)')

#CausationMeasure_df.sort_values(by='S_init_total', inplace=True, ascending=True)
#CausationMeasure_df.plot.scatter(x='S_init_total', y='C(A->B)')
#plt.title('Causation Entropy v.s. S(A)+S(B)')
plt.show()