import numpy as np
from qiskit import QuantumCircuit
from qiskit_aer import Aer
from qiskit import transpile
from qiskit.quantum_info import DensityMatrix, Pauli, entropy, partial_trace, purity, concurrence
from qiskit.quantum_info import random_density_matrix, state_fidelity, Operator, random_statevector
from qiskit.circuit.library import RXGate
import numpy as np
import matplotlib.pyplot as plt
from functools import reduce
import pandas as pd
from scipy.linalg import logm
import math
import random

# Define the function of quantum relative entropy
def relative_entropy(rho, sigma):
    rho_data = rho.data
    sigma_data = sigma.data
    log_rho = logm(rho_data)/np.log(2)
    log_sigma = logm(sigma_data)/np.log(2)
    
    relative_entropy = np.trace(rho_data@(log_rho-log_sigma)).real
    return relative_entropy

def obs_LowerBound(rho, M_1, M_2):
    d = rho.dim
    gamma = purity(rho)

    p = 1/d + math.sqrt((1-1/d)*(gamma-1/d)) + 1e-12
    f = p*math.log(p) + (1-p)*math.log(1-p)

    M_global = M_1.tensor(M_2)
    rho0 = partial_trace(rho, [1])
    rho1 = partial_trace(rho, [0])

    M1_data = M_1.to_matrix()
    M2_data = M_2.to_matrix()
    norm_M1 = np.linalg.norm(M1_data, ord=1)
    norm_M2 = np.linalg.norm(M2_data, ord=1)

    C_M12 = rho.expectation_value(M_global) - rho0.expectation_value(M_1) * rho1.expectation_value(M_2)
    
    lower_bound = C_M12**2 / (2 * norm_M1**2 * norm_M2**2) + f

    return lower_bound

def Ef_2qubits(C):
    p = (1+math.sqrt(1-C**2))/2
    q = 1-p
    if p==0:
        p = p + 1e-10
    if q==0:
        q = q + 1e-10

    Ef = -p*math.log2(p) - q*math.log2(q)

    return Ef



# Number of qubits in the system (A, B, Ap, Bp)
n_qubits = 6

# Backend for statevector simulation
backend = Aer.get_backend('statevector_simulator')

# Store values of causation measure and initial information of the system
S_A_init = []
Sq_A_init = []
Sc_A_init = []
S_A_init_total = []
S_E_init = []
Sq_E_init = []
Sc_E_init = []
S_E_init_total = []
S_EA_init_total = []
S_B_init = []
Sq_B_init = []
Sc_B_init = []
S_B_init_total = []
S_EA_init = []
C_EA2B = []
C_A2B_condE = []
lb_Lieb = []
obs_lb = []
classicaldet_lb = []
Ef_ApB = []
acc_info_Ap2B = []
acc_info_B2Ap = []
frho_ApEpB = []
inrho_ApEpB = []

inrho = []
frho = []

# Create a quantum circuit to apply the CCNOT transformations on the system A'E'EABB'
qc = QuantumCircuit(n_qubits)
    
# Step 1: Apply CNOT on A -> Ap and B -> Bp
qc.cx(2, 4)  # CNOT(A -> Ap)
qc.cx(1, 0)  # CNOT(B -> Bp)
qc.cx(3, 5)  # CNOT(E -> Ep)


# Step 2: Apply CCNOT on EA->B
#qc.ccx(3, 2, 1)
#qc.cx(3,1)
#qc.cx(2,1)
theta = np.pi/2
ccx_theta = RXGate(theta).control(2)
qc.append(ccx_theta, [3, 2, 1]) # controls = qubits 3,2; target = 1
    
# Step 3: Apply the inverse of Step 1 (CNOT Ap -> A and Bp -> B)
qc.cx(4, 2)  # CNOT(Ap -> A)
qc.cx(0, 1)  # CNOT(Bp -> B)
qc.cx(5, 3)  # CNOT(Ep -> E)

U = Operator(qc)

qc1 = QuantumCircuit(n_qubits)
# Step 1: Apply CNOT on A -> Ap and B -> Bp
qc1.cx(2, 4)  # CNOT(A -> Ap)
qc1.cx(1, 0)  # CNOT(B -> Bp)
qc1.cx(3, 5)  # CNOT(E -> Ep)
U1 = Operator(qc1)

qc2 = QuantumCircuit(n_qubits)
# Step 2: Apply CCNOT on EA->B
qc2.ccx(3, 2, 1)
#qc2.cx(3,1)
#qc2.cx(2,1)
# Step 3: Apply the inverse of Step 1 (CNOT Ap -> A and Bp -> B)
qc2.cx(4, 2)  # CNOT(Ap -> A)
qc2.cx(0, 1)  # CNOT(Bp -> B)
qc2.cx(5, 3)  # CNOT(Ep -> E)
U2 = Operator(qc2)

#qc3 = qc1.compose(qc2)
#U3 = Operator(qc3)

### Simulation with random initial state of A
for _ in range(1):
    #rho_A = random_density_matrix(2)
    
    # Pure states
    #psi_A = random_statevector(2)
    #rho_A = DensityMatrix(psi_A)
    #pA = random.uniform(0,1)
    #rho_Amatrix = np.array(
    #    [[pA, math.sqrt(pA*(1-pA))],
    #    [math.sqrt(pA*(1-pA)), 1-pA]]
    #)
    #rho_A = DensityMatrix(rho_Amatrix)
    rho_A = DensityMatrix(np.array(
        [[0.5, 0.5],
        [0.5, 0.5]]
    ))
    
    #rho_A = DensityMatrix.from_label('1')
    
    #rho_E = random_density_matrix(2)
    #rho_E = DensityMatrix.from_label('1')
    # Pure states
    #psi_E = random_statevector(2)
    #rho_E = DensityMatrix(psi_E)
    #pE = random.uniform(0,1)
    #rho_Ematrix = np.array(
    #    [[pE, math.sqrt(pE*(1-pE))],
    #    [math.sqrt(pE*(1-pE)), 1-pE]]
    #)
    #rho_E = DensityMatrix(rho_Ematrix)

    rho_E = DensityMatrix(np.array(
        [[0.5, 0.5],
        [0.5, 0.5]]
    ))

    # Compute the initial Von Neumann entropy of A and E
    S_A_init.append(entropy(rho_A))
    S_E_init.append(entropy(rho_E))
    S_EA_init.append(entropy(rho_A)+entropy(rho_E))

    rho_A_data = rho_A.data
    A_init_incoherent_data = np.diag(np.diag(rho_A_data))
    A_init_incoherent = DensityMatrix(A_init_incoherent_data)
    #print(A_init_incoherent)
    Sc_A_init.append(entropy(A_init_incoherent))
    C_A_init = relative_entropy(rho_A, A_init_incoherent)
    Sq_A_init.append(C_A_init)
    S_A_init_total.append(entropy(A_init_incoherent) + C_A_init)
    
    rho_E_data = rho_E.data
    E_init_incoherent_data = np.diag(np.diag(rho_E_data))
    E_init_incoherent = DensityMatrix(E_init_incoherent_data)
    #print(E_init_incoherent)
    Sc_E_init.append(entropy(E_init_incoherent))
    C_E_init = relative_entropy(rho_E, E_init_incoherent)
    Sq_E_init.append(C_E_init)
    S_E_init_total.append(entropy(E_init_incoherent) + C_E_init)

    #rho_EA = rho_E.tensor(rho_A)
    #rho_EA_data = rho_EA.data
    #EA_init_incoherent_data = np.diag(np.diag(rho_EA_data))
    #EA_init_incoherent = DensityMatrix(EA_init_incoherent_data)
    ##print(EA_init_incoherent)
    #C_EA_init = relative_entropy(rho_EA, EA_init_incoherent)
    #S_EA_init_total.append(entropy(EA_init_incoherent) + C_EA_init)

    S_EA_init_total.append(entropy(A_init_incoherent)+C_A_init+entropy(E_init_incoherent)+C_E_init)
    
    # Initialize a full 16-qubit density matrix with B, Ap, Bp in |0>
    rho_Ap = DensityMatrix.from_label('0')
    rho_Ep = DensityMatrix.from_label('0')
    #rho_B = DensityMatrix.from_label('0')
    rho_B = DensityMatrix(np.array(
        [[0.5, 0.5],
        [0.5, 0.5]]
    ))
    #rho_B = random_density_matrix(2)
    # Pure states
    #psi_B = random_statevector(2)
    #rho_B = DensityMatrix(psi_B)
    #pB = random.uniform(0,1)
    #rho_Bmatrix = np.array(
    #    [[pB, math.sqrt(pB*(1-pB))],
    #    [math.sqrt(pB*(1-pB)), 1-pB]]
    #)
    #rho_B = DensityMatrix(rho_Bmatrix)
    
    rho_Bp = DensityMatrix.from_label('0')
    
    S_B_init.append(entropy(rho_B))
    #S_init.append(entropy(rho_A)+entropy(rho_E)+entropy(rho_B))

    rho_list = [rho_Ep, rho_Ap, rho_E, rho_A, rho_B, rho_Bp]
    rho_full = reduce(lambda x,y: x.tensor(y), rho_list)

    #encoded_rho = rho_full.evolve(U1)

    #encoded_rho_data = np.real(encoded_rho.data)
    #inrho.append(encoded_rho_data)

    #Irho_ApEpB = partial_trace(encoded_rho, [0,2,3])
    #Irho_ApEpB = np.real(Irho_ApEpB)
    #inrho_ApEpB.append(Irho_ApEpB.data.tolist())

    #final_rho = encoded_rho.evolve(U2)
    #frho_data = np.real(final_rho.data)
    #frho.append(frho_data)
    final_rho = rho_full.evolve(U)

    #print('rho0', final_rho0)
    #print('rho', final_rho)
    #final_rho1 = rho_full.evolve(U)
    #final_rho1_data = final_rho1.data
    #final_rho_data = np.diag(np.diag(final_rho1_data))
    #final_rho = DensityMatrix(final_rho_data)

    # Compute entropy terms for causation measure
    S_EpEBBp = entropy(partial_trace(final_rho, [2, 4]))
    S_ApEpEABp = entropy(partial_trace(final_rho, [1]))
    S_EpEBp = entropy(partial_trace(final_rho, [1, 2, 4]))
    S_ApEpEABBp = entropy(final_rho)
    S_BBp = entropy(partial_trace(final_rho, [2,3,4,5]))
    S_Bp = entropy(partial_trace(final_rho, [1,2,3,4,5]))
    S_f = entropy(final_rho)
    S_Ap = entropy(partial_trace(final_rho, [0,1,2,3,5]))

    
    S_B = entropy(partial_trace(final_rho, [0,2,3,4,5]))
    S_ApA = entropy(partial_trace(final_rho, [0,1,3,5]))
    temp0 = [S_ApA, S_B]
    S_ApAB = entropy(partial_trace(final_rho, [0,3,5]))
    temp = [0, S_ApA-S_ApAB, S_B-S_ApAB]
    lb_Lieb.append(2*max(temp))
    classicaldet_lb.append(2*max(temp) - min(temp0))

    rho_ApB = partial_trace(final_rho, [0,2,3,5])
    M = Pauli('Z')
    obs_lb.append(obs_LowerBound(rho_ApB, M, M))

    rho_EpB = partial_trace(final_rho, [0,2,3,4])
    C_EpB = concurrence(rho_EpB)
    VEf_EpB = Ef_2qubits(C_EpB)
    acc_info_Ap2B.append(S_B - VEf_EpB)

    rho_ApEp = partial_trace(final_rho, [0,1,2,3])
    C_ApEp = concurrence(rho_ApEp)
    VEf_ApEp = Ef_2qubits(C_ApEp)
    acc_info_B2Ap.append(S_Ap - VEf_ApEp)
    
    C_ApB = concurrence(rho_ApB)
    VEf_ApB = Ef_2qubits(C_ApB)
    Ef_ApB.append(VEf_ApB)

    # Compute causation measure
    C_A_to_B_condE = S_EpEBBp + S_ApEpEABp - S_EpEBp - S_ApEpEABBp
    C_A2B_condE.append(C_A_to_B_condE)
    C_E_to_B = S_EpEBp + S_BBp - S_Bp - S_EpEBBp
    C_EA_to_B = S_BBp + S_ApEpEABp - S_Bp - S_f
    C_EA2B.append(C_EA_to_B)
    rho_ApEpB = partial_trace(final_rho, [0,2,3])
    rho_ApEpB = np.real(rho_ApEpB)
    frho_ApEpB.append(rho_ApEpB.data.tolist())

    print("C_A_to_B_condE", C_A_to_B_condE)
    print("C_EA_to_B", C_EA_to_B)
    print("C_E_to_B", C_E_to_B)
    print("Lieb bound", 2*max(temp))
    print("obs_lb", obs_lb)
    print("classicaldet_lb", classicaldet_lb)


# Print the causation measure results
#CausationMeasure_dict = {'S_init_total': S_init, 'S_init(A)': S_A_init, 
#    'S_init(E)': S_E_init,'S_init(EA)':S_EA_init, 
#    'S_init(B)': S_B_init, 'C(A->B|E)': C_EA2B}
CausationMeasure_dict = {'S_EA_init_total': S_EA_init_total, 'S_init(A)': S_A_init, 
    'Sc_A_init': Sc_A_init, 'S_init(E)': S_E_init,'S_init(EA)': S_EA_init, 
    'S_A_init_total':S_A_init_total,
    'Sc_E_init': Sc_E_init, 'S_init(B)': S_B_init, 'C(EA->B)': C_EA2B,
    'C_A2B_condE': C_A2B_condE, 'Lieb lower bound': lb_Lieb, 
    'obs lower bound': obs_lb, 'Ef_ApB': Ef_ApB, 'Acc_info_Ap2B': acc_info_Ap2B,
    'Acc_info_B2Ap': acc_info_B2Ap, 'classicaldet_lb':classicaldet_lb} 
    
CausationMeasure_df = pd.DataFrame(CausationMeasure_dict)

CausationMeasure_df.to_csv("./CCNOT/nondiag_arbitrary.csv")

#file = open('./stateEpApB_data.txt', 'w')
#file.write('\n'.join(frho_ApEpB))
#file.close
#with open('./stateEpApB_data.txt', 'w') as file:
#    for i, matrix in enumerate(frho):
#        file.write(f"encoded matrix {i}:\n")
#        for row in inrho[i]:
#            file.write(' '.join(map(str, row)) + '\n')
#        file.write('\n')
#        file.write(f"output matrix {i}:\n")
#        for row in matrix:
#            file.write(' '.join(map(str, row)) + '\n')
#        file.write('\n\n')




#print("Maxmum causation measure is:", CausationMeasure_df['C(A->B|E)'].max())
#print("Minimum causation measure is:", CausationMeasure_df['C(A->B|E)'].min())
print("Maxmum causation measure is:", CausationMeasure_df['C(EA->B)'].max())
print("Minimum causation measure is:", CausationMeasure_df['C(EA->B)'].min())

#plt.rc('text', usetex=True)

#plt.figure(figsize=(6, 5))
#plt.scatter(CausationMeasure_df['S_EA_init_total'], CausationMeasure_df['C(EA->B)'], label=r"$C(EA\to B)$", marker='o')
#plt.xlabel(r'$H_C(A) + C(A) + H_C(E) + C(E)$')
#plt.ylabel('')  # No y-axis label
#plt.legend()
#plt.grid(True)
#plt.title(r'Transfer Entropy v.s. Full Input Information of Control Qubit')
#plt.tight_layout()

#plt.figure(figsize=(6, 5))
#plt.scatter(CausationMeasure_df['S_init(EA)'], CausationMeasure_df['C(EA->B)'], label=r"$C(EA\to B)$", marker='o')
#plt.xlabel(r'S_init(EA)')
#plt.ylabel('')  # No y-axis label
#plt.legend()
#plt.grid(True)
#plt.title(r'Transfer Entropy v.s. Full Von Neumann Entropy of Control Qubit')
#plt.tight_layout()

#plt.show()