使用 MuE 的生物序列模型

警告

pyro.contrib.mue 中的代码正在开发中。此代码不保证保持向后兼容。

pyro.contrib.mue 提供了处理生物序列数据的建模工具。特别是它实现了 MuE 分布,作为基于多序列比对预处理的完全生成式替代方案。

参考:MuE 模型在 Weinstein 和 Marks (2021) 中描述,https://www.biorxiv.org/content/10.1101/2020.07.31.231381v2

MuE 模型示例

MuE 观测模型示例。

class ProfileHMM(latent_seq_length, alphabet_length, prior_scale=1.0, indel_prior_bias=10.0, cuda=False, pin_memory=False)[source]

基类:torch.nn.modules.module.Module

Profile HMM。

该模型由一个常数分布(Delta 函数)在回归序列上,外加一个 MuE 观测分布组成。先验分布都是正态分布,并通过 Softmax 函数推到 Simplex 上。

参数
  • latent_seq_length (int) – 潜在回归序列 M 的长度。必须大于或等于 1。

  • alphabet_length (int) – 序列字母表长度(例如,氨基酸为 20)。

  • prior_scale (float) – 先验分布的标准差。

  • indel_prior_bias (float) – indel 未发生的对数概率的先验分布的均值。值越高,发生 indel 的概率越低。

  • cuda (bool) – 训练期间将数据传输到 GPU 上。

  • pin_memory (bool) – 锁定内存以加快 GPU 传输。

fit_svi(dataset, epochs=2, batch_size=1, scheduler=None, jit=False)[source]

使用随机变分推断推断近似后验。

这会运行 SVI。这是一种近似推断方法,可用于快速迭代概率模型。

参数
  • dataset (Dataset) – 训练数据集。

  • epochs (int) – 训练轮数。

  • batch_size (int) – Mini-batch 大小(序列数量)。

  • scheduler (pyro.optim.MultiStepLR) – 优化调度器。(默认:Adam 优化器,学习率常数 0.01。)

  • jit (bool) – 是否使用 jit 编译的 ELBO。

evaluate(dataset_train, dataset_test=None, jit=False)[source]

评估训练集和测试集上的性能(对数概率和每个残基的困惑度)。

参数
  • dataset (Dataset) – 训练数据集。

  • dataset – 测试数据集。

  • jit (bool) – 是否使用 jit 编译的 ELBO。

class FactorMuE(data_length, alphabet_length, z_dim, batch_size=10, latent_seq_length=None, indel_factor_dependence=False, indel_prior_scale=1.0, indel_prior_bias=10.0, inverse_temp_prior=100.0, weights_prior_scale=1.0, offset_prior_scale=1.0, z_prior_distribution='Normal', ARD_prior=False, substitution_matrix=True, substitution_prior_scale=10.0, latent_alphabet_length=None, cuda=False, pin_memory=False, epsilon=1e-32)[source]

基类:torch.nn.modules.module.Module

该模型由概率 PCA 加上一个 MuE 输出分布组成。

先验分布都是正态分布,并在相关情况下通过 Softmax 推到 Simplex 上。

参数
  • data_length (int) – 输入序列矩阵的长度,包括末尾的零填充。

  • alphabet_length (int) – 序列字母表长度(例如,氨基酸为 20)。

  • z_dim (int) – z 空间的维度数量。

  • batch_size (int) – Mini-batch 大小。

  • latent_seq_length (int) – 潜在回归序列 (M) 的长度。必须大于或等于 1。(默认:1.1 x data_length。)

  • indel_factor_dependence (bool) – Indel 概率取决于潜在变量 z。

  • indel_prior_scale (float) – indel 参数先验分布的标准差。

  • indel_prior_bias (float) – indel 未发生的对数概率的先验分布的均值。值越高,发生 indel 的概率越低。

  • inverse_temp_prior (float) – 逆温度参数先验分布的均值。

  • weights_prior_scale (float) – 因子先验分布的标准差。

  • offset_prior_scale (float) – pPCA 模型中偏移量(常数)先验分布的标准差。

  • z_prior_distribution (str) – 潜在变量 z 的先验分布。可以是 'Normal' (pPCA 模型) 或 'Laplace' (ICA 模型)。

  • ARD_prior (bool) – 对因子使用自动相关性确定先验。

  • substitution_matrix (bool) – 使用可学习的替换矩阵而不是单位矩阵。

  • substitution_prior_scale (float) – 替换矩阵参数先验分布的标准差(当 substitution_matrix 为 True 时)。

  • latent_alphabet_length (int) – 潜在回归序列中字母表的长度。

  • cuda (bool) – 训练期间将数据传输到 GPU 上。

  • pin_memory (bool) – 锁定内存以加快 GPU 传输。

  • epsilon (float) – 一个用于数值稳定的小值。

fit_svi(dataset, epochs=2, anneal_length=1.0, batch_size=None, scheduler=None, jit=False)[source]

使用随机变分推断推断近似后验。

这会运行 SVI。这是一种近似推断方法,可用于快速迭代概率模型。

参数
  • dataset (Dataset) – 训练数据集。

  • epochs (int) – 训练轮数。

  • anneal_length (float) – 先验 KL 散度权重从 0 到 1 线性退火的轮数,以改进训练。

  • batch_size (int) – Mini-batch 大小(序列数量)。

  • scheduler (pyro.optim.MultiStepLR) – 优化调度器。(默认:Adam 优化器,学习率常数 0.01。)

  • jit (bool) – 是否使用 jit 编译的 ELBO。

evaluate(dataset_train, dataset_test=None, jit=False)[source]

评估训练集和测试集上的性能(对数概率和每个残基的困惑度)。

参数
  • dataset (Dataset) – 训练数据集。

  • dataset – 测试数据集(可选)。

  • jit (bool) – 是否使用 jit 编译的 ELBO。

embed(dataset, batch_size=None)[source]

获取潜在空间嵌入(z 的后验均值)。

参数
  • dataset (Dataset) – 要嵌入的数据集。

  • batch_size (int) – Mini-batch 大小(序列数量)。(默认:模型对象的 batch_size。)

用于参数化 MuE 的状态编排器

class Profile(M, epsilon=1e-32)[source]

基类:torch.nn.modules.module.Module

Profile HMM 状态编排。根据 [1] 中公式 S40(对于 j ∈ {0, 1, 2},r_{M+1,j} = 1 且 u_{M+1,j} = 0)参数化 HMM。关于 Profile HMM 的更多背景信息,请参阅 [2]。

参考文献

[1] E. N. Weinstein, D. S. Marks (2021) “考虑突变变异的生成概率生物序列模型” https://www.biorxiv.org/content/10.1101/2020.07.31.231381v2.full.pdf

[2] R. Durbin, S. R. Eddy, A. Krogh, and G. Mitchison (1998) “生物序列分析:蛋白质和核酸的概率模型” 剑桥大学出版社

参数
  • M (int) – 回归序列的长度。

  • epsilon (float) – 一个用于数值稳定的小值。

forward(precursor_seq_logits, insert_seq_logits, insert_logits, delete_logits, substitute_logits=None)[source]

给定 profile 参数,组装 HMM 参数。

参数
  • precursor_seq_logits (Tensor) – 回归序列 log(x)。最右侧维度应为 (M, D),并且可广播到 (batch_size, M, D),其中 D 是潜在字母表大小。应沿最终轴归一化为 1,即 precursor_seq_logits.logsumexp(-1) = zeros

  • insert_seq_logits (Tensor) – 插入序列 log(c)。最右侧维度应为 (M+1, D),并且可广播到 (batch_size, M+1, D)。应沿最终轴归一化。

  • insert_logits (Tensor) – 插入概率 log(r)。最右侧维度应为 (M, 3, 2),并且可广播到 (batch_size, M, 3, 2)。应沿最终轴归一化。

  • delete_logits (Tensor) – 删除概率 log(u)。最右侧维度应为 (M, 3, 2),并且可广播到 (batch_size, M, 3, 2)。应沿最终轴归一化。

  • substitute_logits (Tensor) – 替换概率 log(l)。最右侧维度应为 (D, B),其中 B 是数据的字母表大小,并且可广播到 (batch_size, D, B)。必须沿最终轴归一化。

返回

initial_logitstransition_logitsobservation_logits。这些参数可用于直接初始化 MissingDataDiscreteHMM 分布。

返回类型

Tensor, Tensor, Tensor

mg2k(m, g, M)[source]

将 (m, g) 索引转换为 k 索引。

缺失或可变长度数据 HMM

class MissingDataDiscreteHMM(initial_logits, transition_logits, observation_logits, validate_args=None)[source]

基类:pyro.distributions.distribution.Distribution, Callable

具有离散潜在状态和离散观测的 HMM,允许缺失数据或可变长度序列。观测假定为独热编码;全为零的行表示缺失数据。

警告

与 Pyro 的 pyro.distributions.DiscreteHMM 不同(后者计算第一个状态的概率为 initial.T @ transition @ emission),该分布使用标准的 HMM 约定,initial.T @ emission

参数
  • initial_logits (Tensor) – 潜在状态初始分类分布的 logits 张量。最右侧尺寸应为 state_dim,并且可广播到 (batch_size, state_dim)

  • transition_logits (Tensor) – 潜在状态之间转移条件分布的 logits 张量。最右侧形状应为 (state_dim, state_dim) (旧状态, 新状态),并且可广播到 (batch_size, state_dim, state_dim)

  • observation_logits (Tensor) – 来自潜在状态的观测分布的 logits 张量。最右侧形状应为 (state_dim, categorical_size),其中 categorical_size 是分类输出的维度,并且可广播到 (batch_size, state_dim, categorical_size)

log_prob(value)[source]
参数

value (Tensor) – 独热编码的观测。必须是实值 (float),并且可广播到 (batch_size, num_steps, categorical_size),其中 categorical_size 是分类输出的维度。缺失数据用零表示,即 value[batch, step, :] == tensor([0, ..., 0])。可通过在序列末尾填充零来处理可变长度观测序列。

sample(sample_shape=torch.Size([]))[source]
参数

sample_shape (Size) – 采样形状,最后一个维度必须是 num_steps,并且必须可广播到 (batch_size, num_steps)。batch_size 必须是 int,不能是 tuple。

filter(value)[source]

计算每个步骤中状态变量基于先前观测的边缘概率。

参数

value (Tensor) – 独热编码的观测。必须是实值 (float),并且可广播到 (batch_size, num_steps, categorical_size),其中 categorical_size 是分类输出的维度。

smooth(value)[source]

计算每个位置状态的后验期望值(平滑)。

参数

value (Tensor) – 独热编码的观测。必须是实值 (float),并且可广播到 (batch_size, num_steps, categorical_size),其中 categorical_size 是分类输出的维度。

sample_states(value)[source]

使用前向过滤-后向采样算法采样状态。

参数

value (Tensor) – 独热编码的观测。必须是实值 (float),并且可广播到 (batch_size, num_steps, categorical_size),其中 categorical_size 是分类输出的维度。

map_states(value)[source]

使用 Viterbi 算法计算状态变量的最大后验 (MAP) 估计。

参数

value (Tensor) – 独热编码的观测。必须是实值 (float),并且可广播到 (batch_size, num_steps, categorical_size),其中 categorical_size 是分类输出的维度。

given_states(states)[source]

以状态变量为条件的分布。

参数

map_states (Tensor) – 状态轨迹。必须是整数值 (long),并且可广播到 (batch_size, num_steps)

sample_given_states(states)[source]

在给定状态变量的情况下采样观测。

参数

map_states (Tensor) – 状态轨迹。必须是整数值 (long),并且可广播到 (batch_size, num_steps)

生物序列数据集加载

class BiosequenceDataset(source, source_type='list', alphabet='amino-acid', max_length=None, include_stop=False, device=None)[source]

基类:Generic[torch.utils.data.dataset.T_co]

加载生物序列数据,可以从 fasta 文件或 Python 列表加载。

参数
  • source – 输入 fasta 文件路径 (str) 或输入序列列表 (str 列表)。

  • source_type (str) – 输入类型,可以是 'list' 或 'fasta'。

  • alphabet (str) – 要使用的字母表。字母表 'amino-acid' 和 'dna' 已预设;任何其他输入将被解释为字母表本身,例如,您可以使用 'ACGU' 作为 RNA。

  • max_length (int) – 序列独热表示的总长度,包括零填充。默认为数据集中最大序列长度。

  • include_stop (bool) – 在每个序列末尾附加停止符号,并将停止符号添加到字母表中。

  • device (torch.device) – 数据在内存中存储的设备。

write(x, alphabet, file, truncate_stop=False, append=False, scores=None)[source]

将序列样本写入文件。

参数
  • x (Tensor) – 独热编码序列,尺寸为 (data_size, seq_length, alphabet_length)。对于可变长度序列,可能用零填充。

  • alphabet (array) – 字母表。

  • file (str) – 输出文件,序列将以 fasta 格式写入。

  • truncate_stop (bool) – 如果为 True,序列将在第一个停止符号处截断(即,停止符号及之后的部分不会写入)。如果为 False,将写入整个序列,包括任何内部停止符号。

  • append (bool) – 如果为 True,序列将附加到输出文件末尾。如果为 False,则先擦除文件。