自动名称生成¶
The pyro.contrib.autoname
模块提供了用于自动生成独特、语义有意义的样本站点名称的工具。
- scope(fn=None, prefix=None, inner=None)[source]¶
- 参数
fn – 一个随机函数(包含 Pyro 原语调用的可调用对象)
prefix – 要添加到样本名称前的字符串(如果提供了
fn
则可选)inner – 用于确定重复名称计数器出现位置的开关
- 返回值
用
ScopeMessenger
装饰的fn
scope
在 Pyro 样本站点名称前加上一个前缀,后跟一个/
。它的工作方式类似于 TensorFlow 的name_scope
和variable_scope
,可以用作上下文管理器、装饰器或高阶函数。scope
对于将组合模型与 guide 或数据对齐非常有用。示例
>>> @scope(prefix="a") ... def model(): ... return pyro.sample("x", dist.Bernoulli(0.5)) ... >>> assert "a/x" in poutine.trace(model).get_trace()
示例
>>> def model(): ... with scope(prefix="a"): ... return pyro.sample("x", dist.Bernoulli(0.5)) ... >>> assert "a/x" in poutine.trace(model).get_trace()
作用域按预期组合,外部作用域名称出现在内部作用域之前。
>>> @scope(prefix="b") ... def model(): ... with scope(prefix="a"): ... return pyro.sample("x", dist.Bernoulli(0.5)) ... >>> assert "b/a/x" in poutine.trace(model).get_trace()
当用作装饰器或高阶函数时,如果未提供用户指定的前缀,
scope
将使用输入函数的名称作为前缀。示例
>>> @scope ... def model(): ... return pyro.sample("x", dist.Bernoulli(0.5)) ... >>> assert "model/x" in poutine.trace(model).get_trace()
- name_count(fn=None)[source]¶
name_count
是一种非常简单的自动命名方案,它只对在执行中出现多次的名称附加一个后缀 “__” 加上一个计数器。只有重复出现的名称实例会获得后缀;第一次出现的实例不会被修改。示例
>>> @name_count ... def model(): ... for i in range(3): ... pyro.sample("x", dist.Bernoulli(0.5)) ... >>> assert "x" in poutine.trace(model).get_trace() >>> assert "x__1" in poutine.trace(model).get_trace() >>> assert "x__2" in poutine.trace(model).get_trace()
name_count
也与scope()
组合使用,通过为重复的作用域入口添加后缀。示例
>>> @name_count ... def model(): ... for i in range(3): ... with pyro.contrib.autoname.scope(prefix="a"): ... pyro.sample("x", dist.Bernoulli(0.5)) ... >>> assert "a/x" in poutine.trace(model).get_trace() >>> assert "a__1/x" in poutine.trace(model).get_trace() >>> assert "a__2/x" in poutine.trace(model).get_trace()
示例
>>> @name_count ... def model(): ... with pyro.contrib.autoname.scope(prefix="a"): ... for i in range(3): ... pyro.sample("x", dist.Bernoulli(0.5)) ... >>> assert "a/x" in poutine.trace(model).get_trace() >>> assert "a/x__1" in poutine.trace(model).get_trace() >>> assert "a/x__2" in poutine.trace(model).get_trace()
- autoname(fn=None, name=None)[source]¶
AutonameMessenger
的便捷包装器为随机变量分配唯一的名称。
对于新变量,如果给定则使用其声明的名称,否则使用分布名称。
sample("x", dist.Bernoulli ... ) # -> x sample(dist.Bernoulli ... ) # -> Bernoulli
对于重复变量名,将计数器作为后缀附加。
sample(dist.Bernoulli ... ) # -> Bernoulli sample(dist.Bernoulli ... ) # -> Bernoulli1 sample(dist.Bernoulli ... ) # -> Bernoulli2
函数和迭代器可以用作名称作用域。
@autoname def f1(): sample(dist.Bernoulli ... ) @autoname def f2(): f1() # -> f2/f1/Bernoulli f1() # -> f2/f1__1/Bernoulli sample(dist.Bernoulli ... ) # -> f2/Bernoulli @autoname(name="model") def f3(): for i in autoname(range(3), name="time"): # model/time/Bernoulli .. model/time__1/Bernoulli .. model/time__2/Bernoulli sample(dist.Bernoulli ... ) # model/time/f1/Bernoulli .. model/time__1/f1/Bernoulli .. model/time__2/f1/Bernoulli f1()
或者可以使用 with 语句添加作用域。
def f4(): with autoname(name="prefix"): f1() # -> prefix/f1/Bernoulli f1() # -> prefix/f1__1/Bernoulli sample(dist.Bernoulli ... ) # -> prefix/Bernoulli
- sample(*args)[source]¶
- sample(name: str, fn, *args, **kwargs)
- sample(fn: pyro.distributions.distribution.Distribution, *args, **kwargs)
命名数据结构¶
The pyro.contrib.named
模块是 Pyro 之上的一层薄薄的语法层。它允许 Pyro 模型看起来像操作 Python 数据结构(例如 latent.x.sample_(...)
)的程序,而不是带有字符串标签语句(例如 x = pyro.sample("x", ...)
)的程序。
该模块提供了三种容器数据结构:named.Object
、named.List
和 named.Dict
。这些数据结构旨在互相嵌套。它们一起跟踪每个数据结构中每块数据的地址,以便此地址可以用作 Pyro 站点。例如
>>> state = named.Object("state")
>>> print(str(state))
state
>>> z = state.x.y.z # z is just a placeholder.
>>> print(str(z))
state.x.y.z
>>> state.xs = named.List() # Create a contained list.
>>> x0 = state.xs.add()
>>> print(str(x0))
state.xs[0]
>>> state.ys = named.Dict()
>>> foo = state.ys['foo']
>>> print(str(foo))
state.ys['foo']
These addresses can now be used inside sample
、observe
和 param
语句中使用。这些命名数据结构甚至提供别名 Pyro 语句的原地方法。例如
>>> state = named.Object("state")
>>> loc = state.loc.param_(torch.zeros(1, requires_grad=True))
>>> scale = state.scale.param_(torch.ones(1, requires_grad=True))
>>> z = state.z.sample_(dist.Normal(loc, scale))
>>> obs = state.x.sample_(dist.Normal(loc, scale), obs=z)
有关如何在模型代码中使用这些方法的更深入示例,请参阅树数据和混合模型示例。
作者:Fritz Obermeyer, Alexander Rush
- class Object(name)[source]¶
基类:
object
用于保存不可变潜在状态的对象。
此对象既可以作为嵌套潜在状态的容器,也可以作为通过 named.sample、named.observe 或 named.param 语句替换为张量的占位符。当用作占位符时,Object 对象取代了普通 pyro.sample 语句中的字符串位置。
- 参数
name (str) – 对象的名称。
示例
state = named.Object("state") state.x = 0 state.ys = named.List() state.zs = named.Dict() state.a.b.c.d.e.f.g = 0 # Creates a chain of named.Objects.
警告
此数据结构是单次写入的:可以添加数据,但不能修改或移除。尝试修改此数据结构可能会导致静默错误。
- sample_(fn: pyro.distributions.torch_distribution.TorchDistributionMixin, *args, obs: Optional[torch.Tensor] = None, obs_mask: Optional[torch.BoolTensor] = None, infer: Optional[pyro.poutine.runtime.InferDict] = None, **kwargs) torch.Tensor ¶
调用随机函数
fn
,并根据name
和包围上下文(例如推断算法)产生额外的副作用。有关讨论,请参阅Pyro 简介。
- param_(init_tensor: Optional[Union[torch.Tensor, Callable[[], torch.Tensor]]] = None, constraint: torch.distributions.constraints.Constraint = Real(), event_dim: Optional[int] = None) torch.Tensor ¶
将变量保存为 param store 中的参数。要与 param store 交互或写入磁盘,请参阅Parameters。
- 参数
name (str) – 参数的名称
init_tensor (torch.Tensor 或 callable) – 初始张量或返回张量的惰性可调用对象。对于大型张量,写成例如
lambda: torch.randn(100000)
可能更划算,它只在初始语句时求值。constraint (torch.distributions.constraints.Constraint) – torch 约束,默认为
constraints.real
。event_dim (int) – (可选) 与批处理无关的最右侧维度数量。此左侧的维度将被视为批处理维度;如果 param 语句位于子采样 plate 内,则参数的相应批处理维度也将相应地被子采样。如果未指定,所有维度都将被视为事件维度,并且不会执行子采样。
- 返回值
一个受约束的参数。可以通过
pyro.param(...).unconstrained()
访问底层无约束参数,其中.unconstrained
是一个弱引用属性。- 返回类型
- class List(name=None)[source]¶
基类:
list
用于保存不可变潜在状态的类列表对象。
创建时必须给定名称
latent = named.List("root")
或者必须立即存储在
named.Object
中latent = named.Object("root") latent.xs = named.List() # Must be bound to a Object before use.
警告
此数据结构是单次写入的:可以添加数据,但不能修改或移除。尝试修改此数据结构可能会导致静默错误。
作用域¶
pyro.contrib.autoname.scoping
包含了 pyro.contrib.autoname.scope()
的实现,这是一个用于自动向样本站点名称附加具有语义意义前缀的工具。
- class NameCountMessenger[source]¶
基类:
pyro.poutine.messenger.Messenger
NameCountMessenger
是pyro.contrib.autoname.name_count()
的实现。
- class ScopeMessenger(prefix=None, inner=None)[source]¶
基类:
pyro.poutine.messenger.Messenger
ScopeMessenger
是pyro.contrib.autoname.scope()
的实现。
- scope(fn=None, prefix=None, inner=None)[source]¶
- 参数
fn – 一个随机函数(包含 Pyro 原语调用的可调用对象)
prefix – 要添加到样本名称前的字符串(如果提供了
fn
则可选)inner – 用于确定重复名称计数器出现位置的开关
- 返回值
用
ScopeMessenger
装饰的fn
scope
在 Pyro 样本站点名称前加上一个前缀,后跟一个/
。它的工作方式类似于 TensorFlow 的name_scope
和variable_scope
,可以用作上下文管理器、装饰器或高阶函数。scope
对于将组合模型与 guide 或数据对齐非常有用。示例
>>> @scope(prefix="a") ... def model(): ... return pyro.sample("x", dist.Bernoulli(0.5)) ... >>> assert "a/x" in poutine.trace(model).get_trace()
示例
>>> def model(): ... with scope(prefix="a"): ... return pyro.sample("x", dist.Bernoulli(0.5)) ... >>> assert "a/x" in poutine.trace(model).get_trace()
作用域按预期组合,外部作用域名称出现在内部作用域之前。
>>> @scope(prefix="b") ... def model(): ... with scope(prefix="a"): ... return pyro.sample("x", dist.Bernoulli(0.5)) ... >>> assert "b/a/x" in poutine.trace(model).get_trace()
当用作装饰器或高阶函数时,如果未提供用户指定的前缀,
scope
将使用输入函数的名称作为前缀。示例
>>> @scope ... def model(): ... return pyro.sample("x", dist.Bernoulli(0.5)) ... >>> assert "model/x" in poutine.trace(model).get_trace()
- name_count(fn=None)[source]¶
name_count
是一种非常简单的自动命名方案,它只对在执行中出现多次的名称附加一个后缀 “__” 加上一个计数器。只有重复出现的名称实例会获得后缀;第一次出现的实例不会被修改。示例
>>> @name_count ... def model(): ... for i in range(3): ... pyro.sample("x", dist.Bernoulli(0.5)) ... >>> assert "x" in poutine.trace(model).get_trace() >>> assert "x__1" in poutine.trace(model).get_trace() >>> assert "x__2" in poutine.trace(model).get_trace()
name_count
也与scope()
组合使用,通过为重复的作用域入口添加后缀。示例
>>> @name_count ... def model(): ... for i in range(3): ... with pyro.contrib.autoname.scope(prefix="a"): ... pyro.sample("x", dist.Bernoulli(0.5)) ... >>> assert "a/x" in poutine.trace(model).get_trace() >>> assert "a__1/x" in poutine.trace(model).get_trace() >>> assert "a__2/x" in poutine.trace(model).get_trace()
示例
>>> @name_count ... def model(): ... with pyro.contrib.autoname.scope(prefix="a"): ... for i in range(3): ... pyro.sample("x", dist.Bernoulli(0.5)) ... >>> assert "a/x" in poutine.trace(model).get_trace() >>> assert "a/x__1" in poutine.trace(model).get_trace() >>> assert "a/x__2" in poutine.trace(model).get_trace()