Zuko in Pyro

此文件包含在 Pyro 流水线中使用基于 Zuko 的归一化流的辅助工具。

配套教程位于 tutorial/svi_flow_guide.ipynbtutorial/vae_flow_prior.ipynb

ZukoToPyro(dist: torch.distributions.distribution.Distribution)[source]

将 Zuko 分布包装为 Pyro 分布。

如果 dist 具有 rsample_and_log_prob 方法(如 Zuko 的流),则在采样时将使用它而不是 rsample。返回的对数密度将缓存起来供以后评分使用。

参数

dist (torch.distributions.Distribution) – 一个分布实例。

flow = zuko.flows.MAF(features=5)

# flow() is a torch.distributions.Distribution

dist = flow()
x = dist.sample((2, 3))
log_p = dist.log_prob(x)

# ZukoToPyro(flow()) is a pyro.distributions.Distribution

dist = ZukoToPyro(flow())
x = dist((2, 3))
log_p = dist.log_prob(x)

with pyro.plate("data", 42):
    z = pyro.sample("z", dist)