基于 Funsor 的 Pyro¶
原语¶
- clear_param_store() None [source]¶
清除全局
ParamStoreDict
。这在使用 REPL 时特别有用。建议在每个训练循环之前调用此函数(以避免泄漏先前模型的参数),并在每个单元测试之前调用(以避免在测试之间泄漏参数)。
- condition(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], data: Union[Dict[str, torch.Tensor], Trace]) Union[pyro.poutine.condition_messenger.ConditionMessenger, Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]] [source]¶
ConditionMessenger
的便捷包装器给定一个包含一些采样语句的随机函数以及按名称组织的观测值字典,将这些名称对应的采样语句更改为使用这些观测值的观察语句。
考虑以下 Pyro 程序
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2
要观察站点 z 的值,我们可以写成
>>> conditioned_model = pyro.poutine.condition(model, data={"z": torch.tensor(1.)})
这等同于在 model 中将 obs=value 作为关键字参数添加到 pyro.sample(“z”, …)。
- 参数
fn – 一个随机函数(包含 Pyro 原语调用的可调用对象)
data – 一个 dict 或一个
Trace
- 返回值
用
ConditionMessenger
装饰的随机函数
- deterministic(name: str, value: torch.Tensor, event_dim: Optional[int] = None) torch.Tensor [source]¶
确定性语句,用于向跟踪添加一个名为 name、值为 value 的
Delta
站点。当需要记录完全由其父节点确定的值时,这非常有用。例如x = pyro.sample("x", dist.Normal(0, 1)) x2 = pyro.deterministic("x2", x ** 2)
注意
该站点不影响模型密度。这目前转换为
sample()
语句,但将来可能会更改。- 参数
name (str) – 站点的名称。
value (torch.Tensor) – 站点的值。
event_dim (int) – 可选的事件维度,默认为 value.ndim。
- do(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], data: Dict[str, Union[torch.Tensor, numbers.Number]]) Union[pyro.poutine.do_messenger.DoMessenger, Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]] [source]¶
DoMessenger
的便捷包装器给定一个包含一些采样语句的随机函数以及按名称组织的值字典,将这些站点的返回值设置为等于这些值,就像它们被硬编码为这些值一样,并引入具有相同名称的新采样站点,其值不会传播。
可与
condition()
自由组合,表示潜在结果的反事实分布。更多细节和理论请参见单世界干预图 [1]。考虑以下 Pyro 程序
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2
要对站点 z 进行干预,我们可以写成
>>> intervened_model = pyro.poutine.do(model, data={"z": torch.tensor(1.)})
这等同于将 z = pyro.sample(“z”, …) 替换为 z = torch.tensor(1.),并引入一个新的采样站点 pyro.sample(“z”, …),其值不在其他地方使用。
参考文献
- [1] 单世界干预图:入门,
Thomas Richardson, James Robins
- 参数
fn – 一个随机函数(包含 Pyro 原语调用的可调用对象)
data – 一个
dict
,将采样站点名称映射到干预值
- 返回值
用
DoMessenger
装饰的随机函数
- enable_validation(is_validate: bool = True) None [source]¶
启用或禁用 Pyro 中的验证检查。验证检查提供了有用的警告和错误,例如 NaN 检查、验证分布参数和支持值、检测 ELBO 和 MCMC 的不正确使用。由于其中一些检查可能很耗时,对于成熟的模型,您可能希望禁用验证以加快推断速度。
默认行为模拟 Python 的
assert
语句:默认情况下验证是开启的,但在 Python 以优化模式运行(通过python -O
)时会禁用。同样,默认行为通过pyro.enable_validation(__debug__)
取决于 Python 的全局__debug__
值。在 jit 编译期间,对于所有支持 PyTorch jit 的推断算法,验证会暂时禁用。建议使用非 jit 推断算法开发模型以方便调试,然后一旦模型正确,可选择切换到 jit 推断。
- 参数
is_validate (bool) – (可选;默认为 True)是否启用验证检查。
- factor(name: str, log_factor: torch.Tensor, *, has_rsample: Optional[bool] = None) None [source]¶
因子语句,用于向概率模型添加任意对数概率因子。
警告
在指南中使用因子语句时,需要指定该因子语句是源自完全重参数化采样(例如重参数化变量变换的雅可比行列式)还是源自非重参数化采样(例如离散采样)。对于完全重参数化的情况,设置
has_rsample=True
;对于非重参数化的情况,设置has_rsample=False
。这仅在指南中需要,模型中不需要。- 参数
name (str) – 平凡采样的名称
log_factor (torch.Tensor) – 一个可能包含批次的对数概率因子。
has_rsample (bool) –
log_factor
是否源自完全重参数化分布。在模型中使用时默认为 False,但在指南中使用时必须指定。
- get_param_store() pyro.params.param_store.ParamStoreDict [source]¶
返回全局
ParamStoreDict
。
- markov(fn=None, history=1, keep=False)[source]¶
MarkovMessenger
的便捷包装器用于在 Funsor 和与 Pyro 位置批量维度一致的张量之间进行转换的处理程序。
- module(name: str, nn_module: torch.nn.modules.module.Module, update_module_params: bool = False) torch.nn.modules.module.Module [source]¶
在 Pyro 的
param_store
中注册torch.nn.Module
的所有参数。结合ParamStoreDict
的save()
和load()
功能,用户可以保存和加载模块。注意
建议考虑使用
PyroModule
,它是pyro.module()
的一个更新的替代方案,对 JIT 编译、在 C++ 中服务以及将参数转换为随机变量有更好的支持。详细信息请参见模块教程。- 参数
name (str) – 模块名称
nn_module (torch.nn.Module) – 要在 Pyro 中注册的模块
update_module_params – 确定 PyTorch 模块中的 Parameters 是否会被 ParamStore 中找到的值覆盖(如果存在)。默认为 False
- 返回值
torch.nn.Module
- param(name: str, init_tensor: Optional[Union[torch.Tensor, Callable[[], torch.Tensor]]] = None, constraint: torch.distributions.constraints.Constraint = Real(), event_dim: Optional[int] = None) torch.Tensor [source]¶
将变量作为参数保存在参数存储中。要与参数存储交互或写入磁盘,请参阅参数。
- 参数
name (str) – 参数名称
init_tensor (torch.Tensor 或 可调用对象) – 初始张量或返回张量的惰性可调用对象。对于大型张量,写成例如
lambda: torch.randn(100000)
可能更划算,这只会在初始语句时进行评估。constraint (torch.distributions.constraints.Constraint) – torch 约束,默认为
constraints.real
。event_dim (int) – (可选)与批处理无关的最右侧维度数量。该维度左侧的维度将被视为批处理维度;如果参数语句在子采样 plate 内部,则参数对应的批处理维度将相应地进行子采样。如果未指定,所有维度都将被视为事件维度,并且不会执行子采样。
- 返回值
一个受约束的参数。底层的无约束参数可以通过
pyro.param(...).unconstrained()
访问,其中.unconstrained
是一个弱引用属性。- 返回值类型
- random_module(name, nn_module, prior, *args, **kwargs)[source]¶
警告
该 random_module 原语已弃用,并将在未来版本中移除。请改用
PyroModule
从torch.nn.Module
实例创建贝叶斯模块。示例请参见贝叶斯回归教程。已弃用:在模块 nn_module 的参数上设置先验。返回一个 nn.Module 上的分布(可调用对象),调用该对象时会返回一个采样得到的 nn.Module。
- 参数
name (str) – pyro 模块名称
nn_module (torch.nn.Module) – 要在 pyro 中注册的模块
prior – pyro 分布、随机函数,或一个 python 字典,其中参数名称作为键,对应的分布/随机函数作为值。
- 返回值
一个返回采样模块的可调用对象
- sample(name: str, fn: pyro.distributions.torch_distribution.TorchDistributionMixin, *args, obs: Optional[torch.Tensor] = None, obs_mask: Optional[torch.BoolTensor] = None, infer: Optional[pyro.poutine.runtime.InferDict] = None, **kwargs) torch.Tensor [source]¶
调用随机函数
fn
,并根据name
和封闭上下文(例如推断算法)产生额外的副作用。讨论请参见Pyro 入门。
- subsample(data: torch.Tensor, event_dim: int) torch.Tensor [source]¶
子采样语句,根据封闭的
plate
对数据张量进行子采样。当通过传递
subsample
或subsample_size
关键字参数由plate
自动执行子采样时,通常在model()
的参数上调用此函数。例如,以下是等价的# Version 1. using indexing def model(data): with pyro.plate("data", len(data), subsample_size=10, dim=-data.dim()) as ind: data = data[ind] # ... # Version 2. using pyro.subsample() def model(data): with pyro.plate("data", len(data), subsample_size=10, dim=-data.dim()): data = pyro.subsample(data, event_dim=0) # ...
- vectorized_markov(fn=None, name=None, size=None, dim=None, history=1)[source]¶
VectorizedMarkovMessenger
的便捷包装器用于变量马尔可夫链的构造,旨在利用并行扫描算法高效消除马尔可夫维度。在允许的情况下,
vectorized_markov
可以与markov
互换使用。for 循环生成
int
和 1 维的torch.Tensor
索引:(0, ..., history-1, torch.arange(0, size-history), ..., torch.arange(history, size))
。int
索引用于启动马尔可夫链,而torch.Tensor
索引用于构建向量化转移概率,以便通过并行扫描算法进行高效消除。当
history==0
时,vectorized_markov
的行为类似于plate
。for 循环运行后,会识别出马尔可夫变量,然后构建
step
信息并添加到跟踪中。step
会告知推断算法哪些变量属于马尔可夫链。data = torch.ones(3, dtype=torch.float) def model(data, vectorized=True): init = pyro.param("init", lambda: torch.rand(3), constraint=constraints.simplex) trans = pyro.param("trans", lambda: torch.rand((3, 3)), constraint=constraints.simplex) locs = pyro.param("locs", lambda: torch.rand(3,)) markov_chain = \ pyro.vectorized_markov(name="time", size=len(data), dim=-1) if vectorized \ else pyro.markov(range(len(data))) for i in markov_chain: x_curr = pyro.sample("x_{}".format(i), dist.Categorical( init if isinstance(i, int) and i < 1 else trans[x_prev]), pyro.sample("y_{}".format(i), dist.Normal(Vindex(locs)[..., x_curr], 1.), obs=data[i]) x_prev = x_curr # trace.nodes["time"]["value"] # frozenset({('x_0', 'x_slice(0, 2, None)', 'x_slice(1, 3, None)')}) # # pyro.vectorized_markov trace # ... # Sample Sites: # locs dist | 3 # value | 3 # log_prob | # x_0 dist | # value 3 1 1 1 1 | # log_prob 3 1 1 1 1 | # y_0 dist 3 1 1 1 1 | # value | # log_prob 3 1 1 1 1 | # x_slice(1, 3, None) dist 3 1 1 1 1 2 | # value 3 1 1 1 1 1 1 | # log_prob 3 3 1 1 1 1 2 | # y_slice(1, 3, None) dist 3 1 1 1 1 1 2 | # value 2 | # log_prob 3 1 1 1 1 1 2 | # # pyro.markov trace # ... # Sample Sites: # locs dist | 3 # value | 3 # log_prob | # x_0 dist | # value 3 1 1 1 1 | # log_prob 3 1 1 1 1 | # y_0 dist 3 1 1 1 1 | # value | # log_prob 3 1 1 1 1 | # x_1 dist 3 1 1 1 1 | # value 3 1 1 1 1 1 | # log_prob 3 3 1 1 1 1 | # y_1 dist 3 1 1 1 1 1 | # value | # log_prob 3 1 1 1 1 1 | # x_2 dist 3 1 1 1 1 1 | # value 3 1 1 1 1 | # log_prob 3 3 1 1 1 1 | # y_2 dist 3 1 1 1 1 | # value | # log_prob 3 1 1 1 1 |
警告
这仅在每个分支只有一个马尔可夫维度时有效。
- 参数
- 返回值
返回
int
和 1 维的torch.Tensor
索引:(0, ..., history-1, torch.arange(size-history), ..., torch.arange(history, size))
。
效果处理器¶
- enum(fn=None, first_available_dim=None)[source]¶
EnumMessenger
的便捷包装器此版本的
EnumMessenger
使用to_data()
为每个离散采样站点分配一个新的枚举维度。
- markov(fn=None, history=1, keep=False)[source]¶
MarkovMessenger
的便捷包装器用于在 Funsor 和与 Pyro 位置批量维度一致的张量之间进行转换的处理程序。
- named(fn=None, first_available_dim=None)[source]¶
NamedMessenger
的便捷包装器用于
to_funsor()
和to_data()
原语的基础效果处理程序类。任何在内部调用这些原语或封装调用这些原语代码的效果处理程序都应继承自NamedMessenger
。这种设计确保全局名称-维度映射在处理程序退出时重置,而不是潜在地持续到整个程序终止。
- plate(fn=None, name=None, size=None, subsample_size=None, subsample=None, dim=None, use_cuda=None, device=None)[source]¶
PlateMessenger
的便捷包装器将新的
IndepMessenger
实现与现有的pyro.poutine.BroadcastMessenger
结合。最终应能完全替代pyro.plate
。
- replay(fn=None, trace=None, params=None)[source]¶
ReplayMessenger
的便捷包装器此版本的
ReplayMessenger
与原始版本几乎相同,区别在于它对重放的 funsor 值调用to_data()
。这可能会导致不同的解包形状,但应产生正确的分配。
- trace(fn=None, graph_type=None, param_only=None, pack_online=True)[source]¶
TraceMessenger
的便捷包装器设置
pack_online=True
会在线进行打包,而不是事后进行,一旦分布和值可用,就立即将它们转换为 Funsors。设置
pack_online=False
会在执行后计算打包所需的信息。每个采样站点都用一个dim_to_name
字典进行标注,该字典可以直接传递给to_funsor()
。
- vectorized_markov(fn=None, name=None, size=None, dim=None, history=1)[source]¶
VectorizedMarkovMessenger
的便捷包装器用于变量马尔可夫链的构造,旨在利用并行扫描算法高效消除马尔可夫维度。在允许的情况下,
vectorized_markov
可以与markov
互换使用。for 循环生成
int
和 1 维的torch.Tensor
索引:(0, ..., history-1, torch.arange(0, size-history), ..., torch.arange(history, size))
。int
索引用于启动马尔可夫链,而torch.Tensor
索引用于构建向量化转移概率,以便通过并行扫描算法进行高效消除。当
history==0
时,vectorized_markov
的行为类似于plate
。for 循环运行后,会识别出马尔可夫变量,然后构建
step
信息并添加到跟踪中。step
会告知推断算法哪些变量属于马尔可夫链。data = torch.ones(3, dtype=torch.float) def model(data, vectorized=True): init = pyro.param("init", lambda: torch.rand(3), constraint=constraints.simplex) trans = pyro.param("trans", lambda: torch.rand((3, 3)), constraint=constraints.simplex) locs = pyro.param("locs", lambda: torch.rand(3,)) markov_chain = \ pyro.vectorized_markov(name="time", size=len(data), dim=-1) if vectorized \ else pyro.markov(range(len(data))) for i in markov_chain: x_curr = pyro.sample("x_{}".format(i), dist.Categorical( init if isinstance(i, int) and i < 1 else trans[x_prev]), pyro.sample("y_{}".format(i), dist.Normal(Vindex(locs)[..., x_curr], 1.), obs=data[i]) x_prev = x_curr # trace.nodes["time"]["value"] # frozenset({('x_0', 'x_slice(0, 2, None)', 'x_slice(1, 3, None)')}) # # pyro.vectorized_markov trace # ... # Sample Sites: # locs dist | 3 # value | 3 # log_prob | # x_0 dist | # value 3 1 1 1 1 | # log_prob 3 1 1 1 1 | # y_0 dist 3 1 1 1 1 | # value | # log_prob 3 1 1 1 1 | # x_slice(1, 3, None) dist 3 1 1 1 1 2 | # value 3 1 1 1 1 1 1 | # log_prob 3 3 1 1 1 1 2 | # y_slice(1, 3, None) dist 3 1 1 1 1 1 2 | # value 2 | # log_prob 3 1 1 1 1 1 2 | # # pyro.markov trace # ... # Sample Sites: # locs dist | 3 # value | 3 # log_prob | # x_0 dist | # value 3 1 1 1 1 | # log_prob 3 1 1 1 1 | # y_0 dist 3 1 1 1 1 | # value | # log_prob 3 1 1 1 1 | # x_1 dist 3 1 1 1 1 | # value 3 1 1 1 1 1 | # log_prob 3 3 1 1 1 1 | # y_1 dist 3 1 1 1 1 1 | # value | # log_prob 3 1 1 1 1 1 | # x_2 dist 3 1 1 1 1 1 | # value 3 1 1 1 1 | # log_prob 3 3 1 1 1 1 | # y_2 dist 3 1 1 1 1 | # value | # log_prob 3 1 1 1 1 |
警告
这仅在每个分支只有一个马尔可夫维度时有效。
- 参数
- 返回值
返回
int
和 1 维的torch.Tensor
索引:(0, ..., history-1, torch.arange(size-history), ..., torch.arange(history, size))
。
- class NamedMessenger(first_available_dim=None)[source]¶
基类:
pyro.poutine.reentrant_messenger.ReentrantMessenger
用于
to_funsor()
和to_data()
原语的基础效果处理程序类。任何在内部调用这些原语或封装调用这些原语代码的效果处理程序都应继承自NamedMessenger
。这种设计确保全局名称-维度映射在处理程序退出时重置,而不是潜在地持续到整个程序终止。
- class MarkovMessenger(history=1, keep=False)[source]¶
基类:
pyro.contrib.funsor.handlers.named_messenger.NamedMessenger
用于在 Funsor 和与 Pyro 位置批量维度一致的张量之间进行转换的处理程序。
- class GlobalNamedMessenger(first_available_dim=None)[source]¶
基类:
pyro.contrib.funsor.handlers.named_messenger.NamedMessenger
使用
to_funsor()
和to_data()
原语分配DimType.GLOBAL
或DimType.VISIBLE
维度的任何新效果处理程序的基础类。充当不应由
MarkovMessenger
回收的维度的手动“作用域”:全局维度将被视为活动,直到它们最初分配的范围内最里面的GlobalNamedMessenger
退出。
- class StackFrame(name_to_dim, dim_to_name, history=1, keep=False)[source]¶
基类:
object
整数位置维度和名称之间的一致双向映射。可以像字典一样查询(
value = frame[key]
,frame[key] = value
)。
- class DimStack[source]¶
基类:
object
跟踪名称和维度之间映射的全局状态的单个实例。
替换了 plate 的
_DimAllocator
、enum 的_EnumAllocator
、MarkovMessenger
中的stack
、EnumMessenger
中的_param_dims
和_value_dims
,以及msg['infer']
中的dim_to_symbol
。- MAX_DIM = -25¶
- DEFAULT_FIRST_DIM = -5¶
- property global_frame¶
- property local_frame¶
- property current_write_env¶
- property current_read_env¶
收集计算完整名称 <–> 维度映射所需的所有帧,并在计算的任何点解释 Funsor 输入或批量形状。
推断算法¶
- class ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]¶
-
- loss_and_grads(model, guide, *args, **kwargs)[source]¶
参见
pyro.infer.traceenum_elbo.TraceEnum_ELBO.loss_and_grads()
- class Jit_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]¶
- class Trace_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]¶
- class JitTrace_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]¶
基类:
pyro.contrib.funsor.infer.elbo.Jit_ELBO
,pyro.contrib.funsor.infer.trace_elbo.Trace_ELBO
- class TraceMarkovEnum_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]¶
基类:
pyro.contrib.funsor.infer.elbo.ELBO
- differentiable_loss(model, guide, *args, **kwargs)[source]¶
参见
pyro.infer.traceenum_elbo.TraceEnum_ELBO.differentiable_loss()
- class TraceEnum_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]¶
基类:
pyro.contrib.funsor.infer.elbo.ELBO
- differentiable_loss(model, guide, *args, **kwargs)[source]¶
参见
pyro.infer.traceenum_elbo.TraceEnum_ELBO.differentiable_loss()
- class JitTraceEnum_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]¶
基类:
pyro.contrib.funsor.infer.elbo.Jit_ELBO
,pyro.contrib.funsor.infer.traceenum_elbo.TraceEnum_ELBO
- class JitTraceMarkovEnum_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]¶
基类:
pyro.contrib.funsor.infer.elbo.Jit_ELBO
,pyro.contrib.funsor.infer.traceenum_elbo.TraceMarkovEnum_ELBO
- class TraceTMC_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]¶
基类:
pyro.contrib.funsor.infer.elbo.ELBO
- differentiable_loss(model, guide, *args, **kwargs)[source]¶
参见
pyro.infer.tracetmc_elbo.TraceTMC_ELBO.differentiable_loss()
- class JitTraceTMC_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]¶
基类:
pyro.contrib.funsor.infer.elbo.Jit_ELBO
,pyro.contrib.funsor.infer.tracetmc_elbo.TraceTMC_ELBO