Pyro 示例

数据集

多 MNIST

该脚本生成一个类似于文献 [1] 中描述的多 MNIST 数据集。

[1] Eslami, SM Ali, et al. “Attend, infer, repeat: Fast scene understanding with generative models.” Advances in Neural Information Processing Systems. 2016.

imresize(arr, size)[source]
sample_one(canvas_size, mnist)[source]
sample_multi(num_digits, canvas_size, mnist)[source]
mk_dataset(n, mnist, max_digits, canvas_size)[source]
load_mnist(root_path)[source]
load(root_path)[source]

BART 乘车人数

load_bart_od()[source]

加载 2011-2019 年间每对 BART 车站之间的每小时始发-终点乘车人数数据集。

来源 https://www.bart.gov/about/reports/ridership

首次调用时,将下载此数据集。后续调用将从本地缓存文件 .pkl.bz2 读取。此功能尝试下载由 Pyro 团队维护的预处理压缩缓存文件。缓存命中时速度应非常快。缓存未命中时,将回退到下载原始数据源并预处理数据集,这需要约 350MB 文件传输,存储数 GB 临时文件,并花费 30 分钟以上的时间。

返回

数据集是一个包含以下字段的字典

  • ”stations”: 车站名称字符串列表

  • ”start_date”: 第一次观察的 datetime.datetime 对象

  • ”counts”: 乘车人数的 torch.FloatTensor,形状为 (num_hours, len(stations), len(stations))

load_fake_od()[source]

创建一个微型合成数据集用于冒烟测试。

Nextstrain SARS-CoV-2 计数

load_nextstrain_counts(map_location=None) dict[source]

加载 SARS-CoV-2 数据集。

原始数据集是可通过 nextstrain 获取的预处理中间文件 metadata.tsv.gz。然后,Broad Institute 的 预处理脚本metadata.tsv.gz 文件聚合到 (月,位置,谱系) 和 (谱系,突变) 分箱中。

工具集

class MNIST(root: str, train: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False)[source]

基类: Generic[torch.utils.data.dataset.T_co]

mirrors = ['https://d2hg8soec8ck9v.cloudfront.net/datasets/mnist/', 'http://yann.lecun.com/exdb/mnist/', 'https://ossci-datasets.s3.amazonaws.com/mnist/']
get_data_loader(dataset_name, data_dir, batch_size=1, dataset_transforms=None, is_training_set=True, shuffle=True)[source]
print_and_log(logger, msg)[source]
get_data_directory(filepath=None)[source]