因果效应 VAE¶
本模块实现了因果效应变分自编码器 [1],它展示了多项创新,包括
用于存在隐藏混杂因素的因果效应推断的生成模型;
使用孪生神经网络的模型和指南,以应对不平衡的治疗分配;以及
一个自定义的训练损失函数,它包含 ELBO 项以及训练指南以回答反事实查询所需的额外项。
主要接口是 CEVAE
类,但用户可以通过使用组件 Model
、Guide
、TraceCausalEffect_ELBO
和工具类进行定制。
参考文献
- [1] C. Louizos, U. Shalit, J. Mooij, D. Sontag, R. Zemel, M. Welling (2017)。
- Causal Effect Inference with Deep Latent-Variable Models.
CEVAE 类¶
- class CEVAE(feature_dim, outcome_dist='bernoulli', latent_dim=20, hidden_dim=200, num_layers=3, num_samples=100)[source]¶
基类:
torch.nn.modules.module.Module
实现因果效应 VAE [1] 的主类。这假设了一个图形模型
digraph { Z [pos="1,2!",style=filled]; X [pos="2,1!"]; y [pos="1,0!"]; t [pos="0,1!"]; Z -> X; Z -> t; Z -> y; t -> y; }其中 t 是二元治疗变量,y 是结果,Z 是未观测到的混杂因素,X 是隐藏混杂因素 Z 的噪声函数。
示例
cevae = CEVAE(feature_dim=5) cevae.fit(x_train, t_train, y_train) ite = cevae.ite(x_test) # individual treatment effect ate = ite.mean() # average treatment effect
- 变量
- 参数
- fit(x, t, y, num_epochs=100, batch_size=100, learning_rate=0.001, learning_rate_decay=0.1, weight_decay=0.0001, log_every=100)[source]¶
使用
SVI
和TraceCausalEffect_ELBO
损失函数进行训练。- 参数
x (Tensor) –
t (Tensor) –
y (Tensor) –
num_epochs (int) – 训练轮次(epoch)的数量。默认为 100。
batch_size (int) – 批量大小。默认为 100。
learning_rate (float) – 学习率。默认为 1e-3。
learning_rate_decay (float) – 跨所有轮次(epoch)的学习率衰减;每步的衰减率将取决于批量大小和轮次数量,使得初始学习率为
learning_rate
,最终学习率为learning_rate * learning_rate_decay
。默认为 0.1。weight_decay (float) – 权重衰减。默认为 1e-4。
log_every (int) – 每隔多少步记录损失。如果为零,则不记录损失。默认为 100。
- 返回值
每轮(epoch)损失列表
- ite(x, num_samples=None, batch_size=None)[source]¶
计算数据批量
x
的个体治疗效应 (ITE)。\[ITE(x) = \mathbb E\bigl[ \mathbf y \mid \mathbf X=x, do(\mathbf t=1) \bigr] - \mathbb E\bigl[ \mathbf y \mid \mathbf X=x, do(\mathbf t=0) \bigr]\]这具有复杂度
O(len(x) * num_samples ** 2)
。
- to_script_module()[source]¶
使用
torch.jit.trace_module()
编译此模块,假设 self 已经拟合了数据。- 返回值
self 的跟踪版本,带有一个
ite()
方法。- 返回类型
CEVAE 组件¶
- class Model(config)[source]¶
-
具有潜混杂因素
z
和二元治疗t
的因果模型的生成模型z ~ p(z) # latent confounder x ~ p(x|z) # partial noisy observation of z t ~ p(t|z) # treatment, whose application is biased by z y ~ p(y|t,z) # outcome
这些分布中的每一个都由一个神经网络定义。
y
分布由一对不相交的神经网络定义,分别定义了p(y|t=0,z)
和p(y|t=1,z)
;这使得可以处理高度不平衡的治疗分配情况。- 参数
config (dict) – 一个字典,指定
feature_dim
、latent_dim
、hidden_dim
、num_layers
和outcome_dist
。
- class Guide(config)[source]¶
-
用于估计因果效应的推断模型,具有潜混杂因素
z
和二元治疗t
t ~ q(t|x) # treatment y ~ q(y|t,x) # outcome z ~ q(z|y,t,x) # latent confounder, an embedding
这些分布中的每一个都由一个神经网络定义。
y
和z
分布由一对不相交的神经网络定义,分别定义了p(-|t=0,...)
和p(-|t=1,...)
;这使得可以处理高度不平衡的治疗分配情况。- 参数
config (dict) – 一个字典,指定
feature_dim
、latent_dim
、hidden_dim
、num_layers
和outcome_dist
。
- class TraceCausalEffect_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]¶
基类:
pyro.infer.trace_elbo.Trace_ELBO
用于训练
CEVAE
的损失函数。根据 [1],CEVAE 的目标函数(需要最大化)是-loss = ELBO + log q(t|x) + log q(y|t,x)
工具类¶
- class FullyConnected(sizes, final_activation=None)[source]¶
基类:
torch.nn.modules.container.Sequential
具有 ELU 激活函数的多层全连接网络。
- class DistributionNet(*args, **kwargs)[source]¶
基类:
torch.nn.modules.module.Module
分布网络的基类。
- class BernoulliNet(sizes)[source]¶
基类:
pyro.contrib.cevae.DistributionNet
FullyConnected
网络,输出单个logits
值。这用于表示以大小为
sizes[0]
的实数值为条件的单个伯努利随机变量的条件概率分布,例如net = BernoulliNet([3, 4]) z = torch.randn(3) logits, = net(z) t = net.make_dist(logits).sample()
- class ExponentialNet(sizes)[source]¶
基类:
pyro.contrib.cevae.DistributionNet
FullyConnected
网络,输出一个受约束的rate
。这用于表示以大小为
sizes[0]
的实数值为条件的单个指数随机变量的条件概率分布,例如net = ExponentialNet([3, 4]) x = torch.randn(3) rate, = net(x) y = net.make_dist(rate).sample()
- class LaplaceNet(sizes)[source]¶
基类:
pyro.contrib.cevae.DistributionNet
FullyConnected
网络,输出一对受约束的loc,scale
。这用于表示以大小为
sizes[0]
的实数值为条件的单个拉普拉斯随机变量的条件概率分布,例如net = LaplaceNet([3, 4]) x = torch.randn(3) loc, scale = net(x) y = net.make_dist(loc, scale).sample()
- class NormalNet(sizes)[source]¶
基类:
pyro.contrib.cevae.DistributionNet
FullyConnected
网络,输出一对受约束的loc,scale
。这用于表示以大小为
sizes[0]
的实数值为条件的单个指数随机变量的条件概率分布,例如net = NormalNet([3, 4]) x = torch.randn(3) loc, scale = net(x) y = net.make_dist(loc, scale).sample()
- class StudentTNet(sizes)[source]¶
基类:
pyro.contrib.cevae.DistributionNet
FullyConnected
网络,输出一个受约束的df,loc,scale
三元组,其中共享df > 1
。这用于表示以大小为
sizes[0]
的实数值为条件的单个 Student's t 随机变量的条件概率分布,例如net = StudentTNet([3, 4]) x = torch.randn(3) df, loc, scale = net(x) y = net.make_dist(df, loc, scale).sample()
- class DiagNormalNet(sizes)[source]¶
基类:
torch.nn.modules.module.Module
FullyConnected
网络,输出一对受约束的loc,scale
。这用于表示以大小为
sizes[0]
的实数值为条件的、大小为sizes[-1]
的对角正态随机变量的条件概率分布,例如net = DiagNormalNet([3, 4, 5]) z = torch.randn(3) loc, scale = net(z) x = dist.Normal(loc, scale).sample()
这适用于潜变量
z
的分布和预白化的x
特征,并且保守地裁剪loc
和scale
值。