using HGF
using Turing
params_list = (; hgf = HGF.premade_hgf("binary_3level"));
test_hgf = HGF.premade_hgf("binary_3level");
#Create agent with binary softmax action
test_agent = HGF.premade_agent(
    "hgf_unit_square_sigmoid_action",
    params_list,
);

test_input = [1.,0,1,1,0]
test_responses = [1.,0,0,1,1]

params_list = (sigmoid_action_precision = 5, u__category_means = Real[0.0, 1.0], u__input_precision = Inf, u__x1_coupling_strenght = 1.,
 x1__x2_coupling_strenght = 1.0, 
 x2__x3_coupling_strenght = 1.0, x2__initial_mean = 0, x2__initial_precision = 1,
 x3__initial_mean = 1, x3__initial_precision = 1)

HGF.set_params!(test_agent, params_list)
HGF.reset!(test_agent)

HGF.give_inputs!(test_agent, test_input)

using Plots

hgf_trajectory_plot(test_agent, "x1","posterior_mean")

params_prior_list = (
    x2__evolution_rate = Normal(-3.0,16),
    x3__evolution_rate = Normal(-6.0,16)
)

psl = HGF.model_sampling(test_agent, params_prior_list, "x2__posterior_mean",100, test_input)


chn=HGF.fit_model(test_agent,test_input,test_responses, params_prior_list, params_list,NUTS(),2000)

write("chain-file.jls", chn)

chn2 = read("chain-file.jls", Chains)

describe(chn2)

sl = HGF.model_sampling(test_agent, chn, "x2__posterior_mean",100, test_input)[1]

b = getfield(describe(chn)[2].nt,Symbol("50.0%"))

using Plots

HGF.predictive_simulation_plot(test_hgf, chn2, "x2__posterior_mean",10000, test_input; title = "x2__posterior_mean",alpha = 0.01)

a = Float64[]

typeof(a)

push!(a,1)

c = (; a = 3, b=4)

keys(c)

HGF.prior_trajectory_plot(test_agent, params_prior_list, "x2__posterior_mean",2, test_input; title = "x2__posterior_mean",alpha = 0.1)

dist=params_prior_list.x2__evolution_rate

rand(dist)

median(getproperty(params_prior_list,Symbol("x2__evolution_rate")))

params_list = (action_precision = 5, u__category_means = Real[0.0, 1.0], u__input_precision = Inf, u__x1_coupling_strenght = 1.,
 x1__x2_coupling_strenght = 1.0, 
 x2__x3_coupling_strenght = 1.0, x2__initial_mean = 0, x2__initial_precision = 1,
 x3__initial_mean = 1, x3__initial_precision = 1, x2__evolution_rate =-3.,
 x3__evolution_rate = -6.)

HGF.set_params!(test_agent, params_list)
HGF.reset!(test_agent)

HGF.give_inputs!(test_agent, test_input)

hgf_trajectory_plot(test_agent,"x2", "posterior_mean")

a = [1, missing, 2]
b=replace(a, missing=>NaN)
a
b