参数

Pyro 中的参数本质上是 PyTorch Tensor 的轻量级封装,带有唯一的名称。因此,参数是 Pyro 中主要的状态对象。用户通常通过 Pyro 原语 pyro.param 与参数交互。参数在随机变分推断中扮演着核心角色,用于表示参数化模型和指南中参数的点估计。

参数存储

StateDict[源代码]

基类: typing_extensions.TypedDict

params: Dict[str, torch.Tensor]
constraints: Dict[str, torch.distributions.constraints.Constraint]
ParamStoreDict[源代码]

基类: object

Pyro 中的全局参数存储。它本质上是一个键值存储。典型用户主要通过原语 pyro.param 与 ParamStore 交互。

请参阅 引言 以获取进一步讨论,并参阅 SVI 第一部分 以获取一些示例。

在 Pyro 中使用参数时需要牢记的一些事项

  • 参数必须分配唯一的名称

  • 传递给 pyro.paraminit_tensor 参数仅在给定(命名)参数首次注册到 Pyro 时使用。

  • 因此,如果用户在 REPL 中工作,可能需要使用 clear() 方法来获得期望的行为。此方法也可以通过 pyro.clear_param_store() 调用。

  • 注册到 Pyro 的 PyTorch nn.Module 中参数的内部名称会加上模块的 Pyro 名称前缀。因此,这并不妨碍用户拥有两个不同的模块,每个模块都包含一个名为 weight 的参数。相比之下,用户只能拥有一个顶级参数(在任何模块之外)名为 weight

  • 可以使用 saveload 将参数保存到磁盘或从磁盘加载。

  • 通常,参数与 受约束的 值和 无约束的 值相关联。例如,在底层,一个约束为正的参数在对数空间中表示为无约束张量。

clear() None[源代码]

清空 ParamStore

items() Iterator[Tuple[str, torch.Tensor]][源代码]

遍历 (name, constrained_param) 对。注意 constrained_param 处于受约束的(即用户可见的)空间。

keys() KeysView[str][源代码]

遍历参数名称。

values() Iterator[torch.Tensor][源代码]

遍历受约束的参数值。

setdefault(name: str, init_constrained_value: Union[torch.Tensor, Callable[[], torch.Tensor]], constraint: torch.distributions.constraints.Constraint = Real()) torch.Tensor[源代码]

ParamStoreDict 中检索 受约束的 参数值(如果存在),否则设置初始值。注意,这比 dict.setdefault() 稍微复杂一些。

如果参数已存在,则 init_constrained_tensor 将被忽略。为了避免昂贵地创建 init_constrained_tensor,您可以将其封装在一个 lambda 中,该 lambda 仅在参数不存在时才会被评估。

param_store.get("foo", lambda: (0.001 * torch.randn(1000, 1000)).exp(),
                constraint=constraints.positive)
参数
  • name (str) – 参数名称

  • init_constrained_value (torch.Tensor返回 torch.Tensor 的可调用对象) – 初始受约束值

  • constraint (Constraint) – torch 约束对象

返回

受约束的参数值

返回类型

torch.Tensor

named_parameters() ItemsView[str, torch.Tensor][源代码]

返回 ParamStore 中每个参数的 (name, unconstrained_value) 元组的迭代器。注意,如果参数受到约束,unconstrained_value 将位于约束隐式使用的无约束空间中。

get_all_param_names() KeysView[str][源代码]
replace_param(param_name: str, new_param: torch.Tensor, old_param: torch.Tensor) None[源代码]
get_param(name: str, init_tensor: Optional[torch.Tensor] = None, constraint: torch.distributions.constraints.Constraint = Real(), event_dim: Optional[int] = None) torch.Tensor[源代码]

通过名称获取参数。如果 ParamStore 中不存在,则会创建并存储。Pyro 原语 pyro.param 会分派到此方法。

参数
返回

参数

返回类型

torch.Tensor

match(name: str) Dict[str, torch.Tensor][源代码]

获取所有匹配正则表达式的参数。参数必须存在。

参数

name (str) – 正则表达式

返回

字典,键为参数名称,值为 torch Tensor

param_name(p: torch.Tensor) Optional[str][源代码]

从参数获取参数名称

参数

p – 参数

返回

参数名称

get_state() pyro.params.param_store.StateDict[源代码]

获取 ParamStore 状态。

set_state(state: pyro.params.param_store.StateDict) None[源代码]

使用先前 get_state() 调用获取的状态设置 ParamStore 状态

save(filename: str) None[源代码]

将参数保存到文件

参数

filename (str) – 要保存到的文件名

load(filename: str, map_location: Optional[Union[Callable[[torch.Tensor, str], torch.Tensor], torch.device, str, Dict[str, str]]] = None) None[源代码]

从文件加载参数

注意

如果在从磁盘加载的参数上使用 pyro.module(),请务必设置 update_module_params 标志

pyro.get_param_store().load('saved_params.save')
pyro.module('module', nn, update_module_params=True)
参数
  • filename (str) – 要加载的文件名

  • map_location (function, torch.device, string 字典) – 指定如何重新映射存储位置

scope(state: Optional[pyro.params.param_store.StateDict] = None) Iterator[pyro.params.param_store.StateDict][源代码]

在同一进程中使用多个参数存储的上下文管理器。

这是 get_state(), clear()set_state() 的轻量级封装。对于内存空间受限的大型模型,您可能需要手动使用 save(), clear()load()

示例用法

param_store = pyro.get_param_store()

# Train multiple models, while avoiding param name conflicts.
with param_store.scope() as scope1:
    # ...Train one model,guide pair...
with param_store.scope() as scope2:
    # ...Train another model,guide pair...

# Now evaluate each, still avoiding name conflicts.
with param_store.scope(scope1):  # loads the first model's scope
   # ...evaluate the first model...
with param_store.scope(scope2):  # loads the second model's scope
   # ...evaluate the second model...
param_with_module_name(pyro_name: str, param_name: str) str[源代码]
module_from_param_with_module_name(param_name: str) str[源代码]
user_param_name(param_name: str) str[源代码]
normalize_param_name(name: str) str[源代码]