原语

get_param_store() pyro.params.param_store.ParamStoreDict[source]

返回全局的 ParamStoreDict

clear_param_store() None[source]

清除全局的 ParamStoreDict

这在使用 REPL 时尤其有用。我们建议在每个训练循环之前(以避免泄露之前模型的参数)和每个单元测试之前(以避免测试之间泄露参数)调用此函数。

param(name: str, init_tensor: Optional[Union[torch.Tensor, Callable[[], torch.Tensor]]] = None, constraint: torch.distributions.constraints.Constraint = Real(), event_dim: Optional[int] = None) torch.Tensor[source]

将变量作为参数保存在参数存储中。要与参数存储交互或写入磁盘,请参阅参数

参数
  • name (str) – 参数名称

  • init_tensor (torch.Tensorcallable) – 初始张量或返回张量的延迟可调用对象。对于大型张量,编写例如 lambda: torch.randn(100000) 可能更经济,因为它只会在初始声明时被求值。

  • constraint (torch.distributions.constraints.Constraint) – torch 约束,默认为 constraints.real

  • event_dim (int) – (可选) 最右边与批量处理无关的维度数量。此维度左侧的维度将被视为批量维度;如果 param 语句位于一个子采样 plate 内,则参数的相应批量维度将进行相应的子采样。如果未指定,所有维度都将被视为事件维度,并且不会执行子采样。

返回值

一个受约束的参数。可以通过 pyro.param(...).unconstrained() 访问底层无约束的参数,其中 .unconstrained 是一个弱引用属性。

返回类型

torch.Tensor

sample(name: str, fn: pyro.distributions.torch_distribution.TorchDistributionMixin, *args, obs: Optional[torch.Tensor] = None, obs_mask: Optional[torch.BoolTensor] = None, infer: Optional[pyro.poutine.runtime.InferDict] = None, **kwargs) torch.Tensor[source]

调用随机函数 fn,并根据 name 和外部上下文(例如推断算法)产生额外的副作用。有关讨论,请参阅Pyro 简介

参数
  • name – 采样名称

  • fn – 分布类或函数

  • obs – 观测数据(可选;仅应在推断上下文中使用),可通过 kwargs 指定

  • obs_mask (boolTensor) – 可选的布尔张量掩码,其形状可与 fn.batch_shape 广播。如果提供,mask=True 的事件将以 obs 为条件,其余事件将通过采样进行填充。这将引入一个名为 name + "_unobserved" 的潜在采样站点,该站点应由 guide 使用。

  • infer (dict) – 可选的推断参数字典,在 kwargs 中指定。详情请参阅推断文档。

返回值

sample

factor(name: str, log_factor: torch.Tensor, *, has_rsample: Optional[bool] = None) None[source]

Factor 语句,用于向概率模型添加任意对数概率因子。

警告

在 guide 中使用 factor 语句时,需要指定 factor 语句是源自完全重参数化采样(例如重参数化变量变换的 Jacobian 行列式)还是非重参数化采样(例如离散样本)。对于完全重参数化的情况,设置 has_rsample=True;对于非重参数化的情况,设置 has_rsample=False。这仅在 guide 中需要,模型中不需要。

参数
  • name (str) – 简单采样的名称

  • log_factor (torch.Tensor) – 可能批量化的对数概率因子。

  • has_rsample (bool) – log_factor 是否源自完全重参数化分布。在模型中使用时默认为 False,但在 guide 中使用时必须指定。

deterministic(name: str, value: torch.Tensor, event_dim: Optional[int] = None) torch.Tensor[source]

确定性语句,用于向 trace 添加一个名为 name 且值为 valueDelta 站点。这在我们想要记录完全由其父节点确定的值时很有用。例如

x = pyro.sample("x", dist.Normal(0, 1))
x2 = pyro.deterministic("x2", x ** 2)

注意

该站点不影响模型密度。它目前转换为 sample() 语句,但将来可能会改变。

参数
  • name (str) – 站点的名称。

  • value (torch.Tensor) – 站点的值。

  • event_dim (int) – 可选事件维度,默认为 value.ndim

subsample(data: torch.Tensor, event_dim: int) torch.Tensor[source]

子采样语句,用于根据外部的 plate 对数据张量进行子采样。

当由 plate 通过传递 subsamplesubsample_size kwarg 自动执行子采样时,通常在 model() 的参数上调用此函数。例如,以下两种写法是等效的

# Version 1. using indexing
def model(data):
    with pyro.plate("data", len(data), subsample_size=10, dim=-data.dim()) as ind:
        data = data[ind]
        # ...

# Version 2. using pyro.subsample()
def model(data):
    with pyro.plate("data", len(data), subsample_size=10, dim=-data.dim()):
        data = pyro.subsample(data, event_dim=0)
        # ...
参数
  • data (Tensor) – 批量数据的张量。

  • event_dim (int) – 数据张量的事件维度。左侧的维度被视为批量维度。

返回值

子采样版本的 data

返回类型

Tensor

class plate(name: str, size: Optional[int] = None, subsample_size: Optional[int] = None, subsample: Optional[torch.Tensor] = None, dim: Optional[int] = None, use_cuda: Optional[bool] = None, device: Optional[str] = None)[source]

基类: pyro.poutine.plate_messenger.PlateMessenger

构建条件独立变量序列的结构。

plate 可以顺序地用作生成器,也可以并行地用作上下文管理器(以前分别为 irangeiarange)。

顺序的 plate 类似于内置函数 range(),因为它生成一系列值。

矢量化的 plate 类似于 torch.arange(),因为它生成一个索引数组,其他张量可以通过该数组进行索引。platetorch.arange() 的不同之处在于,它还会通知推断算法,被索引的变量是条件独立的。为此,plate 作为上下文管理器提供,而不是一个函数,并且用户必须保证 plate 上下文中的所有计算都是条件独立的

with pyro.plate("name", size) as ind:
    # ...do conditionally independent stuff with ind...

此外,plate 可以通过对索引进行子采样并通知推断算法缩放各种计算值来利用条件独立性假设。这通常用于对数据进行小批量子采样

with pyro.plate("data", len(data), subsample_size=100) as ind:
    batch = data[ind]
    assert len(batch) == 100

默认情况下,subsample_size=False,这只会生成一个 torch.arange(0, size)。如果 0 < subsample_size <= size,这将生成一个大小为 subsample_size 的随机批量索引,并在该上下文中将所有对数似然项按 size/batch_size 进行缩放。

警告

只有当上下文中的所有计算都是条件独立时,这才是正确的。

参数
  • name (str) – 一个唯一的名称,用于帮助推断算法匹配模型和 guide 之间的 plate 站点。

  • size (int) – 可选的待子采样的集合大小(类似于内置函数 range 中的 stop)。

  • subsample_size (int) – 子采样中使用的小批量大小。默认为 size

  • subsample (支持 len() 的任何类型。) – 可选的自定义子采样,用于用户定义的子采样方案。如果指定,则 subsample_size 将设置为 len(subsample)

  • dim (int) – 用于此独立性索引的可选维度。如果指定,dim 应该为负数,即从右侧开始索引。如果未指定,dim 将设置为所有外部 plate 上下文左侧的最右边的维度。

  • use_cuda (bool) – 已弃用,请改用 device 参数。可选的布尔值,指定是否对 subsamplelog_prob 使用 cuda 张量。默认为 torch.Tensor.is_cuda

  • device (str) – 可选关键字,指定 subsamplelog_prob 的结果放置在哪个设备上。默认情况下,结果放置在与默认张量相同的设备上。

返回值

一个可重用的上下文管理器,生成一个单维度的索引 torch.Tensor

示例

>>> # This version declares sequential independence and subsamples data:
>>> for i in pyro.plate('data', 100, subsample_size=10):
...     if z[i]:  # Control flow in this example prevents vectorization.
...         obs = pyro.sample(f'obs_{i}', dist.Normal(loc, scale),
...                           obs=data[i])
>>> # This version declares vectorized independence:
>>> with pyro.plate('data'):
...     obs = pyro.sample('obs', dist.Normal(loc, scale), obs=data)
>>> # This version subsamples data in vectorized way:
>>> with pyro.plate('data', 100, subsample_size=10) as ind:
...     obs = pyro.sample('obs', dist.Normal(loc, scale), obs=data[ind])
>>> # This wraps a user-defined subsampling method for use in pyro:
>>> ind = torch.randint(0, 100, (10,)).long() # custom subsample
>>> with pyro.plate('data', 100, subsample=ind):
...     obs = pyro.sample('obs', dist.Normal(loc, scale), obs=data[ind])
>>> # This reuses two different independence contexts.
>>> x_axis = pyro.plate('outer', 320, dim=-1)
>>> y_axis = pyro.plate('inner', 200, dim=-2)
>>> with x_axis:
...     x_noise = pyro.sample("x_noise", dist.Normal(loc, scale))
...     assert x_noise.shape == (320,)
>>> with y_axis:
...     y_noise = pyro.sample("y_noise", dist.Normal(loc, scale))
...     assert y_noise.shape == (200, 1)
>>> with x_axis, y_axis:
...     xy_noise = pyro.sample("xy_noise", dist.Normal(loc, scale))
...     assert xy_noise.shape == (200, 320)

有关详细讨论,请参阅SVI Part II

plate_stack(prefix: str, sizes: Sequence[int], rightmost_dim: int = - 1) Iterator[None][source]

创建具有以下维度的一个连续的 plate

rightmost_dim - len(sizes), ..., rightmost_dim
参数
  • prefix (str) – plate 的名称前缀。

  • sizes (iterable) – plate 大小的可迭代对象。

  • rightmost_dim (int) – 最右侧的维度,从右边开始计数。

module(name: str, nn_module: torch.nn.modules.module.Module, update_module_params: bool = False) torch.nn.modules.module.Module[source]

torch.nn.Module 的所有参数注册到 Pyro 的 param_store 中。结合 ParamStoreDictsave()load() 功能,这允许用户保存和加载模块。

注意

请考虑改用 PyroModule,它是 pyro.module() 的一个更新的替代方案,对以下方面有更好的支持:jit 编译、在 C++ 中提供服务以及将参数转换为随机变量。详情请参阅模块教程

参数
  • name (str) – 模块名称

  • nn_module (torch.nn.Module) – 要注册到 Pyro 的模块

  • update_module_params – 确定 PyTorch 模块中的参数是否会被 ParamStore 中的值覆盖(如果存在)。默认为 False

返回值

torch.nn.Module

random_module(name, nn_module, prior, *args, **kwargs)[source]

警告

random_module 原语已弃用,并将在未来版本中移除。请改用 PyroModuletorch.nn.Module 实例创建贝叶斯模块。有关示例,请参阅贝叶斯回归教程

已弃用 在模块 nn_module 的参数上放置一个先验分布。返回一个关于 nn.Module 的分布(可调用对象),调用时返回一个采样的 nn.Module

参数
  • name (str) – pyro 模块名称

  • nn_module (torch.nn.Module) – 要注册到 pyro 的模块

  • prior – pyro 分布、随机函数或 python 字典,其中参数名称作为键,相应的分布/随机函数作为值。

返回值

一个返回采样模块的可调用对象

barrier(data: torch.Tensor) torch.Tensor[source]

实验性 确保 data 中的所有值都是具体值,而不是延迟的 funsor 值。这与 pyro.poutine.collapse() 结合使用时很有用。

enable_validation(is_validate: bool = True) None[source]

启用或禁用 Pyro 中的验证检查。验证检查提供有用的警告和错误,例如 NaN 检查、验证分布参数和支持值、检测 ELBO 和 MCMC 的错误使用。由于其中一些检查可能开销较大,您可能希望对成熟模型禁用验证以加快推断速度。

默认行为模仿 Python 的 assert 语句:验证默认开启,但在 Python 以优化模式运行(通过 python -O)时禁用。同样地,默认行为通过 pyro.enable_validation(__debug__) 取决于 Python 的全局 __debug__ 值。

在 jit 编译期间,对于所有支持 PyTorch jit 的推断算法,验证会暂时禁用。我们建议使用非 jit 编译的推断算法进行模型开发以简化调试,然后在模型正确后选择性地转向 jit 编译的推断。

参数

is_validate (bool) – (可选;默认为 True) 是否启用验证检查。

validation_enabled(is_validate: bool = True) Iterator[None][source]

上下文管理器,在临时启用/禁用验证检查时很有用。

参数

is_validate (bool) – (可选;默认为 True) 临时覆盖验证检查。

trace(fn=None, ignore_warnings=False, jit_options=None)[source]

torch.jit.trace() 的惰性替代品,适用于调用 pyro.param() 的 Pyro 函数。

实际的编译产物存储在输出的 compiled 属性中。调用此属性上的诊断方法。

示例

def model(x):
    scale = pyro.param("scale", torch.tensor(0.5), constraint=constraints.positive)
    return pyro.sample("y", dist.Normal(x, scale))

@pyro.ops.jit.trace
def model_log_prob_fn(x, y):
    cond_model = pyro.condition(model, data={"y": y})
    tr = pyro.poutine.trace(cond_model).get_trace(x)
    return tr.log_prob_sum()
参数
  • fn (callable) – 要跟踪的函数。

  • ignore_warnins (bool) – 是否忽略 jit 警告。

  • jit_options (dict) – 可选的选项字典,传递给 torch.jit.trace(),例如 {"optimize": False}