基于 Funsor 的 Pyro

原语

clear_param_store() None[source]

清除全局 ParamStoreDict

这在使用 REPL 时特别有用。建议在每个训练循环之前调用此函数(以避免泄漏先前模型的参数),并在每个单元测试之前调用(以避免在测试之间泄漏参数)。

condition(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], data: Union[Dict[str, torch.Tensor], Trace]) Union[pyro.poutine.condition_messenger.ConditionMessenger, Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]][source]

ConditionMessenger 的便捷包装器

给定一个包含一些采样语句的随机函数以及按名称组织的观测值字典,将这些名称对应的采样语句更改为使用这些观测值的观察语句。

考虑以下 Pyro 程序

>>> def model(x):
...     s = pyro.param("s", torch.tensor(0.5))
...     z = pyro.sample("z", dist.Normal(x, s))
...     return z ** 2

要观察站点 z 的值,我们可以写成

>>> conditioned_model = pyro.poutine.condition(model, data={"z": torch.tensor(1.)})

这等同于在 model 中将 obs=value 作为关键字参数添加到 pyro.sample(“z”, …)

参数
  • fn – 一个随机函数(包含 Pyro 原语调用的可调用对象)

  • data – 一个 dict 或一个 Trace

返回值

ConditionMessenger 装饰的随机函数

deterministic(name: str, value: torch.Tensor, event_dim: Optional[int] = None) torch.Tensor[source]

确定性语句,用于向跟踪添加一个名为 name、值为 valueDelta 站点。当需要记录完全由其父节点确定的值时,这非常有用。例如

x = pyro.sample("x", dist.Normal(0, 1))
x2 = pyro.deterministic("x2", x ** 2)

注意

该站点不影响模型密度。这目前转换为 sample() 语句,但将来可能会更改。

参数
  • name (str) – 站点的名称。

  • value (torch.Tensor) – 站点的值。

  • event_dim (int) – 可选的事件维度,默认为 value.ndim

do(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], data: Dict[str, Union[torch.Tensor, numbers.Number]]) Union[pyro.poutine.do_messenger.DoMessenger, Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]][source]

DoMessenger 的便捷包装器

给定一个包含一些采样语句的随机函数以及按名称组织的值字典,将这些站点的返回值设置为等于这些值,就像它们被硬编码为这些值一样,并引入具有相同名称的新采样站点,其值不会传播。

可与 condition() 自由组合,表示潜在结果的反事实分布。更多细节和理论请参见单世界干预图 [1]。

考虑以下 Pyro 程序

>>> def model(x):
...     s = pyro.param("s", torch.tensor(0.5))
...     z = pyro.sample("z", dist.Normal(x, s))
...     return z ** 2

要对站点 z 进行干预,我们可以写成

>>> intervened_model = pyro.poutine.do(model, data={"z": torch.tensor(1.)})

这等同于将 z = pyro.sample(“z”, …) 替换为 z = torch.tensor(1.),并引入一个新的采样站点 pyro.sample(“z”, …),其值不在其他地方使用。

参考文献

[1] 单世界干预图:入门,

Thomas Richardson, James Robins

参数
  • fn – 一个随机函数(包含 Pyro 原语调用的可调用对象)

  • data – 一个 dict,将采样站点名称映射到干预值

返回值

DoMessenger 装饰的随机函数

enable_validation(is_validate: bool = True) None[source]

启用或禁用 Pyro 中的验证检查。验证检查提供了有用的警告和错误,例如 NaN 检查、验证分布参数和支持值、检测 ELBO 和 MCMC 的不正确使用。由于其中一些检查可能很耗时,对于成熟的模型,您可能希望禁用验证以加快推断速度。

默认行为模拟 Python 的 assert 语句:默认情况下验证是开启的,但在 Python 以优化模式运行(通过 python -O)时会禁用。同样,默认行为通过 pyro.enable_validation(__debug__) 取决于 Python 的全局 __debug__ 值。

在 jit 编译期间,对于所有支持 PyTorch jit 的推断算法,验证会暂时禁用。建议使用非 jit 推断算法开发模型以方便调试,然后一旦模型正确,可选择切换到 jit 推断。

参数

is_validate (bool) – (可选;默认为 True)是否启用验证检查。

factor(name: str, log_factor: torch.Tensor, *, has_rsample: Optional[bool] = None) None[source]

因子语句,用于向概率模型添加任意对数概率因子。

警告

在指南中使用因子语句时,需要指定该因子语句是源自完全重参数化采样(例如重参数化变量变换的雅可比行列式)还是源自非重参数化采样(例如离散采样)。对于完全重参数化的情况,设置 has_rsample=True;对于非重参数化的情况,设置 has_rsample=False。这仅在指南中需要,模型中不需要。

参数
  • name (str) – 平凡采样的名称

  • log_factor (torch.Tensor) – 一个可能包含批次的对数概率因子。

  • has_rsample (bool) – log_factor 是否源自完全重参数化分布。在模型中使用时默认为 False,但在指南中使用时必须指定。

get_param_store() pyro.params.param_store.ParamStoreDict[source]

返回全局 ParamStoreDict

markov(fn=None, history=1, keep=False)[source]

MarkovMessenger 的便捷包装器

用于在 Funsor 和与 Pyro 位置批量维度一致的张量之间进行转换的处理程序。

参数
  • history (int) – 从当前上下文可见的先前上下文数量。默认为 1。如果为零,则类似于 pyro.plate

  • keep (bool) – 如果为 True,帧可以重放。这在分支时很重要:如果 keep=True,同一级别的相邻分支可以相互依赖;如果 keep=False,相邻分支是独立的(以它们的共享祖先为条件)。

module(name: str, nn_module: torch.nn.modules.module.Module, update_module_params: bool = False) torch.nn.modules.module.Module[source]

在 Pyro 的 param_store 中注册 torch.nn.Module 的所有参数。结合 ParamStoreDictsave()load() 功能,用户可以保存和加载模块。

注意

建议考虑使用 PyroModule,它是 pyro.module() 的一个更新的替代方案,对 JIT 编译、在 C++ 中服务以及将参数转换为随机变量有更好的支持。详细信息请参见模块教程

参数
  • name (str) – 模块名称

  • nn_module (torch.nn.Module) – 要在 Pyro 中注册的模块

  • update_module_params – 确定 PyTorch 模块中的 Parameters 是否会被 ParamStore 中找到的值覆盖(如果存在)。默认为 False

返回值

torch.nn.Module

param(name: str, init_tensor: Optional[Union[torch.Tensor, Callable[[], torch.Tensor]]] = None, constraint: torch.distributions.constraints.Constraint = Real(), event_dim: Optional[int] = None) torch.Tensor[source]

将变量作为参数保存在参数存储中。要与参数存储交互或写入磁盘,请参阅参数

参数
  • name (str) – 参数名称

  • init_tensor (torch.Tensor可调用对象) – 初始张量或返回张量的惰性可调用对象。对于大型张量,写成例如 lambda: torch.randn(100000) 可能更划算,这只会在初始语句时进行评估。

  • constraint (torch.distributions.constraints.Constraint) – torch 约束,默认为 constraints.real

  • event_dim (int) – (可选)与批处理无关的最右侧维度数量。该维度左侧的维度将被视为批处理维度;如果参数语句在子采样 plate 内部,则参数对应的批处理维度将相应地进行子采样。如果未指定,所有维度都将被视为事件维度,并且不会执行子采样。

返回值

一个受约束的参数。底层的无约束参数可以通过 pyro.param(...).unconstrained() 访问,其中 .unconstrained 是一个弱引用属性。

返回值类型

torch.Tensor

random_module(name, nn_module, prior, *args, **kwargs)[source]

警告

random_module 原语已弃用,并将在未来版本中移除。请改用 PyroModuletorch.nn.Module 实例创建贝叶斯模块。示例请参见贝叶斯回归教程

已弃用:在模块 nn_module 的参数上设置先验。返回一个 nn.Module 上的分布(可调用对象),调用该对象时会返回一个采样得到的 nn.Module

参数
  • name (str) – pyro 模块名称

  • nn_module (torch.nn.Module) – 要在 pyro 中注册的模块

  • prior – pyro 分布、随机函数,或一个 python 字典,其中参数名称作为键,对应的分布/随机函数作为值。

返回值

一个返回采样模块的可调用对象

sample(name: str, fn: pyro.distributions.torch_distribution.TorchDistributionMixin, *args, obs: Optional[torch.Tensor] = None, obs_mask: Optional[torch.BoolTensor] = None, infer: Optional[pyro.poutine.runtime.InferDict] = None, **kwargs) torch.Tensor[source]

调用随机函数 fn,并根据 name 和封闭上下文(例如推断算法)产生额外的副作用。讨论请参见Pyro 入门

参数
  • name – 采样的名称

  • fn – 分布类或函数

  • obs – 观测数据(可选;仅应在推断上下文中使用),可在 kwargs 中可选指定

  • obs_mask (boolTensor) – 可选的布尔张量掩码,形状可与 fn.batch_shape 广播。如果提供,mask=True 的事件将以 obs 为条件,其余事件将通过采样进行填充。这将引入一个名为 name + "_unobserved" 的潜在采样站点,指南应使用该站点。

  • infer (dict) – kwargs 中指定的推断参数的可选字典。详细信息请参阅推断文档。

返回值

采样

set_rng_seed(rng_seed)[source]

设置 torchtorch.cuda(如果可用)的种子。

参数

rng_seed (int) – 种子值。

subsample(data: torch.Tensor, event_dim: int) torch.Tensor[source]

子采样语句,根据封闭的 plate 对数据张量进行子采样。

当通过传递 subsamplesubsample_size 关键字参数由 plate 自动执行子采样时,通常在 model() 的参数上调用此函数。例如,以下是等价的

# Version 1. using indexing
def model(data):
    with pyro.plate("data", len(data), subsample_size=10, dim=-data.dim()) as ind:
        data = data[ind]
        # ...

# Version 2. using pyro.subsample()
def model(data):
    with pyro.plate("data", len(data), subsample_size=10, dim=-data.dim()):
        data = pyro.subsample(data, event_dim=0)
        # ...
参数
  • data (Tensor) – 批处理数据的张量。

  • event_dim (int) – 数据张量的事件维度。左侧的维度被视为批处理维度。

返回值

data 的子采样版本

返回值类型

张量

to_data(x, name_to_dim=None, dim_type=DimType.LOCAL)[source]
to_funsor(x, output=None, dim_to_name=None, dim_type=DimType.LOCAL)[source]
vectorized_markov(fn=None, name=None, size=None, dim=None, history=1)[source]

VectorizedMarkovMessenger 的便捷包装器

用于变量马尔可夫链的构造,旨在利用并行扫描算法高效消除马尔可夫维度。在允许的情况下,vectorized_markov 可以与 markov 互换使用。

for 循环生成 int 和 1 维的 torch.Tensor 索引:(0, ..., history-1, torch.arange(0, size-history), ..., torch.arange(history, size))int 索引用于启动马尔可夫链,而 torch.Tensor 索引用于构建向量化转移概率,以便通过并行扫描算法进行高效消除。

history==0 时,vectorized_markov 的行为类似于 plate

for 循环运行后,会识别出马尔可夫变量,然后构建 step 信息并添加到跟踪中。step 会告知推断算法哪些变量属于马尔可夫链。

data = torch.ones(3, dtype=torch.float)

def model(data, vectorized=True):

    init = pyro.param("init", lambda: torch.rand(3), constraint=constraints.simplex)
    trans = pyro.param("trans", lambda: torch.rand((3, 3)), constraint=constraints.simplex)
    locs = pyro.param("locs", lambda: torch.rand(3,))

    markov_chain = \
        pyro.vectorized_markov(name="time", size=len(data), dim=-1) if vectorized \
        else pyro.markov(range(len(data)))
    for i in markov_chain:
        x_curr = pyro.sample("x_{}".format(i), dist.Categorical(
            init if isinstance(i, int) and i < 1 else trans[x_prev]),

        pyro.sample("y_{}".format(i),
                    dist.Normal(Vindex(locs)[..., x_curr], 1.),
                    obs=data[i])
        x_prev = x_curr

#  trace.nodes["time"]["value"]
#  frozenset({('x_0', 'x_slice(0, 2, None)', 'x_slice(1, 3, None)')})
#
#  pyro.vectorized_markov trace
#  ...
#  Sample Sites:
#      locs dist               | 3
#          value               | 3
#       log_prob               |
#       x_0 dist               |
#          value     3 1 1 1 1 |
#       log_prob     3 1 1 1 1 |
#       y_0 dist     3 1 1 1 1 |
#          value               |
#       log_prob     3 1 1 1 1 |
#  x_slice(1, 3, None) dist   3 1 1 1 1 2 |
#          value 3 1 1 1 1 1 1 |
#       log_prob 3 3 1 1 1 1 2 |
#  y_slice(1, 3, None) dist 3 1 1 1 1 1 2 |
#          value             2 |
#       log_prob 3 1 1 1 1 1 2 |
#
#  pyro.markov trace
#  ...
#  Sample Sites:
#      locs dist             | 3
#          value             | 3
#       log_prob             |
#       x_0 dist             |
#          value   3 1 1 1 1 |
#       log_prob   3 1 1 1 1 |
#       y_0 dist   3 1 1 1 1 |
#          value             |
#       log_prob   3 1 1 1 1 |
#       x_1 dist   3 1 1 1 1 |
#          value 3 1 1 1 1 1 |
#       log_prob 3 3 1 1 1 1 |
#       y_1 dist 3 1 1 1 1 1 |
#          value             |
#       log_prob 3 1 1 1 1 1 |
#       x_2 dist 3 1 1 1 1 1 |
#          value   3 1 1 1 1 |
#       log_prob 3 3 1 1 1 1 |
#       y_2 dist   3 1 1 1 1 |
#          value             |
#       log_prob   3 1 1 1 1 |

警告

这仅在每个分支只有一个马尔可夫维度时有效。

参数
  • name (str) – 马尔可夫维度的唯一名称,用于帮助推断算法消除马尔可夫链中的变量。

  • size (int) – 马尔可夫链的长度(大小)。

  • dim (int) – 用于此马尔可夫维度的可选维度。如果指定,dim 应为负值,即应从右侧索引。如果未指定,dim 将设置为所有封闭 plate 上下文左侧的最右侧维度。

  • history (int) – 马尔可夫链的记忆(阶数)。也是从当前上下文可见的先前上下文数量。默认为 1。如果为零,则类似于 plate

返回值

返回 int 和 1 维的 torch.Tensor 索引:(0, ..., history-1, torch.arange(size-history), ..., torch.arange(history, size))

效果处理器

enum(fn=None, first_available_dim=None)[source]

EnumMessenger 的便捷包装器

此版本的 EnumMessenger 使用 to_data() 为每个离散采样站点分配一个新的枚举维度。

markov(fn=None, history=1, keep=False)[source]

MarkovMessenger 的便捷包装器

用于在 Funsor 和与 Pyro 位置批量维度一致的张量之间进行转换的处理程序。

参数
  • history (int) – 从当前上下文可见的先前上下文数量。默认为 1。如果为零,则类似于 pyro.plate

  • keep (bool) – 如果为 True,帧可以重放。这在分支时很重要:如果 keep=True,同一级别的相邻分支可以相互依赖;如果 keep=False,相邻分支是独立的(以它们的共享祖先为条件)。

named(fn=None, first_available_dim=None)[source]

NamedMessenger 的便捷包装器

用于 to_funsor()to_data() 原语的基础效果处理程序类。任何在内部调用这些原语或封装调用这些原语代码的效果处理程序都应继承自 NamedMessenger

这种设计确保全局名称-维度映射在处理程序退出时重置,而不是潜在地持续到整个程序终止。

plate(fn=None, name=None, size=None, subsample_size=None, subsample=None, dim=None, use_cuda=None, device=None)[source]

PlateMessenger 的便捷包装器

将新的 IndepMessenger 实现与现有的 pyro.poutine.BroadcastMessenger 结合。最终应能完全替代 pyro.plate

replay(fn=None, trace=None, params=None)[source]

ReplayMessenger 的便捷包装器

此版本的 ReplayMessenger 与原始版本几乎相同,区别在于它对重放的 funsor 值调用 to_data()。这可能会导致不同的解包形状,但应产生正确的分配。

trace(fn=None, graph_type=None, param_only=None, pack_online=True)[source]

TraceMessenger 的便捷包装器

设置 pack_online=True 会在线进行打包,而不是事后进行,一旦分布和值可用,就立即将它们转换为 Funsors。

设置 pack_online=False 会在执行后计算打包所需的信息。每个采样站点都用一个 dim_to_name 字典进行标注,该字典可以直接传递给 to_funsor()

vectorized_markov(fn=None, name=None, size=None, dim=None, history=1)[source]

VectorizedMarkovMessenger 的便捷包装器

用于变量马尔可夫链的构造,旨在利用并行扫描算法高效消除马尔可夫维度。在允许的情况下,vectorized_markov 可以与 markov 互换使用。

for 循环生成 int 和 1 维的 torch.Tensor 索引:(0, ..., history-1, torch.arange(0, size-history), ..., torch.arange(history, size))int 索引用于启动马尔可夫链,而 torch.Tensor 索引用于构建向量化转移概率,以便通过并行扫描算法进行高效消除。

history==0 时,vectorized_markov 的行为类似于 plate

for 循环运行后,会识别出马尔可夫变量,然后构建 step 信息并添加到跟踪中。step 会告知推断算法哪些变量属于马尔可夫链。

data = torch.ones(3, dtype=torch.float)

def model(data, vectorized=True):

    init = pyro.param("init", lambda: torch.rand(3), constraint=constraints.simplex)
    trans = pyro.param("trans", lambda: torch.rand((3, 3)), constraint=constraints.simplex)
    locs = pyro.param("locs", lambda: torch.rand(3,))

    markov_chain = \
        pyro.vectorized_markov(name="time", size=len(data), dim=-1) if vectorized \
        else pyro.markov(range(len(data)))
    for i in markov_chain:
        x_curr = pyro.sample("x_{}".format(i), dist.Categorical(
            init if isinstance(i, int) and i < 1 else trans[x_prev]),

        pyro.sample("y_{}".format(i),
                    dist.Normal(Vindex(locs)[..., x_curr], 1.),
                    obs=data[i])
        x_prev = x_curr

#  trace.nodes["time"]["value"]
#  frozenset({('x_0', 'x_slice(0, 2, None)', 'x_slice(1, 3, None)')})
#
#  pyro.vectorized_markov trace
#  ...
#  Sample Sites:
#      locs dist               | 3
#          value               | 3
#       log_prob               |
#       x_0 dist               |
#          value     3 1 1 1 1 |
#       log_prob     3 1 1 1 1 |
#       y_0 dist     3 1 1 1 1 |
#          value               |
#       log_prob     3 1 1 1 1 |
#  x_slice(1, 3, None) dist   3 1 1 1 1 2 |
#          value 3 1 1 1 1 1 1 |
#       log_prob 3 3 1 1 1 1 2 |
#  y_slice(1, 3, None) dist 3 1 1 1 1 1 2 |
#          value             2 |
#       log_prob 3 1 1 1 1 1 2 |
#
#  pyro.markov trace
#  ...
#  Sample Sites:
#      locs dist             | 3
#          value             | 3
#       log_prob             |
#       x_0 dist             |
#          value   3 1 1 1 1 |
#       log_prob   3 1 1 1 1 |
#       y_0 dist   3 1 1 1 1 |
#          value             |
#       log_prob   3 1 1 1 1 |
#       x_1 dist   3 1 1 1 1 |
#          value 3 1 1 1 1 1 |
#       log_prob 3 3 1 1 1 1 |
#       y_1 dist 3 1 1 1 1 1 |
#          value             |
#       log_prob 3 1 1 1 1 1 |
#       x_2 dist 3 1 1 1 1 1 |
#          value   3 1 1 1 1 |
#       log_prob 3 3 1 1 1 1 |
#       y_2 dist   3 1 1 1 1 |
#          value             |
#       log_prob   3 1 1 1 1 |

警告

这仅在每个分支只有一个马尔可夫维度时有效。

参数
  • name (str) – 马尔可夫维度的唯一名称,用于帮助推断算法消除马尔可夫链中的变量。

  • size (int) – 马尔可夫链的长度(大小)。

  • dim (int) – 用于此马尔可夫维度的可选维度。如果指定,dim 应为负值,即应从右侧索引。如果未指定,dim 将设置为所有封闭 plate 上下文左侧的最右侧维度。

  • history (int) – 马尔可夫链的记忆(阶数)。也是从当前上下文可见的先前上下文数量。默认为 1。如果为零,则类似于 plate

返回值

返回 int 和 1 维的 torch.Tensor 索引:(0, ..., history-1, torch.arange(size-history), ..., torch.arange(history, size))

class NamedMessenger(first_available_dim=None)[source]

基类:pyro.poutine.reentrant_messenger.ReentrantMessenger

用于 to_funsor()to_data() 原语的基础效果处理程序类。任何在内部调用这些原语或封装调用这些原语代码的效果处理程序都应继承自 NamedMessenger

这种设计确保全局名称-维度映射在处理程序退出时重置,而不是潜在地持续到整个程序终止。

class MarkovMessenger(history=1, keep=False)[source]

基类:pyro.contrib.funsor.handlers.named_messenger.NamedMessenger

用于在 Funsor 和与 Pyro 位置批量维度一致的张量之间进行转换的处理程序。

参数
  • history (int) – 从当前上下文可见的先前上下文数量。默认为 1。如果为零,则类似于 pyro.plate

  • keep (bool) – 如果为 True,帧可以重放。这在分支时很重要:如果 keep=True,同一级别的相邻分支可以相互依赖;如果 keep=False,相邻分支是独立的(以它们的共享祖先为条件)。

class GlobalNamedMessenger(first_available_dim=None)[source]

基类:pyro.contrib.funsor.handlers.named_messenger.NamedMessenger

使用 to_funsor()to_data() 原语分配 DimType.GLOBALDimType.VISIBLE 维度的任何新效果处理程序的基础类。

充当不应由 MarkovMessenger 回收的维度的手动“作用域”:全局维度将被视为活动,直到它们最初分配的范围内最里面的 GlobalNamedMessenger 退出。

to_funsor(x, output=None, dim_to_name=None, dim_type=DimType.LOCAL)[source]
to_data(x, name_to_dim=None, dim_type=DimType.LOCAL)[source]
class StackFrame(name_to_dim, dim_to_name, history=1, keep=False)[source]

基类:object

整数位置维度和名称之间的一致双向映射。可以像字典一样查询(value = frame[key], frame[key] = value)。

class DimType(value)

基类:enum.Enum

列举可分配的维度类型

LOCAL = 0
GLOBAL = 1
VISIBLE = 2
class DimRequest(value, dim_type)

基类:tuple

dim_type

字段编号 1 的别名

value

字段编号 0 的别名

class DimStack[source]

基类:object

跟踪名称和维度之间映射的全局状态的单个实例。

替换了 plate 的 _DimAllocator、enum 的 _EnumAllocatorMarkovMessenger 中的 stackEnumMessenger 中的 _param_dims_value_dims,以及 msg['infer'] 中的 dim_to_symbol

MAX_DIM = -25
DEFAULT_FIRST_DIM = -5
set_first_available_dim(dim)[source]
push_global(frame)[source]
pop_global()[source]
push_iter(frame)[source]
pop_iter()[source]
push_local(frame)[source]
pop_local()[source]
property global_frame
property local_frame
property current_write_env
property current_read_env

收集计算完整名称 <–> 维度映射所需的所有帧,并在计算的任何点解释 Funsor 输入或批量形状。

allocate(key_to_value_request)[source]
names_from_batch_shape(batch_shape, dim_type=DimType.LOCAL)[source]

推断算法

class ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]

基类:pyro.infer.elbo.ELBO

differentiable_loss(model, guide, *args, **kwargs)[source]
loss(model, guide, *args, **kwargs)[source]

参见 pyro.infer.traceenum_elbo.TraceEnum_ELBO.loss()

loss_and_grads(model, guide, *args, **kwargs)[source]

参见 pyro.infer.traceenum_elbo.TraceEnum_ELBO.loss_and_grads()

class Jit_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]

基类: pyro.contrib.funsor.infer.elbo.ELBO

differentiable_loss(model, guide, *args, **kwargs)[source]
class Trace_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]

基类: pyro.contrib.funsor.infer.elbo.ELBO

differentiable_loss(model, guide, *args, **kwargs)[source]

参见 pyro.infer.trace_elbo.Trace_ELBO.differentiable_loss()

class JitTrace_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]

基类: pyro.contrib.funsor.infer.elbo.Jit_ELBO, pyro.contrib.funsor.infer.trace_elbo.Trace_ELBO

apply_optimizer(x)[source]
terms_from_trace(tr)[source]

用于从执行跟踪中提取elbo组件的辅助函数。

class TraceMarkovEnum_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]

基类: pyro.contrib.funsor.infer.elbo.ELBO

differentiable_loss(model, guide, *args, **kwargs)[source]

参见 pyro.infer.traceenum_elbo.TraceEnum_ELBO.differentiable_loss()

class TraceEnum_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]

基类: pyro.contrib.funsor.infer.elbo.ELBO

differentiable_loss(model, guide, *args, **kwargs)[source]

参见 pyro.infer.traceenum_elbo.TraceEnum_ELBO.differentiable_loss()

class JitTraceEnum_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]

基类: pyro.contrib.funsor.infer.elbo.Jit_ELBO, pyro.contrib.funsor.infer.traceenum_elbo.TraceEnum_ELBO

class JitTraceMarkovEnum_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]

基类: pyro.contrib.funsor.infer.elbo.Jit_ELBO, pyro.contrib.funsor.infer.traceenum_elbo.TraceMarkovEnum_ELBO

class TraceTMC_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]

基类: pyro.contrib.funsor.infer.elbo.ELBO

differentiable_loss(model, guide, *args, **kwargs)[source]

参见 pyro.infer.tracetmc_elbo.TraceTMC_ELBO.differentiable_loss()

class JitTraceTMC_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]

基类: pyro.contrib.funsor.infer.elbo.Jit_ELBO, pyro.contrib.funsor.infer.tracetmc_elbo.TraceTMC_ELBO

infer_discrete(model, first_available_dim=None, temperature=1)[source]