Poutine (效果处理器)

在内置的推断算法之下,Pyro 拥有一个由可组合的效果处理器构成的库,用于创建新的推断算法以及处理概率程序。Pyro 的推断算法都是通过将这些处理器应用于随机函数构建而成。为了对效果处理器是什么以及它们解决了什么问题有一个总体了解,请阅读 Matija Pretnar 所著的《代数效应和处理器简介》

处理器

Poutine 是一个由可组合的效果处理器构成的库,用于记录和修改 Pyro 程序的行为。这些底层组件简化了新推断算法和行为的实现。

处理器可以用作高阶函数、装饰器或上下文管理器来修改函数或代码块的行为

例如,考虑以下 Pyro 程序

>>> def model(x):
...     s = pyro.param("s", torch.tensor(0.5))
...     z = pyro.sample("z", dist.Normal(x, s))
...     return z ** 2

我们可以使用 condition 将采样站点标记为观察值,它返回一个与 model 具有相同输入和输出签名的可调用对象

>>> conditioned_model = poutine.condition(model, data={"z": 1.0})

我们也可以将处理器用作装饰器

>>> @pyro.condition(data={"z": 1.0})
... def model(x):
...     s = pyro.param("s", torch.tensor(0.5))
...     z = pyro.sample("z", dist.Normal(x, s))
...     return z ** 2

或用作上下文管理器

>>> with pyro.condition(data={"z": 1.0}):
...     s = pyro.param("s", torch.tensor(0.5))
...     z = pyro.sample("z", dist.Normal(0., s))
...     y = z ** 2

处理器可以自由组合

>>> conditioned_model = poutine.condition(model, data={"z": 1.0})
>>> traced_model = poutine.trace(conditioned_model)

许多推断算法或算法组件只需几行代码即可实现

guide_tr = poutine.trace(guide).get_trace(...)
model_tr = poutine.trace(poutine.replay(conditioned_model, trace=guide_tr)).get_trace(...)
monte_carlo_elbo = model_tr.log_prob_sum() - guide_tr.log_prob_sum()
block(fn: None = None, hide_fn: Optional[Callable[[Message], Optional[bool]]] = None, expose_fn: Optional[Callable[[Message], Optional[bool]]] = None, hide_all: bool = True, expose_all: bool = False, hide: Optional[List[str]] = None, expose: Optional[List[str]] = None, hide_types: Optional[List[str]] = None, expose_types: Optional[List[str]] = None) pyro.poutine.block_messenger.BlockMessenger[源代码]
block(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], hide_fn: Optional[Callable[[Message], Optional[bool]]] = None, expose_fn: Optional[Callable[[Message], Optional[bool]]] = None, hide_all: bool = True, expose_all: bool = False, hide: Optional[List[str]] = None, expose: Optional[List[str]] = None, hide_types: Optional[List[str]] = None, expose_types: Optional[List[str]] = None) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]

BlockMessenger 的便捷包装器

此处理器有选择地对外隐藏 Pyro 原语站点。默认行为:阻塞所有内容。

如果满足以下至少一个条件,则站点被隐藏

  1. hide_fn(msg) is True(not expose_fn(msg)) is True

  2. msg["name"]hide

  3. msg["type"]hide_types

  4. msg["name"] 不在 expose 中且 msg["type"] 不在 expose_types

  5. hidehide_typesexpose_types 都为 None

例如,假设随机函数 fn 有两个采样站点“a”和“b”。那么 BlockMessenger(fn, hide=["a"]) 之外的任何效果都不会应用于站点“a”,而只会看到站点“b”

>>> def fn():
...     a = pyro.sample("a", dist.Normal(0., 1.))
...     return pyro.sample("b", dist.Normal(a, 1.))
>>> fn_inner = pyro.poutine.trace(fn)
>>> fn_outer = pyro.poutine.trace(pyro.poutine.block(fn_inner, hide=["a"]))
>>> trace_inner = fn_inner.get_trace()
>>> trace_outer  = fn_outer.get_trace()
>>> "a" in trace_inner
True
>>> "a" in trace_outer
False
>>> "b" in trace_inner
True
>>> "b" in trace_outer
True
参数
  • fn – 一个随机函数 (包含 Pyro 原语调用的可调用对象)

  • hide_fn – 一个接受站点并返回 True 以隐藏站点,或返回 False/None 以暴露站点的函数。如果指定此参数,则忽略所有其他参数。仅指定 hide_fn 或 expose_fn 中的一个,不能同时指定两者。

  • expose_fn – 一个接受站点并返回 True 以暴露站点,或返回 False/None 以隐藏站点的函数。如果指定此参数,则忽略所有其他参数。仅指定 hide_fn 或 expose_fn 中的一个,不能同时指定两者。

  • hide_all (bool) – 隐藏所有站点

  • expose_all (bool) – 正常暴露所有站点

  • hide (list) – 要隐藏的站点名称列表

  • expose (list) – 要暴露的站点名称列表,同时隐藏所有其他站点

  • hide_types (list) – 要隐藏的站点类型列表

  • expose_types (list) – 要暴露的站点类型列表,同时隐藏所有其他站点

返回

一个用 BlockMessenger 装饰的随机函数

broadcast(fn: None = None) pyro.poutine.broadcast_messenger.BroadcastMessenger[源代码]
broadcast(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]

BroadcastMessenger 的便捷包装器

在单个或嵌套的 plate 上下文中时,自动广播采样站点处随机函数的批形状。现有的 batch_shape 必须与 cond_indep_stack 中安装的 plate 上下文的大小可广播。

请注意,下面的 model_automatic_broadcast 如何自动化地扩展分布的批形状。这使得模块化 Pyro 模型变得容易,因为子组件与外部包装的 plate 上下文无关。

>>> def model_broadcast_by_hand():
...     with IndepMessenger("batch", 100, dim=-2):
...         with IndepMessenger("components", 3, dim=-1):
...             sample = pyro.sample("sample", dist.Bernoulli(torch.ones(3) * 0.5)
...                                                .expand_by(100))
...             assert sample.shape == torch.Size((100, 3))
...     return sample
>>> @poutine.broadcast
... def model_automatic_broadcast():
...     with IndepMessenger("batch", 100, dim=-2):
...         with IndepMessenger("components", 3, dim=-1):
...             sample = pyro.sample("sample", dist.Bernoulli(torch.tensor(0.5)))
...             assert sample.shape == torch.Size((100, 3))
...     return sample
collapse(fn: None = None, *args: Any, **kwargs: Any) pyro.poutine.collapse_messenger.CollapseMessenger[源代码]
collapse(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], *args: Any, **kwargs: Any) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]

CollapseMessenger 的便捷包装器

实验性 - 通过惰性采样并尝试使用共轭关系来折叠上下文中的所有站点。如果不知道共轭关系,这将失败。使用采样站点结果的代码必须编写为接受 Funsors 而非 Tensors。这需要安装 funsor

警告

这与自动猜测 max_plate_nesting 不兼容。如果在折叠的上下文中出现任何 plate,您应该手动向您的推断算法声明 max_plate_nesting (例如 Trace_ELBO(max_plate_nesting=1))。

condition(data: Union[Dict[str, torch.Tensor], Trace]) pyro.poutine.condition_messenger.ConditionMessenger[源代码]
condition(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], data: Union[Dict[str, torch.Tensor], Trace]) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]

ConditionMessenger 的便捷包装器

给定一个包含一些采样语句和按名称提供的观察值字典的随机函数,将这些名称处的采样语句更改为使用这些值的观察。

考虑以下 Pyro 程序

>>> def model(x):
...     s = pyro.param("s", torch.tensor(0.5))
...     z = pyro.sample("z", dist.Normal(x, s))
...     return z ** 2

要观察站点 z 的值,我们可以这样写

>>> conditioned_model = pyro.poutine.condition(model, data={"z": torch.tensor(1.)})

这等效于在 model 中的 pyro.sample(“z”, …) 中将 obs=value 添加为关键字参数。

参数
  • fn – 一个随机函数 (包含 Pyro 原语调用的可调用对象)

  • data – 一个 dict 或一个 Trace

返回

一个用 ConditionMessenger 装饰的随机函数

do(data: Dict[str, Union[torch.Tensor, numbers.Number]]) pyro.poutine.do_messenger.DoMessenger[源代码]
do(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], data: Dict[str, Union[torch.Tensor, numbers.Number]]) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]

DoMessenger 的便捷包装器

给定一个包含一些采样语句和按名称提供的值字典的随机函数,将这些站点的返回值设置为等于这些值,就好像它们被硬编码为这些值一样,并引入名称相同但其值不传播的新采样站点。

可以与 condition() 自由组合来表示潜在结果的反事实分布。有关更多详细信息和理论,请参阅 Single World Intervention Graphs [1]。

考虑以下 Pyro 程序

>>> def model(x):
...     s = pyro.param("s", torch.tensor(0.5))
...     z = pyro.sample("z", dist.Normal(x, s))
...     return z ** 2

要对站点 z 进行干预(设置一个值),我们可以这样写

>>> intervened_model = pyro.poutine.do(model, data={"z": torch.tensor(1.)})

这等效于将 z = pyro.sample(“z”, …) 替换为 z = torch.tensor(1.),并引入一个新的采样站点 pyro.sample(“z”, …),其值不用于其他地方。

参考

[1] Single World Intervention Graphs: A Primer,

Thomas Richardson, James Robins

参数
  • fn – 一个随机函数 (包含 Pyro 原语调用的可调用对象)

  • data – 一个将采样站点名称映射到干预值的 dict

返回

一个用 DoMessenger 装饰的随机函数

enum(fn: None = None, first_available_dim: Optional[int] = None) pyro.poutine.enum_messenger.EnumMessenger[源代码]
enum(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], first_available_dim: Optional[int] = None) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]

EnumMessenger 的便捷包装器

在标记为 infer={"enumerate": "parallel"} 的离散采样站点上并行枚举。

参数

first_available_dim (int) – 可用于并行枚举的第一个张量维度(从右计数)。此维度及其左侧的所有维度都可能被 Pyro 内部使用。这应该是一个负整数或 None。

escape(escape_fn: Callable[[Message], bool]) pyro.poutine.escape_messenger.EscapeMessenger[源代码]
escape(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], escape_fn: Callable[[Message], bool]) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]

EscapeMessenger 的便捷包装器

一个通过引发 util.NonlocalExit 异常来进行非本地退出的信使

infer_config(config_fn: Callable[[Message], InferDict]) pyro.poutine.infer_config_messenger.InferConfigMessenger[源代码]
infer_config(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], config_fn: Callable[[Message], InferDict]) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]

InferConfigMessenger 的便捷包装器

给定一个包含 Pyro 原语调用的可调用对象 fn,以及一个接受轨迹站点并返回字典的可调用对象 config_fn,将采样站点处 infer kwarg 的值更新为 config_fn(site)。

参数
  • fn – 一个随机函数 (包含 Pyro 原语调用的可调用对象)

  • config_fn – 一个接受站点并返回一个 infer 字典的可调用对象

返回

一个用 InferConfigMessenger 装饰的随机函数

lift(prior: Union[Callable, Distribution, Dict[str, Union[Distribution, Callable]]]) pyro.poutine.lift_messenger.LiftMessenger[源代码]
lift(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], prior: Union[Callable, Distribution, Dict[str, Union[Distribution, Callable]]]) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]

LiftMessenger 的便捷包装器

给定一个包含 param 调用和先验分布的随机函数,创建一个随机函数,其中所有 param 调用都被替换为从先验中采样。先验应该是一个可调用对象或一个名称到可调用对象的字典。

考虑以下 Pyro 程序

>>> def model(x):
...     s = pyro.param("s", torch.tensor(0.5))
...     z = pyro.sample("z", dist.Normal(x, s))
...     return z ** 2
>>> lifted_model = pyro.poutine.lift(model, prior={"s": dist.Exponential(0.3)})

lift 使用 prior 中的分布使 param 语句的行为类似于 sample 语句。在此示例中,站点 s 现在将表现得好像它被 s = pyro.sample("s", dist.Exponential(0.3)) 替换了一样

>>> tr = pyro.poutine.trace(lifted_model).get_trace(0.0)
>>> tr.nodes["s"]["type"] == "sample"
True
>>> tr2 = pyro.poutine.trace(lifted_model).get_trace(0.0)
>>> bool((tr2.nodes["s"]["value"] == tr.nodes["s"]["value"]).all())
False
参数
  • fn – 其参数将被提升为随机值的函数

  • prior – 作为 Distribution 或随机函数字典形式的先验函数

返回

一个用 LiftMessenger 装饰的 fn

markov(fn: None = None, history: int = 1, keep: bool = False, dim: Optional[int] = None, name: Optional[str] = None) pyro.poutine.markov_messenger.MarkovMessenger[源代码]
markov(fn: Iterable[int] = None, history: int = 1, keep: bool = False, dim: Optional[int] = None, name: Optional[str] = None) pyro.poutine.markov_messenger.MarkovMessenger
markov(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T] = None, history: int = 1, keep: bool = False, dim: Optional[int] = None, name: Optional[str] = None) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]

马尔可夫依赖声明。

这可以通过多种方式使用

  • 用作上下文管理器

  • 用作递归函数的装饰器

  • 用作马尔可夫链的迭代器

参数
  • history (int) – 从当前上下文可见的先前上下文的数量。默认为 1。如果为零,则类似于 pyro.plate

  • keep (bool) – 如果为 True,则帧是可回放的。这在分支时很重要:如果 keep=True,同一级别的相邻分支可以相互依赖;如果 keep=False,相邻分支是独立的(以它们的共享部分为条件)

  • dim (int) – 可选的维度,用于此独立性索引。接口存根,行为尚未实现。

  • name (str) – 可选的唯一名称,用于帮助推断算法匹配模型和 guide 之间的 pyro.markov() 站点。接口存根,行为尚未实现。

mask(mask: Union[bool, torch.BoolTensor]) pyro.poutine.mask_messenger.MaskMessenger[源代码]
mask(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], mask: Union[bool, torch.BoolTensor]) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]

MaskMessenger 的便捷包装器

给定一个包含一些批处理采样语句和掩码张量的随机函数,按元素掩掉一些采样语句。

参数
  • fn – 一个随机函数 (包含 Pyro 原语调用的可调用对象)

  • mask (torch.BoolTensor) – 一个取值为 {0,1} 的掩码张量(1 表示包含站点,0 表示排除站点)

返回

一个用 MaskMessenger 装饰的随机函数

queue(fn=None, queue=None, max_tries=None, extend_fn=None, escape_fn=None, num_samples=None)[源代码]

用于离散变量的顺序枚举。

给定一个随机函数和一个队列,从队列中的一个完整轨迹中返回一个返回值。

参数
  • fn – 一个随机函数 (包含 Pyro 原语调用的可调用对象)

  • queue – 一个用于存储部分轨迹的队列数据结构,例如 multiprocessing.Queue

  • max_tries – 计算一个完整轨迹的最大尝试次数

  • extend_fn – 一个函数(可能是随机的),它接受一个部分轨迹和一个站点,并返回一个扩展轨迹列表

  • escape_fn – 一个函数(可能是随机的),它接受一个部分轨迹和一个站点,并返回一个布尔值来决定是否退出

  • num_samples – extend_fn 返回的扩展轨迹的可选数量

返回

一个用 poutine 逻辑装饰的随机函数

reparam(config: Union[Dict[str, Reparam], Callable[[Message], Optional[Reparam]]]) pyro.poutine.reparam_messenger.ReparamMessenger[源代码]
reparam(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], config: Union[Dict[str, Reparam], Callable[[Message], Optional[Reparam]]]) pyro.poutine.reparam_messenger.ReparamHandler[pyro.poutine.handlers._P, pyro.poutine.handlers._T]

ReparamMessenger 的便捷包装器

将每个受影响的采样站点重新参数化为一个或多个辅助采样站点,然后进行确定性变换 [1]。

要指定重新参数化器,请将 config 字典或可调用对象传递给构造函数。有关可用的重新参数化器,请参阅 pyro.infer.reparam 模块。

注意,一些重新参数化器可以检查它们影响的函数的 *args,**kwargs 输入;这些重新参数化器要求将 poutine.reparam 用作装饰器而不是上下文管理器。

[1] Maria I. Gorinova, Dave Moore, Matthew D. Hoffman (2019)

《概率程序的自动重新参数化》 https://arxiv.org/pdf/1906.03028.pdf

参数

config (dictcallable) – 配置,可以是将站点名称映射到 Reparameterizer 的 dict,也可以是将站点映射到 Reparam 或 None 的函数。有关内置的配置策略,请参阅 pyro.infer.reparam.strategies

replay(fn: None = None, trace: Optional[Trace] = None, params: Optional[Dict[str, torch.Tensor]] = None) pyro.poutine.replay_messenger.ReplayMessenger[源代码]
replay(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], trace: Optional[Trace] = None, params: Optional[Dict[str, torch.Tensor]] = None) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]

ReplayMessenger 的便捷包装器

给定一个包含 Pyro 原语调用的可调用对象,返回一个可调用对象,该对象运行原始函数,并在新轨迹中重用轨迹中站点的值

考虑以下 Pyro 程序

>>> def model(x):
...     s = pyro.param("s", torch.tensor(0.5))
...     z = pyro.sample("z", dist.Normal(x, s))
...     return z ** 2

replay 使 sample 语句的行为如同它们在轨迹中相应站点处采样了值一样

>>> old_trace = pyro.poutine.trace(model).get_trace(1.0)
>>> replayed_model = pyro.poutine.replay(model, trace=old_trace)
>>> bool(replayed_model(0.0) == old_trace.nodes["_RETURN"]["value"])
True
参数
  • fn – 一个随机函数 (包含 Pyro 原语调用的可调用对象)

  • trace – 用于回放的 Trace 数据结构

  • params – fn 中 param 站点名称和约束值的字典,用于回放

返回

一个用 ReplayMessenger 装饰的随机函数

scale(scale: Union[float, torch.Tensor]) pyro.poutine.scale_messenger.ScaleMessenger[source]
scale(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], scale: Union[float, torch.Tensor]) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]

便捷的包装器 ScaleMessenger

给定一个包含一些采样语句的随机函数和一个正比例因子,对函数中所有采样(sample)和观测(observe)站点的得分进行缩放。

考虑以下 Pyro 程序

>>> def model(x):
...     s = pyro.param("s", torch.tensor(0.5))
...     pyro.sample("z", dist.Normal(x, s), obs=torch.tensor(1.0))

scale 对采样站点的对数概率进行乘性缩放

>>> scaled_model = pyro.poutine.scale(model, scale=0.5)
>>> scaled_tr = pyro.poutine.trace(scaled_model).get_trace(0.0)
>>> unscaled_tr = pyro.poutine.trace(model).get_trace(0.0)
>>> bool((scaled_tr.log_prob_sum() == 0.5 * unscaled_tr.log_prob_sum()).all())
True
参数
  • fn – 一个随机函数 (包含 Pyro 原语调用的可调用对象)

  • scale – 一个正比例因子

返回

ScaleMessenger 装饰的随机函数

seed(rng_seed: int) pyro.poutine.seed_messenger.SeedMessenger[source]
seed(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], rng_seed: int) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]

便捷的包装器 SeedMessenger

通过设置随机数生成器的种子,将其设置为预定义的状态。这与在调用 fn 之前调用 pyro.set_rng_seed() 相同。这个处理器在标准 Pyro 后端对原始语句没有额外的效果,但在其他后端(例如 NumPy 后端)中可能会拦截 pyro.sample 调用。

参数
  • fn – 随机函数(包含 Pyro 原始调用)。

  • rng_seed (int) – 随机数生成器种子。

substitute(data: Dict[str, torch.Tensor]) pyro.poutine.substitute_messenger.SubstituteMessenger[source]
substitute(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], data: Dict[str, torch.Tensor]) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]

便捷的包装器 SubstituteMessenger

给定一个包含参数调用的随机函数和一组参数值,创建一个随机函数,其中所有参数调用都替换为固定值。data 应该是一个从名称到值的字典。考虑以下 Pyro 程序

>>> def model(x):
...     a = pyro.param("a", torch.tensor(0.5))
...     x = pyro.sample("x", dist.Bernoulli(probs=a))
...     return x
>>> substituted_model = pyro.poutine.substitute(model, data={"a": torch.tensor(0.3)})

在此示例中,站点 a 现在的值将是 torch.tensor(0.3)。 :param data: 站点名称为键的字典值。 :returns: 用 SubstituteMessenger 装饰的 fn

trace(fn: None = None, graph_type: Optional[Literal['flat', 'dense']] = None, param_only: Optional[bool] = None) pyro.poutine.trace_messenger.TraceMessenger[source]
trace(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], graph_type: Optional[Literal['flat', 'dense']] = None, param_only: Optional[bool] = None) pyro.poutine.trace_messenger.TraceHandler[pyro.poutine.handlers._P, pyro.poutine.handlers._T]

便捷的包装器 TraceMessenger

返回一个处理器,用于记录原始调用的输入和输出及其依赖关系。

考虑以下 Pyro 程序

>>> def model(x):
...     s = pyro.param("s", torch.tensor(0.5))
...     z = pyro.sample("z", dist.Normal(x, s))
...     return z ** 2

我们可以使用 trace 记录其执行,并使用生成的数据结构计算执行中所有采样站点的对数联合概率或提取所有参数。

>>> trace = pyro.poutine.trace(model).get_trace(0.0)
>>> logp = trace.log_prob_sum()
>>> params = [trace.nodes[name]["value"].unconstrained() for name in trace.param_nodes]
参数
  • fn – 一个随机函数 (包含 Pyro 原语调用的可调用对象)

  • graph_type – 指定构建图类型的字符串

  • param_only – 如果为 true,则仅记录参数而不记录采样

返回

TraceMessenger 装饰的随机函数

uncondition(fn: None = None) pyro.poutine.uncondition_messenger.UnconditionMessenger[source]
uncondition(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T] = None) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]

便捷的包装器 UnconditionMessenger

用于强制观测节点的值从其分布中采样(忽略观测值)的消息处理器。

config_enumerate(guide=None, default='parallel', expand=False, num_samples=None, tmc='diagonal')[source]

为 guide 中的所有相关站点配置枚举。这主要与 TraceEnum_ELBO 结合使用。

当配置离散变量的穷举枚举时,这将配置所有满足 .has_enumerate_support == True 的采样站点。当通过 default="parallel", num_samples=n 配置局部并行蒙特卡洛采样时,这将配置所有采样站点。这不会覆盖现有的注释 infer={"enumerate": ...}

这可以作为函数使用

guide = config_enumerate(guide)

或作为装饰器使用

@config_enumerate
def guide1(*args, **kwargs):
    ...

@config_enumerate(default="sequential", expand=True)
def guide2(*args, **kwargs):
    ...
参数
  • guide (callable) – 一个 Pyro 模型,将在 SVI 中用作 guide。

  • default (str) – 要使用的枚举策略,可以是“sequential”、“parallel”或 None。默认为“parallel”。

  • expand (bool) – 是否展开枚举的采样值。详见 enumerate_support()。这仅适用于穷举枚举(num_samples=None)。如果 num_samplesNone,则采样值总是会展开。

  • num_samples (intNone) – 如果非 None,则使用局部蒙特卡洛采样而不是穷举枚举。这适用于连续和离散分布。

  • tmc (stringNone) – 在 Tensor Monte Carlo 中使用的“mixture”或“diagonal”策略

返回

经过注释的 guide

返回类型

callable

Trace

class Trace(graph_type: Literal['flat', 'dense'] = 'flat')[source]

继承自: object

表示执行跟踪中不同 Pyro 原始类型之间关系的图数据结构。

Pyro 程序的执行跟踪是对该程序单次执行中每次 pyro.sample()pyro.param() 调用的记录。跟踪是定向图,其节点表示原始调用或输入/输出,其边表示这些原始调用之间的条件依赖关系。它们由 poutine.trace 创建和填充。

跟踪中的每个节点(或站点)包含站点的名称、输入和输出值,以及由推理算法或用户注释添加的额外元数据。对于 pyro.sample,跟踪还包括站点的随机函数以及用户添加的任何观测数据。

考虑以下 Pyro 程序

>>> def model(x):
...     s = pyro.param("s", torch.tensor(0.5))
...     z = pyro.sample("z", dist.Normal(x, s))
...     return z ** 2

我们可以使用 pyro.poutine.trace 记录其执行,并使用生成的数据结构计算执行中所有采样站点的对数联合概率或提取所有参数。

>>> trace = pyro.poutine.trace(model).get_trace(0.0)
>>> logp = trace.log_prob_sum()
>>> params = [trace.nodes[name]["value"].unconstrained() for name in trace.param_nodes]

我们还可以检查或操作跟踪中的单个节点。trace.nodes 包含一个 collections.OrderedDict,其中包含与 xsz 以及返回值对应的站点名称和元数据

>>> list(name for name in trace.nodes.keys())  
["_INPUT", "s", "z", "_RETURN"]

trace.nodes 的值是节点元数据的字典

>>> trace.nodes["z"]  
{'type': 'sample', 'name': 'z', 'is_observed': False,
 'fn': Normal(), 'value': tensor(0.6480), 'args': (), 'kwargs': {},
 'infer': {}, 'scale': 1.0, 'cond_indep_stack': (),
 'done': True, 'stop': False, 'continuation': None}

'infer' 是一个由用户或算法指定的元数据字典。'args''kwargs' 是通过 pyro.sample 传递给 fn.__call__fn.log_prob 的参数。'scale' 用于在计算对数联合概率时缩放站点的对数概率。'cond_indep_stack' 包含与执行中出现的 pyro.plate 上下文对应的数据结构。'done''stop''continuation' 仅由 Pyro 内部使用。

参数

graph_type (string) – 指定构建跟踪图类型的字符串

add_edge(site1: str, site2: str) None[source]
add_node(site_name: str, **kwargs: Any) None[source]
参数

site_name (string) – 要添加的站点名称

向跟踪添加一个站点。

尝试添加重复节点时会引发错误,而不是静默覆盖。

compute_log_prob(site_filter: Callable[[str, Message], bool] = <function allow_all_sites>) None[source]

计算跟踪的站点对数概率。每个 log_prob 的形状等于对应的 batch_shape。每个 log_prob_sum 是一个标量。这两个计算都经过 memoization(记忆化)。

compute_score_parts() None[source]

计算跟踪每个站点的批量局部得分部分。每个 log_prob 的形状等于对应的 batch_shape。每个 log_prob_sum 是一个标量。所有计算都经过 memoization(记忆化)。

copy() pyro.poutine.trace_struct.Trace[source]

创建自身的浅拷贝,保留节点和边。

detach_() None[source]

分离(原地)跟踪每个采样站点的值。

property edges: Iterable[Tuple[str, str]]
format_shapes(title: str = 'Trace Shapes:', last_site: Optional[str] = None) str[source]

返回一个字符串,显示跟踪中所有站点的形状表格。

iter_stochastic_nodes() Iterator[Tuple[str, Message]][source]
返回

跟踪中随机节点的迭代器。

log_prob_sum(site_filter: Callable[[str, Message], bool] = <function allow_all_sites>) Union[torch.Tensor, float][source]

计算跟踪的站点对数概率。每个 log_prob 的形状等于对应的 batch_shape。每个 log_prob_sum 是一个标量。log_prob_sum 的计算经过 memoization(记忆化)。

返回

总对数概率。

返回类型

torch.Tensor

property nonreparam_stochastic_nodes: List[str]

其随机函数不是可重参数化原始分布的采样站点名称列表

类型

返回值

property observation_nodes: List[str]

观测站点名称列表

类型

返回值

pack_tensors(plate_to_symbol: Optional[Dict[str, str]] = None) None[source]

计算跟踪中张量的紧凑表示。这应该在调用 compute_log_prob()compute_score_parts() 之后调用。

property param_nodes: List[str]

参数站点名称列表

类型

返回值

predecessors(site_name: str) Set[str][source]
remove_node(site_name: str) None[source]
property reparameterized_nodes: List[str]

其随机函数是可重参数化原始分布的采样站点名称列表

类型

返回值

property stochastic_nodes: List[str]

采样站点名称列表

类型

返回值

successors(site_name: str) Set[str][source]
symbolize_dims(plate_to_symbol: Optional[Dict[str, str]] = None) None[source]

为所有张量维度分配唯一符号。

topological_sort(reverse: bool = False) List[str][source]

以拓扑排序的顺序返回节点(站点名称)列表。

参数

reverse (bool) – 是否按逆序返回列表。

返回

拓扑排序的节点(站点名称)列表。

Runtime

class InferDict[source]

继承自: typing_extensions.TypedDict

包含推理相关信息的字典。

这可用于配置按站点推理策略,例如

pyro.sample(
    "x",
    dist.Bernoulli(0.5),
    infer={"enumerate": "parallel"},
)
enumerate (str)

如果是字符串“sequential”或“parallel”,则启用枚举。并行枚举通常更快,但需要安全的广播操作和静态结构。

expand (bool)

枚举时是否展开分布。如果缺失,默认为 False。

is_auxiliary (bool)

采样站点是否为辅助站点,例如用于确定性地转换辅助变量的 guide。如果缺失,默认为 False。

is_observed (bool)

采样站点是否为观测(即非潜在)站点。如果缺失,默认为 False。

num_samples (int)

抽取的样本数量。如果缺失,默认为 1。

obs (可选 torch.Tensor)

观测值,对于潜在变量为 None。如果缺失,默认为 None。

prior (可选 torch.distributions.Distribution)

(内部) 用于 GuideMessenger 中存储模型的先验分布(以其上游站点为条件)。

tmc (str)

在 TraceTMC_ELBO 中是否使用对角线或混合近似进行 Tensor Monte Carlo。

was_observed (bool)

(内部) 在通过 Reweighted Wake Sleep 或 Compiled Sequential Importance Sampling 进行推理的上下文中,采样站点最初是否被观测。

enumerate: typing_extensions.Literal[sequential, parallel]
expand: bool
is_auxiliary: bool
is_observed: bool
num_samples: int
obs: Optional[torch.Tensor]
prior: TorchDistributionMixin
tmc: typing_extensions.Literal[diagonal, mixture]
was_observed: bool
class Message[source]

继承自: typing_extensions.TypedDict, Generic[pyro.poutine.runtime._P, pyro.poutine.runtime._T]

Pyro 内部用于效果处理的消息类型。

消息存储在跟踪对象中,例如

trace.nodes["my_site_name"]  # This is a Message.
type (str)

消息类型,通常是字符串“sample”、“param”、“plate”或“markov”,但也可能是自定义类型。

name (str)

站点名称,通常是采样或参数的名称。

fn (可调用对象)

用于生成采样的分布或函数。

is_observed (bool)

指示值是否被观测的标志。

args (元组)

传递给分布或函数的按位置参数。

kwargs (字典)

传递给分布或函数的关键字参数。

value (torch.Tensor)

采样值(无论是观测值还是采样值)。

scale (torch.Tensor)

对数概率的缩放因子。

mask (bool torch.Tensor)

用于遮盖对数概率的布尔值或张量。

cond_indep_stack (元组)

站点本地的条件独立元数据栈。不可变。

done (bool)

指示消息是否已处理的标志。

stop (bool)

停止进一步处理消息的标志。

continuation (可调用对象)

处理消息后要调用的函数。

infer (可选 InferDict)

推理参数字典。

obs (torch.Tensor)

观测值。

log_prob (torch.Tensor)

采样的对数概率。

log_prob_sum (torch.Tensor)

对数概率之和。

unscaled_log_prob (torch.Tensor)

未缩放的对数概率。

score_parts (pyro.distributions.ScoreParts)

得分部分的集合。

packed (Message)

经过打包的消息,在枚举过程中使用。

args: Tuple
cond_indep_stack: Tuple[CondIndepStackFrame, ...]
continuation: Optional[Callable[[Message], None]]
done: bool
fn: Callable[[pyro.poutine.runtime._P], pyro.poutine.runtime._T]
infer: Optional[pyro.poutine.runtime.InferDict]
is_observed: bool
kwargs: Dict
log_prob: torch.Tensor
log_prob_sum: torch.Tensor
mask: Optional[Union[bool, torch.Tensor]]
name: Optional[str]
obs: Optional[torch.Tensor]
packed: Message
scale: Union[torch.Tensor, float]
score_parts: ScoreParts
stop: bool
type: str
unscaled_log_prob: torch.Tensor
value: Optional[pyro.poutine.runtime._T]
exception NonlocalExit(site: pyro.poutine.runtime.Message, *args, **kwargs)[source]

继承自: Exception

用于从 poutine 执行中非局部退出的异常。

由 poutine.EscapeMessenger 用于返回站点信息。

reset_stack() None[source]

重置栈中剩余帧的状态。对于 poutine.queue 中的多次重复执行是必需的。

am_i_wrapped() bool[source]

检查当前计算是否被 poutine 包装。 :returns: bool

apply_stack(initial_msg: pyro.poutine.runtime.Message) None[source]

按照以下方案在单个站点执行效果栈:

  1. 对于栈中从下到上的每个 Messenger,执行 Messenger._process_message 并传入消息;如果消息字段“stop”为 True,则停止;否则继续。

  2. 应用默认行为(default_process_message)完成剩余的站点执行。

  3. 对于栈中从上到下的每个 Messenger,执行 _postprocess_message 以使用站点结果更新消息和内部 messenger 状态。

  4. 如果消息字段“continuation”非 None,则使用消息调用它。

参数

initial_msg (dict) – 跟踪站点的起始版本

返回

None

default_process_message(msg: pyro.poutine.runtime.Message) None[source]

处理推理中消息的默认方法。

参数

msg – 要处理的消息

返回

None

effectful(fn: None = None, type: Optional[str] = None) Callable[[Callable[[pyro.poutine.runtime._P], pyro.poutine.runtime._T]], Callable[[...], pyro.poutine.runtime._T]][source]
effectful(fn: Callable[[pyro.poutine.runtime._P], pyro.poutine.runtime._T] = None, type: Optional[str] = None) Callable[[...], pyro.poutine.runtime._T]
参数
  • fn – 执行有效果计算的函数或可调用对象

  • type (str) – 操作的类型标签,例如 “sample”

调用 apply_stack() 应用任何活动效果的包装器。

get_mask() Optional[Union[bool, torch.Tensor]][source]

记录封闭的 poutine.mask 处理器产生的效果。

这对于在预测期间避免昂贵的 pyro.factor() 计算非常有用,因为此时无需计算对数密度,例如

def model():
    # ...
    if poutine.get_mask() is not False:
        log_density = my_expensive_computation()
        pyro.factor("foo", log_density)
    # ...
返回

遮罩。

返回类型

None, booltorch.Tensor

get_plates() Tuple[CondIndepStackFrame, ...][source]

记录封闭的 pyro.plate 上下文产生的效果。

返回

pyro.poutine.indep_messenger.CondIndepStackFrame 对象的元组。

返回类型

tuple

Utilities

all_escape(trace: Trace, msg: Message) bool[source]
参数
  • trace – 部分跟踪(partial trace)

  • msg – Pyro 原语站点处的 message

返回

布尔决策值

一个实用函数,用于检查站点是否尚未在跟踪中。

由 EscapeMessenger 用来决定是否在某个站点进行非局部退出。用于近似积分出变量以减少方差的子例程。

discrete_escape(trace: Trace, msg: Message) bool[source]
参数
  • trace – 部分跟踪(partial trace)

  • msg – Pyro 原语站点处的 message

返回

布尔决策值

一个实用函数,用于检查采样站点是否为离散的且尚未在跟踪中。

由 EscapeMessenger 用来决定是否在某个站点进行非局部退出。用于积分出离散变量以减少方差的子例程。

enable_validation(is_validate: bool) None[source]
enum_extend(trace: Trace, msg: Message, num_samples: Optional[int] = None) List[Trace][source]
参数
  • trace – 部分跟踪(partial trace)

  • msg – Pyro 原语站点处的 message

  • num_samples – 返回的扩展跟踪的最大数量。

返回

一个跟踪列表,输入跟踪的副本,每个副本包含一个额外的站点

一个实用函数,用于复制和扩展跟踪,根据输入站点构建站点,这些站点的取值从输入站点的分布的支持中枚举。

用于精确推断和积分出离散变量。

is_validation_enabled() bool[source]
mc_extend(trace: Trace, msg: Message, num_samples: Optional[int] = None) List[Trace][source]
参数
  • trace – 部分跟踪(partial trace)

  • msg – Pyro 原语站点处的 message

  • num_samples – 返回的扩展跟踪的最大数量。

返回

一个跟踪列表,输入跟踪的副本,每个副本包含一个额外的站点

一个实用函数,用于复制和扩展跟踪,根据输入站点构建站点,这些站点的取值从输入站点的函数中采样。

用于对单个采样站点进行蒙特卡洛边际化。

prune_subsample_sites(trace: Trace) Trace[source]

复制并移除跟踪中的所有 subsample 站点。

site_is_factor(site: Message) bool[source]

确定跟踪站点是否源自 factor 语句。

site_is_subsample(site: Message) bool[source]

确定跟踪站点是否源自 plate 内部的 subsample 语句。

信使(Messengers)

信使(Messenger)对象包含 handlers 暴露的效应的实现。高级用户可以修改现有 handlers 背后的信使实现,或者编写新的信使来实现新的效应,并与库的其余部分正确组合。

信使(Messenger)

class Messenger[source]

继承自: object

一个上下文管理器类,用于修改随机函数(即包含 Pyro 原语语句的可调用对象)的行为并添加副作用。

这是基础 Messenger 类。它为所有 Pyro 原语实现了默认行为,因此随机函数 fn 引发的联合分布与 Messenger()(fn) 引发的联合分布相同。

推断过程中传递的消息的转换器类。大多数推断操作都在此类的子类中实现。

classmethod register(fn: Optional[Callable] = None, type: Optional[str] = None, post: Optional[bool] = None) Callable[source]
参数
  • fn – 实现操作的函数

  • type (str) – 操作的名称(也会传递给 effectful()

  • post (bool) – 如果为 True,则将此操作用作后处理

动态地向效应添加操作。对于为库生成包装器很有用。

示例

@SomeMessengerClass.register
def some_function(msg)
    ...do_something...
    return msg
classmethod unregister(fn: Optional[Callable] = None, type: Optional[str] = None) Optional[Callable][source]
参数
  • fn – 实现操作的函数

  • type (str) – 操作的名称(也会传递给 effectful()

动态地从效应中移除操作。对于移除库中的包装器很有用。

示例

SomeMessengerClass.unregister(some_function, "name")
block_messengers(predicate: Callable[[pyro.poutine.messenger.Messenger], bool]) Iterator[List[pyro.poutine.messenger.Messenger]][source]

实验性上下文管理器,用于临时从 _PYRO_STACK 中移除匹配的信使。注意,这不会调用 .__exit__().__enter__() 方法。

这对于选择性地阻止外部 handler 非常有用。

参数

predicate (callable) – 将信使实例映射到布尔值的谓词。这将静音所有满足 bool(predicate(m)) is True 的信使 m

产生(Yields)

被阻止的匹配信使列表。

unwrap(fn: Callable) Callable[source]

递归地解包 poutine。

BlockMessenger

class BlockMessenger(hide_fn: Optional[Callable[[Message], Optional[bool]]] = None, expose_fn: Optional[Callable[[Message], Optional[bool]]] = None, hide_all: bool = True, expose_all: bool = False, hide: Optional[List[str]] = None, expose: Optional[List[str]] = None, hide_types: Optional[List[str]] = None, expose_types: Optional[List[str]] = None)[source]

基类: pyro.poutine.messenger.Messenger

此处理器有选择地对外隐藏 Pyro 原语站点。默认行为:阻塞所有内容。

如果满足以下至少一个条件,则站点被隐藏

  1. hide_fn(msg) is True(not expose_fn(msg)) is True

  2. msg["name"]hide

  3. msg["type"]hide_types

  4. msg["name"] 不在 expose 中且 msg["type"] 不在 expose_types

  5. hidehide_typesexpose_types 都为 None

例如,假设随机函数 fn 有两个采样站点“a”和“b”。那么 BlockMessenger(fn, hide=["a"]) 之外的任何效果都不会应用于站点“a”,而只会看到站点“b”

>>> def fn():
...     a = pyro.sample("a", dist.Normal(0., 1.))
...     return pyro.sample("b", dist.Normal(a, 1.))
>>> fn_inner = pyro.poutine.trace(fn)
>>> fn_outer = pyro.poutine.trace(pyro.poutine.block(fn_inner, hide=["a"]))
>>> trace_inner = fn_inner.get_trace()
>>> trace_outer  = fn_outer.get_trace()
>>> "a" in trace_inner
True
>>> "a" in trace_outer
False
>>> "b" in trace_inner
True
>>> "b" in trace_outer
True
参数
  • fn – 一个随机函数 (包含 Pyro 原语调用的可调用对象)

  • hide_fn – 一个接受站点并返回 True 以隐藏站点,或返回 False/None 以暴露站点的函数。如果指定此参数,则忽略所有其他参数。仅指定 hide_fn 或 expose_fn 中的一个,不能同时指定两者。

  • expose_fn – 一个接受站点并返回 True 以暴露站点,或返回 False/None 以隐藏站点的函数。如果指定此参数,则忽略所有其他参数。仅指定 hide_fn 或 expose_fn 中的一个,不能同时指定两者。

  • hide_all (bool) – 隐藏所有站点

  • expose_all (bool) – 正常暴露所有站点

  • hide (list) – 要隐藏的站点名称列表

  • expose (list) – 要暴露的站点名称列表,同时隐藏所有其他站点

  • hide_types (list) – 要隐藏的站点类型列表

  • expose_types (list) – 要暴露的站点类型列表,同时隐藏所有其他站点

返回

一个用 BlockMessenger 装饰的随机函数

BroadcastMessenger

class BroadcastMessenger[source]

基类: pyro.poutine.messenger.Messenger

在单个或嵌套的 plate 上下文中时,自动广播采样站点处随机函数的批形状。现有的 batch_shape 必须与 cond_indep_stack 中安装的 plate 上下文的大小可广播。

请注意,下面的 model_automatic_broadcast 如何自动化地扩展分布的批形状。这使得模块化 Pyro 模型变得容易,因为子组件与外部包装的 plate 上下文无关。

>>> def model_broadcast_by_hand():
...     with IndepMessenger("batch", 100, dim=-2):
...         with IndepMessenger("components", 3, dim=-1):
...             sample = pyro.sample("sample", dist.Bernoulli(torch.ones(3) * 0.5)
...                                                .expand_by(100))
...             assert sample.shape == torch.Size((100, 3))
...     return sample
>>> @poutine.broadcast
... def model_automatic_broadcast():
...     with IndepMessenger("batch", 100, dim=-2):
...         with IndepMessenger("components", 3, dim=-1):
...             sample = pyro.sample("sample", dist.Bernoulli(torch.tensor(0.5)))
...             assert sample.shape == torch.Size((100, 3))
...     return sample

CollapseMessenger

class CollapseMessenger(*args: Any, **kwargs: Any)[source]

基类: pyro.poutine.trace_messenger.TraceMessenger

实验性 - 通过惰性采样并尝试使用共轭关系来折叠上下文中的所有站点。如果不知道共轭关系,这将失败。使用采样站点结果的代码必须编写为接受 Funsors 而非 Tensors。这需要安装 funsor

警告

这与自动猜测 max_plate_nesting 不兼容。如果在折叠的上下文中出现任何 plate,您应该手动向您的推断算法声明 max_plate_nesting (例如 Trace_ELBO(max_plate_nesting=1))。

ConditionMessenger

class ConditionMessenger(data: Union[Dict[str, torch.Tensor], pyro.poutine.trace_struct.Trace])[source]

基类: pyro.poutine.messenger.Messenger

给定一个包含一些采样语句和按名称提供的观察值字典的随机函数,将这些名称处的采样语句更改为使用这些值的观察。

考虑以下 Pyro 程序

>>> def model(x):
...     s = pyro.param("s", torch.tensor(0.5))
...     z = pyro.sample("z", dist.Normal(x, s))
...     return z ** 2

要观察站点 z 的值,我们可以这样写

>>> conditioned_model = pyro.poutine.condition(model, data={"z": torch.tensor(1.)})

这等效于在 model 中的 pyro.sample(“z”, …) 中将 obs=value 添加为关键字参数。

参数
  • fn – 一个随机函数 (包含 Pyro 原语调用的可调用对象)

  • data – 一个 dict 或一个 Trace

返回

一个用 ConditionMessenger 装饰的随机函数

DoMessenger

class DoMessenger(data: Dict[str, Union[torch.Tensor, numbers.Number]])[source]

基类: pyro.poutine.messenger.Messenger

给定一个包含一些采样语句和按名称提供的值字典的随机函数,将这些站点的返回值设置为等于这些值,就好像它们被硬编码为这些值一样,并引入名称相同但其值不传播的新采样站点。

可以与 condition() 自由组合来表示潜在结果的反事实分布。有关更多详细信息和理论,请参阅 Single World Intervention Graphs [1]。

考虑以下 Pyro 程序

>>> def model(x):
...     s = pyro.param("s", torch.tensor(0.5))
...     z = pyro.sample("z", dist.Normal(x, s))
...     return z ** 2

要对站点 z 进行干预(设置一个值),我们可以这样写

>>> intervened_model = pyro.poutine.do(model, data={"z": torch.tensor(1.)})

这等效于将 z = pyro.sample(“z”, …) 替换为 z = torch.tensor(1.),并引入一个新的采样站点 pyro.sample(“z”, …),其值不用于其他地方。

参考

[1] Single World Intervention Graphs: A Primer,

Thomas Richardson, James Robins

参数
  • fn – 一个随机函数 (包含 Pyro 原语调用的可调用对象)

  • data – 一个将采样站点名称映射到干预值的 dict

返回

一个用 DoMessenger 装饰的随机函数

EnumMessenger

class EnumMessenger(first_available_dim: Optional[int] = None)[source]

基类: pyro.poutine.messenger.Messenger

在标记为 infer={"enumerate": "parallel"} 的离散采样站点上并行枚举。

参数

first_available_dim (int) – 可用于并行枚举的第一个张量维度(从右计数)。此维度及其左侧的所有维度都可能被 Pyro 内部使用。这应该是一个负整数或 None。

enumerate_site(msg: pyro.poutine.runtime.Message) torch.Tensor[source]

EscapeMessenger

class EscapeMessenger(escape_fn: Callable[[pyro.poutine.runtime.Message], bool])[source]

基类: pyro.poutine.messenger.Messenger

一个通过引发 util.NonlocalExit 异常来进行非本地退出的信使

IndepMessenger

class CondIndepStackFrame(name, dim, size, counter, full_size)[source]

基类: tuple

counter: int

字段 3 的别名

dim: Optional[int]

字段 1 的别名

full_size: Optional[int]

字段 4 的别名

name: str

字段 0 的别名

size: int

字段 2 的别名

property vectorized: bool
class IndepMessenger(name: str, size: int, dim: Optional[int] = None, device: Optional[str] = None)[source]

基类: pyro.poutine.messenger.Messenger

此信使跟踪由嵌套的 plate 上下文声明的独立性信息栈。此信息存储在每个 sample/observe 站点中的 cond_indep_stack 中,供 TraceMessenger 使用。

示例

x_axis = IndepMessenger('outer', 320, dim=-1)
y_axis = IndepMessenger('inner', 200, dim=-2)
with x_axis:
    x_noise = sample("x_noise", dist.Normal(loc, scale).expand_by([320]))
with y_axis:
    y_noise = sample("y_noise", dist.Normal(loc, scale).expand_by([200, 1]))
with x_axis, y_axis:
    xy_noise = sample("xy_noise", dist.Normal(loc, scale).expand_by([200, 320]))
property indices: torch.Tensor
next_context() None[source]

递增计数器。

InferConfigMessenger

class InferConfigMessenger(config_fn: Callable[[Message], InferDict])[source]

基类: pyro.poutine.messenger.Messenger

给定一个包含 Pyro 原语调用的可调用对象 fn,以及一个接受轨迹站点并返回字典的可调用对象 config_fn,将采样站点处 infer kwarg 的值更新为 config_fn(site)。

参数
  • fn – 一个随机函数 (包含 Pyro 原语调用的可调用对象)

  • config_fn – 一个接受站点并返回一个 infer 字典的可调用对象

返回

一个用 InferConfigMessenger 装饰的随机函数

LiftMessenger

class LiftMessenger(prior: Union[Callable, pyro.distributions.distribution.Distribution, Dict[str, Union[pyro.distributions.distribution.Distribution, Callable]]])[source]

基类: pyro.poutine.messenger.Messenger

给定一个包含 param 调用和先验分布的随机函数,创建一个随机函数,其中所有 param 调用都被替换为从先验中采样。先验应该是一个可调用对象或一个名称到可调用对象的字典。

考虑以下 Pyro 程序

>>> def model(x):
...     s = pyro.param("s", torch.tensor(0.5))
...     z = pyro.sample("z", dist.Normal(x, s))
...     return z ** 2
>>> lifted_model = pyro.poutine.lift(model, prior={"s": dist.Exponential(0.3)})

lift 使用 prior 中的分布使 param 语句的行为类似于 sample 语句。在此示例中,站点 s 现在将表现得好像它被 s = pyro.sample("s", dist.Exponential(0.3)) 替换了一样

>>> tr = pyro.poutine.trace(lifted_model).get_trace(0.0)
>>> tr.nodes["s"]["type"] == "sample"
True
>>> tr2 = pyro.poutine.trace(lifted_model).get_trace(0.0)
>>> bool((tr2.nodes["s"]["value"] == tr.nodes["s"]["value"]).all())
False
参数
  • fn – 其参数将被提升为随机值的函数

  • prior – 作为 Distribution 或随机函数字典形式的先验函数

返回

一个用 LiftMessenger 装饰的 fn

MarkovMessenger

class MarkovMessenger(history: int = 1, keep: bool = False, dim: Optional[int] = None, name: Optional[str] = None)[source]

基类: pyro.poutine.reentrant_messenger.ReentrantMessenger

马尔可夫依赖声明。

这在统计上相当于内存管理域。

参数
  • history (int) – 从当前上下文可见的先前上下文的数量。默认为 1。如果为零,则类似于 pyro.plate

  • keep (bool) – 如果为真,则帧可回放。这在分支时很重要:如果 keep=True,同一级别的相邻分支可以相互依赖;如果 keep=False,则相邻分支是独立的(以其共享祖先为条件)。

  • dim (int) – 可选的维度,用于此独立性索引。接口存根,行为尚未实现。

  • name (str) – 可选的唯一名称,用于帮助推断算法匹配模型和 guide 之间的 pyro.markov() 站点。接口存根,行为尚未实现。

generator(iterable: Iterable[int]) typing_extensions.Self[source]

MaskMessenger

class MaskMessenger(mask: Union[bool, torch.BoolTensor])[source]

基类: pyro.poutine.messenger.Messenger

给定一个包含一些批处理采样语句和掩码张量的随机函数,按元素掩掉一些采样语句。

参数
  • fn – 一个随机函数 (包含 Pyro 原语调用的可调用对象)

  • mask (torch.BoolTensor) – 一个取值为 {0,1} 的掩码张量(1 表示包含站点,0 表示排除站点)

返回

一个用 MaskMessenger 装饰的随机函数

PlateMessenger

class PlateMessenger(name: str, size: Optional[int] = None, subsample_size: Optional[int] = None, subsample: Optional[torch.Tensor] = None, dim: Optional[int] = None, use_cuda: Optional[bool] = None, device: Optional[str] = None)[source]

基类: pyro.poutine.subsample_messenger.SubsampleMessenger

广播神奇功能的瑞士军刀:结合了形状推断、独立性标注和二次采样

block_plate(name: Optional[str] = None, dim: Optional[int] = None, *, strict: bool = True) Iterator[None][source]

实验性上下文管理器,用于临时阻止单个外部 plate。

这对于在 plated 上下文中对辅助变量或惰性采样全局变量非常有用。例如,以下模型是等效的

示例

def model_1(data):
    loc = pyro.sample("loc", dist.Normal(0, 1))
    with pyro.plate("data", len(data)):
        with block_plate("data"):
            scale = pyro.sample("scale", dist.LogNormal(0, 1))
        pyro.sample("x", dist.Normal(loc, scale))

def model_2(data):
    loc = pyro.sample("loc", dist.Normal(0, 1))
    scale = pyro.sample("scale", dist.LogNormal(0, 1))
    with pyro.plate("data", len(data)):
        pyro.sample("x", dist.Normal(loc, scale))
参数
  • name (str) – 要匹配的 plate 的可选名称。

  • dim (int) – 要匹配的 plate 的可选 dim。必须为负数。

  • strict (bool) – 如果未找到匹配的 plate 是否报错。默认为 True。

引发(Raises)

如果在未找到外部 plate 且 strict=True 时,引发 ValueError。

ReentrantMessenger

class ReentrantMessenger[source]

基类: pyro.poutine.messenger.Messenger

ReparamMessenger

class ReparamHandler(msngr, fn: Callable[[pyro.poutine.reparam_messenger._P], pyro.poutine.reparam_messenger._T])[source]

基类: Generic[pyro.poutine.reparam_messenger._P, pyro.poutine.reparam_messenger._T]

重新参数化 poutine。

class ReparamMessenger(config: Union[Dict[str, Reparam], Callable[[Message], Optional[Reparam]]])[source]

基类: pyro.poutine.messenger.Messenger

将每个受影响的采样站点重新参数化为一个或多个辅助采样站点,然后进行确定性变换 [1]。

要指定重新参数化器,请将 config 字典或可调用对象传递给构造函数。有关可用的重新参数化器,请参阅 pyro.infer.reparam 模块。

注意,一些重新参数化器可以检查它们影响的函数的 *args,**kwargs 输入;这些重新参数化器要求将 poutine.reparam 用作装饰器而不是上下文管理器。

[1] Maria I. Gorinova, Dave Moore, Matthew D. Hoffman (2019)

《概率程序的自动重新参数化》 https://arxiv.org/pdf/1906.03028.pdf

参数

config (dictcallable) – 配置,可以是将站点名称映射到 Reparameterizer 的 dict,也可以是将站点映射到 Reparam 或 None 的函数。有关内置的配置策略,请参阅 pyro.infer.reparam.strategies

ReplayMessenger

class ReplayMessenger(trace: Optional[Trace] = None, params: Optional[Dict[str, torch.Tensor] = None)[source]

基类: pyro.poutine.messenger.Messenger

给定一个包含 Pyro 原语调用的可调用对象,返回一个可调用对象,该对象运行原始函数,并在新轨迹中重用轨迹中站点的值

考虑以下 Pyro 程序

>>> def model(x):
...     s = pyro.param("s", torch.tensor(0.5))
...     z = pyro.sample("z", dist.Normal(x, s))
...     return z ** 2

replay 使 sample 语句的行为如同它们在轨迹中相应站点处采样了值一样

>>> old_trace = pyro.poutine.trace(model).get_trace(1.0)
>>> replayed_model = pyro.poutine.replay(model, trace=old_trace)
>>> bool(replayed_model(0.0) == old_trace.nodes["_RETURN"]["value"])
True
参数
  • fn – 一个随机函数 (包含 Pyro 原语调用的可调用对象)

  • trace – 用于回放的 Trace 数据结构

  • params – fn 中 param 站点名称和约束值的字典,用于回放

返回

一个用 ReplayMessenger 装饰的随机函数

ScaleMessenger

class ScaleMessenger(scale: Union[float, torch.Tensor])[source]

基类: pyro.poutine.messenger.Messenger

给定一个包含一些采样语句的随机函数和一个正比例因子,对函数中所有采样(sample)和观测(observe)站点的得分进行缩放。

考虑以下 Pyro 程序

>>> def model(x):
...     s = pyro.param("s", torch.tensor(0.5))
...     pyro.sample("z", dist.Normal(x, s), obs=torch.tensor(1.0))

scale 对采样站点的对数概率进行乘性缩放

>>> scaled_model = pyro.poutine.scale(model, scale=0.5)
>>> scaled_tr = pyro.poutine.trace(scaled_model).get_trace(0.0)
>>> unscaled_tr = pyro.poutine.trace(model).get_trace(0.0)
>>> bool((scaled_tr.log_prob_sum() == 0.5 * unscaled_tr.log_prob_sum()).all())
True
参数
  • fn – 一个随机函数 (包含 Pyro 原语调用的可调用对象)

  • scale – 一个正比例因子

返回

ScaleMessenger 装饰的随机函数

SeedMessenger

class SeedMessenger(rng_seed: int)[source]

基类: pyro.poutine.messenger.Messenger

通过设置随机数生成器的种子,将其设置为预定义的状态。这与在调用 fn 之前调用 pyro.set_rng_seed() 相同。这个处理器在标准 Pyro 后端对原始语句没有额外的效果,但在其他后端(例如 NumPy 后端)中可能会拦截 pyro.sample 调用。

参数
  • fn – 随机函数(包含 Pyro 原始调用)。

  • rng_seed (int) – 随机数生成器种子。

SubsampleMessenger

class SubsampleMessenger(name: str, size: Optional[int] = None, subsample_size: Optional[int] = None, subsample: Optional[torch.Tensor] = None, dim: Optional[int] = None, use_cuda: Optional[bool] = None, device: Optional[str] = None)[source]

基类: pyro.poutine.indep_messenger.IndepMessenger

IndepMessenger 的扩展,包括二次采样。

SubstituteMessenger

class SubstituteMessenger(data: Dict[str, torch.Tensor])[source]

基类: pyro.poutine.messenger.Messenger

给定一个包含参数调用的随机函数和一组参数值,创建一个随机函数,其中所有参数调用都替换为固定值。data 应该是一个从名称到值的字典。考虑以下 Pyro 程序

>>> def model(x):
...     a = pyro.param("a", torch.tensor(0.5))
...     x = pyro.sample("x", dist.Bernoulli(probs=a))
...     return x
>>> substituted_model = pyro.poutine.substitute(model, data={"a": torch.tensor(0.3)})

在此示例中,站点 a 现在的值将是 torch.tensor(0.3)。 :param data: 站点名称为键的字典值。 :returns: 用 SubstituteMessenger 装饰的 fn

TraceMessenger

class TraceHandler(msngr: pyro.poutine.trace_messenger.TraceMessenger, fn: Callable[[pyro.poutine.reparam_messenger._P], pyro.poutine.reparam_messenger._T])[source]

基类: Generic[pyro.poutine.reparam_messenger._P, pyro.poutine.reparam_messenger._T]

执行跟踪 poutine。

TraceHandler 记录每个 Pyro 原语的输入和输出,并将它们作为站点存储在 Trace() 中。理论上,这对于每种推断算法来说是足够的信息(以及变量中隐式计算图的信息?)。

我们也可以用它来进行可视化。

get_trace(*args, **kwargs) pyro.poutine.trace_struct.Trace[source]
返回

数据结构

返回类型

pyro.poutine.Trace

一个非常常见的用例的辅助方法。调用此 poutine 并返回其跟踪而不是函数的返回值。

property trace: pyro.poutine.trace_struct.Trace
class TraceMessenger(graph_type: Optional[Literal['flat', 'dense']] = None, param_only: Optional[bool] = None)[source]

基类: pyro.poutine.messenger.Messenger

返回一个处理器,用于记录原始调用的输入和输出及其依赖关系。

考虑以下 Pyro 程序

>>> def model(x):
...     s = pyro.param("s", torch.tensor(0.5))
...     z = pyro.sample("z", dist.Normal(x, s))
...     return z ** 2

我们可以使用 trace 记录其执行,并使用生成的数据结构计算执行中所有采样站点的对数联合概率或提取所有参数。

>>> trace = pyro.poutine.trace(model).get_trace(0.0)
>>> logp = trace.log_prob_sum()
>>> params = [trace.nodes[name]["value"].unconstrained() for name in trace.param_nodes]
参数
  • fn – 一个随机函数 (包含 Pyro 原语调用的可调用对象)

  • graph_type – 指定构建图类型的字符串

  • param_only – 如果为 true,则仅记录参数而不记录采样

返回

TraceMessenger 装饰的随机函数

get_trace() pyro.poutine.trace_struct.Trace[source]
返回

数据结构

返回类型

pyro.poutine.Trace

一个非常常见的用例的辅助方法。返回 self.trace 的浅拷贝。

identify_dense_edges(trace: pyro.poutine.trace_struct.Trace) None[source]

通过添加基于存储在每个站点中的 cond_indep_stack 信息的所有边,就地修改跟踪。

UnconditionMessenger

class UnconditionMessenger[source]

基类: pyro.poutine.messenger.Messenger

用于强制观测节点的值从其分布中采样(忽略观测值)的消息处理器。

GuideMessenger

class GuideMessenger(model: Callable)[source]

基类: pyro.poutine.trace_messenger.TraceMessenger, abc.ABC

基于效应的 guides 的抽象基类。

派生类必须实现 get_posterior() 方法。

property model: Callable
__call__(*args, **kwargs) Dict[str, torch.Tensor][source]

从 guide 中抽取后验样本,并根据这些样本重放 model。

返回

一个将采样站点名称映射到采样值的字典。这包括潜在、确定性和观测值。

返回类型

字典(dict)

abstract get_posterior(name: str, prior: TorchDistributionMixin) Union[TorchDistributionMixin, torch.Tensor][source]

抽象方法,用于计算后验分布或给定先验分布并以前向上游后验样本为条件采样后验值。

实现可以在此函数内部使用 pyro.parampyro.sample,但 pyro.sample 语句应设置 infer={"is_auxiliary": True"}

实现可以访问更多信息进行计算

  • value = self.upstream_value(name) 是一个上游

    样本点或确定性点的值。

  • self.trace 是上游站点的轨迹,可能对其他信息有用,例如 self.trace.nodes["my_site"]["fn"]self.trace.nodes["my_site"]["cond_indep_stack"]

  • args, kwargs = self.args_kwargs 是模型的输入,并且

    可能对均摊有用。

参数
  • name (str) – 要采样的样本点的名称。

  • prior (Distribution) – 此样本点的先验分布(以来自后验分布的上游样本为条件)。

返回

后验分布或来自后验分布的样本。

返回类型

Distributiontorch.Tensor

upstream_value(name: str) Optional[torch.Tensor][source]

用于 get_posterior()

返回

一个上游样本点或确定性点的值

返回类型

torch.Tensor

get_traces() Tuple[pyro.poutine.trace_struct.Trace, pyro.poutine.trace_struct.Trace][source]

在运行 __call__() 后可以调用此函数来提取一对轨迹。

与生成一对轨迹的轨迹回放模式相反,GuideMessenger 交错执行模型和引导计算,因此只需要一次 guide(*args, **kwargs) 调用即可创建两个轨迹。此函数仅从该引导的 .trace 属性中提取相关信息。

返回

一对 (model_trace, guide_trace)

返回类型

tuple