简易定制指南

EasyGuide

class EasyGuide(model)[source]

基类: pyro.nn.module.PyroModule

“简易指南”的基类,它比 AutoGuide 更灵活,但比原始的 Pyro 指南更容易编写。

派生类应该定义一个 guide() 方法。这个 guide() 方法可以将普通指南语句(例如 pyro.samplepyro.param)与以下特殊语句结合使用

  • group = self.group(...) 选择模型中的多个 pyro.sample 站点。有关后续方法,请参见 Group

  • with self.plate(...): ... 应该用于代替 pyro.plate

  • self.map_estimate(...) 对单个站点使用 Delta 指南。

派生类还可以重写 init() 方法,为模型站点提供自定义初始化。

参数

model (callable) – 一个 Pyro 模型。

property model
abstract guide(*args, **kargs)[source]

指南实现,由用户重写。

init(site)[source]

模型初始化方法,可由用户重写。

这应该输入一个站点并输出该站点的有效样本。默认行为是抽取随机样本

return site["fn"]()

有关其他可能的初始化函数,请参见 https://docs.pyro.org.cn/en/stable/infer.autoguide.html#module-pyro.infer.autoguide.initialization

forward(*args, **kwargs)[source]

运行指南。这通常用于推断算法。

注意

此方法在内部由 Module 使用。用户应该改用 __call__()

plate(name, size=None, subsample_size=None, subsample=None, *args, **kwargs)[source]

围绕 pyro.plate 的包装器,允许 EasyGuide 自动构造板。你应该在 guide() 实现中使用它,而不是 pyro.plate

group(match='.*')[source]

选择一组 Group 模型站点进行联合指导。

参数

match (str) – 匹配模型样本站点名称的正则表达式字符串。

返回值

一组模型站点。

返回类型

Group

map_estimate(name)[source]

使用 Delta 分布构建最大后验 (MAP) 指南。

参数

name (str) – 模型样本站点的名称。

返回值

采样值。

返回类型

torch.Tensor

training: bool

easy_guide

easy_guide(model)[source]

创建 EasyGuide 的便捷装饰器。以下是等效的

# Version 1. Decorate a function.
@easy_guide(model)
def guide(self, foo, bar):
    return my_guide(foo, bar)

# Version 2. Create and instantiate a subclass of EasyGuide.
class Guide(EasyGuide):
    def guide(self, foo, bar):
        return my_guide(foo, bar)
guide = Guide(model)

注意 @easy_guide 包装器不能被 pickle 化;要构建可以被 pickle 化的指南,请改为继承 EasyGuide

参数

model (callable) – 一个 Pyro 模型。

Group

class Group(guide, sites)[source]

基类: object

一个自动指南助手,用于匹配一组模型站点。

变量
  • event_shape (torch.Size) – 模型中所有匹配样本站点的展平连接后的总形状。

  • prototype_sites (list) – 模型原型追踪中所有匹配样本站点的列表。

参数
  • guide (EasyGuide) – 一个 EasyGuide 实例。

  • sites (list) – 模型站点的列表。

property guide
sample(guide_name, fn, infer=None)[source]

围绕 pyro.sample() 的包装器,用于创建一个单一的辅助样本站点,然后解包到多个样本站点以供模型重放。

参数
  • guide_name (str) – 辅助指南站点的名称。

  • fn (callable) – 形状为 self.event_shape 的分布。

  • infer (dict) – 可选的推断配置字典。

返回值

一个对 (guide_z, model_zs),其中 guide_z 是单个连接的 blob,而 model_zs 是一个将站点名称映射到受限模型样本的字典。

返回类型

tuple

map_estimate()[source]

使用 Delta 分布构建最大后验 (MAP) 指南。

返回值

一个将模型站点名称映射到采样值的字典。

返回类型

dict