推断工具

enable_validation(is_validate)[source]
is_validation_enabled()[source]
validation_enabled(is_validate=True)[source]

模型检查

get_dependencies(model: Callable, model_args: Optional[tuple] = None, model_kwargs: Optional[dict] = None, include_deterministic: bool = False) Dict[str, object][source]

推断关于条件模型的依赖结构。

这会返回一个嵌套字典,其结构类似于

{
    "prior_dependencies": {
        "variable1": {"variable1": set()},
        "variable2": {"variable1": set(), "variable2": set()},
        ...
    },
    "posterior_dependencies": {
        "variable1": {"variable1": {"plate1"}, "variable2": set()},
        ...
    },
}

其中

  • prior_dependencies 是一个字典,将下游的隐变量和观测变量映射到另一个字典,该字典再将它们依赖的上游隐变量映射到引起完全依赖的板(plates)的集合。也就是说,包含的板会像完全二分图一样引入二次数量的依赖,而排除的板只会像独立的平行边集合一样引入线性数量的依赖。先验依赖关系遵循原始模型顺序。

  • posterior_dependencies 是一个类似的字典,但将隐变量映射到它们在后验中依赖的隐变量或观测站点。后验依赖关系与模型顺序相反。

依赖关系省略了 pyro.deterministic 站点和 pyro.sample(..., Delta(...)) 站点。

示例

这是一个没有板的简单示例。我们看到每个节点都依赖于自身,并且只有隐变量出现在后验中

def model_1():
    a = pyro.sample("a", dist.Normal(0, 1))
    pyro.sample("b", dist.Normal(a, 1), obs=torch.tensor(0.0))

assert get_dependencies(model_1) == {
    "prior_dependencies": {
        "a": {"a": set()},
        "b": {"a": set(), "b": set()},
    },
    "posterior_dependencies": {
        "a": {"a": set(), "b": set()},
    },
}

这是一个示例,其中两个变量 ab 在先验中条件独立,但在后验中对它们都依赖的所谓“碰撞变量” c 变得条件依赖。这在图模型文献中被称为“道德化”(moralization)

def model_2():
    a = pyro.sample("a", dist.Normal(0, 1))
    b = pyro.sample("b", dist.LogNormal(0, 1))
    c = pyro.sample("c", dist.Normal(a, b))
    pyro.sample("d", dist.Normal(c, 1), obs=torch.tensor(0.))

assert get_dependencies(model_2) == {
    "prior_dependencies": {
        "a": {"a": set()},
        "b": {"b": set()},
        "c": {"a": set(), "b": set(), "c": set()},
        "d": {"c": set(), "d": set()},
    },
    "posterior_dependencies": {
        "a": {"a": set(), "b": set(), "c": set()},
        "b": {"b": set(), "c": set()},
        "c": {"c": set(), "d": set()},
    },
}

在存在板的情况下,依赖关系可能更复杂。到目前为止,所有字典值都是空的板集合,但在下面的后验中,我们看到 a` 通过板 p` 依赖于自身。这意味着,在 a` 的元素中,例如 a[0]` 依赖于 a[1]`(这就是为什么我们明确允许变量依赖于自身)

def model_3():
    with pyro.plate("p", 5):
        a = pyro.sample("a", dist.Normal(0, 1))
    pyro.sample("b", dist.Normal(a.sum(), 1), obs=torch.tensor(0.0))

assert get_dependencies(model_3) == {
    "prior_dependencies": {
        "a": {"a": set()},
        "b": {"a": set(), "b": set()},
    },
    "posterior_dependencies": {
        "a": {"a": {"p"}, "b": set()},
    },
}
[1] S.Webb, A.Goliński, R.Zinkov, N.Siddharth, T.Rainforth, Y.W.Teh, F.Wood (2018)

“生成模型的忠实反演用于有效的摊销推断” https://dl.acm.org/doi/10.5555/3327144.3327229

参数
  • model (可调用) – 一个模型。

  • model_args (tuple) – 模型参数的可选元组。

  • model_kwargs (dict) – 模型关键字参数的可选字典。

  • include_deterministic (bool) – 是否包含确定性站点。

返回

一个元数据字典(参见上文)。

返回类型

dict

render_model(model: Callable, model_args: Optional[Union[tuple, List[tuple]]] = None, model_kwargs: Optional[Union[dict, List[dict]]] = None, filename: Optional[str] = None, render_distributions: bool = False, render_params: bool = False, render_deterministic: bool = False) graphviz.graphs.Digraph[source]

使用 graphviz 渲染模型。

如果提供了 filename,则保存图像;否则绘制图形。用法示例参见模型渲染教程

参数
  • model – 要渲染的模型。

  • model_args – 要传递给模型的位置参数元组,或半监督模型的元组列表。

  • model_kwargs – 要传递给模型的关键字参数字典,或半监督模型的字典列表。

  • filename (str) – 保存渲染模型的文件的名称或路径。

  • render_distributions (bool) – 是否在图中包含随机变量分布标注(以及参数约束)。

  • render_params (bool) – 是否在图中显示参数。

  • render_deterministic (bool) – 是否包含确定性站点。

返回

模型图。

返回类型

graphviz.Digraph

交互式先验调整

class Resampler(guide: Callable, simulator: Optional[Callable] = None, *, num_guide_samples: int, max_plate_nesting: Optional[int] = None)[source]

用于交互式调整生成模型的重采样器,通常作为贝叶斯工作流程早期阶段进行先验预测检查时使用。

这旨在作为计算缓存,通过基于下游模拟的样本对先验分布的参数进行交互式调整。其思想是模拟可能成本高昂,但当微调参数分布的参数时,可以通过重要性重采样重用大部分先前样本。

参数
  • guide (可调用) – 一个不带参数的 pyro 模型。导引应该是弥散的,覆盖比随后传递给 sample()model 更大的空间。必须可以通过 pyro.plate 进行矢量化。

  • simulator (可调用) – 一个可选的更大 pyro 模型,包含导引隐变量的超集。必须可以通过 pyro.plate 进行矢量化。

  • num_guide_samples (int) – 从导引中抽取的初始样本数量。这应该远大于在后续调用 sample() 中请求的 num_samples

  • max_plate_nesting (int) – 模型中的最大板嵌套深度。如果缺失,将通过运行导引来猜测。

sample(model: Callable, num_samples: int, stable: bool = True) Dict[str, torch.Tensor][source]

绘制一组最多 num_samples 个模型样本,可选地由 simulator 扩展。

在内部,这会重要性重采样在 .__init__() 中由 guide 生成的样本,并且不会重新运行 guidesimulator。如果原始导引样本对模型分布覆盖不佳,样本将显示多样性较低。

参数
  • model (可调用) – 一个与原始 guide 具有相同隐变量的模型。必须可以通过 pyro.plate 进行矢量化。

  • num_samples (int) – 要抽取的样本数量。

  • stable (bool) – 是否使用分段常数多项分布采样。可视化时设为 True,蒙特卡洛积分时设为 False。默认为 True。

返回

一个堆叠样本的字典。

返回类型

Dict[str, torch.Tensor]