Poutine (效果处理器)¶
在内置的推断算法之下,Pyro 拥有一个由可组合的效果处理器构成的库,用于创建新的推断算法以及处理概率程序。Pyro 的推断算法都是通过将这些处理器应用于随机函数构建而成。为了对效果处理器是什么以及它们解决了什么问题有一个总体了解,请阅读 Matija Pretnar 所著的《代数效应和处理器简介》。
处理器¶
Poutine 是一个由可组合的效果处理器构成的库,用于记录和修改 Pyro 程序的行为。这些底层组件简化了新推断算法和行为的实现。
处理器可以用作高阶函数、装饰器或上下文管理器来修改函数或代码块的行为
例如,考虑以下 Pyro 程序
>>> def model(x):
... s = pyro.param("s", torch.tensor(0.5))
... z = pyro.sample("z", dist.Normal(x, s))
... return z ** 2
我们可以使用 condition
将采样站点标记为观察值,它返回一个与 model
具有相同输入和输出签名的可调用对象
>>> conditioned_model = poutine.condition(model, data={"z": 1.0})
我们也可以将处理器用作装饰器
>>> @pyro.condition(data={"z": 1.0})
... def model(x):
... s = pyro.param("s", torch.tensor(0.5))
... z = pyro.sample("z", dist.Normal(x, s))
... return z ** 2
或用作上下文管理器
>>> with pyro.condition(data={"z": 1.0}):
... s = pyro.param("s", torch.tensor(0.5))
... z = pyro.sample("z", dist.Normal(0., s))
... y = z ** 2
处理器可以自由组合
>>> conditioned_model = poutine.condition(model, data={"z": 1.0})
>>> traced_model = poutine.trace(conditioned_model)
许多推断算法或算法组件只需几行代码即可实现
guide_tr = poutine.trace(guide).get_trace(...)
model_tr = poutine.trace(poutine.replay(conditioned_model, trace=guide_tr)).get_trace(...)
monte_carlo_elbo = model_tr.log_prob_sum() - guide_tr.log_prob_sum()
- block(fn: None = None, hide_fn: Optional[Callable[[Message], Optional[bool]]] = None, expose_fn: Optional[Callable[[Message], Optional[bool]]] = None, hide_all: bool = True, expose_all: bool = False, hide: Optional[List[str]] = None, expose: Optional[List[str]] = None, hide_types: Optional[List[str]] = None, expose_types: Optional[List[str]] = None) pyro.poutine.block_messenger.BlockMessenger [源代码]¶
- block(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], hide_fn: Optional[Callable[[Message], Optional[bool]]] = None, expose_fn: Optional[Callable[[Message], Optional[bool]]] = None, hide_all: bool = True, expose_all: bool = False, hide: Optional[List[str]] = None, expose: Optional[List[str]] = None, hide_types: Optional[List[str]] = None, expose_types: Optional[List[str]] = None) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]
BlockMessenger
的便捷包装器此处理器有选择地对外隐藏 Pyro 原语站点。默认行为:阻塞所有内容。
如果满足以下至少一个条件,则站点被隐藏
hide_fn(msg) is True
或(not expose_fn(msg)) is True
msg["name"] 在 hide 中
msg["type"] 在 hide_types 中
msg["name"] 不在 expose 中且 msg["type"] 不在 expose_types 中
hide
、hide_types
和expose_types
都为None
例如,假设随机函数 fn 有两个采样站点“a”和“b”。那么
BlockMessenger(fn, hide=["a"])
之外的任何效果都不会应用于站点“a”,而只会看到站点“b”>>> def fn(): ... a = pyro.sample("a", dist.Normal(0., 1.)) ... return pyro.sample("b", dist.Normal(a, 1.)) >>> fn_inner = pyro.poutine.trace(fn) >>> fn_outer = pyro.poutine.trace(pyro.poutine.block(fn_inner, hide=["a"])) >>> trace_inner = fn_inner.get_trace() >>> trace_outer = fn_outer.get_trace() >>> "a" in trace_inner True >>> "a" in trace_outer False >>> "b" in trace_inner True >>> "b" in trace_outer True
- 参数
fn – 一个随机函数 (包含 Pyro 原语调用的可调用对象)
hide_fn – 一个接受站点并返回 True 以隐藏站点,或返回 False/None 以暴露站点的函数。如果指定此参数,则忽略所有其他参数。仅指定 hide_fn 或 expose_fn 中的一个,不能同时指定两者。
expose_fn – 一个接受站点并返回 True 以暴露站点,或返回 False/None 以隐藏站点的函数。如果指定此参数,则忽略所有其他参数。仅指定 hide_fn 或 expose_fn 中的一个,不能同时指定两者。
hide_all (bool) – 隐藏所有站点
expose_all (bool) – 正常暴露所有站点
hide (list) – 要隐藏的站点名称列表
expose (list) – 要暴露的站点名称列表,同时隐藏所有其他站点
hide_types (list) – 要隐藏的站点类型列表
expose_types (list) – 要暴露的站点类型列表,同时隐藏所有其他站点
- 返回
一个用
BlockMessenger
装饰的随机函数
- broadcast(fn: None = None) pyro.poutine.broadcast_messenger.BroadcastMessenger [源代码]¶
- broadcast(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]
BroadcastMessenger
的便捷包装器在单个或嵌套的
plate
上下文中时,自动广播采样站点处随机函数的批形状。现有的 batch_shape 必须与 cond_indep_stack 中安装的plate
上下文的大小可广播。请注意,下面的 model_automatic_broadcast 如何自动化地扩展分布的批形状。这使得模块化 Pyro 模型变得容易,因为子组件与外部包装的
plate
上下文无关。>>> def model_broadcast_by_hand(): ... with IndepMessenger("batch", 100, dim=-2): ... with IndepMessenger("components", 3, dim=-1): ... sample = pyro.sample("sample", dist.Bernoulli(torch.ones(3) * 0.5) ... .expand_by(100)) ... assert sample.shape == torch.Size((100, 3)) ... return sample
>>> @poutine.broadcast ... def model_automatic_broadcast(): ... with IndepMessenger("batch", 100, dim=-2): ... with IndepMessenger("components", 3, dim=-1): ... sample = pyro.sample("sample", dist.Bernoulli(torch.tensor(0.5))) ... assert sample.shape == torch.Size((100, 3)) ... return sample
- collapse(fn: None = None, *args: Any, **kwargs: Any) pyro.poutine.collapse_messenger.CollapseMessenger [源代码]¶
- collapse(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], *args: Any, **kwargs: Any) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]
CollapseMessenger
的便捷包装器实验性 - 通过惰性采样并尝试使用共轭关系来折叠上下文中的所有站点。如果不知道共轭关系,这将失败。使用采样站点结果的代码必须编写为接受 Funsors 而非 Tensors。这需要安装
funsor
。警告
这与自动猜测
max_plate_nesting
不兼容。如果在折叠的上下文中出现任何 plate,您应该手动向您的推断算法声明max_plate_nesting
(例如Trace_ELBO(max_plate_nesting=1)
)。
- condition(data: Union[Dict[str, torch.Tensor], Trace]) pyro.poutine.condition_messenger.ConditionMessenger [源代码]¶
- condition(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], data: Union[Dict[str, torch.Tensor], Trace]) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]
ConditionMessenger
的便捷包装器给定一个包含一些采样语句和按名称提供的观察值字典的随机函数,将这些名称处的采样语句更改为使用这些值的观察。
考虑以下 Pyro 程序
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2
要观察站点 z 的值,我们可以这样写
>>> conditioned_model = pyro.poutine.condition(model, data={"z": torch.tensor(1.)})
这等效于在 model 中的 pyro.sample(“z”, …) 中将 obs=value 添加为关键字参数。
- 参数
fn – 一个随机函数 (包含 Pyro 原语调用的可调用对象)
data – 一个 dict 或一个
Trace
- 返回
一个用
ConditionMessenger
装饰的随机函数
- do(data: Dict[str, Union[torch.Tensor, numbers.Number]]) pyro.poutine.do_messenger.DoMessenger [源代码]¶
- do(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], data: Dict[str, Union[torch.Tensor, numbers.Number]]) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]
DoMessenger
的便捷包装器给定一个包含一些采样语句和按名称提供的值字典的随机函数,将这些站点的返回值设置为等于这些值,就好像它们被硬编码为这些值一样,并引入名称相同但其值不传播的新采样站点。
可以与
condition()
自由组合来表示潜在结果的反事实分布。有关更多详细信息和理论,请参阅 Single World Intervention Graphs [1]。考虑以下 Pyro 程序
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2
要对站点 z 进行干预(设置一个值),我们可以这样写
>>> intervened_model = pyro.poutine.do(model, data={"z": torch.tensor(1.)})
这等效于将 z = pyro.sample(“z”, …) 替换为 z = torch.tensor(1.),并引入一个新的采样站点 pyro.sample(“z”, …),其值不用于其他地方。
参考
- [1] Single World Intervention Graphs: A Primer,
Thomas Richardson, James Robins
- 参数
fn – 一个随机函数 (包含 Pyro 原语调用的可调用对象)
data – 一个将采样站点名称映射到干预值的
dict
- 返回
一个用
DoMessenger
装饰的随机函数
- enum(fn: None = None, first_available_dim: Optional[int] = None) pyro.poutine.enum_messenger.EnumMessenger [源代码]¶
- enum(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], first_available_dim: Optional[int] = None) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]
EnumMessenger
的便捷包装器在标记为
infer={"enumerate": "parallel"}
的离散采样站点上并行枚举。- 参数
first_available_dim (int) – 可用于并行枚举的第一个张量维度(从右计数)。此维度及其左侧的所有维度都可能被 Pyro 内部使用。这应该是一个负整数或 None。
- escape(escape_fn: Callable[[Message], bool]) pyro.poutine.escape_messenger.EscapeMessenger [源代码]¶
- escape(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], escape_fn: Callable[[Message], bool]) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]
EscapeMessenger
的便捷包装器一个通过引发 util.NonlocalExit 异常来进行非本地退出的信使
- infer_config(config_fn: Callable[[Message], InferDict]) pyro.poutine.infer_config_messenger.InferConfigMessenger [源代码]¶
- infer_config(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], config_fn: Callable[[Message], InferDict]) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]
InferConfigMessenger
的便捷包装器给定一个包含 Pyro 原语调用的可调用对象 fn,以及一个接受轨迹站点并返回字典的可调用对象 config_fn,将采样站点处 infer kwarg 的值更新为 config_fn(site)。
- 参数
fn – 一个随机函数 (包含 Pyro 原语调用的可调用对象)
config_fn – 一个接受站点并返回一个 infer 字典的可调用对象
- 返回
一个用
InferConfigMessenger
装饰的随机函数
- lift(prior: Union[Callable, Distribution, Dict[str, Union[Distribution, Callable]]]) pyro.poutine.lift_messenger.LiftMessenger [源代码]¶
- lift(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], prior: Union[Callable, Distribution, Dict[str, Union[Distribution, Callable]]]) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]
LiftMessenger
的便捷包装器给定一个包含 param 调用和先验分布的随机函数,创建一个随机函数,其中所有 param 调用都被替换为从先验中采样。先验应该是一个可调用对象或一个名称到可调用对象的字典。
考虑以下 Pyro 程序
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2 >>> lifted_model = pyro.poutine.lift(model, prior={"s": dist.Exponential(0.3)})
lift
使用prior
中的分布使param
语句的行为类似于sample
语句。在此示例中,站点 s 现在将表现得好像它被s = pyro.sample("s", dist.Exponential(0.3))
替换了一样>>> tr = pyro.poutine.trace(lifted_model).get_trace(0.0) >>> tr.nodes["s"]["type"] == "sample" True >>> tr2 = pyro.poutine.trace(lifted_model).get_trace(0.0) >>> bool((tr2.nodes["s"]["value"] == tr.nodes["s"]["value"]).all()) False
- 参数
fn – 其参数将被提升为随机值的函数
prior – 作为 Distribution 或随机函数字典形式的先验函数
- 返回
一个用
LiftMessenger
装饰的fn
- markov(fn: None = None, history: int = 1, keep: bool = False, dim: Optional[int] = None, name: Optional[str] = None) pyro.poutine.markov_messenger.MarkovMessenger [源代码]¶
- markov(fn: Iterable[int] = None, history: int = 1, keep: bool = False, dim: Optional[int] = None, name: Optional[str] = None) pyro.poutine.markov_messenger.MarkovMessenger
- markov(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T] = None, history: int = 1, keep: bool = False, dim: Optional[int] = None, name: Optional[str] = None) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]
马尔可夫依赖声明。
这可以通过多种方式使用
用作上下文管理器
用作递归函数的装饰器
用作马尔可夫链的迭代器
- mask(mask: Union[bool, torch.BoolTensor]) pyro.poutine.mask_messenger.MaskMessenger [源代码]¶
- mask(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], mask: Union[bool, torch.BoolTensor]) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]
MaskMessenger
的便捷包装器给定一个包含一些批处理采样语句和掩码张量的随机函数,按元素掩掉一些采样语句。
- 参数
fn – 一个随机函数 (包含 Pyro 原语调用的可调用对象)
mask (torch.BoolTensor) – 一个取值为
{0,1}
的掩码张量(1 表示包含站点,0 表示排除站点)
- 返回
一个用
MaskMessenger
装饰的随机函数
- queue(fn=None, queue=None, max_tries=None, extend_fn=None, escape_fn=None, num_samples=None)[源代码]¶
用于离散变量的顺序枚举。
给定一个随机函数和一个队列,从队列中的一个完整轨迹中返回一个返回值。
- 参数
fn – 一个随机函数 (包含 Pyro 原语调用的可调用对象)
queue – 一个用于存储部分轨迹的队列数据结构,例如 multiprocessing.Queue
max_tries – 计算一个完整轨迹的最大尝试次数
extend_fn – 一个函数(可能是随机的),它接受一个部分轨迹和一个站点,并返回一个扩展轨迹列表
escape_fn – 一个函数(可能是随机的),它接受一个部分轨迹和一个站点,并返回一个布尔值来决定是否退出
num_samples – extend_fn 返回的扩展轨迹的可选数量
- 返回
一个用 poutine 逻辑装饰的随机函数
- reparam(config: Union[Dict[str, Reparam], Callable[[Message], Optional[Reparam]]]) pyro.poutine.reparam_messenger.ReparamMessenger [源代码]¶
- reparam(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], config: Union[Dict[str, Reparam], Callable[[Message], Optional[Reparam]]]) pyro.poutine.reparam_messenger.ReparamHandler[pyro.poutine.handlers._P, pyro.poutine.handlers._T]
ReparamMessenger
的便捷包装器将每个受影响的采样站点重新参数化为一个或多个辅助采样站点,然后进行确定性变换 [1]。
要指定重新参数化器,请将
config
字典或可调用对象传递给构造函数。有关可用的重新参数化器,请参阅pyro.infer.reparam
模块。注意,一些重新参数化器可以检查它们影响的函数的
*args,**kwargs
输入;这些重新参数化器要求将poutine.reparam
用作装饰器而不是上下文管理器。- [1] Maria I. Gorinova, Dave Moore, Matthew D. Hoffman (2019)
《概率程序的自动重新参数化》 https://arxiv.org/pdf/1906.03028.pdf
- 参数
config (dict 或 callable) – 配置,可以是将站点名称映射到
Reparameterizer
的 dict,也可以是将站点映射到Reparam
或 None 的函数。有关内置的配置策略,请参阅pyro.infer.reparam.strategies
。
- replay(fn: None = None, trace: Optional[Trace] = None, params: Optional[Dict[str, torch.Tensor]] = None) pyro.poutine.replay_messenger.ReplayMessenger [源代码]¶
- replay(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], trace: Optional[Trace] = None, params: Optional[Dict[str, torch.Tensor]] = None) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]
ReplayMessenger
的便捷包装器给定一个包含 Pyro 原语调用的可调用对象,返回一个可调用对象,该对象运行原始函数,并在新轨迹中重用轨迹中站点的值
考虑以下 Pyro 程序
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2
replay
使sample
语句的行为如同它们在轨迹中相应站点处采样了值一样>>> old_trace = pyro.poutine.trace(model).get_trace(1.0) >>> replayed_model = pyro.poutine.replay(model, trace=old_trace) >>> bool(replayed_model(0.0) == old_trace.nodes["_RETURN"]["value"]) True
- 参数
fn – 一个随机函数 (包含 Pyro 原语调用的可调用对象)
trace – 用于回放的
Trace
数据结构params – fn 中 param 站点名称和约束值的字典,用于回放
- 返回
一个用
ReplayMessenger
装饰的随机函数
- scale(scale: Union[float, torch.Tensor]) pyro.poutine.scale_messenger.ScaleMessenger [source]¶
- scale(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], scale: Union[float, torch.Tensor]) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]
便捷的包装器
ScaleMessenger
给定一个包含一些采样语句的随机函数和一个正比例因子,对函数中所有采样(sample)和观测(observe)站点的得分进行缩放。
考虑以下 Pyro 程序
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... pyro.sample("z", dist.Normal(x, s), obs=torch.tensor(1.0))
scale
对采样站点的对数概率进行乘性缩放>>> scaled_model = pyro.poutine.scale(model, scale=0.5) >>> scaled_tr = pyro.poutine.trace(scaled_model).get_trace(0.0) >>> unscaled_tr = pyro.poutine.trace(model).get_trace(0.0) >>> bool((scaled_tr.log_prob_sum() == 0.5 * unscaled_tr.log_prob_sum()).all()) True
- 参数
fn – 一个随机函数 (包含 Pyro 原语调用的可调用对象)
scale – 一个正比例因子
- 返回
用
ScaleMessenger
装饰的随机函数
- seed(rng_seed: int) pyro.poutine.seed_messenger.SeedMessenger [source]¶
- seed(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], rng_seed: int) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]
便捷的包装器
SeedMessenger
通过设置随机数生成器的种子,将其设置为预定义的状态。这与在调用 fn 之前调用
pyro.set_rng_seed()
相同。这个处理器在标准 Pyro 后端对原始语句没有额外的效果,但在其他后端(例如 NumPy 后端)中可能会拦截pyro.sample
调用。- 参数
fn – 随机函数(包含 Pyro 原始调用)。
rng_seed (int) – 随机数生成器种子。
- substitute(data: Dict[str, torch.Tensor]) pyro.poutine.substitute_messenger.SubstituteMessenger [source]¶
- substitute(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], data: Dict[str, torch.Tensor]) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]
便捷的包装器
SubstituteMessenger
给定一个包含参数调用的随机函数和一组参数值,创建一个随机函数,其中所有参数调用都替换为固定值。data 应该是一个从名称到值的字典。考虑以下 Pyro 程序
>>> def model(x): ... a = pyro.param("a", torch.tensor(0.5)) ... x = pyro.sample("x", dist.Bernoulli(probs=a)) ... return x >>> substituted_model = pyro.poutine.substitute(model, data={"a": torch.tensor(0.3)})
在此示例中,站点 a 现在的值将是 torch.tensor(0.3)。 :param data: 站点名称为键的字典值。 :returns: 用
SubstituteMessenger
装饰的fn
- trace(fn: None = None, graph_type: Optional[Literal['flat', 'dense']] = None, param_only: Optional[bool] = None) pyro.poutine.trace_messenger.TraceMessenger [source]¶
- trace(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], graph_type: Optional[Literal['flat', 'dense']] = None, param_only: Optional[bool] = None) pyro.poutine.trace_messenger.TraceHandler[pyro.poutine.handlers._P, pyro.poutine.handlers._T]
便捷的包装器
TraceMessenger
返回一个处理器,用于记录原始调用的输入和输出及其依赖关系。
考虑以下 Pyro 程序
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2
我们可以使用
trace
记录其执行,并使用生成的数据结构计算执行中所有采样站点的对数联合概率或提取所有参数。>>> trace = pyro.poutine.trace(model).get_trace(0.0) >>> logp = trace.log_prob_sum() >>> params = [trace.nodes[name]["value"].unconstrained() for name in trace.param_nodes]
- 参数
fn – 一个随机函数 (包含 Pyro 原语调用的可调用对象)
graph_type – 指定构建图类型的字符串
param_only – 如果为 true,则仅记录参数而不记录采样
- 返回
用
TraceMessenger
装饰的随机函数
- uncondition(fn: None = None) pyro.poutine.uncondition_messenger.UnconditionMessenger [source]¶
- uncondition(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T] = None) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]
便捷的包装器
UnconditionMessenger
用于强制观测节点的值从其分布中采样(忽略观测值)的消息处理器。
- config_enumerate(guide=None, default='parallel', expand=False, num_samples=None, tmc='diagonal')[source]¶
为 guide 中的所有相关站点配置枚举。这主要与
TraceEnum_ELBO
结合使用。当配置离散变量的穷举枚举时,这将配置所有满足
.has_enumerate_support == True
的采样站点。当通过default="parallel", num_samples=n
配置局部并行蒙特卡洛采样时,这将配置所有采样站点。这不会覆盖现有的注释infer={"enumerate": ...}
。这可以作为函数使用
guide = config_enumerate(guide)
或作为装饰器使用
@config_enumerate def guide1(*args, **kwargs): ... @config_enumerate(default="sequential", expand=True) def guide2(*args, **kwargs): ...
- 参数
guide (callable) – 一个 Pyro 模型,将在
SVI
中用作 guide。default (str) – 要使用的枚举策略,可以是“sequential”、“parallel”或 None。默认为“parallel”。
expand (bool) – 是否展开枚举的采样值。详见
enumerate_support()
。这仅适用于穷举枚举(num_samples=None
)。如果num_samples
非None
,则采样值总是会展开。num_samples (int 或 None) – 如果非
None
,则使用局部蒙特卡洛采样而不是穷举枚举。这适用于连续和离散分布。tmc (string 或 None) – 在 Tensor Monte Carlo 中使用的“mixture”或“diagonal”策略
- 返回
经过注释的 guide
- 返回类型
callable
Trace¶
- class Trace(graph_type: Literal['flat', 'dense'] = 'flat')[source]¶
继承自:
object
表示执行跟踪中不同 Pyro 原始类型之间关系的图数据结构。
Pyro 程序的执行跟踪是对该程序单次执行中每次
pyro.sample()
和pyro.param()
调用的记录。跟踪是定向图,其节点表示原始调用或输入/输出,其边表示这些原始调用之间的条件依赖关系。它们由poutine.trace
创建和填充。跟踪中的每个节点(或站点)包含站点的名称、输入和输出值,以及由推理算法或用户注释添加的额外元数据。对于
pyro.sample
,跟踪还包括站点的随机函数以及用户添加的任何观测数据。考虑以下 Pyro 程序
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2
我们可以使用
pyro.poutine.trace
记录其执行,并使用生成的数据结构计算执行中所有采样站点的对数联合概率或提取所有参数。>>> trace = pyro.poutine.trace(model).get_trace(0.0) >>> logp = trace.log_prob_sum() >>> params = [trace.nodes[name]["value"].unconstrained() for name in trace.param_nodes]
我们还可以检查或操作跟踪中的单个节点。
trace.nodes
包含一个collections.OrderedDict
,其中包含与x
、s
、z
以及返回值对应的站点名称和元数据>>> list(name for name in trace.nodes.keys()) ["_INPUT", "s", "z", "_RETURN"]
trace.nodes
的值是节点元数据的字典>>> trace.nodes["z"] {'type': 'sample', 'name': 'z', 'is_observed': False, 'fn': Normal(), 'value': tensor(0.6480), 'args': (), 'kwargs': {}, 'infer': {}, 'scale': 1.0, 'cond_indep_stack': (), 'done': True, 'stop': False, 'continuation': None}
'infer'
是一个由用户或算法指定的元数据字典。'args'
和'kwargs'
是通过pyro.sample
传递给fn.__call__
或fn.log_prob
的参数。'scale'
用于在计算对数联合概率时缩放站点的对数概率。'cond_indep_stack'
包含与执行中出现的pyro.plate
上下文对应的数据结构。'done'
、'stop'
和'continuation'
仅由 Pyro 内部使用。- 参数
graph_type (string) – 指定构建跟踪图类型的字符串
- add_node(site_name: str, **kwargs: Any) None [source]¶
- 参数
site_name (string) – 要添加的站点名称
向跟踪添加一个站点。
尝试添加重复节点时会引发错误,而不是静默覆盖。
- compute_log_prob(site_filter: Callable[[str, Message], bool] = <function allow_all_sites>) None [source]¶
计算跟踪的站点对数概率。每个
log_prob
的形状等于对应的batch_shape
。每个log_prob_sum
是一个标量。这两个计算都经过 memoization(记忆化)。
- compute_score_parts() None [source]¶
计算跟踪每个站点的批量局部得分部分。每个
log_prob
的形状等于对应的batch_shape
。每个log_prob_sum
是一个标量。所有计算都经过 memoization(记忆化)。
- copy() pyro.poutine.trace_struct.Trace [source]¶
创建自身的浅拷贝,保留节点和边。
- format_shapes(title: str = 'Trace Shapes:', last_site: Optional[str] = None) str [source]¶
返回一个字符串,显示跟踪中所有站点的形状表格。
- log_prob_sum(site_filter: Callable[[str, Message], bool] = <function allow_all_sites>) Union[torch.Tensor, float] [source]¶
计算跟踪的站点对数概率。每个
log_prob
的形状等于对应的batch_shape
。每个log_prob_sum
是一个标量。log_prob_sum
的计算经过 memoization(记忆化)。- 返回
总对数概率。
- 返回类型
- pack_tensors(plate_to_symbol: Optional[Dict[str, str]] = None) None [source]¶
计算跟踪中张量的紧凑表示。这应该在调用
compute_log_prob()
或compute_score_parts()
之后调用。
Runtime¶
- class InferDict[source]¶
继承自:
typing_extensions.TypedDict
包含推理相关信息的字典。
这可用于配置按站点推理策略,例如
pyro.sample( "x", dist.Bernoulli(0.5), infer={"enumerate": "parallel"}, )
- 键
- enumerate (str)
如果是字符串“sequential”或“parallel”,则启用枚举。并行枚举通常更快,但需要安全的广播操作和静态结构。
- expand (bool)
枚举时是否展开分布。如果缺失,默认为 False。
- is_auxiliary (bool)
采样站点是否为辅助站点,例如用于确定性地转换辅助变量的 guide。如果缺失,默认为 False。
- is_observed (bool)
采样站点是否为观测(即非潜在)站点。如果缺失,默认为 False。
- num_samples (int)
抽取的样本数量。如果缺失,默认为 1。
- obs (可选 torch.Tensor)
观测值,对于潜在变量为 None。如果缺失,默认为 None。
- prior (可选 torch.distributions.Distribution)
(内部) 用于 GuideMessenger 中存储模型的先验分布(以其上游站点为条件)。
- tmc (str)
在 TraceTMC_ELBO 中是否使用对角线或混合近似进行 Tensor Monte Carlo。
- was_observed (bool)
(内部) 在通过 Reweighted Wake Sleep 或 Compiled Sequential Importance Sampling 进行推理的上下文中,采样站点最初是否被观测。
- enumerate: typing_extensions.Literal[sequential, parallel]¶
- obs: Optional[torch.Tensor]¶
- prior: TorchDistributionMixin¶
- tmc: typing_extensions.Literal[diagonal, mixture]¶
- class Message[source]¶
继承自:
typing_extensions.TypedDict
,Generic
[pyro.poutine.runtime._P
,pyro.poutine.runtime._T
]Pyro 内部用于效果处理的消息类型。
消息存储在跟踪对象中,例如
trace.nodes["my_site_name"] # This is a Message.
- 键
- type (str)
消息类型,通常是字符串“sample”、“param”、“plate”或“markov”,但也可能是自定义类型。
- name (str)
站点名称,通常是采样或参数的名称。
- fn (可调用对象)
用于生成采样的分布或函数。
- is_observed (bool)
指示值是否被观测的标志。
- args (元组)
传递给分布或函数的按位置参数。
- kwargs (字典)
传递给分布或函数的关键字参数。
- value (torch.Tensor)
采样值(无论是观测值还是采样值)。
- scale (torch.Tensor)
对数概率的缩放因子。
- mask (bool torch.Tensor)
用于遮盖对数概率的布尔值或张量。
- cond_indep_stack (元组)
站点本地的条件独立元数据栈。不可变。
- done (bool)
指示消息是否已处理的标志。
- stop (bool)
停止进一步处理消息的标志。
- continuation (可调用对象)
处理消息后要调用的函数。
- infer (可选 InferDict)
推理参数字典。
- obs (torch.Tensor)
观测值。
- log_prob (torch.Tensor)
采样的对数概率。
- log_prob_sum (torch.Tensor)
对数概率之和。
- unscaled_log_prob (torch.Tensor)
未缩放的对数概率。
- score_parts (pyro.distributions.ScoreParts)
得分部分的集合。
- packed (Message)
经过打包的消息,在枚举过程中使用。
- args: Tuple¶
- cond_indep_stack: Tuple[CondIndepStackFrame, ...]¶
- fn: Callable[[pyro.poutine.runtime._P], pyro.poutine.runtime._T]¶
- infer: Optional[pyro.poutine.runtime.InferDict]¶
- kwargs: Dict¶
- log_prob: torch.Tensor¶
- log_prob_sum: torch.Tensor¶
- mask: Optional[Union[bool, torch.Tensor]]¶
- obs: Optional[torch.Tensor]¶
- scale: Union[torch.Tensor, float]¶
- score_parts: ScoreParts¶
- unscaled_log_prob: torch.Tensor¶
- value: Optional[pyro.poutine.runtime._T]¶
- exception NonlocalExit(site: pyro.poutine.runtime.Message, *args, **kwargs)[source]¶
继承自:
Exception
用于从 poutine 执行中非局部退出的异常。
由 poutine.EscapeMessenger 用于返回站点信息。
- apply_stack(initial_msg: pyro.poutine.runtime.Message) None [source]¶
按照以下方案在单个站点执行效果栈:
对于栈中从下到上的每个
Messenger
,执行Messenger._process_message
并传入消息;如果消息字段“stop”为 True,则停止;否则继续。应用默认行为(
default_process_message
)完成剩余的站点执行。对于栈中从上到下的每个
Messenger
,执行_postprocess_message
以使用站点结果更新消息和内部 messenger 状态。如果消息字段“continuation”非
None
,则使用消息调用它。
- 参数
initial_msg (dict) – 跟踪站点的起始版本
- 返回
None
- default_process_message(msg: pyro.poutine.runtime.Message) None [source]¶
处理推理中消息的默认方法。
- 参数
msg – 要处理的消息
- 返回
None
- effectful(fn: None = None, type: Optional[str] = None) Callable[[Callable[[pyro.poutine.runtime._P], pyro.poutine.runtime._T]], Callable[[...], pyro.poutine.runtime._T]] [source]¶
- effectful(fn: Callable[[pyro.poutine.runtime._P], pyro.poutine.runtime._T] = None, type: Optional[str] = None) Callable[[...], pyro.poutine.runtime._T]
- 参数
fn – 执行有效果计算的函数或可调用对象
type (str) – 操作的类型标签,例如 “sample”
调用
apply_stack()
应用任何活动效果的包装器。
- get_mask() Optional[Union[bool, torch.Tensor]] [source]¶
记录封闭的
poutine.mask
处理器产生的效果。这对于在预测期间避免昂贵的
pyro.factor()
计算非常有用,因为此时无需计算对数密度,例如def model(): # ... if poutine.get_mask() is not False: log_density = my_expensive_computation() pyro.factor("foo", log_density) # ...
- 返回
遮罩。
- 返回类型
None, bool 或 torch.Tensor
Utilities¶
- all_escape(trace: Trace, msg: Message) bool [source]¶
- 参数
trace – 部分跟踪(partial trace)
msg – Pyro 原语站点处的 message
- 返回
布尔决策值
一个实用函数,用于检查站点是否尚未在跟踪中。
由 EscapeMessenger 用来决定是否在某个站点进行非局部退出。用于近似积分出变量以减少方差的子例程。
- discrete_escape(trace: Trace, msg: Message) bool [source]¶
- 参数
trace – 部分跟踪(partial trace)
msg – Pyro 原语站点处的 message
- 返回
布尔决策值
一个实用函数,用于检查采样站点是否为离散的且尚未在跟踪中。
由 EscapeMessenger 用来决定是否在某个站点进行非局部退出。用于积分出离散变量以减少方差的子例程。
- enum_extend(trace: Trace, msg: Message, num_samples: Optional[int] = None) List[Trace] [source]¶
- 参数
trace – 部分跟踪(partial trace)
msg – Pyro 原语站点处的 message
num_samples – 返回的扩展跟踪的最大数量。
- 返回
一个跟踪列表,输入跟踪的副本,每个副本包含一个额外的站点
一个实用函数,用于复制和扩展跟踪,根据输入站点构建站点,这些站点的取值从输入站点的分布的支持中枚举。
用于精确推断和积分出离散变量。
信使(Messengers)¶
信使(Messenger)对象包含 handlers 暴露的效应的实现。高级用户可以修改现有 handlers 背后的信使实现,或者编写新的信使来实现新的效应,并与库的其余部分正确组合。
信使(Messenger)¶
- class Messenger[source]¶
继承自:
object
一个上下文管理器类,用于修改随机函数(即包含 Pyro 原语语句的可调用对象)的行为并添加副作用。
这是基础 Messenger 类。它为所有 Pyro 原语实现了默认行为,因此随机函数 fn 引发的联合分布与
Messenger()(fn)
引发的联合分布相同。推断过程中传递的消息的转换器类。大多数推断操作都在此类的子类中实现。
- classmethod register(fn: Optional[Callable] = None, type: Optional[str] = None, post: Optional[bool] = None) Callable [source]¶
- 参数
fn – 实现操作的函数
type (str) – 操作的名称(也会传递给
effectful()
)post (bool) – 如果为 True,则将此操作用作后处理
动态地向效应添加操作。对于为库生成包装器很有用。
示例
@SomeMessengerClass.register def some_function(msg) ...do_something... return msg
- classmethod unregister(fn: Optional[Callable] = None, type: Optional[str] = None) Optional[Callable] [source]¶
- 参数
fn – 实现操作的函数
type (str) – 操作的名称(也会传递给
effectful()
)
动态地从效应中移除操作。对于移除库中的包装器很有用。
示例
SomeMessengerClass.unregister(some_function, "name")
- block_messengers(predicate: Callable[[pyro.poutine.messenger.Messenger], bool]) Iterator[List[pyro.poutine.messenger.Messenger]] [source]¶
实验性上下文管理器,用于临时从 _PYRO_STACK 中移除匹配的信使。注意,这不会调用
.__exit__()
和.__enter__()
方法。这对于选择性地阻止外部 handler 非常有用。
- 参数
predicate (callable) – 将信使实例映射到布尔值的谓词。这将静音所有满足
bool(predicate(m)) is True
的信使m
。- 产生(Yields)
被阻止的匹配信使列表。
BlockMessenger¶
- class BlockMessenger(hide_fn: Optional[Callable[[Message], Optional[bool]]] = None, expose_fn: Optional[Callable[[Message], Optional[bool]]] = None, hide_all: bool = True, expose_all: bool = False, hide: Optional[List[str]] = None, expose: Optional[List[str]] = None, hide_types: Optional[List[str]] = None, expose_types: Optional[List[str]] = None)[source]¶
基类:
pyro.poutine.messenger.Messenger
此处理器有选择地对外隐藏 Pyro 原语站点。默认行为:阻塞所有内容。
如果满足以下至少一个条件,则站点被隐藏
hide_fn(msg) is True
或(not expose_fn(msg)) is True
msg["name"] 在 hide 中
msg["type"] 在 hide_types 中
msg["name"] 不在 expose 中且 msg["type"] 不在 expose_types 中
hide
、hide_types
和expose_types
都为None
例如,假设随机函数 fn 有两个采样站点“a”和“b”。那么
BlockMessenger(fn, hide=["a"])
之外的任何效果都不会应用于站点“a”,而只会看到站点“b”>>> def fn(): ... a = pyro.sample("a", dist.Normal(0., 1.)) ... return pyro.sample("b", dist.Normal(a, 1.)) >>> fn_inner = pyro.poutine.trace(fn) >>> fn_outer = pyro.poutine.trace(pyro.poutine.block(fn_inner, hide=["a"])) >>> trace_inner = fn_inner.get_trace() >>> trace_outer = fn_outer.get_trace() >>> "a" in trace_inner True >>> "a" in trace_outer False >>> "b" in trace_inner True >>> "b" in trace_outer True
- 参数
fn – 一个随机函数 (包含 Pyro 原语调用的可调用对象)
hide_fn – 一个接受站点并返回 True 以隐藏站点,或返回 False/None 以暴露站点的函数。如果指定此参数,则忽略所有其他参数。仅指定 hide_fn 或 expose_fn 中的一个,不能同时指定两者。
expose_fn – 一个接受站点并返回 True 以暴露站点,或返回 False/None 以隐藏站点的函数。如果指定此参数,则忽略所有其他参数。仅指定 hide_fn 或 expose_fn 中的一个,不能同时指定两者。
hide_all (bool) – 隐藏所有站点
expose_all (bool) – 正常暴露所有站点
hide (list) – 要隐藏的站点名称列表
expose (list) – 要暴露的站点名称列表,同时隐藏所有其他站点
hide_types (list) – 要隐藏的站点类型列表
expose_types (list) – 要暴露的站点类型列表,同时隐藏所有其他站点
- 返回
一个用
BlockMessenger
装饰的随机函数
BroadcastMessenger¶
- class BroadcastMessenger[source]¶
基类:
pyro.poutine.messenger.Messenger
在单个或嵌套的
plate
上下文中时,自动广播采样站点处随机函数的批形状。现有的 batch_shape 必须与 cond_indep_stack 中安装的plate
上下文的大小可广播。请注意,下面的 model_automatic_broadcast 如何自动化地扩展分布的批形状。这使得模块化 Pyro 模型变得容易,因为子组件与外部包装的
plate
上下文无关。>>> def model_broadcast_by_hand(): ... with IndepMessenger("batch", 100, dim=-2): ... with IndepMessenger("components", 3, dim=-1): ... sample = pyro.sample("sample", dist.Bernoulli(torch.ones(3) * 0.5) ... .expand_by(100)) ... assert sample.shape == torch.Size((100, 3)) ... return sample
>>> @poutine.broadcast ... def model_automatic_broadcast(): ... with IndepMessenger("batch", 100, dim=-2): ... with IndepMessenger("components", 3, dim=-1): ... sample = pyro.sample("sample", dist.Bernoulli(torch.tensor(0.5))) ... assert sample.shape == torch.Size((100, 3)) ... return sample
CollapseMessenger¶
- class CollapseMessenger(*args: Any, **kwargs: Any)[source]¶
基类:
pyro.poutine.trace_messenger.TraceMessenger
实验性 - 通过惰性采样并尝试使用共轭关系来折叠上下文中的所有站点。如果不知道共轭关系,这将失败。使用采样站点结果的代码必须编写为接受 Funsors 而非 Tensors。这需要安装
funsor
。警告
这与自动猜测
max_plate_nesting
不兼容。如果在折叠的上下文中出现任何 plate,您应该手动向您的推断算法声明max_plate_nesting
(例如Trace_ELBO(max_plate_nesting=1)
)。
ConditionMessenger¶
- class ConditionMessenger(data: Union[Dict[str, torch.Tensor], pyro.poutine.trace_struct.Trace])[source]¶
基类:
pyro.poutine.messenger.Messenger
给定一个包含一些采样语句和按名称提供的观察值字典的随机函数,将这些名称处的采样语句更改为使用这些值的观察。
考虑以下 Pyro 程序
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2
要观察站点 z 的值,我们可以这样写
>>> conditioned_model = pyro.poutine.condition(model, data={"z": torch.tensor(1.)})
这等效于在 model 中的 pyro.sample(“z”, …) 中将 obs=value 添加为关键字参数。
- 参数
fn – 一个随机函数 (包含 Pyro 原语调用的可调用对象)
data – 一个 dict 或一个
Trace
- 返回
一个用
ConditionMessenger
装饰的随机函数
DoMessenger¶
- class DoMessenger(data: Dict[str, Union[torch.Tensor, numbers.Number]])[source]¶
基类:
pyro.poutine.messenger.Messenger
给定一个包含一些采样语句和按名称提供的值字典的随机函数,将这些站点的返回值设置为等于这些值,就好像它们被硬编码为这些值一样,并引入名称相同但其值不传播的新采样站点。
可以与
condition()
自由组合来表示潜在结果的反事实分布。有关更多详细信息和理论,请参阅 Single World Intervention Graphs [1]。考虑以下 Pyro 程序
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2
要对站点 z 进行干预(设置一个值),我们可以这样写
>>> intervened_model = pyro.poutine.do(model, data={"z": torch.tensor(1.)})
这等效于将 z = pyro.sample(“z”, …) 替换为 z = torch.tensor(1.),并引入一个新的采样站点 pyro.sample(“z”, …),其值不用于其他地方。
参考
- [1] Single World Intervention Graphs: A Primer,
Thomas Richardson, James Robins
- 参数
fn – 一个随机函数 (包含 Pyro 原语调用的可调用对象)
data – 一个将采样站点名称映射到干预值的
dict
- 返回
一个用
DoMessenger
装饰的随机函数
EnumMessenger¶
- class EnumMessenger(first_available_dim: Optional[int] = None)[source]¶
基类:
pyro.poutine.messenger.Messenger
在标记为
infer={"enumerate": "parallel"}
的离散采样站点上并行枚举。- 参数
first_available_dim (int) – 可用于并行枚举的第一个张量维度(从右计数)。此维度及其左侧的所有维度都可能被 Pyro 内部使用。这应该是一个负整数或 None。
- enumerate_site(msg: pyro.poutine.runtime.Message) torch.Tensor [source]¶
EscapeMessenger¶
- class EscapeMessenger(escape_fn: Callable[[pyro.poutine.runtime.Message], bool])[source]¶
基类:
pyro.poutine.messenger.Messenger
一个通过引发 util.NonlocalExit 异常来进行非本地退出的信使
IndepMessenger¶
- class IndepMessenger(name: str, size: int, dim: Optional[int] = None, device: Optional[str] = None)[source]¶
基类:
pyro.poutine.messenger.Messenger
此信使跟踪由嵌套的
plate
上下文声明的独立性信息栈。此信息存储在每个 sample/observe 站点中的cond_indep_stack
中,供TraceMessenger
使用。示例
x_axis = IndepMessenger('outer', 320, dim=-1) y_axis = IndepMessenger('inner', 200, dim=-2) with x_axis: x_noise = sample("x_noise", dist.Normal(loc, scale).expand_by([320])) with y_axis: y_noise = sample("y_noise", dist.Normal(loc, scale).expand_by([200, 1])) with x_axis, y_axis: xy_noise = sample("xy_noise", dist.Normal(loc, scale).expand_by([200, 320]))
- property indices: torch.Tensor¶
InferConfigMessenger¶
- class InferConfigMessenger(config_fn: Callable[[Message], InferDict])[source]¶
基类:
pyro.poutine.messenger.Messenger
给定一个包含 Pyro 原语调用的可调用对象 fn,以及一个接受轨迹站点并返回字典的可调用对象 config_fn,将采样站点处 infer kwarg 的值更新为 config_fn(site)。
- 参数
fn – 一个随机函数 (包含 Pyro 原语调用的可调用对象)
config_fn – 一个接受站点并返回一个 infer 字典的可调用对象
- 返回
一个用
InferConfigMessenger
装饰的随机函数
LiftMessenger¶
- class LiftMessenger(prior: Union[Callable, pyro.distributions.distribution.Distribution, Dict[str, Union[pyro.distributions.distribution.Distribution, Callable]]])[source]¶
基类:
pyro.poutine.messenger.Messenger
给定一个包含 param 调用和先验分布的随机函数,创建一个随机函数,其中所有 param 调用都被替换为从先验中采样。先验应该是一个可调用对象或一个名称到可调用对象的字典。
考虑以下 Pyro 程序
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2 >>> lifted_model = pyro.poutine.lift(model, prior={"s": dist.Exponential(0.3)})
lift
使用prior
中的分布使param
语句的行为类似于sample
语句。在此示例中,站点 s 现在将表现得好像它被s = pyro.sample("s", dist.Exponential(0.3))
替换了一样>>> tr = pyro.poutine.trace(lifted_model).get_trace(0.0) >>> tr.nodes["s"]["type"] == "sample" True >>> tr2 = pyro.poutine.trace(lifted_model).get_trace(0.0) >>> bool((tr2.nodes["s"]["value"] == tr.nodes["s"]["value"]).all()) False
- 参数
fn – 其参数将被提升为随机值的函数
prior – 作为 Distribution 或随机函数字典形式的先验函数
- 返回
一个用
LiftMessenger
装饰的fn
MarkovMessenger¶
MaskMessenger¶
- class MaskMessenger(mask: Union[bool, torch.BoolTensor])[source]¶
基类:
pyro.poutine.messenger.Messenger
给定一个包含一些批处理采样语句和掩码张量的随机函数,按元素掩掉一些采样语句。
- 参数
fn – 一个随机函数 (包含 Pyro 原语调用的可调用对象)
mask (torch.BoolTensor) – 一个取值为
{0,1}
的掩码张量(1 表示包含站点,0 表示排除站点)
- 返回
一个用
MaskMessenger
装饰的随机函数
PlateMessenger¶
- class PlateMessenger(name: str, size: Optional[int] = None, subsample_size: Optional[int] = None, subsample: Optional[torch.Tensor] = None, dim: Optional[int] = None, use_cuda: Optional[bool] = None, device: Optional[str] = None)[source]¶
基类:
pyro.poutine.subsample_messenger.SubsampleMessenger
广播神奇功能的瑞士军刀:结合了形状推断、独立性标注和二次采样
- block_plate(name: Optional[str] = None, dim: Optional[int] = None, *, strict: bool = True) Iterator[None] [source]¶
实验性上下文管理器,用于临时阻止单个外部 plate。
这对于在 plated 上下文中对辅助变量或惰性采样全局变量非常有用。例如,以下模型是等效的
示例
def model_1(data): loc = pyro.sample("loc", dist.Normal(0, 1)) with pyro.plate("data", len(data)): with block_plate("data"): scale = pyro.sample("scale", dist.LogNormal(0, 1)) pyro.sample("x", dist.Normal(loc, scale)) def model_2(data): loc = pyro.sample("loc", dist.Normal(0, 1)) scale = pyro.sample("scale", dist.LogNormal(0, 1)) with pyro.plate("data", len(data)): pyro.sample("x", dist.Normal(loc, scale))
ReentrantMessenger¶
ReparamMessenger¶
- class ReparamHandler(msngr, fn: Callable[[pyro.poutine.reparam_messenger._P], pyro.poutine.reparam_messenger._T])[source]¶
基类:
Generic
[pyro.poutine.reparam_messenger._P
,pyro.poutine.reparam_messenger._T
]重新参数化 poutine。
- class ReparamMessenger(config: Union[Dict[str, Reparam], Callable[[Message], Optional[Reparam]]])[source]¶
基类:
pyro.poutine.messenger.Messenger
将每个受影响的采样站点重新参数化为一个或多个辅助采样站点,然后进行确定性变换 [1]。
要指定重新参数化器,请将
config
字典或可调用对象传递给构造函数。有关可用的重新参数化器,请参阅pyro.infer.reparam
模块。注意,一些重新参数化器可以检查它们影响的函数的
*args,**kwargs
输入;这些重新参数化器要求将poutine.reparam
用作装饰器而不是上下文管理器。- [1] Maria I. Gorinova, Dave Moore, Matthew D. Hoffman (2019)
《概率程序的自动重新参数化》 https://arxiv.org/pdf/1906.03028.pdf
- 参数
config (dict 或 callable) – 配置,可以是将站点名称映射到
Reparameterizer
的 dict,也可以是将站点映射到Reparam
或 None 的函数。有关内置的配置策略,请参阅pyro.infer.reparam.strategies
。
ReplayMessenger¶
- class ReplayMessenger(trace: Optional[Trace] = None, params: Optional[Dict[str, torch.Tensor] = None)[source]¶
基类:
pyro.poutine.messenger.Messenger
给定一个包含 Pyro 原语调用的可调用对象,返回一个可调用对象,该对象运行原始函数,并在新轨迹中重用轨迹中站点的值
考虑以下 Pyro 程序
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2
replay
使sample
语句的行为如同它们在轨迹中相应站点处采样了值一样>>> old_trace = pyro.poutine.trace(model).get_trace(1.0) >>> replayed_model = pyro.poutine.replay(model, trace=old_trace) >>> bool(replayed_model(0.0) == old_trace.nodes["_RETURN"]["value"]) True
- 参数
fn – 一个随机函数 (包含 Pyro 原语调用的可调用对象)
trace – 用于回放的
Trace
数据结构params – fn 中 param 站点名称和约束值的字典,用于回放
- 返回
一个用
ReplayMessenger
装饰的随机函数
ScaleMessenger¶
- class ScaleMessenger(scale: Union[float, torch.Tensor])[source]¶
基类:
pyro.poutine.messenger.Messenger
给定一个包含一些采样语句的随机函数和一个正比例因子,对函数中所有采样(sample)和观测(observe)站点的得分进行缩放。
考虑以下 Pyro 程序
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... pyro.sample("z", dist.Normal(x, s), obs=torch.tensor(1.0))
scale
对采样站点的对数概率进行乘性缩放>>> scaled_model = pyro.poutine.scale(model, scale=0.5) >>> scaled_tr = pyro.poutine.trace(scaled_model).get_trace(0.0) >>> unscaled_tr = pyro.poutine.trace(model).get_trace(0.0) >>> bool((scaled_tr.log_prob_sum() == 0.5 * unscaled_tr.log_prob_sum()).all()) True
- 参数
fn – 一个随机函数 (包含 Pyro 原语调用的可调用对象)
scale – 一个正比例因子
- 返回
用
ScaleMessenger
装饰的随机函数
SeedMessenger¶
- class SeedMessenger(rng_seed: int)[source]¶
基类:
pyro.poutine.messenger.Messenger
通过设置随机数生成器的种子,将其设置为预定义的状态。这与在调用 fn 之前调用
pyro.set_rng_seed()
相同。这个处理器在标准 Pyro 后端对原始语句没有额外的效果,但在其他后端(例如 NumPy 后端)中可能会拦截pyro.sample
调用。- 参数
fn – 随机函数(包含 Pyro 原始调用)。
rng_seed (int) – 随机数生成器种子。
SubsampleMessenger¶
- class SubsampleMessenger(name: str, size: Optional[int] = None, subsample_size: Optional[int] = None, subsample: Optional[torch.Tensor] = None, dim: Optional[int] = None, use_cuda: Optional[bool] = None, device: Optional[str] = None)[source]¶
基类:
pyro.poutine.indep_messenger.IndepMessenger
IndepMessenger 的扩展,包括二次采样。
SubstituteMessenger¶
- class SubstituteMessenger(data: Dict[str, torch.Tensor])[source]¶
基类:
pyro.poutine.messenger.Messenger
给定一个包含参数调用的随机函数和一组参数值,创建一个随机函数,其中所有参数调用都替换为固定值。data 应该是一个从名称到值的字典。考虑以下 Pyro 程序
>>> def model(x): ... a = pyro.param("a", torch.tensor(0.5)) ... x = pyro.sample("x", dist.Bernoulli(probs=a)) ... return x >>> substituted_model = pyro.poutine.substitute(model, data={"a": torch.tensor(0.3)})
在此示例中,站点 a 现在的值将是 torch.tensor(0.3)。 :param data: 站点名称为键的字典值。 :returns: 用
SubstituteMessenger
装饰的fn
TraceMessenger¶
- class TraceHandler(msngr: pyro.poutine.trace_messenger.TraceMessenger, fn: Callable[[pyro.poutine.reparam_messenger._P], pyro.poutine.reparam_messenger._T])[source]¶
基类:
Generic
[pyro.poutine.reparam_messenger._P
,pyro.poutine.reparam_messenger._T
]执行跟踪 poutine。
TraceHandler 记录每个 Pyro 原语的输入和输出,并将它们作为站点存储在 Trace() 中。理论上,这对于每种推断算法来说是足够的信息(以及变量中隐式计算图的信息?)。
我们也可以用它来进行可视化。
- get_trace(*args, **kwargs) pyro.poutine.trace_struct.Trace [source]¶
- 返回
数据结构
- 返回类型
一个非常常见的用例的辅助方法。调用此 poutine 并返回其跟踪而不是函数的返回值。
- property trace: pyro.poutine.trace_struct.Trace¶
- class TraceMessenger(graph_type: Optional[Literal['flat', 'dense']] = None, param_only: Optional[bool] = None)[source]¶
基类:
pyro.poutine.messenger.Messenger
返回一个处理器,用于记录原始调用的输入和输出及其依赖关系。
考虑以下 Pyro 程序
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2
我们可以使用
trace
记录其执行,并使用生成的数据结构计算执行中所有采样站点的对数联合概率或提取所有参数。>>> trace = pyro.poutine.trace(model).get_trace(0.0) >>> logp = trace.log_prob_sum() >>> params = [trace.nodes[name]["value"].unconstrained() for name in trace.param_nodes]
- 参数
fn – 一个随机函数 (包含 Pyro 原语调用的可调用对象)
graph_type – 指定构建图类型的字符串
param_only – 如果为 true,则仅记录参数而不记录采样
- 返回
用
TraceMessenger
装饰的随机函数
- get_trace() pyro.poutine.trace_struct.Trace [source]¶
- 返回
数据结构
- 返回类型
一个非常常见的用例的辅助方法。返回
self.trace
的浅拷贝。
- identify_dense_edges(trace: pyro.poutine.trace_struct.Trace) None [source]¶
通过添加基于存储在每个站点中的 cond_indep_stack 信息的所有边,就地修改跟踪。
UnconditionMessenger¶
- class UnconditionMessenger[source]¶
基类:
pyro.poutine.messenger.Messenger
用于强制观测节点的值从其分布中采样(忽略观测值)的消息处理器。
GuideMessenger¶
- class GuideMessenger(model: Callable)[source]¶
基类:
pyro.poutine.trace_messenger.TraceMessenger
,abc.ABC
基于效应的 guides 的抽象基类。
派生类必须实现
get_posterior()
方法。- property model: Callable¶
- __call__(*args, **kwargs) Dict[str, torch.Tensor] [source]¶
从 guide 中抽取后验样本,并根据这些样本重放 model。
- 返回
一个将采样站点名称映射到采样值的字典。这包括潜在、确定性和观测值。
- 返回类型
- abstract get_posterior(name: str, prior: TorchDistributionMixin) Union[TorchDistributionMixin, torch.Tensor] [source]¶
抽象方法,用于计算后验分布或给定先验分布并以前向上游后验样本为条件采样后验值。
实现可以在此函数内部使用
pyro.param
和pyro.sample
,但pyro.sample
语句应设置infer={"is_auxiliary": True"}
。实现可以访问更多信息进行计算
value = self.upstream_value(name)
是一个上游样本点或确定性点的值。
self.trace
是上游站点的轨迹,可能对其他信息有用,例如self.trace.nodes["my_site"]["fn"]
或self.trace.nodes["my_site"]["cond_indep_stack"]
。args, kwargs = self.args_kwargs
是模型的输入,并且可能对均摊有用。
- 参数
name (str) – 要采样的样本点的名称。
prior (Distribution) – 此样本点的先验分布(以来自后验分布的上游样本为条件)。
- 返回
后验分布或来自后验分布的样本。
- 返回类型
- upstream_value(name: str) Optional[torch.Tensor] [source]¶
用于
get_posterior()
。- 返回
一个上游样本点或确定性点的值
- 返回类型
- get_traces() Tuple[pyro.poutine.trace_struct.Trace, pyro.poutine.trace_struct.Trace] [source]¶
在运行
__call__()
后可以调用此函数来提取一对轨迹。与生成一对轨迹的轨迹回放模式相反,
GuideMessenger
交错执行模型和引导计算,因此只需要一次guide(*args, **kwargs)
调用即可创建两个轨迹。此函数仅从该引导的.trace
属性中提取相关信息。- 返回
一对
(model_trace, guide_trace)
- 返回类型