原语¶
- get_param_store() pyro.params.param_store.ParamStoreDict [source]¶
返回全局的
ParamStoreDict
。
- clear_param_store() None [source]¶
清除全局的
ParamStoreDict
。这在使用 REPL 时尤其有用。我们建议在每个训练循环之前(以避免泄露之前模型的参数)和每个单元测试之前(以避免测试之间泄露参数)调用此函数。
- param(name: str, init_tensor: Optional[Union[torch.Tensor, Callable[[], torch.Tensor]]] = None, constraint: torch.distributions.constraints.Constraint = Real(), event_dim: Optional[int] = None) torch.Tensor [source]¶
将变量作为参数保存在参数存储中。要与参数存储交互或写入磁盘,请参阅参数。
- 参数
name (str) – 参数名称
init_tensor (torch.Tensor 或 callable) – 初始张量或返回张量的延迟可调用对象。对于大型张量,编写例如
lambda: torch.randn(100000)
可能更经济,因为它只会在初始声明时被求值。constraint (torch.distributions.constraints.Constraint) – torch 约束,默认为
constraints.real
。event_dim (int) – (可选) 最右边与批量处理无关的维度数量。此维度左侧的维度将被视为批量维度;如果 param 语句位于一个子采样 plate 内,则参数的相应批量维度将进行相应的子采样。如果未指定,所有维度都将被视为事件维度,并且不会执行子采样。
- 返回值
一个受约束的参数。可以通过
pyro.param(...).unconstrained()
访问底层无约束的参数,其中.unconstrained
是一个弱引用属性。- 返回类型
- sample(name: str, fn: pyro.distributions.torch_distribution.TorchDistributionMixin, *args, obs: Optional[torch.Tensor] = None, obs_mask: Optional[torch.BoolTensor] = None, infer: Optional[pyro.poutine.runtime.InferDict] = None, **kwargs) torch.Tensor [source]¶
调用随机函数
fn
,并根据name
和外部上下文(例如推断算法)产生额外的副作用。有关讨论,请参阅Pyro 简介。
- factor(name: str, log_factor: torch.Tensor, *, has_rsample: Optional[bool] = None) None [source]¶
Factor 语句,用于向概率模型添加任意对数概率因子。
警告
在 guide 中使用 factor 语句时,需要指定 factor 语句是源自完全重参数化采样(例如重参数化变量变换的 Jacobian 行列式)还是非重参数化采样(例如离散样本)。对于完全重参数化的情况,设置
has_rsample=True
;对于非重参数化的情况,设置has_rsample=False
。这仅在 guide 中需要,模型中不需要。- 参数
name (str) – 简单采样的名称
log_factor (torch.Tensor) – 可能批量化的对数概率因子。
has_rsample (bool) –
log_factor
是否源自完全重参数化分布。在模型中使用时默认为 False,但在 guide 中使用时必须指定。
- deterministic(name: str, value: torch.Tensor, event_dim: Optional[int] = None) torch.Tensor [source]¶
确定性语句,用于向 trace 添加一个名为 name 且值为 value 的
Delta
站点。这在我们想要记录完全由其父节点确定的值时很有用。例如x = pyro.sample("x", dist.Normal(0, 1)) x2 = pyro.deterministic("x2", x ** 2)
注意
该站点不影响模型密度。它目前转换为
sample()
语句,但将来可能会改变。- 参数
name (str) – 站点的名称。
value (torch.Tensor) – 站点的值。
event_dim (int) – 可选事件维度,默认为 value.ndim。
- subsample(data: torch.Tensor, event_dim: int) torch.Tensor [source]¶
子采样语句,用于根据外部的
plate
对数据张量进行子采样。当由
plate
通过传递subsample
或subsample_size
kwarg 自动执行子采样时,通常在model()
的参数上调用此函数。例如,以下两种写法是等效的# Version 1. using indexing def model(data): with pyro.plate("data", len(data), subsample_size=10, dim=-data.dim()) as ind: data = data[ind] # ... # Version 2. using pyro.subsample() def model(data): with pyro.plate("data", len(data), subsample_size=10, dim=-data.dim()): data = pyro.subsample(data, event_dim=0) # ...
- class plate(name: str, size: Optional[int] = None, subsample_size: Optional[int] = None, subsample: Optional[torch.Tensor] = None, dim: Optional[int] = None, use_cuda: Optional[bool] = None, device: Optional[str] = None)[source]¶
基类:
pyro.poutine.plate_messenger.PlateMessenger
构建条件独立变量序列的结构。
plate
可以顺序地用作生成器,也可以并行地用作上下文管理器(以前分别为irange
和iarange
)。顺序的
plate
类似于内置函数range()
,因为它生成一系列值。矢量化的
plate
类似于torch.arange()
,因为它生成一个索引数组,其他张量可以通过该数组进行索引。plate
与torch.arange()
的不同之处在于,它还会通知推断算法,被索引的变量是条件独立的。为此,plate
作为上下文管理器提供,而不是一个函数,并且用户必须保证plate
上下文中的所有计算都是条件独立的with pyro.plate("name", size) as ind: # ...do conditionally independent stuff with ind...
此外,
plate
可以通过对索引进行子采样并通知推断算法缩放各种计算值来利用条件独立性假设。这通常用于对数据进行小批量子采样with pyro.plate("data", len(data), subsample_size=100) as ind: batch = data[ind] assert len(batch) == 100
默认情况下,
subsample_size=False
,这只会生成一个torch.arange(0, size)
。如果0 < subsample_size <= size
,这将生成一个大小为subsample_size
的随机批量索引,并在该上下文中将所有对数似然项按size/batch_size
进行缩放。警告
只有当上下文中的所有计算都是条件独立时,这才是正确的。
- 参数
size (int) – 可选的待子采样的集合大小(类似于内置函数 range 中的 stop)。
subsample_size (int) – 子采样中使用的小批量大小。默认为 size。
subsample (支持 len() 的任何类型。) – 可选的自定义子采样,用于用户定义的子采样方案。如果指定,则 subsample_size 将设置为 len(subsample)。
dim (int) – 用于此独立性索引的可选维度。如果指定,
dim
应该为负数,即从右侧开始索引。如果未指定,dim
将设置为所有外部plate
上下文左侧的最右边的维度。use_cuda (bool) – 已弃用,请改用 device 参数。可选的布尔值,指定是否对 subsample 和 log_prob 使用 cuda 张量。默认为
torch.Tensor.is_cuda
。device (str) – 可选关键字,指定 subsample 和 log_prob 的结果放置在哪个设备上。默认情况下,结果放置在与默认张量相同的设备上。
- 返回值
一个可重用的上下文管理器,生成一个单维度的索引
torch.Tensor
。
示例
>>> # This version declares sequential independence and subsamples data: >>> for i in pyro.plate('data', 100, subsample_size=10): ... if z[i]: # Control flow in this example prevents vectorization. ... obs = pyro.sample(f'obs_{i}', dist.Normal(loc, scale), ... obs=data[i])
>>> # This version declares vectorized independence: >>> with pyro.plate('data'): ... obs = pyro.sample('obs', dist.Normal(loc, scale), obs=data)
>>> # This version subsamples data in vectorized way: >>> with pyro.plate('data', 100, subsample_size=10) as ind: ... obs = pyro.sample('obs', dist.Normal(loc, scale), obs=data[ind])
>>> # This wraps a user-defined subsampling method for use in pyro: >>> ind = torch.randint(0, 100, (10,)).long() # custom subsample >>> with pyro.plate('data', 100, subsample=ind): ... obs = pyro.sample('obs', dist.Normal(loc, scale), obs=data[ind])
>>> # This reuses two different independence contexts. >>> x_axis = pyro.plate('outer', 320, dim=-1) >>> y_axis = pyro.plate('inner', 200, dim=-2) >>> with x_axis: ... x_noise = pyro.sample("x_noise", dist.Normal(loc, scale)) ... assert x_noise.shape == (320,) >>> with y_axis: ... y_noise = pyro.sample("y_noise", dist.Normal(loc, scale)) ... assert y_noise.shape == (200, 1) >>> with x_axis, y_axis: ... xy_noise = pyro.sample("xy_noise", dist.Normal(loc, scale)) ... assert xy_noise.shape == (200, 320)
有关详细讨论,请参阅SVI Part II。
- plate_stack(prefix: str, sizes: Sequence[int], rightmost_dim: int = - 1) Iterator[None] [source]¶
创建具有以下维度的一个连续的
plate
栈rightmost_dim - len(sizes), ..., rightmost_dim
- module(name: str, nn_module: torch.nn.modules.module.Module, update_module_params: bool = False) torch.nn.modules.module.Module [source]¶
将
torch.nn.Module
的所有参数注册到 Pyro 的param_store
中。结合ParamStoreDict
的save()
和load()
功能,这允许用户保存和加载模块。注意
请考虑改用
PyroModule
,它是pyro.module()
的一个更新的替代方案,对以下方面有更好的支持:jit 编译、在 C++ 中提供服务以及将参数转换为随机变量。详情请参阅模块教程。- 参数
name (str) – 模块名称
nn_module (torch.nn.Module) – 要注册到 Pyro 的模块
update_module_params – 确定 PyTorch 模块中的参数是否会被 ParamStore 中的值覆盖(如果存在)。默认为 False
- 返回值
torch.nn.Module
- random_module(name, nn_module, prior, *args, **kwargs)[source]¶
警告
random_module 原语已弃用,并将在未来版本中移除。请改用
PyroModule
从torch.nn.Module
实例创建贝叶斯模块。有关示例,请参阅贝叶斯回归教程。已弃用 在模块 nn_module 的参数上放置一个先验分布。返回一个关于 nn.Module 的分布(可调用对象),调用时返回一个采样的 nn.Module。
- 参数
name (str) – pyro 模块名称
nn_module (torch.nn.Module) – 要注册到 pyro 的模块
prior – pyro 分布、随机函数或 python 字典,其中参数名称作为键,相应的分布/随机函数作为值。
- 返回值
一个返回采样模块的可调用对象
- barrier(data: torch.Tensor) torch.Tensor [source]¶
实验性 确保
data
中的所有值都是具体值,而不是延迟的 funsor 值。这与pyro.poutine.collapse()
结合使用时很有用。
- enable_validation(is_validate: bool = True) None [source]¶
启用或禁用 Pyro 中的验证检查。验证检查提供有用的警告和错误,例如 NaN 检查、验证分布参数和支持值、检测 ELBO 和 MCMC 的错误使用。由于其中一些检查可能开销较大,您可能希望对成熟模型禁用验证以加快推断速度。
默认行为模仿 Python 的
assert
语句:验证默认开启,但在 Python 以优化模式运行(通过python -O
)时禁用。同样地,默认行为通过pyro.enable_validation(__debug__)
取决于 Python 的全局__debug__
值。在 jit 编译期间,对于所有支持 PyTorch jit 的推断算法,验证会暂时禁用。我们建议使用非 jit 编译的推断算法进行模型开发以简化调试,然后在模型正确后选择性地转向 jit 编译的推断。
- 参数
is_validate (bool) – (可选;默认为 True) 是否启用验证检查。
- validation_enabled(is_validate: bool = True) Iterator[None] [source]¶
上下文管理器,在临时启用/禁用验证检查时很有用。
- 参数
is_validate (bool) – (可选;默认为 True) 临时覆盖验证检查。
- trace(fn=None, ignore_warnings=False, jit_options=None)[source]¶
torch.jit.trace()
的惰性替代品,适用于调用pyro.param()
的 Pyro 函数。实际的编译产物存储在输出的
compiled
属性中。调用此属性上的诊断方法。示例
def model(x): scale = pyro.param("scale", torch.tensor(0.5), constraint=constraints.positive) return pyro.sample("y", dist.Normal(x, scale)) @pyro.ops.jit.trace def model_log_prob_fn(x, y): cond_model = pyro.condition(model, data={"y": y}) tr = pyro.poutine.trace(cond_model).get_trace(x) return tr.log_prob_sum()
- 参数
fn (callable) – 要跟踪的函数。
ignore_warnins (bool) – 是否忽略 jit 警告。
jit_options (dict) – 可选的选项字典,传递给
torch.jit.trace()
,例如{"optimize": False}
。