ベイズ推定のアニメーション (ガンマ分布のサンプルの場合)

黒木玄

2017-11-19

In [1]:
versioninfo()
Julia Version 0.6.0
Commit 903644385b* (2017-06-19 13:05 UTC)
Platform Info:
  OS: Windows (x86_64-w64-mingw32)
  CPU: Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz
  WORD_SIZE: 64
  BLAS: libopenblas (USE64BITINT DYNAMIC_ARCH NO_AFFINITY Haswell)
  LAPACK: libopenblas64_
  LIBM: libopenlibm
  LLVM: libLLVM-3.9.1 (ORCJIT, haswell)

パッケージの読み込みと諸定義

In [2]:
@time begin
    #using Mamba
    
    using KernelDensity
    function makefunc_pdfkde(X)
        local ik = InterpKDE(kde(X))
        local pdfkde(x) = pdf(ik, x)
        return pdfkde
    end
    function makefunc_pdfkde(X,Y)
        local ik = InterpKDE(kde((X,Y)))
        local pdfkde(x, y) = pdf(ik, x, y)
        return pdfkde
    end

    using Optim
    optim_options = Optim.Options(
        store_trace = true,
        extended_trace = true
    )
    
    using QuadGK

    import PyPlot
    plt = PyPlot

    using Distributions
    @everywhere GTDist(μ, ρ, ν) = LocationScale(Float64(μ), Float64(ρ), TDist(Float64(ν)))
    @everywhere GTDist(ρ, ν) = GTDist(zero(ρ), ρ, ν)

    using JLD2
    using FileIO
end

using PyCall
@pyimport matplotlib.animation as anim
function showgif(filename)
    open(filename) do f
        base64_video = base64encode(f)
        display("text/html", """<img src="data:image/gif;base64,$base64_video">""")
    end
end

macro sum(f_k, k, itr)
    quote
        begin
            local s = zero(($(esc(k))->$(esc(f_k)))($(esc(itr))[1]))
            for $(esc(k)) in $(esc(itr))
                s += $(esc(f_k))
            end
            s
        end
    end
end
 11.834807 seconds (12.13 M allocations: 808.333 MiB, 2.92% gc time)
Out[2]:
@sum (macro with 1 method)

MambaパッケージでMCMCを実行するための函数

In [3]:
@everywhere mixnormal(a,b,c) = MixtureModel(Normal[Normal(b, 1.0), Normal(c, 1.0)], [1.0-a, a])

mixnormal(w::AbstractVector) = mixnormal(w[1], w[2], w[3])
unlink_mixnormal(w) = [logit(w[1]), w[2], w[3]]     # unlink_mixnormal : (0,1)×R^2 -> R^3
link_mixnormal(z)   = [invlogit(z[1]), z[2], z[3]]  #   link_mixnormal : R^3 → (0,1)×R^2

function sample2model_mixnormal(Y;
        dist_model = mixnormal,
        a0 = 0.5,
        b0 = 0.0,
        c0 = 0.0,
        prior_a = Uniform(0.0, 1.0),
        prior_b = Normal(0.0, 1.0),
        prior_c = Normal(0.0, 1.0),
        chains = 2
    )
    local data = Dict{Symbol, Any}(
        :Y => Y,
        :n => length(Y),
        :a0 => a0,
        :b0 => b0,
        :c0 => c0,
    )
    
    local model = Model(
        y = Stochastic(1, (a, b, c) -> dist_model(a, b, c), false),
        a = Stochastic(() -> prior_a, true),
        b = Stochastic(() -> prior_b, true),
        c = Stochastic(() -> prior_b, true),
    )
    local scheme = [
        NUTS([:a, :b, :c])
    ]
    setsamplers!(model, scheme)
    
    local inits = [
        Dict{Symbol, Any}(
            :y => data[:Y],
            :a => data[:a0],
            :b => data[:b0],
            :c => data[:c0],
        )
        for k in 1:chains
    ]
    return model, data, inits
end
Out[3]:
sample2model_mixnormal (generic function with 1 method)
In [4]:
@everywhere normal(mu, sigma) = Normal(mu, sigma)

normal(w::AbstractVector) = normal(w[1], w[2])
unlink_normal(w) = [w[1], log(w[2])]  # unlink_normal : R×(0,∞) -> R^2
link_normal(z)   = [z[2], exp(z[2])]  #   link_normal : R^2 → R×(0,∞)

function sample2model_normal(Y;
        dist_model = normal,
        mu0    = 0.0,
        sigma0 = 1.0,
        prior_mu = Normal(0.0, 1.0),
        prior_sigma = Truncated(Normal(1.0, 1.0), 0, Inf),
        chains = 2
    )
    local data = Dict{Symbol, Any}(
        :Y => Y,
        :n => length(Y),
        :mu0 => mu0,
        :sigma0 => sigma0,
    )
    
    local model = Model(
        y = Stochastic(1, (mu, sigma) -> dist_model(mu, sigma), false),
        mu = Stochastic(() -> prior_mu, true),
        sigma = Stochastic(() -> prior_sigma, true),
    )
    local scheme = [
        NUTS([:mu, :sigma])
    ]
    setsamplers!(model, scheme)
    
    local inits = [
        Dict{Symbol, Any}(
            :y => data[:Y],
            :mu => data[:mu0],
            :sigma => data[:sigma0],
        )
        for k in 1:chains
    ]
    return model, data, inits
end
Out[4]:
sample2model_normal (generic function with 1 method)
In [5]:
@everywhere normal1(mu::Real) = Normal(mu, 1.0)

normal1(w::AbstractVector) = normal1(w[1])
unlink_normal1(w) = w
link_normal1(z)   = z

function sample2model_normal1(Y;
        dist_model = normal1,
        mu0    = 0.0,
        prior_mu = Normal(0.0, 1.0),
        chains = 2
    )
    local data = Dict{Symbol, Any}(
        :Y => Y,
        :n => length(Y),
        :mu0 => mu0,
    )
    
    local model = Model(
        y = Stochastic(1, mu -> dist_model(mu), false),
        mu = Stochastic(() -> prior_mu, true),
    )
    local scheme = [
        NUTS([:mu])
    ]
    setsamplers!(model, scheme)
    
    local inits = [
        Dict{Symbol, Any}(
            :y => data[:Y],
            :mu => data[:mu0],
        )
        for k in 1:chains
    ]
    return model, data, inits
end
Out[5]:
sample2model_normal1 (generic function with 1 method)

Mambaによるシミュレーション結果のまとめの表示

In [6]:
## Summary
function showsummary(sim;
        sortkeys=true, figsize_t=(8, 3), figsize_c=(8, 3.5),
        show_describe=true, show_gelman=true, plot_trace=true, plot_contour=true)
    ## Summary of MCMC
    if show_describe
        println("\n========== Summary:\n")
        display(describe(sim))
    end

    # Convergence Diagnostics
    if show_gelman && length(sim.chains)  2 
       println("========== Gelman Diagnostic:")
       show(gelmandiag(sim))
    end

    ## Plot
    sleep(0.1)
    if plot_trace
        #draw(plot(sim), fmt=:png, width=10inch, height=3.5inch, nrow=1, ncol=2, ask=false)
        pyplot_trace(sim, sortkeys=sortkeys, figsize=figsize_t)
    end
    if plot_contour
        #draw(plot(sim, :contour), fmt=:png, width=10inch, height=4.5inch, nrow=1, ncol=2, ask=false)
        pyplot_contour(sim, sortkeys=sortkeys, figsize=figsize_c)
    end
end

## plot traces
function pyplot_trace(sim; sortkeys = false, figsize = (8, 3))
    if sortkeys
        keys_sim = sort(keys(sim))
    else
        keys_sim = keys(sim)
    end
    for var in keys_sim
        plt.figure(figsize=figsize)
        
        plt.subplot(1,2,1)
        for k in sim.chains
            plt.plot(sim.range, sim[:,var,:].value[:,1,k], label="$k", lw=0.4, alpha=0.8)
        end
        plt.xlabel("iterations")
        plt.grid(ls=":")
        #plt.legend(loc="upper right")
        plt.title("trace of $var", fontsize=10)
        
        plt.subplot(1,2,2)
        local xmin = quantile(vec(sim[:,var,:].value), 0.005)
        local xmax = quantile(vec(sim[:,var,:].value), 0.995)
        for k in sim.chains
            local chain = sim[:,var,:].value[:,1,k]
            local pdfkde = makefunc_pdfkde(chain)
            local x = linspace(xmin, xmax, 201)
            plt.plot(x, pdfkde.(x), label="$k", lw=0.8, alpha=0.8)
        end
        plt.xlabel("$var")
        plt.grid(ls=":")
        plt.title("empirical posterior pdf of $var", fontsize=10)

        plt.tight_layout()
    end
end

# plot contours
function pyplot_contour(sim; sortkeys = true, figsize = (8, 3.5))
    if sortkeys
        keys_sim = sort(keys(sim))
    else
        keys_sim = keys(sim)
    end
    local m = 0
    local K = length(keys_sim)
    for i in 1:K
        for j in i+1:K
            m += 1
            mod1(m,2) == 1 && plt.figure(figsize=figsize)
            plt.subplot(1,2, mod1(m,2))

            local u = keys_sim[i]
            local v = keys_sim[j]
            local X = vec(sim[:,u,:].value)
            local Y = vec(sim[:,v,:].value)
            local pdfkde = makefunc_pdfkde(X,Y)
            local xmin = quantile(X, 0.005)
            local xmax = quantile(X, 0.995)
            local ymin = quantile(Y, 0.005)
            local ymax = quantile(Y, 0.995)
            local x = linspace(xmin, xmax, 200)
            local y = linspace(ymin, ymax, 200)
            
            plt.pcolormesh(x', y, pdfkde.(x',y), cmap="CMRmap")
            plt.colorbar()
            plt.grid(ls=":")
            plt.xlabel(u)
            plt.ylabel(v)
            plt.title("posterior of ($u, $v)", fontsize=10)

            mod1(m,2) == 2 && plt.tight_layout()

            if 2*m == K*(K-1) && mod1(m,2) == 1
                plt.subplot(1,2,2)
                
                plt.pcolormesh(y', x, pdfkde.(x,y'), cmap="CMRmap")
                plt.colorbar()
                plt.grid(ls=":")
                plt.xlabel(v)
                plt.ylabel(u)
                plt.title("posterior of ($v, $u)", fontsize=10)

                plt.tight_layout()
            end
        end
    end
end
Out[6]:
pyplot_contour (generic function with 1 method)

WAICなどを計算するための函数

In [7]:
# loglik[l,i] = lpdf(w_l, Y_i) と chain[l,:] = w_l を取り出す函数を作る函数
#
function makefunc_loglikchainof(lpdf, symbols, Y)
    #
    # loglikchainof(sim) で loglik[l,i] と chain[l,:] が抽出される
    #
    local function loglikchainof(sim)
        local val = sim[:, symbols, :].value
        local chain = vcat((val[:,:,k] for k in 1:size(val,3))...)
        local L = size(chain,1)
        local n = length(Y)
        local loglik = Array{Float64, 2}(L, n)
        for i in 1:n
            for l in 1:L
                loglik[l,i] = lpdf(chain[l,:], Y[i])
            end
        end
        return loglik, chain
    end
    return loglikchainof
end

# 予測分布函数 p^*(x,y) = mean of { lpdf(w_l, y) }_{l=1}^L を作る函数
#
function makefunc_pdfpred(lpdf, chain)
    local L = size(chain,1)
    local pred_Bayes(y) = @sum(exp(lpdf((@view chain[l,:]), y)), l, 1:L)/L
    return pred_Bayes
end

# loglik[l,i] からWAICを計算する函数
#
function WAICof(loglik)
    local L, n
    L, n = size(loglik)
    local T_n = -mean(log(mean(exp(loglik[l,i]) for l in 1:L)) for i in 1:n)
    local V_n  = sum(var(loglik[:,i], corrected=false) for i in 1:n)
    local WAIC = 2*n*T_n + 2*V_n
    return WAIC, 2*n*T_n, 2*V_n
end

# loglik[l,i] からLOOCVを素朴に計算する函数
#
function LOOCVof(loglik)
    local L, n
    L, n = size(loglik)
    local LOOCV = 2*sum(log(mean(exp(-loglik[l,i]) for l in 1:L)) for i in 1:n)
    return LOOCV
end

# 自由エネルギー(の2倍)を計算するための函数
# 
# 自由エネルギーの2の逆温度微分は E^β_w[2n L_n] に等しいので、
# それを β=0.0 から 1.0 まで数値積分すれば自由エネルギーの2倍を計算できる。
#
function FreeEnergyof(loglik)
    local E2nLn = makefunc_E2nLn(loglik)
    local F = quadgk(E2nLn, 0.0, 1.0)[1]
    return F, E2nLn
end

function makefunc_E2nLn(loglik)
    local L = size(loglik)[1]
    local negloglik = -sum(loglik, 2)
    local negloglik_n = negloglik .- maximum(negloglik)
    local function E2nLn(beta)
        local Edenominator = @sum(             exp((1-beta)*negloglik_n[l]), l, 1:L)/L
        if Edenominator == zero(Edenominator) || !isfinite(Edenominator)
            return zero(Edenominator)
        end
        local Enumerator   = @sum(negloglik[l]*exp((1-beta)*negloglik_n[l]), l, 1:L)/L
        return 2*Enumerator/Edenominator
    end
    return E2nLn
end

# loglik[l,i] からWBICを計算する函数
#
function WBICof(loglik)
    local E2nLn = makefunc_E2nLn(loglik)
    local n = size(loglik, 2)
    local WBIC = E2nLn(1/log(n))
    return WBIC
end

# 汎化損失を計算する函数
#
function GeneralizationLossof(pdftrue, pdfpred; xmin=-10.0, xmax=10.0)
    local f(x) = -pdftrue(x)*log(pdfpred(x))
    return quadgk(f, xmin, xmax)
end

# Shannon情報量を計算する函数
#
ShannonInformationof(pdftrue; xmin=-10.0, xmax=10.0) = GeneralizationLossof(pdftrue, pdftrue, xmin=xmin, xmax=xmax)
ShannonInformationof(dist::Distribution, n) = 2*n*entropy(dist)

# 最尤法を実行して AIC と BIC を計算する函数
#
# lpdf(w,y) = log p(y|w)
#
# w = link_model(z)   ←実ベクトル z 全体をパラメーター w の空間内に移す函数
# z = unlink_model(w) ←link_model(z)の逆函数
# これらは optimize() 函数の適用時にパラメーター w が定義域から外れることを防ぐための函数達
#
# (X[i], Y[i] はサンプル
#
# chain は loglik, chain = loglikchainof(sim) で作った chain
#
# optimize函数は1変数函数の場合には初期条件を与えて解を求める函数ではなくなるので、
# その場合には2変数函数に拡張してから使用している.
#
function AICandBICof(lpdf, link_model, unlink_model, Y, chain)
    local n = length(Y)
    local L = size(chain,1)
    local nparams = size(chain,2)
    local negloglik(z) = -@sum(lpdf(link_model(z), Y[i]), i, 1:n)
    local minnegloglik_chain, l
    minnegloglik_chain, l = findmin(negloglik(unlink_model(chain[l,:])) for l in 1:L)
    local o, minnegloglik, param_AIC
    if size(chain,2) == 1
        local f(z) = negloglik([z[1]]) + z[2]^2/eps()
        o = optimize(f, [unlink_model(chain[l,:])[1], 0.0])
        minnegloglik = o.minimum
        param_AIC = link_model([o.minimizer[1]])
    else
        o = optimize(negloglik, unlink_model(chain[l,:]))
        minnegloglik = o.minimum
        param_AIC = link_model(o.minimizer)
    end
    local T_AIC = 2.0*minnegloglik
    local V_AIC = 2.0*nparams
    local T_BIC = T_AIC
    local V_BIC = nparams*log(n)
    local AIC = T_AIC + V_AIC
    local BIC = T_BIC + V_BIC
    return AIC, T_AIC, V_AIC, 
        BIC, T_BIC, V_BIC, 
        param_AIC
end
Out[7]:
AICandBICof (generic function with 1 method)

情報をまとめて表示するための函数

In [8]:
function statsof(sim, Y; 
        dist_true=mixnormal(0.0, 0.0, 0.0), 
        dist_model=mixnormal, link_model=link_mixnormal, unlink_model=unlink_mixnormal)
    
    local n = length(Y)
    
    local lpdf(w, y) = logpdf(dist_model(w), y)
    
    local loglikchainof = makefunc_loglikchainof(lpdf, sort(keys(sim)), Y)
    local loglik, chain 
    loglik, chain = loglikchainof(sim)
    
    local WAIC, T_WAIC, V_WAIC
    WAIC, T_WAIC, V_WAIC = WAICof(loglik)
    local LOOCV = LOOCVof(loglik)
    
    local WBIC = WBICof(loglik)
    local FreeEnergy = FreeEnergyof(loglik)[1]
    
    local param_Bayes = vec(mean(chain, 1))
    local pred_Bayes = makefunc_pdfpred(lpdf, chain)
    
    local GeneralizationLoss = 2*n*GeneralizationLossof(x->pdf(dist_true,x), pred_Bayes)[1]
    
    local AIC, T_AIC, V_AIC, BIC, T_BIC, V_BIC, param_MLE
    AIC, T_AIC, V_AIC, BIC, T_BIC, V_BIC, param_MLE = AICandBICof(lpdf, link_model, unlink_model, Y, chain)
    
    local pred_MLE(y) = exp(lpdf(param_MLE, y))
    return WAIC, T_WAIC, V_WAIC, LOOCV, GeneralizationLoss,
        WBIC, FreeEnergy,
        param_Bayes, pred_Bayes, 
        AIC, T_AIC, V_AIC, 
        BIC, T_BIC, V_BIC,
        param_MLE, pred_MLE
end

function show_all_results(dist_true, Y, sim; statsfunc=statsof, 
        dist_model=mixnormal, link_model=link_mixnormal, unlink_model=link_mixnormal,
        figsize=(6,4.2), xmin=-4.0, xmax=6.0)
    WAIC, T_WAIC, V_WAIC, LOOCV, GeneralizationLoss,
    WBIC, FreeEnergy,
    param_Bayes, pred_Bayes, 
    AIC, T_AIC, V_AIC, 
    BIC, T_BIC, V_BIC,
    param_MLE, pred_MLE = statsfunc(sim, Y, dist_true=dist_true, 
        dist_model=dist_model, link_model=link_model, unlink_model=unlink_model)
    
    n = length(Y)
    println("\n=== Estimates by $dist_model  (n = $n) ===")
    @show param_Bayes
    @show param_MLE
    
    println("--- Information Criterions")
    println("* AIC     = $AIC = $T_AIC + $V_AIC")
    println("* GenLoss = $GeneralizationLoss")
    println("* WAIC    = $WAIC = $T_WAIC + $V_WAIC")
    println("* LOOCV   = $LOOCV")
    println("---")
    println("* BIC        = $BIC = $T_BIC + $V_BIC")
    println("* FreeEnergy = $FreeEnergy")
    println("* WBIC       = $WBIC")

    println("="^78 * "\n")
    
    sleep(0.1)
    plt.figure(figsize=figsize)
    kde_sample = makefunc_pdfkde(Y)
    x = linspace(xmin, xmax, 201)
    plt.plot(x, pdf.(dist_true, x), label="true distribution")
    plt.scatter(Y, kde_sample.(Y), label="sample", s=10, color="k", alpha=0.5)
    plt.plot(x, kde_sample.(x), label="KDE of sample",   color="k", alpha=0.5)
    plt.plot(x, pred_Bayes.(x), label="Baysian predictive", ls="--")
    plt.plot(x, pred_MLE.(x),   label="MLE predictive",     ls="-.")
    plt.xlabel("x")
    plt.ylabel("probability density")
    plt.grid(ls=":")
    plt.legend(fontsize=8)
    plt.title("Estimates by $dist_model: n = $(length(Y))")
end

function plotsample(dist_true, Y; figsize=(6,4.2), xmin=-4.0, xmax=6.0)
    sleep(0.1)
    plt.figure(figsize=figsize)
    kde_sample = makefunc_pdfkde(Y)
    x = linspace(xmin, xmax, 201)
    plt.plot(x, pdf.(dist_true, x), label="true dist.")
    plt.scatter(Y, kde_sample.(Y), label="sample", s=10, color="k", alpha=0.5)
    plt.plot(x, kde_sample.(x), label="KDE of sample",   color="k", alpha=0.5)
    plt.xlabel("x")
    plt.ylabel("probability density")
    plt.grid(ls=":")
    plt.legend(fontsize=8)
    plt.title("Sample size n = $(length(Y))")
end
Out[8]:
plotsample (generic function with 1 method)

単純な正規分布モデルを最尤法で解くための函数

In [9]:
# dist_true 分布に従う乱数で生成したサンプルの配列 Y を与えると
# 正規分布モデルの最尤法で推定して、AICなどを返す函数
#
function fit_Normal(dist_true, Y)
    local n = length(Y)
    local d = fit(Normal,Y)
    local mu = d.μ
    local sigma = d.σ
    local pred_Normal(y) = pdf(d,y)
    local T_Normal = -2*sum(logpdf.(d,Y))
    local V_Normal = 4.0
    local AIC_Normal = T_Normal + V_Normal
    local TB_Normal = T_Normal
    local VB_Normal = 2.0*log(n)
    local BIC_Normal = TB_Normal + VB_Normal
    local f(y) = -pdf(dist_true, y)*logpdf(d, y)
    local GL_Normal = 2*n*quadgk(f, -10, 10)[1]
    return mu, sigma, pred_Normal, 
        AIC_Normal, T_Normal, V_Normal, 
        BIC_Normal, TB_Normal, VB_Normal, 
        GL_Normal
end

# グラフをプロットする範囲が xmin から xmax まで
#
function show_fit_Normal(dist_true, Y; figsize=(6,4.2), xmin=-4.0, xmax=6.0)
    mu, sigma, pred_Normal, 
    AIC_Normal, T_Normal, V_Normal, 
    BIC_Normal, TB_Normal, VB_Normal, 
    GL_Normal = fit_Normal(dist_true, Y)
    println("--- Normal Fitting")
    println("* μ = $mu")
    println("* σ = $sigma")
    println("* GenLoss = $GL_Normal")
    println("* AIC     = $AIC_Normal = $T_Normal + $V_Normal")
    println("* BIC     = $BIC_Normal = $TB_Normal + $VB_Normal")
    
    sleep(0.1)
    plt.figure(figsize=figsize)
    kde_sample = makefunc_pdfkde(Y)
    x = linspace(xmin, xmax, 201)
    plt.plot(x, pdf.(dist_true, x), label="true distribution")
    plt.scatter(Y, kde_sample.(Y), label="sample", s=10, color="k", alpha=0.5)
    plt.plot(x, kde_sample.(x), label="KDE of sample",   color="k", alpha=0.5)
    plt.plot(x, pred_Normal.(x), label="Normal predictive", ls="--")
    plt.xlabel("x")
    plt.ylabel("probability density")
    plt.grid(ls=":")
    plt.legend(fontsize=8)
    plt.title("Sample size n = $(length(Y))")
end
Out[9]:
show_fit_Normal (generic function with 1 method)
In [10]:
function InformationCriterions(dist_true, Y, sim; dist_model=mixnormal)
    local n = length(Y)
    local lpdf(w, y) = logpdf(dist_model(w), y)

    local loglikchainof = makefunc_loglikchainof(lpdf, sort(keys(sim)), Y)
    local loglik, chain 
    loglik, chain = loglikchainof(sim)

    local WAIC, T_WAIC, V_WAIC
    WAIC, T_WAIC, V_WAIC = WAICof(loglik)
    local LOOCV = LOOCVof(loglik)
    
    local pred_Bayes = makefunc_pdfpred(lpdf, chain)
    local GL = 2*n*GeneralizationLossof(x->pdf(dist_true,x), pred_Bayes)[1]
    
    local WBIC = WBICof(loglik)
    local FreeEnergy = FreeEnergyof(loglik)[1]

    return chain, WAIC, T_WAIC, V_WAIC, LOOCV, GL, WBIC, FreeEnergy
 end
Out[10]:
InformationCriterions (generic function with 1 method)

プロットのための函数

In [11]:
function plotsim(t; modelname="normal", xmin=-3, xmax=8, y1max=0.45, y2max=0.40)
    plt.clf()
    
    local n = t
    local Y = Sample[1:n]
    local Chain = eval(Symbol(:Chain_, modelname))
    local chain = Chain[n]
    local dist_model = eval(Symbol(modelname))
    local lpdf(w,x) = logpdf(dist_model(w),x)
    local pred = makefunc_pdfpred(lpdf, chain)
    local x = linspace(xmin,xmax,201)
    local kl = eval(Symbol(:kl_, modelname))
    local wt = eval(Symbol(:wt_, modelname))
    
    plt.subplot(121)
    plt.plot(x, pdf.(dist_true,x), color="black", ls=":", label="true dist")
    plt.scatter(Y, pdf.(dist_true,Y), color="red", s=10, alpha=0.5, label="sample") 
    plt.plot(x, pred.(x), label="predictive")
    plt.ylim(-y1max/40, y1max)
    plt.grid(ls=":")
    plt.legend()
    plt.title("$modelname: n = $n")
    
    plt.subplot(122)
    plt.plot(1:n, kl[1:n], label="KL information")
    plt.plot(1:n, wt[1:n], label="(WAIC \$-\$ T\$_\\mathrm{true}\$)/2n")
    plt.xlabel("sample size n")
    plt.ylabel("\$-\$ log(probability)")
    plt.xlim(-1, N+1)
    plt.ylim(min(minimum(kl), minimum(wt[5:end]))-y2max/40, y2max)
    plt.grid(ls=":")
    plt.legend()
    plt.title("$modelname: n = $n")
    
    plt.tight_layout()
    plt.plot()
end
Out[11]:
plotsim (generic function with 1 method)

プロットの実行 (ガンマ分布のサンプルの場合)

GIFアニメーションを作成したい場合には適当にコメントアウトを外して実行する。

In [12]:
seed = 4649
#seed = 12345
#seed = 2017

dataname = "Gamma128Sample$seed"
#dataname = "MixGamma128Sample$seed"

data = load("$dataname.jld2")
for v in sort(collect(keys(data)))
    ex = parse("$v = data[\"$v\"]")
    # println(ex)
    eval(ex)
end

@show mu_true = mean(dist_true)
@show sigma_true = std(dist_true)
@show dist_normal_fitting = Normal(mu_true, sigma_true)

shannon_true = Shannon[1]/2
gl_normal_fitting = GeneralizationLossof(x->pdf(dist_true,x), x->pdf(dist_normal_fitting,x))[1]
kl_normal_fitting = gl_normal_fitting - shannon_true

KL_normal    = GL_normal    - Shannon
KL_normal1   = GL_normal1   - Shannon
KL_mixnormal = GL_mixnormal - Shannon

WT_normal    = WAIC_normal    - T_true
WT_normal1   = WAIC_normal1   - T_true
WT_mixnormal = WAIC_mixnormal - T_true

kl_normal    = [KL_normal[n]/(2n) for n in 1:N]
kl_normal1   = [KL_normal1[n]/(2n) for n in 1:N]
kl_mixnormal = [KL_mixnormal[n]/(2n) for n in 1:N]

wt_normal    = [WT_normal[n]/(2n) for n in 1:N]
wt_normal1   = [WT_normal1[n]/(2n) for n in 1:N]
wt_mixnormal = [WT_mixnormal[n]/(2n) for n in 1:N]

@show shannon_true
@show gl_normal_fitting
@show kl_normal_fitting
@show exp(-kl_normal_fitting)
println()
@show kl_normal
@show kl_normal1
@show kl_mixnormal
println()
Y = Sample
n = length(Y)
d = fit(Gamma, Y)
pred(x) = pdf(d, x)

gl_gamma = GeneralizationLossof(x->pdf(dist_true,x), x->pdf(d,x), xmin=0.0, xmax=20.0)[1]
kl_gamma = gl_gamma - shannon_true
@show shannon_true
@show gl_gamma
@show kl_gamma

sleep(0.1)

plt.figure(figsize=(10,3.5))
plotsim(N, modelname="normal")

plt.figure(figsize=(10,3.5))
plotsim(N, modelname="normal1")

plt.figure(figsize=(10,3.5))
plotsim(N, modelname="mixnormal")

modelname = "normal"
file = "$dataname$modelname.gif"
plot1frame(t) = plotsim(t, modelname=modelname)
fig = plt.figure(figsize=(10,3.5))
frames = [N; 1;1;1;1; 2;2;2;3;3;3;4;4;4; 5;5;5;6;6;7;7;8;8;9;9;10;10; 11:N; N;N;N;N;N;N;N;N;N]
interval = 100
@time myanim = anim.FuncAnimation(fig, plot1frame, frames=frames, interval=interval)
#@time myanim[:save](file, writer="imagemagick")
sleep(0.1)
showgif(file)
plt.clf()

modelname = "normal1"
file = "$dataname$modelname.gif"
plot1frame(t) = plotsim(t, modelname=modelname)
fig = plt.figure(figsize=(10,3.5))
frames = [N; 1;1;1;1; 2;2;2;3;3;3;4;4;4; 5;5;5;6;6;7;7;8;8;9;9;10;10; 11:N; N;N;N;N;N;N;N;N;N]
interval = 100
@time myanim = anim.FuncAnimation(fig, plot1frame, frames=frames, interval=interval)
#@time myanim[:save](file, writer="imagemagick")
sleep(0.1)
showgif(file)
plt.clf()

modelname = "mixnormal"
file = "$dataname$modelname.gif"
plot1frame(t) = plotsim(t, modelname=modelname)
fig = plt.figure(figsize=(10,3.5))
frames = [N; 1;1;1;1; 2;2;2;3;3;3;4;4;4; 5;5;5;6;6;7;7;8;8;9;9;10;10; 11:N; N;N;N;N;N;N;N;N;N]
interval = 100
@time myanim = anim.FuncAnimation(fig, plot1frame, frames=frames, interval=interval)
#@time myanim[:save](file, writer="imagemagick")
sleep(0.1)
showgif(file)
plt.clf()

sleep(0.1)
x = linspace(-3,8)
plt.figure(figsize=(5,3.0))
plt.plot(x, pdf.(dist_true,x), color="black", ls=":", label="true dist")
#plt.scatter(Y, pdf.(dist_true,Y), color="red", s=10, alpha=0.5, label="sample") 
plt.plot(x, pred.(x), label="predictive")
plt.ylim(-0.02, 0.45)
plt.grid(ls=":")
plt.legend()
plt.title("gamma: n = $n")
mu_true = mean(dist_true) = 2.5
sigma_true = std(dist_true) = 1.0
dist_normal_fitting = Normal(mu_true, sigma_true) = Distributions.Normal{Float64}(μ=2.5, σ=1.0)
shannon_true = 1.3634322389764533
gl_normal_fitting = 1.4188708216864312
kl_normal_fitting = 0.055438582709977924
exp(-kl_normal_fitting) = 0.9460701269495115

kl_normal = [1.31064, 0.943061, 0.750522, 0.656015, 0.467435, 0.325275, 0.280714, 0.183553, 0.14322, 0.109568, 0.133723, 0.114118, 0.134027, 0.13493, 0.165084, 0.129665, 0.0916437, 0.109228, 0.122627, 0.128549, 0.144589, 0.118252, 0.12343, 0.104349, 0.0963167, 0.0944973, 0.10444, 0.0742909, 0.0781926, 0.0830617, 0.0784852, 0.0842484, 0.0801226, 0.0893429, 0.0927317, 0.0942377, 0.100268, 0.108295, 0.081999, 0.079874, 0.0838082, 0.0843932, 0.0866203, 0.0807927, 0.0828459, 0.0811817, 0.0795774, 0.0762512, 0.0768373, 0.0720845, 0.072021, 0.0710833, 0.069372, 0.0677738, 0.0695583, 0.0721299, 0.0692544, 0.0710969, 0.0747382, 0.0721481, 0.07313, 0.0730386, 0.0720626, 0.0750822, 0.0726765, 0.0700707, 0.071154, 0.0701254, 0.0693048, 0.0688835, 0.0698035, 0.0727116, 0.0711797, 0.0741647, 0.0717155, 0.0710162, 0.0686571, 0.0685826, 0.0663693, 0.0664087, 0.0675754, 0.0677752, 0.0660047, 0.0678938, 0.0659807, 0.0710998, 0.0713109, 0.0702581, 0.0713986, 0.070322, 0.0716038, 0.0703248, 0.0687681, 0.0724746, 0.070599, 0.0704589, 0.0687648, 0.0675405, 0.0690222, 0.0677945, 0.069507, 0.0676838, 0.0727986, 0.0720131, 0.0738819, 0.0729581, 0.0730847, 0.0709425, 0.0707688, 0.0697003, 0.0694465, 0.0677609, 0.0663643, 0.0741785, 0.0723012, 0.0716901, 0.0712664, 0.0707643, 0.0698307, 0.0691062, 0.0689839, 0.0694474, 0.0698461, 0.0693909, 0.0711229, 0.0698827, 0.0690576, 0.0686411]
kl_normal1 = [1.47104, 1.04922, 0.856708, 0.646608, 0.505585, 0.350638, 0.301355, 0.159513, 0.151451, 0.125463, 0.151449, 0.134613, 0.15219, 0.152594, 0.174633, 0.141745, 0.109988, 0.127805, 0.134692, 0.141619, 0.160252, 0.134758, 0.134227, 0.111601, 0.102443, 0.101496, 0.108661, 0.0845866, 0.0882789, 0.0921233, 0.0870597, 0.0909659, 0.089052, 0.092995, 0.0945692, 0.0912503, 0.0920543, 0.0970527, 0.0715839, 0.0695211, 0.0740948, 0.0805521, 0.0693679, 0.0694012, 0.0701721, 0.0724424, 0.0742791, 0.0728815, 0.0720971, 0.0718206, 0.0705532, 0.0632158, 0.0634734, 0.0624744, 0.0668962, 0.0671967, 0.0662609, 0.0693644, 0.0721963, 0.0699896, 0.0741157, 0.072758, 0.0711214, 0.072918, 0.068315, 0.0677923, 0.064398, 0.0656126, 0.0673253, 0.0626848, 0.0641893, 0.0644017, 0.0622744, 0.0645114, 0.0631903, 0.0618102, 0.0608189, 0.059143, 0.0598792, 0.0592656, 0.0609539, 0.0608831, 0.0623464, 0.0630496, 0.0619223, 0.0587734, 0.0587298, 0.0595798, 0.0601699, 0.0613004, 0.0624615, 0.0627157, 0.0628518, 0.0597006, 0.0599679, 0.0591537, 0.0592126, 0.0590257, 0.0599146, 0.060057, 0.0609821, 0.0601873, 0.0576759, 0.0579047, 0.0568314, 0.0575244, 0.0569019, 0.0564451, 0.0561743, 0.0562749, 0.0561539, 0.0562237, 0.0561309, 0.055473, 0.0554871, 0.0554596, 0.0554911, 0.0554578, 0.055464, 0.055468, 0.0554557, 0.0554974, 0.0554629, 0.0555053, 0.0554721, 0.0554648, 0.0554687, 0.0555168]
kl_mixnormal = [1.51714, 1.23032, 0.956264, 0.764635, 0.611714, 0.461523, 0.386516, 0.242406, 0.227114, 0.192356, 0.210525, 0.189439, 0.204289, 0.198312, 0.229041, 0.184636, 0.146015, 0.165102, 0.170953, 0.17621, 0.189366, 0.167973, 0.164655, 0.140756, 0.132038, 0.125436, 0.132522, 0.110978, 0.112889, 0.114524, 0.106492, 0.112621, 0.106069, 0.11177, 0.113272, 0.11073, 0.111814, 0.114561, 0.0845257, 0.0841041, 0.0865583, 0.094262, 0.0841411, 0.0810539, 0.083699, 0.0863985, 0.0887248, 0.0836628, 0.0846884, 0.0831397, 0.081651, 0.0718591, 0.0712221, 0.0727205, 0.0739206, 0.0764266, 0.074292, 0.0750328, 0.0811272, 0.0781057, 0.0813954, 0.0800016, 0.080839, 0.0804213, 0.0752033, 0.0760047, 0.0724008, 0.07307, 0.0724738, 0.0677019, 0.0692192, 0.0692594, 0.0665389, 0.0674687, 0.0682824, 0.0668629, 0.0660012, 0.0649428, 0.0652945, 0.0644028, 0.0655677, 0.0657117, 0.0666502, 0.0680592, 0.0677797, 0.0608641, 0.0609913, 0.0623759, 0.0632366, 0.0641064, 0.0632388, 0.0634444, 0.0637755, 0.0597194, 0.0590873, 0.0577823, 0.05877, 0.0578065, 0.0596795, 0.0596639, 0.0607269, 0.0602364, 0.0547192, 0.0559958, 0.0561995, 0.0557398, 0.0549923, 0.0552078, 0.0553165, 0.0546666, 0.0547095, 0.055189, 0.0535805, 0.0541584, 0.053423, 0.0520137, 0.0519102, 0.0518535, 0.05307, 0.0520342, 0.0513495, 0.0521865, 0.0526227, 0.052231, 0.0531176, 0.0521187, 0.0523198, 0.052091]

shannon_true = 1.3634322389764533
gl_gamma = 1.3654148880562635
kl_gamma = 0.0019826490798102725
  0.229449 seconds (76.46 k allocations: 4.110 MiB, 2.80% gc time)