Zuko in Pyro¶
此文件包含在 Pyro 流水线中使用基于 Zuko 的归一化流的辅助工具。
配套教程位于 tutorial/svi_flow_guide.ipynb 和 tutorial/vae_flow_prior.ipynb。
- 类 ZukoToPyro(dist: torch.distributions.distribution.Distribution)[source]¶
将 Zuko 分布包装为 Pyro 分布。
如果
dist
具有rsample_and_log_prob
方法(如 Zuko 的流),则在采样时将使用它而不是rsample
。返回的对数密度将缓存起来供以后评分使用。- 参数
dist (torch.distributions.Distribution) – 一个分布实例。
flow = zuko.flows.MAF(features=5) # flow() is a torch.distributions.Distribution dist = flow() x = dist.sample((2, 3)) log_p = dist.log_prob(x) # ZukoToPyro(flow()) is a pyro.distributions.Distribution dist = ZukoToPyro(flow()) x = dist((2, 3)) log_p = dist.log_prob(x) with pyro.plate("data", 42): z = pyro.sample("z", dist)