因果效应 VAE

本模块实现了因果效应变分自编码器 [1],它展示了多项创新,包括

  • 用于存在隐藏混杂因素的因果效应推断的生成模型;

  • 使用孪生神经网络的模型和指南,以应对不平衡的治疗分配;以及

  • 一个自定义的训练损失函数,它包含 ELBO 项以及训练指南以回答反事实查询所需的额外项。

主要接口是 CEVAE 类,但用户可以通过使用组件 ModelGuideTraceCausalEffect_ELBO 和工具类进行定制。

参考文献

[1] C. Louizos, U. Shalit, J. Mooij, D. Sontag, R. Zemel, M. Welling (2017)。

CEVAE 类

class CEVAE(feature_dim, outcome_dist='bernoulli', latent_dim=20, hidden_dim=200, num_layers=3, num_samples=100)[source]

基类: torch.nn.modules.module.Module

实现因果效应 VAE [1] 的主类。这假设了一个图形模型

digraph { Z [pos="1,2!",style=filled]; X [pos="2,1!"]; y [pos="1,0!"]; t [pos="0,1!"]; Z -> X; Z -> t; Z -> y; t -> y; }

其中 t 是二元治疗变量,y 是结果,Z 是未观测到的混杂因素,X 是隐藏混杂因素 Z 的噪声函数。

示例

cevae = CEVAE(feature_dim=5)
cevae.fit(x_train, t_train, y_train)
ite = cevae.ite(x_test)  # individual treatment effect
ate = ite.mean()         # average treatment effect
变量
  • model (Model) – 生成模型。

  • guide (Guide) – 推断模型。

参数
  • feature_dim (int) – 特征空间 x 的维度。

  • outcome_dist (str) – 其中之一:“bernoulli”(默认)、“exponential”、“laplace”、“normal”、“studentt”。

  • latent_dim (int) – 潜变量 z 的维度。默认为 20。

  • hidden_dim (int) – 全连接网络隐藏层的维度。默认为 200。

  • num_layers (int) – 全连接网络中隐藏层的数量。

  • num_samples (int) – ite() 方法的默认样本数量。默认为 100。

fit(x, t, y, num_epochs=100, batch_size=100, learning_rate=0.001, learning_rate_decay=0.1, weight_decay=0.0001, log_every=100)[source]

使用 SVITraceCausalEffect_ELBO 损失函数进行训练。

参数
  • x (Tensor) –

  • t (Tensor) –

  • y (Tensor) –

  • num_epochs (int) – 训练轮次(epoch)的数量。默认为 100。

  • batch_size (int) – 批量大小。默认为 100。

  • learning_rate (float) – 学习率。默认为 1e-3。

  • learning_rate_decay (float) – 跨所有轮次(epoch)的学习率衰减;每步的衰减率将取决于批量大小和轮次数量,使得初始学习率为 learning_rate,最终学习率为 learning_rate * learning_rate_decay。默认为 0.1。

  • weight_decay (float) – 权重衰减。默认为 1e-4。

  • log_every (int) – 每隔多少步记录损失。如果为零,则不记录损失。默认为 100。

返回值

每轮(epoch)损失列表

ite(x, num_samples=None, batch_size=None)[source]

计算数据批量 x 的个体治疗效应 (ITE)。

\[ITE(x) = \mathbb E\bigl[ \mathbf y \mid \mathbf X=x, do(\mathbf t=1) \bigr] - \mathbb E\bigl[ \mathbf y \mid \mathbf X=x, do(\mathbf t=0) \bigr]\]

这具有复杂度 O(len(x) * num_samples ** 2)

参数
  • x (Tensor) – 数据批量。

  • num_samples (int) – 蒙特卡罗样本的数量。默认为 self.num_samples,该值默认为 100

  • batch_size (int) – 批量大小。默认为 len(x)

返回值

一个大小为 len(x) 的估计效应张量。

返回类型

Tensor

to_script_module()[source]

使用 torch.jit.trace_module() 编译此模块,假设 self 已经拟合了数据。

返回值

self 的跟踪版本,带有一个 ite() 方法。

返回类型

torch.jit.ScriptModule

training: bool

CEVAE 组件

class Model(config)[source]

基类: pyro.nn.module.PyroModule

具有潜混杂因素 z 和二元治疗 t 的因果模型的生成模型

z ~ p(z)      # latent confounder
x ~ p(x|z)    # partial noisy observation of z
t ~ p(t|z)    # treatment, whose application is biased by z
y ~ p(y|t,z)  # outcome

这些分布中的每一个都由一个神经网络定义。y 分布由一对不相交的神经网络定义,分别定义了 p(y|t=0,z)p(y|t=1,z);这使得可以处理高度不平衡的治疗分配情况。

参数

config (dict) – 一个字典,指定 feature_dimlatent_dimhidden_dimnum_layersoutcome_dist

forward(x, t=None, y=None, size=None)[source]
y_mean(x, t=None)[source]
z_dist()[source]
x_dist(z)[source]
y_dist(t, z)[source]
t_dist(z)[source]
training: bool
class Guide(config)[source]

基类: pyro.nn.module.PyroModule

用于估计因果效应的推断模型,具有潜混杂因素 z 和二元治疗 t

t ~ q(t|x)      # treatment
y ~ q(y|t,x)    # outcome
z ~ q(z|y,t,x)  # latent confounder, an embedding

这些分布中的每一个都由一个神经网络定义。yz 分布由一对不相交的神经网络定义,分别定义了 p(-|t=0,...)p(-|t=1,...);这使得可以处理高度不平衡的治疗分配情况。

参数

config (dict) – 一个字典,指定 feature_dimlatent_dimhidden_dimnum_layersoutcome_dist

forward(x, t=None, y=None, size=None)[source]
t_dist(x)[source]
y_dist(t, x)[source]
z_dist(y, t, x)[source]
training: bool
class TraceCausalEffect_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]

基类: pyro.infer.trace_elbo.Trace_ELBO

用于训练 CEVAE 的损失函数。根据 [1],CEVAE 的目标函数(需要最大化)是

-loss = ELBO + log q(t|x) + log q(y|t,x)
loss(model, guide, *args, **kwargs)[source]

工具类

class FullyConnected(sizes, final_activation=None)[source]

基类: torch.nn.modules.container.Sequential

具有 ELU 激活函数的多层全连接网络。

append(layer)[source]
class DistributionNet(*args, **kwargs)[source]

基类: torch.nn.modules.module.Module

分布网络的基类。

static get_class(dtype)[source]

通过其名称前缀获取子类,例如

assert DistributionNet.get_class("bernoulli") is BernoulliNet
training: bool
class BernoulliNet(sizes)[source]

基类: pyro.contrib.cevae.DistributionNet

FullyConnected 网络,输出单个 logits 值。

这用于表示以大小为 sizes[0] 的实数值为条件的单个伯努利随机变量的条件概率分布,例如

net = BernoulliNet([3, 4])
z = torch.randn(3)
logits, = net(z)
t = net.make_dist(logits).sample()
forward(x)[source]
static make_dist(logits)[source]
training: bool
class ExponentialNet(sizes)[source]

基类: pyro.contrib.cevae.DistributionNet

FullyConnected 网络,输出一个受约束的 rate

这用于表示以大小为 sizes[0] 的实数值为条件的单个指数随机变量的条件概率分布,例如

net = ExponentialNet([3, 4])
x = torch.randn(3)
rate, = net(x)
y = net.make_dist(rate).sample()
forward(x)[source]
static make_dist(rate)[source]
training: bool
class LaplaceNet(sizes)[source]

基类: pyro.contrib.cevae.DistributionNet

FullyConnected 网络,输出一对受约束的 loc,scale

这用于表示以大小为 sizes[0] 的实数值为条件的单个拉普拉斯随机变量的条件概率分布,例如

net = LaplaceNet([3, 4])
x = torch.randn(3)
loc, scale = net(x)
y = net.make_dist(loc, scale).sample()
forward(x)[source]
static make_dist(loc, scale)[source]
training: bool
class NormalNet(sizes)[source]

基类: pyro.contrib.cevae.DistributionNet

FullyConnected 网络,输出一对受约束的 loc,scale

这用于表示以大小为 sizes[0] 的实数值为条件的单个指数随机变量的条件概率分布,例如

net = NormalNet([3, 4])
x = torch.randn(3)
loc, scale = net(x)
y = net.make_dist(loc, scale).sample()
forward(x)[source]
static make_dist(loc, scale)[source]
training: bool
class StudentTNet(sizes)[source]

基类: pyro.contrib.cevae.DistributionNet

FullyConnected 网络,输出一个受约束的 df,loc,scale 三元组,其中共享 df > 1

这用于表示以大小为 sizes[0] 的实数值为条件的单个 Student's t 随机变量的条件概率分布,例如

net = StudentTNet([3, 4])
x = torch.randn(3)
df, loc, scale = net(x)
y = net.make_dist(df, loc, scale).sample()
forward(x)[source]
static make_dist(df, loc, scale)[source]
training: bool
class DiagNormalNet(sizes)[source]

基类: torch.nn.modules.module.Module

FullyConnected 网络,输出一对受约束的 loc,scale

这用于表示以大小为 sizes[0] 的实数值为条件的、大小为 sizes[-1] 的对角正态随机变量的条件概率分布,例如

net = DiagNormalNet([3, 4, 5])
z = torch.randn(3)
loc, scale = net(z)
x = dist.Normal(loc, scale).sample()

这适用于潜变量 z 的分布和预白化的 x 特征,并且保守地裁剪 locscale 值。

forward(x)[source]
training: bool