自动名称生成

The pyro.contrib.autoname 模块提供了用于自动生成独特、语义有意义的样本站点名称的工具。

scope(fn=None, prefix=None, inner=None)[source]
参数
  • fn – 一个随机函数(包含 Pyro 原语调用的可调用对象)

  • prefix – 要添加到样本名称前的字符串(如果提供了 fn 则可选)

  • inner – 用于确定重复名称计数器出现位置的开关

返回值

ScopeMessenger 装饰的 fn

scope 在 Pyro 样本站点名称前加上一个前缀,后跟一个 /。它的工作方式类似于 TensorFlow 的 name_scopevariable_scope,可以用作上下文管理器、装饰器或高阶函数。

scope 对于将组合模型与 guide 或数据对齐非常有用。

示例

>>> @scope(prefix="a")
... def model():
...     return pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "a/x" in poutine.trace(model).get_trace()

示例

>>> def model():
...     with scope(prefix="a"):
...         return pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "a/x" in poutine.trace(model).get_trace()

作用域按预期组合,外部作用域名称出现在内部作用域之前。

>>> @scope(prefix="b")
... def model():
...     with scope(prefix="a"):
...         return pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "b/a/x" in poutine.trace(model).get_trace()

当用作装饰器或高阶函数时,如果未提供用户指定的前缀,scope 将使用输入函数的名称作为前缀。

示例

>>> @scope
... def model():
...     return pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "model/x" in poutine.trace(model).get_trace()
name_count(fn=None)[source]

name_count 是一种非常简单的自动命名方案,它只对在执行中出现多次的名称附加一个后缀 “__” 加上一个计数器。只有重复出现的名称实例会获得后缀;第一次出现的实例不会被修改。

示例

>>> @name_count
... def model():
...     for i in range(3):
...         pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "x" in poutine.trace(model).get_trace()
>>> assert "x__1" in poutine.trace(model).get_trace()
>>> assert "x__2" in poutine.trace(model).get_trace()

name_count 也与 scope() 组合使用,通过为重复的作用域入口添加后缀。

示例

>>> @name_count
... def model():
...     for i in range(3):
...         with pyro.contrib.autoname.scope(prefix="a"):
...             pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "a/x" in poutine.trace(model).get_trace()
>>> assert "a__1/x" in poutine.trace(model).get_trace()
>>> assert "a__2/x" in poutine.trace(model).get_trace()

示例

>>> @name_count
... def model():
...     with pyro.contrib.autoname.scope(prefix="a"):
...         for i in range(3):
...             pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "a/x" in poutine.trace(model).get_trace()
>>> assert "a/x__1" in poutine.trace(model).get_trace()
>>> assert "a/x__2" in poutine.trace(model).get_trace()
autoname(fn=None, name=None)[source]

AutonameMessenger 的便捷包装器

为随机变量分配唯一的名称。

  1. 对于新变量,如果给定则使用其声明的名称,否则使用分布名称。

    sample("x", dist.Bernoulli ... )  # -> x
    sample(dist.Bernoulli ... )  # -> Bernoulli
    
  2. 对于重复变量名,将计数器作为后缀附加。

    sample(dist.Bernoulli ... )  # -> Bernoulli
    sample(dist.Bernoulli ... )  # -> Bernoulli1
    sample(dist.Bernoulli ... )  # -> Bernoulli2
    
  3. 函数和迭代器可以用作名称作用域。

    @autoname
    def f1():
        sample(dist.Bernoulli ... )
    
    @autoname
    def f2():
        f1()  # -> f2/f1/Bernoulli
        f1()  # -> f2/f1__1/Bernoulli
        sample(dist.Bernoulli ... )  # -> f2/Bernoulli
    
    @autoname(name="model")
    def f3():
        for i in autoname(range(3), name="time"):
            # model/time/Bernoulli .. model/time__1/Bernoulli .. model/time__2/Bernoulli
            sample(dist.Bernoulli ... )
            # model/time/f1/Bernoulli .. model/time__1/f1/Bernoulli .. model/time__2/f1/Bernoulli
            f1()
    
  4. 或者可以使用 with 语句添加作用域。

    def f4():
        with autoname(name="prefix"):
            f1()  # -> prefix/f1/Bernoulli
            f1()  # -> prefix/f1__1/Bernoulli
            sample(dist.Bernoulli ... )  # -> prefix/Bernoulli
    
sample(*args)[source]
sample(name: str, fn, *args, **kwargs)
sample(fn: pyro.distributions.distribution.Distribution, *args, **kwargs)

命名数据结构

The pyro.contrib.named 模块是 Pyro 之上的一层薄薄的语法层。它允许 Pyro 模型看起来像操作 Python 数据结构(例如 latent.x.sample_(...))的程序,而不是带有字符串标签语句(例如 x = pyro.sample("x", ...))的程序。

该模块提供了三种容器数据结构:named.Objectnamed.Listnamed.Dict。这些数据结构旨在互相嵌套。它们一起跟踪每个数据结构中每块数据的地址,以便此地址可以用作 Pyro 站点。例如

>>> state = named.Object("state")
>>> print(str(state))
state

>>> z = state.x.y.z  # z is just a placeholder.
>>> print(str(z))
state.x.y.z

>>> state.xs = named.List()  # Create a contained list.
>>> x0 = state.xs.add()
>>> print(str(x0))
state.xs[0]

>>> state.ys = named.Dict()
>>> foo = state.ys['foo']
>>> print(str(foo))
state.ys['foo']

These addresses can now be used inside sampleobserveparam 语句中使用。这些命名数据结构甚至提供别名 Pyro 语句的原地方法。例如

>>> state = named.Object("state")
>>> loc = state.loc.param_(torch.zeros(1, requires_grad=True))
>>> scale = state.scale.param_(torch.ones(1, requires_grad=True))
>>> z = state.z.sample_(dist.Normal(loc, scale))
>>> obs = state.x.sample_(dist.Normal(loc, scale), obs=z)

有关如何在模型代码中使用这些方法的更深入示例,请参阅树数据混合模型示例。

作者:Fritz Obermeyer, Alexander Rush

class Object(name)[source]

基类:object

用于保存不可变潜在状态的对象。

此对象既可以作为嵌套潜在状态的容器,也可以作为通过 named.sample、named.observe 或 named.param 语句替换为张量的占位符。当用作占位符时,Object 对象取代了普通 pyro.sample 语句中的字符串位置。

参数

name (str) – 对象的名称。

示例

state = named.Object("state")
state.x = 0
state.ys = named.List()
state.zs = named.Dict()
state.a.b.c.d.e.f.g = 0  # Creates a chain of named.Objects.

警告

此数据结构是单次写入的:可以添加数据,但不能修改或移除。尝试修改此数据结构可能会导致静默错误。

sample_(fn: pyro.distributions.torch_distribution.TorchDistributionMixin, *args, obs: Optional[torch.Tensor] = None, obs_mask: Optional[torch.BoolTensor] = None, infer: Optional[pyro.poutine.runtime.InferDict] = None, **kwargs) torch.Tensor

调用随机函数 fn,并根据 name 和包围上下文(例如推断算法)产生额外的副作用。有关讨论,请参阅Pyro 简介

参数
  • name – 样本的名称

  • fn – 分布类或函数

  • obs – 观测数据(可选;应仅在推断上下文中使用),可在 kwargs 中可选指定

  • obs_mask (boolTensor) – 可选的布尔张量掩码,形状可与 fn.batch_shape 广播。如果提供,mask=True 的事件将以 obs 为条件,其余事件将通过采样进行估算。这将引入一个名为 name + "_unobserved" 的潜在样本站点,该站点应由 guide 使用。

  • infer (dict) – kwargs 中指定的推断参数的可选字典。有关详细信息,请参阅推断文档。

返回值

样本

param_(init_tensor: Optional[Union[torch.Tensor, Callable[[], torch.Tensor]]] = None, constraint: torch.distributions.constraints.Constraint = Real(), event_dim: Optional[int] = None) torch.Tensor

将变量保存为 param store 中的参数。要与 param store 交互或写入磁盘,请参阅Parameters

参数
  • name (str) – 参数的名称

  • init_tensor (torch.Tensorcallable) – 初始张量或返回张量的惰性可调用对象。对于大型张量,写成例如 lambda: torch.randn(100000) 可能更划算,它只在初始语句时求值。

  • constraint (torch.distributions.constraints.Constraint) – torch 约束,默认为 constraints.real

  • event_dim (int) – (可选) 与批处理无关的最右侧维度数量。此左侧的维度将被视为批处理维度;如果 param 语句位于子采样 plate 内,则参数的相应批处理维度也将相应地被子采样。如果未指定,所有维度都将被视为事件维度,并且不会执行子采样。

返回值

一个受约束的参数。可以通过 pyro.param(...).unconstrained() 访问底层无约束参数,其中 .unconstrained 是一个弱引用属性。

返回类型

torch.Tensor

class List(name=None)[source]

基类:list

用于保存不可变潜在状态的类列表对象。

创建时必须给定名称

latent = named.List("root")

或者必须立即存储在 named.Object

latent = named.Object("root")
latent.xs = named.List()  # Must be bound to a Object before use.

警告

此数据结构是单次写入的:可以添加数据,但不能修改或移除。尝试修改此数据结构可能会导致静默错误。

add()[source]

附加一个新的 named.Object。

返回值

在末尾添加一个新的潜在对象

返回类型

named.Object

class Dict(name=None)[source]

基类:dict

用于保存不可变潜在状态的类字典对象。

创建时必须给定名称

latent = named.Dict("root")

或者必须立即存储在 named.Object

latent = named.Object("root")
latent.xs = named.Dict()  # Must be bound to a Object before use.

警告

此数据结构是单次写入的:可以添加数据,但不能修改或移除。尝试修改此数据结构可能会导致静默错误。

作用域

pyro.contrib.autoname.scoping 包含了 pyro.contrib.autoname.scope() 的实现,这是一个用于自动向样本站点名称附加具有语义意义前缀的工具。

class NameCountMessenger[source]

基类:pyro.poutine.messenger.Messenger

NameCountMessengerpyro.contrib.autoname.name_count() 的实现。

class ScopeMessenger(prefix=None, inner=None)[source]

基类:pyro.poutine.messenger.Messenger

ScopeMessengerpyro.contrib.autoname.scope() 的实现。

scope(fn=None, prefix=None, inner=None)[source]
参数
  • fn – 一个随机函数(包含 Pyro 原语调用的可调用对象)

  • prefix – 要添加到样本名称前的字符串(如果提供了 fn 则可选)

  • inner – 用于确定重复名称计数器出现位置的开关

返回值

ScopeMessenger 装饰的 fn

scope 在 Pyro 样本站点名称前加上一个前缀,后跟一个 /。它的工作方式类似于 TensorFlow 的 name_scopevariable_scope,可以用作上下文管理器、装饰器或高阶函数。

scope 对于将组合模型与 guide 或数据对齐非常有用。

示例

>>> @scope(prefix="a")
... def model():
...     return pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "a/x" in poutine.trace(model).get_trace()

示例

>>> def model():
...     with scope(prefix="a"):
...         return pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "a/x" in poutine.trace(model).get_trace()

作用域按预期组合,外部作用域名称出现在内部作用域之前。

>>> @scope(prefix="b")
... def model():
...     with scope(prefix="a"):
...         return pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "b/a/x" in poutine.trace(model).get_trace()

当用作装饰器或高阶函数时,如果未提供用户指定的前缀,scope 将使用输入函数的名称作为前缀。

示例

>>> @scope
... def model():
...     return pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "model/x" in poutine.trace(model).get_trace()
name_count(fn=None)[source]

name_count 是一种非常简单的自动命名方案,它只对在执行中出现多次的名称附加一个后缀 “__” 加上一个计数器。只有重复出现的名称实例会获得后缀;第一次出现的实例不会被修改。

示例

>>> @name_count
... def model():
...     for i in range(3):
...         pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "x" in poutine.trace(model).get_trace()
>>> assert "x__1" in poutine.trace(model).get_trace()
>>> assert "x__2" in poutine.trace(model).get_trace()

name_count 也与 scope() 组合使用,通过为重复的作用域入口添加后缀。

示例

>>> @name_count
... def model():
...     for i in range(3):
...         with pyro.contrib.autoname.scope(prefix="a"):
...             pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "a/x" in poutine.trace(model).get_trace()
>>> assert "a__1/x" in poutine.trace(model).get_trace()
>>> assert "a__2/x" in poutine.trace(model).get_trace()

示例

>>> @name_count
... def model():
...     with pyro.contrib.autoname.scope(prefix="a"):
...         for i in range(3):
...             pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "a/x" in poutine.trace(model).get_trace()
>>> assert "a/x__1" in poutine.trace(model).get_trace()
>>> assert "a/x__2" in poutine.trace(model).get_trace()