Pyro 示例¶
数据集¶
多 MNIST¶
该脚本生成一个类似于文献 [1] 中描述的多 MNIST 数据集。
[1] Eslami, SM Ali, et al. “Attend, infer, repeat: Fast scene understanding with generative models.” Advances in Neural Information Processing Systems. 2016.
BART 乘车人数¶
- load_bart_od()[source]¶
加载 2011-2019 年间每对 BART 车站之间的每小时始发-终点乘车人数数据集。
来源 https://www.bart.gov/about/reports/ridership
首次调用时,将下载此数据集。后续调用将从本地缓存文件
.pkl.bz2
读取。此功能尝试下载由 Pyro 团队维护的预处理压缩缓存文件。缓存命中时速度应非常快。缓存未命中时,将回退到下载原始数据源并预处理数据集,这需要约 350MB 文件传输,存储数 GB 临时文件,并花费 30 分钟以上的时间。- 返回
数据集是一个包含以下字段的字典
”stations”: 车站名称字符串列表
”start_date”: 第一次观察的
datetime.datetime
对象”counts”: 乘车人数的
torch.FloatTensor
,形状为(num_hours, len(stations), len(stations))
。
Nextstrain SARS-CoV-2 计数¶
- load_nextstrain_counts(map_location=None) dict [source]¶
加载 SARS-CoV-2 数据集。
原始数据集是可通过 nextstrain 获取的预处理中间文件
metadata.tsv.gz
。然后,Broad Institute 的 预处理脚本 将metadata.tsv.gz
文件聚合到 (月,位置,谱系) 和 (谱系,突变) 分箱中。
工具集¶
- class MNIST(root: str, train: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False)[source]¶
基类:
Generic
[torch.utils.data.dataset.T_co
]- mirrors = ['https://d2hg8soec8ck9v.cloudfront.net/datasets/mnist/', 'http://yann.lecun.com/exdb/mnist/', 'https://ossci-datasets.s3.amazonaws.com/mnist/']¶