自动推断向导生成¶
AutoGuide¶
- class AutoGuide(model, *, create_plates=None)[source]¶
-
自动推断向导的基类。
派生类必须实现
forward()
方法,其*args, **kwargs
与基类model
相同。自动推断向导可以单独使用,也可以在
AutoGuideList
对象中组合使用。- 参数
model (callable) – 一个 Pyro 模型。
create_plates (callable) – 一个可选函数,接受与
model()
相同的*args,**kwargs
并返回一个pyro.plate
或 plates 的可迭代对象。未返回的 plates 将照常自动创建。这对于数据子采样很有用。
- property model¶
- call(*args, **kwargs)[source]¶
调用
forward()
并以 tuple 而不是 dict 返回推断向导参数值的方法,这是 JIT 追踪的要求。与forward()
不同,此方法可以被torch.jit.trace_module()
追踪。警告
一旦 PyTorch JIT 追踪器开始接受 dict 作为有效的返回类型,此方法可能会被移除。请参阅 issue <https://github.com/pytorch/pytorch/issues/27743>_。
AutoGuideList¶
- class AutoGuideList(model, *, create_plates=None)[source]¶
基类:
pyro.infer.autoguide.guides.AutoGuide
,torch.nn.modules.container.ModuleList
用于组合多个自动或自定义向导的容器类。
示例用法
guide = AutoGuideList(my_model) guide.append(AutoDiagonalNormal(poutine.block(model, hide=["assignment"]))) guide.append(AutoDiscreteParallel(poutine.block(model, expose=["assignment"]))) svi = SVI(model, guide, optim, Trace_ELBO())
- 参数
model (callable) – 一个 Pyro 模型
- append(part)[source]¶
添加模型一部分的自动或自定义向导。该向导应通过阻塞模型以限制在采样站点的子集上创建。任何两个部分都不应在同一个采样站点上操作。
- 参数
part (AutoGuide 或 callable) – 要添加的部分向导
- forward(*args, **kwargs)[source]¶
一个复合向导,其
*args, **kwargs
与基类model
相同。注意
此方法由
Module
内部使用。用户应改为使用__call__()
。- 返回
一个字典,将采样站点的名称映射到采样值。
- 返回类型
AutoCallable¶
- class AutoCallable(model, guide, median=<function AutoCallable.<lambda>>)[source]¶
基类:
pyro.infer.autoguide.guides.AutoGuide
简单可调用向导的
AutoGuide
包装器。这在内部用于将自动向导与简单的用户定义的可调用向导组合,例如:
def my_local_guide(*args, **kwargs): ... guide = AutoGuideList(model) guide.append(AutoDelta(poutine.block(model, expose=['my_global_param'])) guide.append(my_local_guide) # automatically wrapped in an AutoCallable
要指定一个 median callable,你可以改为
def my_local_median(*args, **kwargs) ... guide.append(AutoCallable(model, my_local_guide, my_local_median))
对于需要访问 plates 等更复杂的向导,用户应该改为继承
AutoGuide
。- 参数
model (callable) – 一个 Pyro 模型
guide (callable) – 一个 Pyro 向导(通常仅作用于模型的一部分)
median (callable) – 一个可选的可调用对象,返回一个将采样站点名称映射到计算得到的中位数张量的字典。
AutoNormal¶
- class AutoNormal(model, *, init_loc_fn=<function init_to_feasible>, init_scale=0.1, create_plates=None)[source]¶
基类:
pyro.infer.autoguide.guides.AutoGuide
此
AutoGuide
实现使用具有对角协方差矩阵的正态分布来构建覆盖整个潜在空间的向导。该向导不依赖于模型的*args, **kwargs
。它应该等同于 :class: AutoDiagonalNormal ,但具有更方便的站点名称和更好的
TraceMeanField_ELBO
支持。在
AutoDiagonalNormal
中,如果你的模型有 N 个命名参数,维度为 k_i,且 sum k_i = D,你会得到一个长度为 D 的均值向量和一个长度为 D 的 sigma 向量。此向导为你提供了 N 个不同的正态分布,你可以按名称调用它们。用法
guide = AutoNormal(model) svi = SVI(model, guide, ...)
- 参数
- scale_constraint = SoftplusPositive(lower_bound=0.0)¶
- forward(*args, **kwargs)[source]¶
一个自动向导,其
*args, **kwargs
与基类model
相同。注意
此方法由
Module
内部使用。用户应改为使用__call__()
。- 返回
一个字典,将采样站点的名称映射到采样值。
- 返回类型
- quantiles(quantiles, *args, **kwargs)[source]¶
返回每个潜在变量的后验分位数。示例
print(guide.quantiles([0.05, 0.5, 0.95]))
- 参数
quantiles (torch.Tensor 或 list) – 请求的分位点列表,值介于 0 和 1 之间。
- 返回
一个字典,将采样站点的名称映射到分位数值张量。
- 返回类型
AutoDelta¶
- class AutoDelta(model, init_loc_fn=<function init_to_median>, *, create_plates=None)[source]¶
基类:
pyro.infer.autoguide.guides.AutoGuide
此
AutoGuide
实现使用 Delta 分布构建覆盖整个潜在空间的 MAP 向导。该向导不依赖于模型的*args, **kwargs
。注意
此类别在约束空间中执行 MAP 推断。
用法
guide = AutoDelta(model) svi = SVI(model, guide, ...)
潜在变量使用
init_loc_fn()
进行初始化。要更改默认行为,可以创建一个自定义的init_loc_fn()
,如 初始化 中所述,例如def my_init_fn(site): if site["name"] == "level": return torch.tensor([-1., 0., 1.]) if site["name"] == "concentration": return torch.ones(k) return init_to_sample(site)
- 参数
model (callable) – 一个 Pyro 模型。
init_loc_fn (callable) – 按站点进行的初始化函数。有关可用函数,请参阅初始化部分。
create_plates (callable) – 一个可选函数,接受与
model()
相同的*args,**kwargs
并返回一个pyro.plate
或 plates 的可迭代对象。未返回的 plates 将照常自动创建。这对于数据子采样很有用。
AutoContinuous¶
- class AutoContinuous(model, init_loc_fn=<function init_to_median>)[source]¶
基类:
pyro.infer.autoguide.guides.AutoGuide
连续值自动微分变分推断 [1] 实现的基类。
这使用了
torch.distributions.transforms
将每个受约束的潜在变量转换为无约束空间,然后将所有变量连接成一个单一的无约束潜在变量。每个派生类都实现一个get_posterior()
方法,返回此单一无约束潜在变量的分布。假设模型结构和潜在维度是固定的,并且所有潜在变量都是连续的。
- 参数
model (callable) – 一个 Pyro 模型
参考文献
- [1] Automatic Differentiation Variational Inference,
Alp Kucukelbir, Dustin Tran, Rajesh Ranganath, Andrew Gelman, David M. Blei
- 参数
model (callable) – 一个 Pyro 模型。
init_loc_fn (callable) – 按站点进行的初始化函数。有关可用函数,请参阅初始化部分。
- get_base_dist()[source]¶
当后验被重参数化为
TransformedDistribution
时,返回其基础分布。这不应依赖于模型的 *args, **kwargs。posterior = TransformedDistribution(self.get_base_dist(), self.get_transform(*args, **kwargs))
- 返回
表示基础分布的
TorchDistribution
实例。
- get_transform(*args, **kwargs)[source]¶
当后验被重参数化为
TransformedDistribution
时,返回应用于基础分布的变换。这可能依赖于模型的 *args, **kwargs。posterior = TransformedDistribution(self.get_base_dist(), self.get_transform(*args, **kwargs))
- 返回
一个
Transform
实例。
- forward(*args, **kwargs)[source]¶
一个自动向导,其
*args, **kwargs
与基类model
相同。注意
此方法由
Module
内部使用。用户应改为使用__call__()
。- 返回
一个字典,将采样站点的名称映射到采样值。
- 返回类型
- quantiles(quantiles, *args, **kwargs)[source]¶
返回每个潜在变量的后验分位数。示例
print(guide.quantiles([0.05, 0.5, 0.95]))
- 参数
quantiles (torch.Tensor 或 list) – 请求的分位点列表,值介于 0 和 1 之间。
- 返回
一个字典,将采样站点的名称映射到分位数值张量。
- 返回类型
AutoMultivariateNormal¶
- class AutoMultivariateNormal(model, init_loc_fn=<function init_to_median>, init_scale=0.1)[source]¶
基类:
pyro.infer.autoguide.guides.AutoContinuous
此
AutoContinuous
实现使用多元正态分布的乔列斯基分解来构建覆盖整个潜在空间的向导。该向导不依赖于模型的*args, **kwargs
。用法
guide = AutoMultivariateNormal(model) svi = SVI(model, guide, ...)
默认情况下,均值向量由
init_loc_fn()
初始化,乔列斯基因子初始化为单位矩阵乘以一个小的因子。- 参数
- scale_constraint = SoftplusPositive(lower_bound=0.0)¶
- scale_tril_constraint = UnitLowerCholesky()¶
AutoDiagonalNormal¶
- class AutoDiagonalNormal(model, init_loc_fn=<function init_to_median>, init_scale=0.1)[source]¶
基类:
pyro.infer.autoguide.guides.AutoContinuous
此
AutoContinuous
实现使用具有对角协方差矩阵的正态分布来构建覆盖整个潜在空间的向导。该向导不依赖于模型的*args, **kwargs
。用法
guide = AutoDiagonalNormal(model) svi = SVI(model, guide, ...)
默认情况下,均值向量初始化为零,尺度初始化为单位矩阵乘以一个小的因子。
- 参数
- scale_constraint = SoftplusPositive(lower_bound=0.0)¶
AutoLowRankMultivariateNormal¶
- class AutoLowRankMultivariateNormal(model, init_loc_fn=<function init_to_median>, init_scale=0.1, rank=None)[源码]¶
基类:
pyro.infer.autoguide.guides.AutoContinuous
AutoContinuous
的此实现使用低秩加对角多变量正态分布来构建整个潜在空间的引导。此引导不依赖于模型的*args, **kwargs
。用法
guide = AutoLowRankMultivariateNormal(model, rank=10) svi = SVI(model, guide, ...)
默认情况下,
cov_diag
初始化为一个小的常数,cov_factor
随机初始化,使得平均而言cov_factor.matmul(cov_factor.t())
的尺度与cov_diag
相同。- 参数
- scale_constraint = SoftplusPositive(lower_bound=0.0)¶
AutoNormalizingFlow¶
- class AutoNormalizingFlow(model, init_transform_fn)[源码]¶
基类:
pyro.infer.autoguide.guides.AutoContinuous
AutoContinuous
的此实现使用通过一系列双射变换(例如各种TransformModule
子类)变换的对角正态分布来构建整个潜在空间的引导。此引导不依赖于模型的*args, **kwargs
。用法
transform_init = partial(iterated, block_autoregressive, repeats=2) guide = AutoNormalizingFlow(model, transform_init) svi = SVI(model, guide, ...)
- 参数
model (callable) – 一个生成模型
init_transform_fn – 一个可调用对象,当提供潜在维度时,它返回一个
Transform
的实例,或者如果变换具有可训练参数,则返回一个TransformModule
实例。
AutoIAFNormal¶
- class AutoIAFNormal(model, hidden_dim=None, init_loc_fn=None, num_transforms=1, **init_transform_kwargs)[源码]¶
基类:
pyro.infer.autoguide.guides.AutoNormalizingFlow
AutoContinuous
的此实现使用通过AffineAutoregressive
变换的对角正态分布来构建整个潜在空间的引导。此引导不依赖于模型的*args, **kwargs
。用法
guide = AutoIAFNormal(model, hidden_dim=latent_dim) svi = SVI(model, guide, ...)
- 参数
model (callable) – 一个生成模型
init_loc_fn (callable) –
一个逐点的初始化函数。有关可用函数,请参阅初始化部分。
警告
此参数仅用于保留向后兼容性,在实践中无效。
num_transforms (int) – 按顺序使用的
AffineAutoregressive
变换次数。init_transform_kwargs –
affine_autoregressive()
接受的其他关键字参数。
AutoLaplaceApproximation¶
- class AutoLaplaceApproximation(model, init_loc_fn=<function init_to_median>)[源码]¶
基类:
pyro.infer.autoguide.guides.AutoContinuous
拉普拉斯近似(二次近似)通过无约束空间中的多变量正态分布来近似后验 \(\log p(z | x)\)。在底层,它使用 Delta 分布来构建整个(无约束)潜在空间的 MAP 引导。其协方差由 \(-\log p(x, z)\) 在 z 的 MAP 点处的 Hessian 矩阵的逆给出。
用法
delta_guide = AutoLaplaceApproximation(model) svi = SVI(model, delta_guide, ...) # ...then train the delta_guide... guide = delta_guide.laplace_approximation()
默认情况下,均值向量初始化为经验先验中位数。
- 参数
model (callable) – 一个生成模型
init_loc_fn (callable) – 按站点进行的初始化函数。有关可用函数,请参阅初始化部分。
- laplace_approximation(*args, **kwargs)[源码]¶
返回一个
AutoMultivariateNormal
实例,其后验的 loc 和 scale_tril 由拉普拉斯近似给出。
AutoDiscreteParallel¶
AutoStructured¶
- class AutoStructured(model, *, conditionals: Union[str, Dict[str, Union[str, Callable]]] = 'mvn', dependencies: Union[str, Dict[str, Dict[str, Union[str, Callable]]]] = 'linear', init_loc_fn: Callable = <function init_to_feasible>, init_scale: float = 0.1, create_plates: Optional[Callable] = None)[源码]¶
基类:
pyro.infer.autoguide.guides.AutoGuide
结构化引导,其条件分布是 Delta、Normal、MultivariateNormal 或由可调用对象给出,并且其潜在变量可以彼此线性(在无约束空间中)依赖,或者通过可调用对象进行剪切依赖。
用法
def model(data): x = pyro.sample("x", dist.LogNormal(0, 1)) with pyro.plate("plate", len(data)): y = pyro.sample("y", dist.Normal(0, 1)) pyro.sample("z", dist.Normal(y, x), obs=data) # Either fully automatic... guide = AutoStructured(model) # ...or with specified conditional and dependency types... guide = AutoStructured( model, conditionals="normal", dependencies="linear" ) # ...or with custom dependency structure and distribution types. guide = AutoStructured( model=model, conditionals={"x": "normal", "y": "delta"}, dependencies={"x": {"y": "linear"}}, )
训练完成后,此引导可与
StructuredReparam
一起使用,以预处理模型用于 HMC 和 NUTS 推断。注意
如果你声明了一个高维下游变量对低维上游变量的依赖,你可能希望对该权重使用较低的学习率,例如:
def optim_config(param_name): config = {"lr": 0.01} if "deps.my_downstream.my_upstream" in param_name: config["lr"] *= 0.1 return config adam = pyro.optim.Adam(optim_config)
- 参数
model (callable) – 一个 Pyro 模型。
conditionals – 单个分布类型或一个字典,将每个潜在变量名称映射到分布类型。分布类型可以是 {"delta", "normal", "mvn"} 中的字符串,或一个返回零均值(或近似居中)噪声分布样本的可调用对象(此类可调用对象通常在内部调用
pyro.param()
和pyro.sample()
)。dependencies – 依赖类型,或一个字典,将每个站点名称映射到一个字典,该字典将其上游依赖项映射到依赖类型。如果仅提供依赖类型,则将推断依赖结构。依赖类型是字符串 "linear",或一个将展平的上游扰动映射到展平的下游扰动的可调用对象。字符串 "linear" 等效于
nn.Linear(upstream.numel(), downstream.numel(), bias=False)
。依赖项不能包含循环或自循环。init_loc_fn (callable) – 按站点进行的初始化函数。有关可用函数,请参阅初始化部分。
init_scale (float) – 每个(无约束变换后的)潜在变量的标准差的初始尺度。
create_plates (callable) – 一个可选函数,接受与
model()
相同的*args,**kwargs
并返回一个pyro.plate
或 plates 的可迭代对象。未返回的 plates 将照常自动创建。这对于数据子采样很有用。
- scale_constraint = SoftplusPositive(lower_bound=0.0)¶
- scale_tril_constraint = SoftplusLowerCholesky()¶
- get_deltas()¶
AutoGaussian¶
- class AutoGaussian(*args, **kwargs)[源码]¶
基类:
pyro.infer.autoguide.guides.AutoGuide
具有最优条件独立结构的 Gaussian 引导。
这等效于全秩的
AutoMultivariateNormal
引导,但其稀疏精度矩阵由模型中的依赖项和 plate 决定 [1]。根据模型结构,这可能比AutoMultivariateNormal
在渐近意义上具有更好的统计效率。此引导实现了多种计算后端。所有后端使用相同的统计学最优参数化。默认的“dense”后端具有与
AutoMultivariateNormal
相似的计算复杂度。实验性的“funsor”后端在时间和空间上可能渐近更便宜(使用 Gaussian 张量变量消除 [2,3]),但会产生较大的常数开销。“funsor”后端需要 funsor,可以通过pip install pyro-ppl[funsor]
安装。此引导当前不依赖于模型的
*args, **kwargs
。示例
guide = AutoGaussian(model) svi = SVI(model, guide, ...)
使用实验性 funsor 后端的示例
!pip install pyro-ppl[funsor] guide = AutoGaussian(model, backend="funsor") svi = SVI(model, guide, ...)
参考文献
- [1] S.Webb, A.Goliński, R.Zinkov, N.Siddharth, T.Rainforth, Y.W.Teh, F.Wood (2018)
“Faithful inversion of generative models for effective amortized inference” https://dl.acm.org/doi/10.5555/3327144.3327229
- [2] F.Obermeyer, E.Bingham, M.Jankowiak, J.Chiu, N.Pradhan, A.M.Rush, N.Goodman
(2019) “Tensor Variable Elimination for Plated Factor Graphs” http://proceedings.mlr.press/v97/obermeyer19a/obermeyer19a.pdf
- [3] F. Obermeyer, E. Bingham, M. Jankowiak, D. Phan, J. P. Chen
(2019) “Functional Tensors for Probabilistic Programming” https://arxiv.org/abs/1910.10775
- 参数
- scale_constraint = SoftplusPositive(lower_bound=0.0)¶
- forward(*args, **kwargs) Dict[str, torch.Tensor] [源码]¶
- median(*args, **kwargs) Dict[str, torch.Tensor] [源码]¶
返回每个潜在变量的后验中位数。
- 返回
一个字典,将采样站点的名称映射到中位数张量。
- 返回类型
AutoMessenger¶
- class AutoMessenger(model: Callable, *, amortized_plates: Tuple[str, ...] = ())[源码]¶
基类:
pyro.poutine.guide.GuideMessenger
,pyro.nn.module.PyroModule
GuideMessenger
自动引导的基类。- 参数
model (callable) – 一个 Pyro 模型。
amortized_plates (tuple) – 一个包含 plate 名称的元组,这些 plate 的引导参数应该共享。这对于子采样非常有用,其中一个引导参数可以在所有 plate 中共享。
- call(*args, **kwargs)[源码]¶
调用
forward()
并以 tuple 而不是 dict 返回推断向导参数值的方法,这是 JIT 追踪的要求。与forward()
不同,此方法可以被torch.jit.trace_module()
追踪。警告
一旦 PyTorch JIT 追踪器开始接受 dict 作为有效的返回类型,此方法可能会被移除。请参阅 issue <https://github.com/pytorch/pytorch/issues/27743>_。
AutoNormalMessenger¶
- class AutoNormalMessenger(model: Callable, *, init_loc_fn: Callable = functools.partial(<function init_to_mean>, fallback=<function init_to_feasible>), init_scale: float = 0.1, amortized_plates: Tuple[str, ...] = ())[源码]¶
基类:
pyro.infer.autoguide.effect.AutoMessenger
具有均值场正态后验的
AutoMessenger
。任意站点的均值场后验是变换后的正态分布。此后验等效于
AutoNormal
或AutoDiagonalNormal
,但允许通过子类化进行自定义。派生类可以覆盖特定站点的
get_posterior()
行为,并简单地使用均值场正态行为作为默认行为,例如:def model(data): a = pyro.sample("a", dist.Normal(0, 1)) b = pyro.sample("b", dist.Normal(0, 1)) c = pyro.sample("c", dist.Normal(a + b, 1)) pyro.sample("obs", dist.Normal(c, 1), obs=data) class MyGuideMessenger(AutoNormalMessenger): def get_posterior(self, name, prior): if name == "c": # Use a custom distribution at site c. bias = pyro.param("c_bias", lambda: torch.zeros(())) weight = pyro.param("c_weight", lambda: torch.ones(()), constraint=constraints.positive) scale = pyro.param("c_scale", lambda: torch.ones(()), constraint=constraints.positive) a = self.upstream_value("a") b = self.upstream_value("b") loc = bias + weight * (a + b) return dist.Normal(loc, scale) # Fall back to mean field. return super().get_posterior(name, prior)
请注意,上面我们手动计算了
loc = bias + weight * (a + b)
。或者,我们可以通过设置loc = bias + weight * prior.loc
来重用模型侧的计算。class MyGuideMessenger_v2(AutoNormalMessenger): def get_posterior(self, name, prior): if name == "c": # Use a custom distribution at site c. bias = pyro.param("c_bias", lambda: torch.zeros(())) scale = pyro.param("c_scale", lambda: torch.ones(()), constraint=constraints.positive) weight = pyro.param("c_weight", lambda: torch.ones(()), constraint=constraints.positive) loc = bias + weight * prior.loc return dist.Normal(loc, scale) # Fall back to mean field. return super().get_posterior(name, prior)
- 参数
- get_posterior(name: str, prior: pyro.distributions.distribution.Distribution) Union[pyro.distributions.distribution.Distribution, torch.Tensor] [源码]¶
AutoHierarchicalNormalMessenger¶
- class AutoHierarchicalNormalMessenger(model: Callable, *, init_loc_fn: Callable = functools.partial(<function init_to_mean>, fallback=<function init_to_feasible>), init_scale: float = 0.1, amortized_plates: Tuple[str, ...] = (), init_weight: float = 1.0, hierarchical_sites: Optional[list] = None)[源码]¶
基类:
pyro.infer.autoguide.effect.AutoNormalMessenger
具有均值场正态后验的
AutoMessenger
,该后验以所有依赖项为条件。任意站点的均值场后验是变换后的正态分布,其均值取决于该站点在模型中给定其依赖项的值。
loc_total = loc + transform.inv(prior.mean) * weight
其中
prior.mean
的值以模型中的上游站点为条件,loc
是未变换空间中均值的独立分量,weight
是对先验均值进行缩放的逐元素因子。此方法不适用于没有均值的分布。派生类可以覆盖特定站点,并简单地使用此方法作为默认方法,有关示例请参阅
AutoNormalMessenger
的文档。- 参数
model (callable) – 一个 Pyro 模型。
init_loc_fn (callable) – 按站点进行的初始化函数。有关可用函数,请参阅初始化部分。
init_scale (float) – 每个(无约束变换后的)潜在变量的标准差的初始尺度。
init_weight (float) – 每个潜在变量的层级站点对后验均值贡献的权重初始值。
hierarchical_sites (list) – 具有层级依赖关系的潜在变量(模型站点)列表。如果为 None,则假定所有站点都具有层级依赖关系。如果为 None,对于没有上游站点的站点,引导的 loc 和 weight 代表/学习了与先验的偏差。
- weight_type = 'element-wise'¶
- get_posterior(name: str, prior: pyro.distributions.distribution.Distribution) Union[pyro.distributions.distribution.Distribution, torch.Tensor] [源码]¶
AutoRegressiveMessenger¶
- class AutoRegressiveMessenger(model: Callable, *, init_loc_fn: Callable = functools.partial(<function init_to_mean>, fallback=<function init_to_feasible>), init_scale: float = 0.1, amortized_plates: Tuple[str, ...] = ())[源码]¶
基类:
pyro.infer.autoguide.effect.AutoMessenger
基类:
pyro.infer.autoguide.effect.AutoMessenger
使用先验依赖结构,通过递归仿射变换先验的
AutoMessenger
。任何站点的后验是先验的已学习仿射变换,以前游后验样本为条件。仿射变换在无约束空间中操作。这仅支持连续潜在变量。
class MyGuideMessenger(AutoRegressiveMessenger): def get_posterior(self, name, prior): if name == "x": # Use a custom distribution at site x. loc = pyro.param("x_loc", lambda: torch.zeros(prior.shape())) scale = pyro.param("x_scale", lambda: torch.ones(prior.shape())), constraint=constraints.positive return dist.Normal(loc, scale).to_event(prior.event_dim()) # Fall back to autoregressive. return super().get_posterior(name, prior)
警告
派生类可以覆盖特定站点的
get_posterior()
行为,并简单地使用回归行为作为默认,例如:- 参数
- 此引导当前不支持基于 JIT 的 ELBO。
- get_posterior(name: str, prior: pyro.distributions.distribution.Distribution) Union[pyro.distributions.distribution.Distribution, torch.Tensor] [源码]¶
training: bool¶
初始化¶
pyro.infer.autoguide.initialization 模块包含自动引导的初始化函数。
- 初始化为先验的随机样本。
init_to_median(site=None, num_samples=15, *, fallback: Optional[Callable] = <function init_to_feasible>)[源码]¶
- 参数
初始化为先验中位数;如果均值未定义,则回退到
fallback
(默认为init_to_feasible()
)。- fallback (callable) – 回退初始化策略,用于
values
中未指定的站点。 引发
-
ValueError – 如果
fallback=None
并且values
中未给出站点的任何值。 init_to_mean(site=None, *, fallback: Optional[Callable] = <function init_to_median>)[源码]¶
- 参数
初始化为先验中位数;如果均值未定义,则回退到
fallback
(默认为init_to_feasible()
)。- fallback (callable) – 回退初始化策略,用于
values
中未指定的站点。 引发
-
初始化为先验均值;如果均值未定义,则回退到
fallback
(默认为init_to_median()
)。 init_to_uniform(site: Optional[dict] = None, radius: float = 2.0)[源码]¶
- 参数
初始化到无约束域中区域
(-radius, radius)
的一个随机点。
- radius (float) – 指定无约束域中绘制初始点的范围。
init_to_value(site: Optional[dict] = None, values: dict = {}, *, fallback: Optional[Callable] = <function init_to_uniform>)[源码]¶
- 参数
初始化为
values
中指定的值。对于未出现在values
中的站点,回退到fallback
(默认为init_to_uniform()
)策略。初始化为先验中位数;如果均值未定义,则回退到
fallback
(默认为init_to_feasible()
)。
- fallback (callable) – 回退初始化策略,用于
values
中未指定的站点。 引发
- values (dict) – 初始值字典,键为站点名称。
init_to_generated(site=None, generate=<function <lambda>>)[源码]¶
初始化为回调函数
generate
返回的另一种初始化策略,该回调函数在每次模型执行时被调用一次。def generate(): values = {"x": torch.randn(100), "y": torch.rand(5)} return init_to_value(values=values) my_init_fn = init_to_generated(generate=generate)
- 参数
这类似于
init_to_value()
,但可以在每次模型执行时产生不同的(例如随机的)值。例如,要生成值并返回init_to_value
,你可以定义:
-
generate (callable) – 一个可调用对象,返回另一个初始化函数,例如返回填充了随机样本字典的
init_to_value(values={...})
。 class InitMessenger(init_fn)[源码]¶
基类:
pyro.poutine.messenger.Messenger
- 参数
通过将
.sample()
调用替换为从初始化策略中抽取的值来初始化站点。这主要用于自动引导类的内部使用。