参数¶
Pyro 中的参数本质上是 PyTorch Tensor 的轻量级封装,带有唯一的名称。因此,参数是 Pyro 中主要的状态对象。用户通常通过 Pyro 原语 pyro.param 与参数交互。参数在随机变分推断中扮演着核心角色,用于表示参数化模型和指南中参数的点估计。
参数存储¶
- 类 StateDict[源代码]¶
基类:
typing_extensions.TypedDict
- params: Dict[str, torch.Tensor]¶
- constraints: Dict[str, torch.distributions.constraints.Constraint]¶
- 类 ParamStoreDict[源代码]¶
基类:
object
Pyro 中的全局参数存储。它本质上是一个键值存储。典型用户主要通过原语 pyro.param 与 ParamStore 交互。
请参阅 引言 以获取进一步讨论,并参阅 SVI 第一部分 以获取一些示例。
在 Pyro 中使用参数时需要牢记的一些事项
参数必须分配唯一的名称
传递给 pyro.param 的 init_tensor 参数仅在给定(命名)参数首次注册到 Pyro 时使用。
因此,如果用户在 REPL 中工作,可能需要使用 clear() 方法来获得期望的行为。此方法也可以通过 pyro.clear_param_store() 调用。
注册到 Pyro 的 PyTorch nn.Module 中参数的内部名称会加上模块的 Pyro 名称前缀。因此,这并不妨碍用户拥有两个不同的模块,每个模块都包含一个名为 weight 的参数。相比之下,用户只能拥有一个顶级参数(在任何模块之外)名为 weight。
可以使用 save 和 load 将参数保存到磁盘或从磁盘加载。
通常,参数与 受约束的 值和 无约束的 值相关联。例如,在底层,一个约束为正的参数在对数空间中表示为无约束张量。
- items() Iterator[Tuple[str, torch.Tensor]] [源代码]¶
遍历
(name, constrained_param)
对。注意 constrained_param 处于受约束的(即用户可见的)空间。
- values() Iterator[torch.Tensor] [源代码]¶
遍历受约束的参数值。
- setdefault(name: str, init_constrained_value: Union[torch.Tensor, Callable[[], torch.Tensor]], constraint: torch.distributions.constraints.Constraint = Real()) torch.Tensor [源代码]¶
从
ParamStoreDict
中检索 受约束的 参数值(如果存在),否则设置初始值。注意,这比dict.setdefault()
稍微复杂一些。如果参数已存在,则
init_constrained_tensor
将被忽略。为了避免昂贵地创建init_constrained_tensor
,您可以将其封装在一个lambda
中,该lambda
仅在参数不存在时才会被评估。param_store.get("foo", lambda: (0.001 * torch.randn(1000, 1000)).exp(), constraint=constraints.positive)
- 参数
name (str) – 参数名称
init_constrained_value (torch.Tensor 或 返回 torch.Tensor 的可调用对象) – 初始受约束值
constraint (Constraint) – torch 约束对象
- 返回
受约束的参数值
- 返回类型
- named_parameters() ItemsView[str, torch.Tensor] [源代码]¶
返回 ParamStore 中每个参数的
(name, unconstrained_value)
元组的迭代器。注意,如果参数受到约束,unconstrained_value 将位于约束隐式使用的无约束空间中。
- replace_param(param_name: str, new_param: torch.Tensor, old_param: torch.Tensor) None [源代码]¶
- get_param(name: str, init_tensor: Optional[torch.Tensor] = None, constraint: torch.distributions.constraints.Constraint = Real(), event_dim: Optional[int] = None) torch.Tensor [源代码]¶
通过名称获取参数。如果 ParamStore 中不存在,则会创建并存储。Pyro 原语 pyro.param 会分派到此方法。
- 参数
name (str) – 参数名称
init_tensor (torch.Tensor) – 初始张量
constraint (torch.distributions.constraints.Constraint) – torch 约束
event_dim (int) – (忽略)
- 返回
参数
- 返回类型
- match(name: str) Dict[str, torch.Tensor] [源代码]¶
获取所有匹配正则表达式的参数。参数必须存在。
- 参数
name (str) – 正则表达式
- 返回
字典,键为参数名称,值为 torch Tensor
- param_name(p: torch.Tensor) Optional[str] [源代码]¶
从参数获取参数名称
- 参数
p – 参数
- 返回
参数名称
- get_state() pyro.params.param_store.StateDict [源代码]¶
获取 ParamStore 状态。
- set_state(state: pyro.params.param_store.StateDict) None [源代码]¶
使用先前
get_state()
调用获取的状态设置 ParamStore 状态
- load(filename: str, map_location: Optional[Union[Callable[[torch.Tensor, str], torch.Tensor], torch.device, str, Dict[str, str]]] = None) None [源代码]¶
从文件加载参数
注意
如果在从磁盘加载的参数上使用
pyro.module()
,请务必设置update_module_params
标志pyro.get_param_store().load('saved_params.save') pyro.module('module', nn, update_module_params=True)
- 参数
filename (str) – 要加载的文件名
map_location (function, torch.device, string 或 字典) – 指定如何重新映射存储位置
- scope(state: Optional[pyro.params.param_store.StateDict] = None) Iterator[pyro.params.param_store.StateDict] [源代码]¶
在同一进程中使用多个参数存储的上下文管理器。
这是
get_state()
,clear()
和set_state()
的轻量级封装。对于内存空间受限的大型模型,您可能需要手动使用save()
,clear()
和load()
。示例用法
param_store = pyro.get_param_store() # Train multiple models, while avoiding param name conflicts. with param_store.scope() as scope1: # ...Train one model,guide pair... with param_store.scope() as scope2: # ...Train another model,guide pair... # Now evaluate each, still avoiding name conflicts. with param_store.scope(scope1): # loads the first model's scope # ...evaluate the first model... with param_store.scope(scope2): # loads the second model's scope # ...evaluate the second model...