重参数化器

pyro.infer.reparam 模块包含用于 pyro.poutine.handlers.reparam() 效果的重参数化策略。这些策略有助于改变条件较差的参数空间的几何形状,以改善后验的形状。它们可以与各种推断算法一起使用,例如 Auto*Normal 导引和 MCMC。

class ReparamMessage[source]
name: str
fn: Callable
value: Optional[torch.Tensor]
is_observed: Optional[bool]
class ReparamResult[source]
fn: Callable
value: Optional[torch.Tensor]
is_observed: bool
class Reparam[source]

重参数化器的抽象基类。

派生类应实现 apply() 方法。

apply(msg: pyro.infer.reparam.reparam.ReparamMessage) pyro.infer.reparam.reparam.ReparamResult[source]

应用重参数化器的抽象方法。

参数

name (dict) – 一个简化的 Pyro 消息,包含以下字段: - name: str 采样站点的名称 - fn: Callable 一个分布 - value: Optional[torch.Tensor] 观察值或初始值 - is_observed: bool value 是否为观察值

返回值

一个简化的 Pyro 消息,包含字段 fn, valueis_observed

返回类型

dict

__call__(name, fn, obs)[source]

已废弃。子类应实现 apply() 方法。此方法将在未来的版本中移除。

自动策略

这些重参数化策略通过 register_reparam_strategy() 注册,并通过 poutine.reparam(config=name_of_strategy) 按名称访问。请参阅 reparam() 了解用法。

class Strategy[source]

基类: abc.ABC

重参数化器配置策略的抽象基类。

派生类必须实现 configure() 方法。

变量

config (dict) – 字典配置。在模型首次运行时填充。之后可用作 poutine.reparam(config=___) 的参数。

abstract configure(msg: dict) Optional[pyro.infer.reparam.reparam.Reparam][source]

输入一个采样站点,并返回 None 或一个 Reparam 实例。

此方法仅在模型首次执行时调用;后续执行将使用存储在 self.config 中的重参数化器。

参数

msg (dict) – 一个可能需要重参数化的采样站点。

返回值

一个可选的重参数化器实例。

__call__(msg_or_fn: Union[dict, Callable])[source]

策略可用作装饰器来重参数化模型。

参数

msg_or_fn – 公共用途:需要装饰的模型。(内部用途:需要配置重参数化的站点)。

class MinimalReparam[source]

基类: pyro.infer.reparam.strategies.Strategy

最小重参数化策略,仅对那些否则会导致错误的站点进行重参数化,例如 StableProjectedNormal 随机变量。

示例

@MinimalReparam()
def model(...):
    ...

这等效于

@poutine.reparam(config=MinimalReparam())
def model(...):
    ...
configure(msg: dict) Optional[pyro.infer.reparam.reparam.Reparam][source]
class AutoReparam(*, centered: Optional[float] = None)[source]

基类: pyro.infer.reparam.strategies.Strategy

应用推荐的重参数化器集合。目前包括:MinimalReparamTransformReparam、完全可学习的 LocScaleReparamGumbelSoftmaxReparam

示例

@AutoReparam()
def model(...):
    ...

这等效于

@poutine.reparam(config=AutoReparam())
def model(...):
    ...

警告

此策略在不同的 Pyro 版本中可能会改变行为。要检查或保存特定行为,请在模型至少运行一次后提取 .config 字典。

参数

centeredLocScaleReparam 重参数化器的可选中心化参数。如果为 None(默认),则学习中心化。如果在 [0.0,1.0] 范围内的一个浮点数,则为固定中心化。要完全去中心化(例如在 MCMC 中),设置为 0.0。

configure(msg: dict) Optional[pyro.infer.reparam.reparam.Reparam][source]

共轭更新

class ConjugateReparam(guide)[source]

基类: pyro.infer.reparam.reparam.Reparam

实验性(EXPERIMENTAL)重参数化为共轭更新分布。

这使用 conjugate_update() 方法更新先验分布 fn。导引可以是一个分布对象,也可以是一个接收模型 *args,**kwargs 作为输入并返回一个分布对象的可调用对象。导引可以是近似的或学习的。

例如,考虑模型和朴素变分导引

total = torch.tensor(10.)
count = torch.tensor(2.)

def model():
    prob = pyro.sample("prob", dist.Beta(0.5, 1.5))
    pyro.sample("count", dist.Binomial(total, prob), obs=count)

guide = AutoDiagonalNormal(model)  # learns the posterior over prob

代替使用这个学习的导引,我们可以手动计算关于“prob”的共轭后验分布,然后在推断期间使用一个更简单的导引,在这种情况下是一个空导引

reparam_model = poutine.reparam(model, {
    "prob": ConjugateReparam(dist.Beta(1 + count, 1 + total - count))
})

def reparam_guide():
    pass  # nothing remains to be modeled!
参数

guide (Distributioncallable) – 一个似然分布或一个返回导引分布的可调用对象。仅支持少数分布,具体取决于先验分布的 conjugate_update() 实现。

apply(msg)[source]

位置-尺度去中心化

class LocScaleReparam(centered=None, shape_params=None)[source]

基类: pyro.infer.reparam.reparam.Reparam

用于潜在变量的通用去中心化重参数化器 [1],这些变量由 locscale(可能还有额外的 shape_params)参数化。

此重参数化仅适用于潜在变量,不适用于似然。

[1] Maria I. Gorinova, Dave Moore, Matthew D. Hoffman (2019)

“概率程序的自动重参数化”(Automatic Reparameterisation of Probabilistic Programs) https://arxiv.org/pdf/1906.03028.pdf

参数
  • centered (float) – 可选的中心化参数。如果为 None(默认),则学习一个每个站点、每个元素的中心化参数,范围在 [0,1]。如果为 0,则完全去中心化分布;如果为 1,则保持中心化分布不变。

  • shape_params (tuplelist) – 可选列表,包含需要从中心化分布不变地复制到去中心化分布的额外参数名称。如果未指定,将复制分布 .arg_constraints 中的所有参数。

apply(msg)[source]

Gumbel-Softmax

class GumbelSoftmaxReparam[source]

基类: pyro.infer.reparam.reparam.Reparam

用于 RelaxedOneHotCategorical 潜在变量的重参数化器。

这对于将多峰后验转换为单峰后验非常有用。请注意,这会使每个事件的潜在维度增加 1。

此重参数化仅适用于潜在变量,不适用于似然。

apply(msg)[source]

变换分布

class TransformReparam[source]

基类: pyro.infer.reparam.reparam.Reparam

用于 pyro.distributions.torch.TransformedDistribution 潜在变量的重参数化器。

这对于具有复杂、改变几何形状的变换的变换分布非常有用,其中后验在 base_dist 空间中具有简单的形状。

此重参数化仅适用于潜在变量,不适用于似然。

apply(msg)[source]

离散余弦变换

class DiscreteCosineReparam(dim=- 1, smooth=0.0, *, experimental_allow_batch=False)[source]

基类: pyro.infer.reparam.unit_jacobian.UnitJacobianReparam

离散余弦重参数化器,使用 DiscreteCosineTransform

这对于序列模型很有用,其中沿着类似时间的轴(例如带状精度矩阵)的耦合会引入长程相关性。这会重参数化到频域表示,其中后验协方差应更接近对角线,从而提高 SVI 中对角线导引的准确性,并提高 HMC 中对角线质量矩阵的有效性。

当对沿着时间维度近似连续的变量进行重参数化时,设置 smooth=1。对于沿着时间轴近似连续可微的变量,设置 smooth=2

此重参数化仅适用于潜在变量,不适用于似然。

参数
  • dim (int) – 进行变换的维度。必须为负数。这是一个从右边开始计数的绝对维度。

  • smooth (float) – 平滑参数。当为 0 时,将白噪声转换为白噪声;当为 1 时,将布朗噪声转换为白噪声;当为 -1 时,将紫噪声转换为白噪声;等等。允许任何实数。https://en.wikipedia.org/wiki/Colors_of_noise

  • experimental_allow_batch (bool) – 实验性(EXPERIMENTAL)允许跨批量维度进行耦合。目标批量维度及其右侧的所有批量维度将被转换为事件维度。默认为 False。

Haar 变换

class HaarReparam(dim=- 1, flip=False, *, experimental_allow_batch=False)[source]

基类: pyro.infer.reparam.unit_jacobian.UnitJacobianReparam

Haar 小波重参数化器,使用 HaarTransform

这对于序列模型很有用,其中沿着类似时间的轴(例如带状精度矩阵)的耦合会引入长程相关性。这会重参数化到频域表示,其中后验协方差应更接近对角线,从而提高 SVI 中对角线导引的准确性,并提高 HMC 中对角线质量矩阵的有效性。

此重参数化仅适用于潜在变量,不适用于似然。

参数
  • dim (int) – 进行变换的维度。必须为负数。这是一个从右边开始计数的绝对维度。

  • flip (bool) – 在应用 Haar 变换之前是否翻转时间轴。默认为 False。

  • experimental_allow_batch (bool) – 实验性(EXPERIMENTAL)允许跨批量维度进行耦合。目标批量维度及其右侧的所有批量维度将被转换为事件维度。默认为 False。

单位雅可比变换

class UnitJacobianReparam(transform, suffix='transformed', *, experimental_allow_batch=False)[source]

基类: pyro.infer.reparam.reparam.Reparam

用于雅可比行列式为一的 Transform 对象的重参数化器。

参数
  • transform (Transform) – 雅可比行列式为 1 的变换。

  • suffix (str) – 要附加到变换后站点的后缀。

  • experimental_allow_batch (bool) – 实验性(EXPERIMENTAL)允许跨批量维度进行耦合。目标批量维度及其右侧的所有批量维度将被转换为事件维度。默认为 False。

apply(msg)[source]

学生T分布

class StudentTReparam[source]

基类: pyro.infer.reparam.reparam.Reparam

用于 StudentT 随机变量的辅助变量重参数化器。

这与 LinearHMMReparam 结合使用时非常有用,因为它允许将 StudentT 过程视为条件高斯过程,从而可以通过 GaussianHMM 进行廉价推断。

这通过引入一个辅助 Gamma 变量来重参数化一个 StudentT,在该变量的条件下,结果是 Normal

apply(msg)[source]

稳定分布

class LatentStableReparam[source]

基类: pyro.infer.reparam.reparam.Reparam

用于 Stable 潜在变量的辅助变量重参数化器。

这对于推断潜在的 Stable 变量非常有用,因为 log_prob() 方法未实现。

这使用 Chambers-Mallows-Stuck 方法 [1],创建一对无参数的辅助分布(Uniform(-pi/2,pi/2)Exponential(1)),它们具有明确定义的 .log_prob() 方法,从而允许在基于似然的推断算法(如 SVI 和 MCMC)中使用重参数化的稳定分布。

此重参数化仅适用于潜在变量,不适用于似然。对于兼容似然的重参数化,请参阅 SymmetricStableReparamStableReparam

[1] J.P. Nolan (2017)。

稳定分布:重尾数据的模型 (Stable Distributions: Models for Heavy Tailed Data)。 https://edspace.american.edu/jpnolan/wp-content/uploads/sites/1720/2020/09/Chap1.pdf

apply(msg)[source]
class SymmetricStableReparam[source]

基类: pyro.infer.reparam.reparam.Reparam

用于对称 Stable 随机变量(即 skew=0 的变量)的辅助变量重参数化器。

这对于推断对称 Stable 变量非常有用,因为 log_prob() 未实现。

这通过将对称 Stable 随机变量重参数化为完全偏斜(skew=1)的 Stable 尺度混合 Normal 随机变量来实现。请参阅 [1] 中的命题 3.(但请注意,我们有所不同,因为 Stable 使用 Nolan 的连续 S0 参数化)。

[1] Alvaro Cartea and Sam Howison (2009)

“使用 Levy-Stable 过程的期权定价”(Option Pricing with Levy-Stable Processes) https://pdfs.semanticscholar.org/4d66/c91b136b2a38117dd16c2693679f5341c616.pdf

apply(msg)[source]
class StableReparam[source]

基类: pyro.infer.reparam.reparam.Reparam

用于任意 Stable 随机变量的辅助变量重参数化器。

这对于推断非对称 Stable 变量非常有用,因为 log_prob() 未实现。

这通过将一个 Stable 随机变量重参数化为另外两个稳定随机变量之和来实现,一个对称,另一个完全偏斜(应用 [1] 的性质 2.3.a)。完全偏斜的变量按照 LatentStableReparam 的方式采样,对称变量按照 SymmetricStableReparam 的方式分解。

[1] V. M. Zolotarev (1986)

“一维稳定分布”(One-dimensional stable distributions)

apply(msg)[source]

投影正态分布

class ProjectedNormalReparam[source]

基类: pyro.infer.reparam.reparam.Reparam

用于 ProjectedNormal 潜在变量的重参数化器。

此重参数化仅适用于潜在变量,不适用于似然。

apply(msg)[source]

隐马尔可夫模型

class LinearHMMReparam(init=None, trans=None, obs=None)[source]

基类: pyro.infer.reparam.reparam.Reparam

用于 LinearHMM 随机变量的辅助变量重参数化器。

这委托给组件重参数化器来创建辅助随机变量,在该变量的条件下,过程变为 GaussianHMM。如果 observation_distTransformedDistribution,这将重新排序这些变换,以便结果是 GaussianHMMTransformedDistribution

这对于训练 LinearHMM 分布的参数非常有用,其 log_prob() 方法未定义。为了在存在非高斯因子(如 Stable()StudentT()LogNormal())的情况下执行推断,请使用 StudentTReparamStableReparamSymmetricStableReparam 等组件重参数化器配置 inittransscale。例如:

hmm = LinearHMM(
    init_dist=Stable(1,0,1,0).expand([2]).to_event(1),
    trans_matrix=torch.eye(2),
    trans_dist=MultivariateNormal(torch.zeros(2), torch.eye(2)),
    obs_matrix=torch.eye(2),
    obs_dist=TransformedDistribution(
        Stable(1.5,-0.5,1.0).expand([2]).to_event(1),
        ExpTransform()))

rep = LinearHMMReparam(init=SymmetricStableReparam(),
                       obs=StableReparam())

with poutine.reparam(config={"hmm": rep}):
    pyro.sample("hmm", hmm, obs=data)
参数
  • init (Reparam) – 初始分布的可选重参数化器。

  • trans (Reparam) – 转移分布的可选重参数化器。

  • obs (Reparam) – 观测分布的可选重参数化器。

apply(msg)[source]

站点分割

class SplitReparam(sections, dim)[source]

基类: pyro.infer.reparam.reparam.Reparam

用于沿维度分割随机变量的重参数化器,类似于 torch.split()

这对于使用不同的重参数化器或推断方法处理张量的不同部分很有用。例如,在对时间序列执行 HMC 推断时,您可以先应用 DiscreteCosineReparamHaarReparam,然后应用 SplitReparam 将其分割为低频和高频分量,最后将低频分量与全局变量一起添加到 full_mass 矩阵中。

参数
  • sections – 单个块的大小或每个块大小的列表。

  • dim (int) – 进行分割的维度。默认为 -1。

类型

list(int)

apply(msg)[source]

神经传输

class NeuTraReparam(guide)[source]

基类: pyro.infer.reparam.reparam.Reparam

多个潜在变量的神经传输重参数化器 [1]。

这使用训练好的 AutoContinuous 导引来改变模型的几何形状,通常用于 MCMC 等。示例用法:

# Step 1. Train a guide
guide = AutoIAFNormal(model)
svi = SVI(model, guide, ...)
# ...train the guide...

# Step 2. Use trained guide in NeuTra MCMC
neutra = NeuTraReparam(guide)
model = poutine.reparam(model, config=lambda _: neutra)
nuts = NUTS(model)
# ...now use the model in HMC or NUTS...

此重参数化仅适用于潜在变量,不适用于似然。请注意,所有站点必须共享一个共同的 NeuTraReparam 实例,并且模型必须具有静态结构。

[1] Hoffman, M. et al. (2019)

“使用神经传输中和哈密顿蒙特卡洛中的不良几何”(NeuTra-lizing Bad Geometry in Hamiltonian Monte Carlo Using Neural Transport) https://arxiv.org/abs/1903.03704

参数

guide (AutoContinuous) – 一个训练好的导引。

reparam(fn=None)[source]
apply(msg)[source]
transform_sample(latent)[source]

给定来自扭曲后验的潜在样本(可能带有批量维度),返回一个 dict,其中包含模型中潜在站点的样本。

参数

latent – 来自扭曲后验的样本(可能带有批量维度)。请注意,批量维度不得与模型中的 plate 维度冲突,即任何批量维度 d < - max_plate_nesting

返回值

一个以模型中潜在站点为键的样本 dict

返回类型

dict

结构化预处理

class StructuredReparam(guide: pyro.infer.autoguide.structured.AutoStructured)[source]

基类: pyro.infer.reparam.reparam.Reparam

多个潜在变量的预处理重参数化器。

这使用训练好的 AutoStructured 导引来改变模型的几何形状,通常用于 MCMC 等。示例用法:

# Step 1. Train a guide
guide = AutoStructured(model, ...)
svi = SVI(model, guide, ...)
# ...train the guide...

# Step 2. Use trained guide in preconditioned MCMC
model = StructuredReparam(guide).reparam(model)
nuts = NUTS(model)
# ...now use the model in HMC or NUTS...

此重参数化仅适用于潜在变量,不适用于似然。请注意,所有站点必须共享一个共同的 StructuredReparam 实例,并且模型必须具有静态结构。

注意

这可以看作是 NeuTraReparam [1] 的受限结构化版本,并结合了在 MAP 估计站点上的 poutine.condition(NeuTra 变换是精确的重参数化器,但点估计的条件化引入了模型近似)。

[1] Hoffman, M. et al. (2019)

“使用神经传输中和哈密顿蒙特卡洛中的不良几何”(NeuTra-lizing Bad Geometry in Hamiltonian Monte Carlo Using Neural Transport) https://arxiv.org/abs/1903.03704

参数

guide (AutoStructured) – 一个训练好的导引。

reparam(fn=None)[source]
apply(msg)[source]
transform_samples(aux_samples, save_params=None)[source]

给定来自扭曲后验的潜在样本(可能带有批量维度),返回一个 dict,其中包含模型中潜在站点的样本。

参数
  • aux_samples (dict) – 字典,将站点名称映射到每个潜在辅助站点的张量值(如果指定了 save_params,则仅包含计算请求参数所需的潜在辅助站点)。

  • save_params (list) – 可选列表,包含需要保存的站点名称。这对于具有大量无关变量的模型很有用。默认为 None,保存所有参数。

返回值

一个以模型中潜在站点为键的样本 dict

返回类型

dict