ベイズ推定のアニメーション (混合ガンマ分布のサンプルの場合)

黒木玄

2017-11-19

In [1]:
versioninfo()
Julia Version 0.6.0
Commit 903644385b* (2017-06-19 13:05 UTC)
Platform Info:
  OS: Windows (x86_64-w64-mingw32)
  CPU: Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz
  WORD_SIZE: 64
  BLAS: libopenblas (USE64BITINT DYNAMIC_ARCH NO_AFFINITY Haswell)
  LAPACK: libopenblas64_
  LIBM: libopenlibm
  LLVM: libLLVM-3.9.1 (ORCJIT, haswell)

パッケージの読み込みと諸定義

In [2]:
@time begin
    #using Mamba
    
    using KernelDensity
    function makefunc_pdfkde(X)
        ik = InterpKDE(kde(X))
        pdfkde(x) = pdf(ik, x)
        return pdfkde
    end
    function makefunc_pdfkde(X,Y)
        ik = InterpKDE(kde((X,Y)))
        pdfkde(x, y) = pdf(ik, x, y)
        return pdfkde
    end

    using Optim
    optim_options = Optim.Options(
        store_trace = true,
        extended_trace = true
    )
    
    using QuadGK

    import PyPlot
    plt = PyPlot

    using Distributions
    @everywhere GTDist(μ, ρ, ν) = LocationScale(Float64(μ), Float64(ρ), TDist(Float64(ν)))
    @everywhere GTDist(ρ, ν) = GTDist(zero(ρ), ρ, ν)

    using JLD2
    using FileIO
end

using PyCall
@pyimport matplotlib.animation as anim
function showgif(filename)
    open(filename) do f
        base64_video = base64encode(f)
        display("text/html", """<img src="data:image/gif;base64,$base64_video">""")
    end
end

macro sum(f_k, k, itr)
    quote
        begin
            s = zero(($(esc(k))->$(esc(f_k)))($(esc(itr))[1]))
            for $(esc(k)) in $(esc(itr))
                s += $(esc(f_k))
            end
            s
        end
    end
end
 10.123873 seconds (12.12 M allocations: 807.434 MiB, 2.96% gc time)
Out[2]:
@sum (macro with 1 method)

MambaパッケージでMCMCを実行するための函数

In [3]:
@everywhere mixnormal(a,b,c) = MixtureModel(Normal[Normal(b, 1.0), Normal(c, 1.0)], [1.0-a, a])

mixnormal(w::AbstractVector) = mixnormal(w[1], w[2], w[3])
unlink_mixnormal(w) = [logit(w[1]), w[2], w[3]]     # unlink_mixnormal : (0,1)×R^2 -> R^3
link_mixnormal(z)   = [invlogit(z[1]), z[2], z[3]]  #   link_mixnormal : R^3 → (0,1)×R^2

function sample2model_mixnormal(Y;
        dist_model = mixnormal,
        a0 = 0.5,
        b0 = 0.0,
        c0 = 0.0,
        prior_a = Uniform(0.0, 1.0),
        prior_b = Normal(0.0, 1.0),
        prior_c = Normal(0.0, 1.0),
        chains = 2
    )
    data = Dict{Symbol, Any}(
        :Y => Y,
        :n => length(Y),
        :a0 => a0,
        :b0 => b0,
        :c0 => c0,
    )
    
    model = Model(
        y = Stochastic(1, (a, b, c) -> dist_model(a, b, c), false),
        a = Stochastic(() -> prior_a, true),
        b = Stochastic(() -> prior_b, true),
        c = Stochastic(() -> prior_b, true),
    )
    scheme = [
        NUTS([:a, :b, :c])
    ]
    setsamplers!(model, scheme)
    
    inits = [
        Dict{Symbol, Any}(
            :y => data[:Y],
            :a => data[:a0],
            :b => data[:b0],
            :c => data[:c0],
        )
        for k in 1:chains
    ]
    return model, data, inits
end
Out[3]:
sample2model_mixnormal (generic function with 1 method)
In [4]:
@everywhere normal(mu, sigma) = Normal(mu, sigma)

normal(w::AbstractVector) = normal(w[1], w[2])
unlink_normal(w) = [w[1], log(w[2])]  # unlink_normal : R×(0,∞) -> R^2
link_normal(z)   = [z[2], exp(z[2])]  #   link_normal : R^2 → R×(0,∞)

function sample2model_normal(Y;
        dist_model = normal,
        mu0    = 0.0,
        sigma0 = 1.0,
        prior_mu = Normal(0.0, 1.0),
        prior_sigma = Truncated(Normal(1.0, 1.0), 0, Inf),
        chains = 2
    )
    data = Dict{Symbol, Any}(
        :Y => Y,
        :n => length(Y),
        :mu0 => mu0,
        :sigma0 => sigma0,
    )
    
    model = Model(
        y = Stochastic(1, (mu, sigma) -> dist_model(mu, sigma), false),
        mu = Stochastic(() -> prior_mu, true),
        sigma = Stochastic(() -> prior_sigma, true),
    )
    scheme = [
        NUTS([:mu, :sigma])
    ]
    setsamplers!(model, scheme)
    
    inits = [
        Dict{Symbol, Any}(
            :y => data[:Y],
            :mu => data[:mu0],
            :sigma => data[:sigma0],
        )
        for k in 1:chains
    ]
    return model, data, inits
end
Out[4]:
sample2model_normal (generic function with 1 method)
In [5]:
@everywhere normal1(mu::Real) = Normal(mu, 1.0)

normal1(w::AbstractVector) = normal1(w[1])
unlink_normal1(w) = w
link_normal1(z)   = z

function sample2model_normal1(Y;
        dist_model = normal1,
        mu0    = 0.0,
        prior_mu = Normal(0.0, 1.0),
        chains = 2
    )
    data = Dict{Symbol, Any}(
        :Y => Y,
        :n => length(Y),
        :mu0 => mu0,
    )
    
    model = Model(
        y = Stochastic(1, mu -> dist_model(mu), false),
        mu = Stochastic(() -> prior_mu, true),
    )
    scheme = [
        NUTS([:mu])
    ]
    setsamplers!(model, scheme)
    
    inits = [
        Dict{Symbol, Any}(
            :y => data[:Y],
            :mu => data[:mu0],
        )
        for k in 1:chains
    ]
    return model, data, inits
end
Out[5]:
sample2model_normal1 (generic function with 1 method)

Mambaによるシミュレーション結果のまとめの表示

In [6]:
## Summary
function showsummary(sim;
        sortkeys=true, figsize_t=(8, 3), figsize_c=(8, 3.5),
        show_describe=true, show_gelman=true, plot_trace=true, plot_contour=true)
    ## Summary of MCMC
    if show_describe
        println("\n========== Summary:\n")
        display(describe(sim))
    end

    # Convergence Diagnostics
    if show_gelman && length(sim.chains)  2 
       println("========== Gelman Diagnostic:")
       show(gelmandiag(sim))
    end

    ## Plot
    sleep(0.1)
    if plot_trace
        #draw(plot(sim), fmt=:png, width=10inch, height=3.5inch, nrow=1, ncol=2, ask=false)
        pyplot_trace(sim, sortkeys=sortkeys, figsize=figsize_t)
    end
    if plot_contour
        #draw(plot(sim, :contour), fmt=:png, width=10inch, height=4.5inch, nrow=1, ncol=2, ask=false)
        pyplot_contour(sim, sortkeys=sortkeys, figsize=figsize_c)
    end
end

## plot traces
function pyplot_trace(sim; sortkeys = false, figsize = (8, 3))
    if sortkeys
        keys_sim = sort(keys(sim))
    else
        keys_sim = keys(sim)
    end
    for var in keys_sim
        plt.figure(figsize=figsize)
        
        plt.subplot(1,2,1)
        for k in sim.chains
            plt.plot(sim.range, sim[:,var,:].value[:,1,k], label="$k", lw=0.4, alpha=0.8)
        end
        plt.xlabel("iterations")
        plt.grid(ls=":")
        #plt.legend(loc="upper right")
        plt.title("trace of $var", fontsize=10)
        
        plt.subplot(1,2,2)
        xmin = quantile(vec(sim[:,var,:].value), 0.005)
        xmax = quantile(vec(sim[:,var,:].value), 0.995)
        for k in sim.chains
            chain = sim[:,var,:].value[:,1,k]
            pdfkde = makefunc_pdfkde(chain)
            x = linspace(xmin, xmax, 201)
            plt.plot(x, pdfkde.(x), label="$k", lw=0.8, alpha=0.8)
        end
        plt.xlabel("$var")
        plt.grid(ls=":")
        plt.title("empirical posterior pdf of $var", fontsize=10)

        plt.tight_layout()
    end
end

# plot contours
function pyplot_contour(sim; sortkeys = true, figsize = (8, 3.5))
    if sortkeys
        keys_sim = sort(keys(sim))
    else
        keys_sim = keys(sim)
    end
    m = 0
    K = length(keys_sim)
    for i in 1:K
        for j in i+1:K
            m += 1
            mod1(m,2) == 1 && plt.figure(figsize=figsize)
            plt.subplot(1,2, mod1(m,2))

            u = keys_sim[i]
            v = keys_sim[j]
            X = vec(sim[:,u,:].value)
            Y = vec(sim[:,v,:].value)
            pdfkde = makefunc_pdfkde(X,Y)
            xmin = quantile(X, 0.005)
            xmax = quantile(X, 0.995)
            ymin = quantile(Y, 0.005)
            ymax = quantile(Y, 0.995)
            x = linspace(xmin, xmax, 200)
            y = linspace(ymin, ymax, 200)
            
            plt.pcolormesh(x', y, pdfkde.(x',y), cmap="CMRmap")
            plt.colorbar()
            plt.grid(ls=":")
            plt.xlabel(u)
            plt.ylabel(v)
            plt.title("posterior of ($u, $v)", fontsize=10)

            mod1(m,2) == 2 && plt.tight_layout()

            if 2*m == K*(K-1) && mod1(m,2) == 1
                plt.subplot(1,2,2)
                
                plt.pcolormesh(y', x, pdfkde.(x,y'), cmap="CMRmap")
                plt.colorbar()
                plt.grid(ls=":")
                plt.xlabel(v)
                plt.ylabel(u)
                plt.title("posterior of ($v, $u)", fontsize=10)

                plt.tight_layout()
            end
        end
    end
end
Out[6]:
pyplot_contour (generic function with 1 method)

WAICなどを計算するための函数

In [7]:
# loglik[l,i] = lpdf(w_l, Y_i) と chain[l,:] = w_l を取り出す函数を作る函数
#
function makefunc_loglikchainof(lpdf, symbols, Y)
    #
    # loglikchainof(sim) で loglik[l,i] と chain[l,:] が抽出される
    #
    function loglikchainof(sim)
        val = sim[:, symbols, :].value
        chain = vcat((val[:,:,k] for k in 1:size(val,3))...)
        L = size(chain,1)
        n = length(Y)
        loglik = Array{Float64, 2}(L, n)
        for i in 1:n
            for l in 1:L
                loglik[l,i] = lpdf(chain[l,:], Y[i])
            end
        end
        return loglik, chain
    end
    return loglikchainof
end

# 予測分布函数 p^*(x,y) = mean of { lpdf(w_l, y) }_{l=1}^L を作る函数
#
function makefunc_pdfpred(lpdf, chain)
    L = size(chain,1)
    pred_Bayes(y) = @sum(exp(lpdf((@view chain[l,:]), y)), l, 1:L)/L
    return pred_Bayes
end

# loglik[l,i] からWAICを計算する函数
#
function WAICof(loglik)
    L, n = size(loglik)
    T_n = -mean(log(mean(exp(loglik[l,i]) for l in 1:L)) for i in 1:n)
    V_n  = sum(var(loglik[:,i], corrected=false) for i in 1:n)
    WAIC = 2*n*T_n + 2*V_n
    return WAIC, 2*n*T_n, 2*V_n
end

# loglik[l,i] からLOOCVを素朴に計算する函数
#
function LOOCVof(loglik)
    L, n = size(loglik)
    LOOCV = 2*sum(log(mean(exp(-loglik[l,i]) for l in 1:L)) for i in 1:n)
    return LOOCV
end

# 自由エネルギー(の2倍)を計算するための函数
# 
# 自由エネルギーの2の逆温度微分は E^β_w[2n L_n] に等しいので、
# それを β=0.0 から 1.0 まで数値積分すれば自由エネルギーの2倍を計算できる。
#
function FreeEnergyof(loglik)
    E2nLn = makefunc_E2nLn(loglik)
    F = quadgk(E2nLn, 0.0, 1.0)[1]
    return F, E2nLn
end

function makefunc_E2nLn(loglik)
    L = size(loglik)[1]
    negloglik = -sum(loglik, 2)
    negloglik_n = negloglik .- maximum(negloglik)
    function E2nLn(beta)
        Edenominator = @sum(             exp((1-beta)*negloglik_n[l]), l, 1:L)/L
        if Edenominator == zero(Edenominator) || !isfinite(Edenominator)
            return zero(Edenominator)
        end
        Enumerator   = @sum(negloglik[l]*exp((1-beta)*negloglik_n[l]), l, 1:L)/L
        return 2*Enumerator/Edenominator
    end
    return E2nLn
end

# loglik[l,i] からWBICを計算する函数
#
function WBICof(loglik)
    E2nLn = makefunc_E2nLn(loglik)
    n = size(loglik, 2)
    WBIC = E2nLn(1/log(n))
    return WBIC
end

# 汎化損失を計算する函数
#
function GeneralizationLossof(pdftrue, pdfpred; xmin=-10.0, xmax=10.0)
    f(x) = -pdftrue(x)*log(pdfpred(x))
    return quadgk(f, xmin, xmax)
end

# Shannon情報量を計算する函数
#
ShannonInformationof(pdftrue; xmin=-10.0, xmax=10.0) = GeneralizationLossof(pdftrue, pdftrue, xmin=xmin, xmax=xmax)
ShannonInformationof(dist::Distribution, n) = 2*n*entropy(dist)

# 最尤法を実行して AIC と BIC を計算する函数
#
# lpdf(w,y) = log p(y|w)
#
# w = link_model(z)   ←実ベクトル z 全体をパラメーター w の空間内に移す函数
# z = unlink_model(w) ←link_model(z)の逆函数
# これらは optimize() 函数の適用時にパラメーター w が定義域から外れることを防ぐための函数達
#
# (X[i], Y[i] はサンプル
#
# chain は loglik, chain = loglikchainof(sim) で作った chain
#
# optimize函数は1変数函数の場合には初期条件を与えて解を求める函数ではなくなるので、
# その場合には2変数函数に拡張してから使用している.
#
function AICandBICof(lpdf, link_model, unlink_model, Y, chain)
    n = length(Y)
    L = size(chain,1)
    nparams = size(chain,2)
    negloglik(z) = -@sum(lpdf(link_model(z), Y[i]), i, 1:n)
    minnegloglik_chain, l = findmin(negloglik(unlink_model(chain[l,:])) for l in 1:L)
    if size(chain,2) == 1
        f(z) = negloglik([z[1]]) + z[2]^2/eps()
        o = optimize(f, [unlink_model(chain[l,:])[1], 0.0])
        minnegloglik = o.minimum
        param_AIC = link_model([o.minimizer[1]])
    else
        o = optimize(negloglik, unlink_model(chain[l,:]))
        minnegloglik = o.minimum
        param_AIC = link_model(o.minimizer)
    end
    T_AIC = 2.0*minnegloglik
    V_AIC = 2.0*nparams
    T_BIC = T_AIC
    V_BIC = nparams*log(n)
    AIC = T_AIC + V_AIC
    BIC = T_BIC + V_BIC
    return AIC, T_AIC, V_AIC, 
        BIC, T_BIC, V_BIC, 
        param_AIC
end
Out[7]:
AICandBICof (generic function with 1 method)

情報をまとめて表示するための函数

In [8]:
function statsof(sim, Y; 
        dist_true=mixnormal(0.0, 0.0, 0.0), 
        dist_model=mixnormal, link_model=link_mixnormal, unlink_model=unlink_mixnormal)
    
    n = length(Y)
    
    lpdf(w, y) = logpdf(dist_model(w), y)
    
    loglikchainof = makefunc_loglikchainof(lpdf, sort(keys(sim)), Y)
    loglik, chain = loglikchainof(sim)
    
    WAIC, T_WAIC, V_WAIC = WAICof(loglik)
    LOOCV = LOOCVof(loglik)
    
    WBIC = WBICof(loglik)
    FreeEnergy = FreeEnergyof(loglik)[1]
    
    param_Bayes = vec(mean(chain, 1))
    pred_Bayes = makefunc_pdfpred(lpdf, chain)
    
    GeneralizationLoss = 2*n*GeneralizationLossof(x->pdf(dist_true,x), pred_Bayes)[1]
    
    AIC, T_AIC, V_AIC, BIC, T_BIC, V_BIC, param_MLE = AICandBICof(lpdf, link_model, unlink_model, Y, chain)
    
    pred_MLE(y) = exp(lpdf(param_MLE, y))
    return WAIC, T_WAIC, V_WAIC, LOOCV, GeneralizationLoss,
        WBIC, FreeEnergy,
        param_Bayes, pred_Bayes, 
        AIC, T_AIC, V_AIC, 
        BIC, T_BIC, V_BIC,
        param_MLE, pred_MLE
end

function show_all_results(dist_true, Y, sim; statsfunc=statsof, 
        dist_model=mixnormal, link_model=link_mixnormal, unlink_model=link_mixnormal,
        figsize=(6,4.2), xmin=-4.0, xmax=6.0)
    WAIC, T_WAIC, V_WAIC, LOOCV, GeneralizationLoss,
    WBIC, FreeEnergy,
    param_Bayes, pred_Bayes, 
    AIC, T_AIC, V_AIC, 
    BIC, T_BIC, V_BIC,
    param_MLE, pred_MLE = statsfunc(sim, Y, dist_true=dist_true, 
        dist_model=dist_model, link_model=link_model, unlink_model=unlink_model)
    
    n = length(Y)
    println("\n=== Estimates by $dist_model  (n = $n) ===")
    @show param_Bayes
    @show param_MLE
    
    println("--- Information Criterions")
    println("* AIC     = $AIC = $T_AIC + $V_AIC")
    println("* GenLoss = $GeneralizationLoss")
    println("* WAIC    = $WAIC = $T_WAIC + $V_WAIC")
    println("* LOOCV   = $LOOCV")
    println("---")
    println("* BIC        = $BIC = $T_BIC + $V_BIC")
    println("* FreeEnergy = $FreeEnergy")
    println("* WBIC       = $WBIC")

    println("="^78 * "\n")
    
    sleep(0.1)
    plt.figure(figsize=figsize)
    kde_sample = makefunc_pdfkde(Y)
    x = linspace(xmin, xmax, 201)
    plt.plot(x, pdf.(dist_true, x), label="true distribution")
    plt.scatter(Y, kde_sample.(Y), label="sample", s=10, color="k", alpha=0.5)
    plt.plot(x, kde_sample.(x), label="KDE of sample",   color="k", alpha=0.5)
    plt.plot(x, pred_Bayes.(x), label="Baysian predictive", ls="--")
    plt.plot(x, pred_MLE.(x),   label="MLE predictive",     ls="-.")
    plt.xlabel("x")
    plt.ylabel("probability density")
    plt.grid(ls=":")
    plt.legend(fontsize=8)
    plt.title("Estimates by $dist_model: n = $(length(Y))")
end

function plotsample(dist_true, Y; figsize=(6,4.2), xmin=-4.0, xmax=6.0)
    sleep(0.1)
    plt.figure(figsize=figsize)
    kde_sample = makefunc_pdfkde(Y)
    x = linspace(xmin, xmax, 201)
    plt.plot(x, pdf.(dist_true, x), label="true dist.")
    plt.scatter(Y, kde_sample.(Y), label="sample", s=10, color="k", alpha=0.5)
    plt.plot(x, kde_sample.(x), label="KDE of sample",   color="k", alpha=0.5)
    plt.xlabel("x")
    plt.ylabel("probability density")
    plt.grid(ls=":")
    plt.legend(fontsize=8)
    plt.title("Sample size n = $(length(Y))")
end
Out[8]:
plotsample (generic function with 1 method)

単純な正規分布モデルを最尤法で解くための函数

In [9]:
# dist_true 分布に従う乱数で生成したサンプルの配列 Y を与えると
# 正規分布モデルの最尤法で推定して、AICなどを返す函数
#
function fit_Normal(dist_true, Y)
    n = length(Y)
    d = fit(Normal,Y)
    mu = d.μ
    sigma = d.σ
    pred_Normal(y) = pdf(d,y)
    T_Normal = -2*sum(logpdf.(d,Y))
    V_Normal = 4.0
    AIC_Normal = T_Normal + V_Normal
    TB_Normal = T_Normal
    VB_Normal = 2.0*log(n)
    BIC_Normal = TB_Normal + VB_Normal
    f(y) = -pdf(dist_true, y)*logpdf(d, y)
    GL_Normal = 2*n*quadgk(f, -10, 10)[1]
    return mu, sigma, pred_Normal, 
        AIC_Normal, T_Normal, V_Normal, 
        BIC_Normal, TB_Normal, VB_Normal, 
        GL_Normal
end

# グラフをプロットする範囲が xmin から xmax まで
#
function show_fit_Normal(dist_true, Y; figsize=(6,4.2), xmin=-4.0, xmax=6.0)
    mu, sigma, pred_Normal, 
    AIC_Normal, T_Normal, V_Normal, 
    BIC_Normal, TB_Normal, VB_Normal, 
    GL_Normal = fit_Normal(dist_true, Y)
    println("--- Normal Fitting")
    println("* μ = $mu")
    println("* σ = $sigma")
    println("* GenLoss = $GL_Normal")
    println("* AIC     = $AIC_Normal = $T_Normal + $V_Normal")
    println("* BIC     = $BIC_Normal = $TB_Normal + $VB_Normal")
    
    sleep(0.1)
    plt.figure(figsize=figsize)
    kde_sample = makefunc_pdfkde(Y)
    x = linspace(xmin, xmax, 201)
    plt.plot(x, pdf.(dist_true, x), label="true distribution")
    plt.scatter(Y, kde_sample.(Y), label="sample", s=10, color="k", alpha=0.5)
    plt.plot(x, kde_sample.(x), label="KDE of sample",   color="k", alpha=0.5)
    plt.plot(x, pred_Normal.(x), label="Normal predictive", ls="--")
    plt.xlabel("x")
    plt.ylabel("probability density")
    plt.grid(ls=":")
    plt.legend(fontsize=8)
    plt.title("Sample size n = $(length(Y))")
end
Out[9]:
show_fit_Normal (generic function with 1 method)
In [10]:
function InformationCriterions(dist_true, Y, sim; dist_model=mixnormal)
    n = length(Y)
    lpdf(w, y) = logpdf(dist_model(w), y)

    loglikchainof = makefunc_loglikchainof(lpdf, sort(keys(sim)), Y)
    loglik, chain = loglikchainof(sim)

    WAIC, T_WAIC, V_WAIC = WAICof(loglik)
    LOOCV = LOOCVof(loglik)
    
    pred_Bayes = makefunc_pdfpred(lpdf, chain)
    GL = 2*n*GeneralizationLossof(x->pdf(dist_true,x), pred_Bayes)[1]
    
    WBIC = WBICof(loglik)
    FreeEnergy = FreeEnergyof(loglik)[1]

    return chain, WAIC, T_WAIC, V_WAIC, LOOCV, GL, WBIC, FreeEnergy
 end
Out[10]:
InformationCriterions (generic function with 1 method)

プロットのための函数

In [11]:
function plotsim(t; modelname="normal", xmin=-3, xmax=8, y1max=0.45, y2max=0.40)
    plt.clf()
    
    n = t
    Y = Sample[1:n]
    Chain = eval(Symbol(:Chain_, modelname))
    chain = Chain[n]
    dist_model = eval(Symbol(modelname))
    lpdf(w,x) = logpdf(dist_model(w),x)
    pred = makefunc_pdfpred(lpdf, chain)
    x = linspace(xmin,xmax,201)
    kl = eval(Symbol(:kl_, modelname))
    wt = eval(Symbol(:wt_, modelname))
    
    plt.subplot(121)
    plt.plot(x, pdf.(dist_true,x), color="black", ls=":", label="true dist")
    plt.scatter(Y, pdf.(dist_true,Y), color="red", s=10, alpha=0.5, label="sample") 
    plt.plot(x, pred.(x), label="predictive")
    plt.ylim(-y1max/40, y1max)
    plt.grid(ls=":")
    plt.legend()
    plt.title("$modelname: n = $n")
    
    plt.subplot(122)
    plt.plot(1:n, kl[1:n], label="KL information")
    plt.plot(1:n, wt[1:n], label="(WAIC \$-\$ T\$_\\mathrm{true}\$)/2n")
    plt.xlabel("sample size n")
    plt.ylabel("\$-\$ log(probability)")
    plt.xlim(-1, N+1)
    plt.ylim(min(minimum(kl), minimum(wt[5:end]))-y2max/40, y2max)
    plt.grid(ls=":")
    plt.legend()
    plt.title("$modelname: n = $n")
    
    plt.tight_layout()
    plt.plot()
end
Out[11]:
plotsim (generic function with 1 method)

プロットの実行 (混合ガンマ分布のサンプルの場合)

GIFアニメーションを作成したい場合には適当にコメントアウトを外して実行する。

なぜか normal モデルの場合の計算がうまく行っていない.

サンプルサイズが 50 をちょっと超えたあたりから, 事後予測分布がおかしな形で動かなくなる.

これを書いている時点(2017-11-19)の時点で原因不明.

In [12]:
N0 = 53
Out[12]:
53
In [13]:
seed = 4649
#seed = 12345
#seed = 2017

#dataname = "Gamma128Sample$seed"
dataname = "MixGamma128Sample$seed"

data = load("$dataname.jld2")
for v in sort(collect(keys(data)))
    ex = parse("$v = data[\"$v\"]")
    # println(ex)
    eval(ex)
end

@show mu_true = mean(dist_true)
@show sigma_true = std(dist_true)
@show dist_normal_fitting = Normal(mu_true, sigma_true)

shannon_true = Shannon[1]/2
gl_normal_fitting = GeneralizationLossof(x->pdf(dist_true,x), x->pdf(dist_normal_fitting,x))[1]
kl_normal_fitting = gl_normal_fitting - shannon_true

KL_normal    = GL_normal    - Shannon
KL_normal1   = GL_normal1   - Shannon
KL_mixnormal = GL_mixnormal - Shannon

WT_normal    = WAIC_normal    - T_true
WT_normal1   = WAIC_normal1   - T_true
WT_mixnormal = WAIC_mixnormal - T_true

kl_normal    = [KL_normal[n]/(2n) for n in 1:N]
kl_normal1   = [KL_normal1[n]/(2n) for n in 1:N]
kl_mixnormal = [KL_mixnormal[n]/(2n) for n in 1:N]

wt_normal    = [WT_normal[n]/(2n) for n in 1:N]
wt_normal1   = [WT_normal1[n]/(2n) for n in 1:N]
wt_mixnormal = [WT_mixnormal[n]/(2n) for n in 1:N]

@show shannon_true
@show gl_normal_fitting
@show kl_normal_fitting
@show exp(-kl_normal_fitting)
println()
@show kl_normal
@show kl_normal1
@show kl_mixnormal

sleep(0.1)

N, Ntmp = N0, N
plt.figure(figsize=(10,3.5))
plotsim(N, modelname="normal", xmin=-1, xmax=11, y1max=0.35, y2max=1.0)
N = Ntmp

plt.figure(figsize=(10,3.5))
plotsim(N, modelname="normal1", xmin=-1, xmax=11, y1max=0.45, y2max=1.0)

plt.figure(figsize=(10,3.5))
plotsim(N, modelname="mixnormal", xmin=-1, xmax=11, y1max=0.35)

N, Ntmp = N0, N
modelname = "normal"
file = "$dataname$modelname.gif"
plot1frame(t) = plotsim(t, modelname=modelname, xmin=-1, xmax=11, y1max=0.35, y2max=1.0)
fig = plt.figure(figsize=(10,3.5))
frames = [N; 1;1;1;1; 2;2;2;3;3;3;4;4;4; 5;5;5;6;6;7;7;8;8;9;9;10;10; 11:N; N;N;N;N;N;N;N;N;N]
interval = 100
@time myanim = anim.FuncAnimation(fig, plot1frame, frames=frames, interval=interval)
#@time myanim[:save](file, writer="imagemagick")
sleep(0.1)
showgif(file)
plt.clf()
N = Ntmp

modelname = "normal1"
file = "$dataname$modelname.gif"
plot1frame(t) = plotsim(t, modelname=modelname, xmin=-1, xmax=11, y1max=0.45, y2max=1.0)
fig = plt.figure(figsize=(10,3.5))
frames = [N; 1;1;1;1; 2;2;2;3;3;3;4;4;4; 5;5;5;6;6;7;7;8;8;9;9;10;10; 11:N; N;N;N;N;N;N;N;N;N]
interval = 100
@time myanim = anim.FuncAnimation(fig, plot1frame, frames=frames, interval=interval)
#@time myanim[:save](file, writer="imagemagick")
sleep(0.1)
showgif(file)
plt.clf()

modelname = "mixnormal"
file = "$dataname$modelname.gif"
plot1frame(t) = plotsim(t, modelname=modelname, xmin=-1, xmax=11, y1max=0.35)
fig = plt.figure(figsize=(10,3.5))
frames = [N; 1;1;1;1; 2;2;2;3;3;3;4;4;4; 5;5;5;6;6;7;7;8;8;9;9;10;10; 11:N; N;N;N;N;N;N;N;N;N]
interval = 100
@time myanim = anim.FuncAnimation(fig, plot1frame, frames=frames, interval=interval)
#@time myanim[:save](file, writer="imagemagick")
sleep(0.1)
showgif(file)
plt.clf()
mu_true = mean(dist_true) = 4.638888888888888
sigma_true = std(dist_true) = 1.7206534073276198
dist_normal_fitting = Normal(mu_true, sigma_true) = Distributions.Normal{Float64}(μ=4.638888888888888, σ=1.7206534073276198)
shannon_true = 1.8659884383030085
gl_normal_fitting = 1.961056779446904
kl_normal_fitting = 0.09506834114389551
exp(-kl_normal_fitting) = 0.9093107890052324

kl_normal = [1.1077, 0.925337, 0.854223, 0.741204, 0.602992, 0.5357, 0.548984, 0.457051, 0.389074, 0.324967, 0.246177, 0.236216, 0.187792, 0.174761, 0.172603, 0.167863, 0.158658, 0.15131, 0.137564, 0.13118, 0.12527, 0.115938, 0.116273, 0.121606, 0.148139, 0.138757, 0.130535, 0.120655, 0.117812, 0.11497, 0.113215, 0.107526, 0.118771, 0.115241, 0.110996, 0.110608, 0.10695, 0.10197, 0.0998489, 0.103479, 0.107688, 0.107653, 0.111213, 0.107524, 0.11676, 0.114816, 0.112079, 0.111797, 0.109398, 0.107913, 0.112535, 0.108294, 0.109284, 0.109229, 0.654352, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884, 11.2884]
kl_normal1 = [0.469338, 0.317258, 0.491943, 0.512478, 0.426022, 0.466365, 0.567822, 0.520142, 0.563978, 0.509785, 0.470615, 0.47968, 0.487556, 0.488506, 0.49153, 0.497326, 0.486954, 0.48601, 0.488313, 0.493408, 0.489353, 0.493989, 0.496289, 0.494973, 0.510689, 0.511468, 0.50089, 0.497773, 0.50264, 0.49898, 0.501839, 0.502788, 0.504716, 0.504796, 0.50442, 0.503308, 0.507649, 0.51679, 0.516278, 0.53789, 0.524397, 0.526767, 0.518328, 0.522482, 0.51592, 0.518816, 0.526767, 0.520111, 0.530223, 0.537413, 0.526192, 0.523686, 0.520523, 0.523356, 0.522565, 0.52461, 0.52679, 0.527904, 0.522455, 0.523619, 0.523672, 0.524857, 0.522813, 0.521011, 0.519495, 0.520701, 0.522226, 0.523678, 0.519907, 0.518955, 0.519854, 0.519678, 0.51954, 0.518924, 0.519734, 0.521185, 0.522938, 0.521516, 0.520886, 0.522851, 0.524018, 0.524245, 0.522615, 0.52498, 0.52715, 0.52613, 0.527622, 0.527922, 0.527909, 0.525759, 0.524713, 0.526634, 0.527123, 0.527866, 0.527655, 0.530335, 0.529693, 0.529949, 0.532817, 0.526526, 0.528587, 0.525845, 0.526543, 0.528348, 0.533262, 0.535737, 0.538431, 0.53574, 0.5349, 0.536403, 0.532585, 0.538721, 0.537365, 0.534868, 0.537766, 0.53835, 0.536759, 0.536567, 0.53777, 0.535104, 0.535929, 0.535122, 0.538842, 0.540624, 0.542785, 0.54028, 0.537643, 0.534929]
kl_mixnormal = [0.789055, 0.432094, 0.572459, 0.549362, 0.38229, 0.45073, 0.542518, 0.464543, 0.507662, 0.439115, 0.367786, 0.382495, 0.392704, 0.393908, 0.400279, 0.404151, 0.379809, 0.365772, 0.331088, 0.316946, 0.218776, 0.206695, 0.204673, 0.151637, 0.132619, 0.147203, 0.119141, 0.113224, 0.117529, 0.11866, 0.130451, 0.120732, 0.107795, 0.109819, 0.104352, 0.118507, 0.0941215, 0.0911094, 0.0929003, 0.088038, 0.0676896, 0.0697837, 0.0589638, 0.0597126, 0.0562562, 0.0576191, 0.0590163, 0.059477, 0.0635582, 0.0656416, 0.0588323, 0.0589703, 0.0497435, 0.0479887, 0.052255, 0.0516188, 0.054265, 0.0556326, 0.0476583, 0.0492152, 0.0516787, 0.0545089, 0.0432992, 0.0374424, 0.0370132, 0.0393991, 0.0373817, 0.0388088, 0.0360922, 0.0323892, 0.0309972, 0.0320834, 0.0308647, 0.0329902, 0.0323674, 0.0325527, 0.0337361, 0.02849, 0.0287955, 0.029912, 0.0311397, 0.0307793, 0.0320212, 0.0310411, 0.0315611, 0.0323589, 0.0338515, 0.0341843, 0.0362223, 0.0338168, 0.0350037, 0.0363661, 0.0371371, 0.0375924, 0.0384804, 0.0379542, 0.037137, 0.0387719, 0.0387748, 0.0377198, 0.0390403, 0.0392754, 0.0393408, 0.038936, 0.0386931, 0.039674, 0.040589, 0.0352098, 0.0360733, 0.038247, 0.0340179, 0.0341761, 0.0343631, 0.0316964, 0.0337333, 0.0347629, 0.0335559, 0.0348608, 0.0366493, 0.0314204, 0.0316562, 0.0334519, 0.033067, 0.0333144, 0.0351068, 0.0320224, 0.029369, 0.0284031]
  0.250735 seconds (76.46 k allocations: 4.113 MiB)