using Gen
https://www.gen.dev/tutorials/intro-to-modeling/tutorial
Installation
brew install julia
julia
type ]
.
add Gen add Plots
Setup
using Plots
Writing a probabilistic model as a generative function
Bayesian linear regressioncrypto,
@gen function line_model(xs::Vector{Float64})
# We begin by sampling a slope and intercept for the line.
# Before we have seen the data, we don't know the values of
# these parameters, so we treat them as random choices. The
# distributions they are drawn from represent our prior beliefs
# about the parameters: in this case, that neither the slope nor the
# intercept will be more than a couple points away from 0.
= ({:slope} ~ normal(0, 1))
slope = ({:intercept} ~ normal(0, 2))
intercept
# We define a function to compute y for a given x
function y(x)
return slope * x + intercept
end
# Given the slope and intercept, we can sample y coordinates
# for each of the x coordinates in our input vector.
for (i, x) in enumerate(xs)
# Note that we name each random choice in this loop
# slightly differently: the first time through,
# the name (:y, 1) will be used, then (:y, 2) for
# the second point, and so on.
:y, i)} ~ normal(y(x), 0.1))
({(end
# Most of the time, we don't care about the return
# value of a model, only the random choices it makes.
# It can sometimems be useful to return something
# meaningful, however; here, we return the function `y`.
return y
end;
= [-5., -4., -3., -2., -1., 0., 1., 2., 3., 4., 5.]; xs
= line_model(xs) y
y (generic function with 1 method)
= Gen.simulate(line_model, (xs,)); trace
get_args(trace) Gen.
([-5.0, -4.0, -3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0],)
function render_trace(trace; show_data=true)
# Pull out xs from the trace
= get_args(trace)
xs,
= minimum(xs)
xmin = maximum(xs)
xmax
# Pull out the return value, useful for plotting
= get_retval(trace)
y
# Draw the line
= collect(range(-5, stop=5, length=1000))
test_xs = plot(test_xs, map(y, test_xs), color="black", alpha=0.5, label=nothing,
fig =(xmin, xmax), ylim=(xmin, xmax))
xlim
if show_data
= [trace[(:y, i)] for i=1:length(xs)]
ys
# Plot the data set
scatter!(xs, ys, c="black", label=nothing)
end
return fig
end;
render_trace(trace)
function grid(renderer::Function, traces)
plot(map(renderer, traces)...)
Plots.end;
= [Gen.simulate(line_model, (xs,)) for _=1:12]
traces grid(render_trace, traces)