Google Research 示例#

Google Research 使用 Flax 进行的研究集合。

注意力机制#

快速注意力(FAVOR+)和使用 Performers 重新思考注意力#

  • GitHub 上的代码

  • 研究论文

    • 使用 Performers 重新思考注意力(Choromanski 等人,2020 年)

      • 引入“Performers,Transformer 架构,它可以以可证明的准确性估计常规(softmax)全秩注意力 Transformer,但仅使用线性(而不是二次)空间和时间复杂度,而不依赖于任何先验,例如稀疏性或低秩。为了近似 softmax 注意力核,Performers 使用一种新的通过正交随机特征的方法 (FAVOR+) 实现的快速注意力,这对于可扩展的核方法可能具有独立的意义。FAVOR+ 也可用于有效地建模超出 softmax 的可核化注意力机制。”

自注意力不需要 O(n^2) 内存#

  • GitHub 上的代码

  • Colab 笔记本

  • 研究论文

    • 自注意力不需要 O(n^2) 内存(Rabe 和 Staats,2021 年)

      • “我们提出了一种非常简单的注意力算法,该算法相对于序列长度需要 O(1) 内存,以及一个扩展的自注意力算法,该算法需要 O(log n) 内存。这与经常出现的自注意力需要 O(n^2) 内存的观点形成对比。虽然时间复杂度仍然是 O(n^2),但在现代加速器上,设备内存而不是计算能力通常是限制因素。因此,减少注意力的内存需求可以处理比其他方式更长的序列……”

计算机视觉#

着色变换器 (ColTran)#

  • GitHub 上的代码

  • 研究论文

    • 着色变换器(Kumar 等人,2020 年)

      • “我们提出了着色变换器 (ColTran),这是一种完全依赖于自注意力进行图像着色的架构。我们引入了条件变换器层,一种基于自注意力条件生成模型的新颖构建块。我们的消融研究表明,采用这种机制优于许多不同的基线。最后,我们证明了 ColTran 可以在 ImageNet 上生成多样化、高保真的着色,即使对于人类评分者来说,这些着色也与真实值基本无法区分。”

视觉变换器 (ViT)、MLP-Mixer 架构 Big Vision#

  • GitHub 上的代码

    • 视觉变换器和 MLP-Mixer 架构

    • Big Vision

      • “此代码库旨在用于使用 Cloud TPU VM 或 GPU 机器训练大规模视觉模型。它基于 Jax/Flax 库,并使用 tf.data 和 TensorFlow Datasets 来实现可扩展且可重现的输入管道。”

  • Colab 笔记本:

    • 视觉变换器和 MLP Mixers 的 JAX 代码

    • 用于生成“如何训练 ViT?”数据超过 5 万个视觉变换器和混合检查点

  • 研究论文

    • 图像价值 16x16 个单词:用于大规模图像识别的变换器(Dosovitskiy 等人,2020 年)

      • “在视觉方面,注意力要么与卷积网络结合使用,要么用于替换卷积网络的某些组件,同时保持其整体结构不变。我们表明,对 CNN 的这种依赖不是必要的,直接应用于图像补丁序列的纯变换器可以在图像分类任务中表现出色。当在大量数据上进行预训练并转移到多个中小型图像识别基准(ImageNet、CIFAR-100、VTAB 等)时,视觉变换器 (ViT) 与最先进的卷积网络相比,取得了出色的结果,同时训练所需的计算资源也少得多。”

    • MLP-Mixer:用于视觉的全 MLP 架构(Tolstikhin 等人,2021 年)

      • “在本文中,我们表明,虽然卷积和注意力都足以获得良好的性能,但它们都不是必需的。我们提出了 MLP-Mixer,一种完全基于多层感知器 (MLP) 的架构。MLP-Mixer 包含两种类型的层:一种是对图像补丁独立应用 MLP 的层(即“混合”每个位置的特征),另一种是对补丁应用 MLP 的层(即“混合”空间信息)。当在大型数据集上训练或使用现代正则化方案时,MLP-Mixer 在图像分类基准上获得了具有竞争力的分数,预训练和推理成本与最先进的模型相当。”

    • 如何训练 ViT?视觉变换器中的数据、增强和正则化(Steiner 等人,2021 年)

      • “视觉变换器 (ViT) 已被证明在广泛的视觉应用(例如图像分类、对象检测和语义图像分割)中获得了极具竞争力的性能。与卷积神经网络相比,人们通常发现,当在较小的训练数据集上进行训练时,视觉变换器的较弱归纳偏置会导致更加依赖模型正则化或数据增强(简称“AugReg”)。我们进行了一项系统的实证研究,以便更好地了解训练数据量、AugReg、模型大小和计算预算之间的相互作用。”

    • 当视觉变换器在没有预训练或强大的数据增强的情况下优于 ResNet 时(X. Chen 等人,2021 年)

      • “视觉变换器 (ViT) 和 MLP 表明了进一步努力用通用神经网络架构取代手动连接的特征或归纳偏置。现有工作通过海量数据(例如大规模预训练和/或重复强数据增强)来增强模型,并且仍然报告与优化相关的问题(例如,对初始化和学习率的敏感性)。因此,本文从损失几何的角度研究 ViT 和 MLP-Mixer,旨在提高模型在训练中的数据效率和在推理中的泛化能力。”

    • LiT:使用锁定图像文本微调的零样本传输(X. Zhai 等人,2021 年)

      • “本文提出了一种对比微调方法,该方法采用对比训练来对齐图像和文本模型,同时仍然利用它们的预训练优势。在我们的实证研究中,我们发现锁定预训练的图像模型,而解锁文本模型的效果最佳。我们将这种对比微调的实例称为“锁定图像微调”(LiT),它只是教文本模型从预训练的图像模型中读取出适用于新任务的良好表示。LiT模型获得了零样本迁移到新的视觉任务的能力,例如图像分类或检索。所提出的LiT具有广泛的适用性;它可以使用多种预训练方法(监督和无监督),并使用三个不同的图像-文本数据集跨多种架构(ResNet、Vision Transformers 和 MLP-Mixer)可靠地工作。”

使用稀疏混合专家 (MoE) 扩展视觉模型#

  • GitHub 上的代码

  • 研究论文

    • 使用稀疏混合专家扩展视觉模型 (Riquelme et al., 2021)

      • “稀疏门控混合专家网络 (MoE) 在自然语言处理中展示了出色的可扩展性。然而,在计算机视觉中,几乎所有高性能网络都是“密集的”,也就是说,每个输入都由每个参数处理。我们提出了一种 Vision MoE (V-MoE),它是 Vision Transformer 的稀疏版本,具有可扩展性,并且可以与最大的密集网络竞争……我们展示了 V-MoE 扩展视觉模型的潜力,并训练了一个 150 亿参数的模型,在 ImageNet 上达到了 90.35% 的准确率……”

扩散#

变分扩散模型#

  • GitHub 上的代码

  • Colab 笔记本

  • 研究论文

    • 变分扩散模型 (Kingma et al., 2021)

      • “基于扩散的生成模型已经展示了令人印象深刻的感知合成能力,但它们也能成为优秀的基于似然的模型吗?我们对此给出了肯定的答案,并引入了一系列基于扩散的生成模型,这些模型在标准图像密度估计基准上获得了最先进的似然度。与其他基于扩散的模型不同,我们的方法允许在模型的其余部分中有效地优化噪声计划。我们表明,变分下界 (VLB) 可以简化为扩散数据的信噪比的非常短的表达式,从而提高我们对此模型类的理论理解。利用这种洞察力,我们证明了文献中提出的几种模型之间的等价性。此外,我们表明连续时间 VLB 对噪声计划是不变的,除了其端点的信噪比。这使我们能够学习一种噪声计划,该计划可以最小化生成的 VLB 估计器的方差,从而加快优化速度……”

领域自适应#

GIFT (特征向目标方向的逐步插值)#

  • GitHub 上的代码

  • 研究论文

    • 野外场景中的逐步领域自适应:当缺少中间分布时 (Abnar et al., 2021)

      • “我们专注于领域自适应问题,当目标是将模型转移到目标分布,而不是学习领域不变的表示时。已经表明,在以下两个假设下:(a)可以访问来自中间分布的样本,并且 (b) 样本被标记了与源分布的变化量,自训练可以成功应用于逐步转移的样本,以使模型适应目标分布。我们假设拥有 (a) 就足以使迭代自训练能够通过利用隐式课程缓慢地使模型适应目标分布。在 (a) 不成立的情况下,我们观察到迭代自训练效果不佳。我们提出了 GIFT,一种通过插值来自源域和目标域的示例表示来从中间分布创建虚拟样本的方法……”

泛化#

替代差距最小化改进了锐度感知训练#

  • GitHub 上的代码

  • 研究论文

    • 替代差距最小化改进了锐度感知训练 (J. Zhuang et al., 2022)

      • “最近提出的锐度感知最小化 (SAM) 通过最小化在参数空间邻域内的最大损失来改进泛化。然而,我们表明,尖锐最小值和平坦最小值都可以具有较低的扰动损失,这意味着 SAM 并不总是喜欢平坦最小值。相反,我们定义了一个替代差距,当邻域半径(导出扰动损失)较小时,它等效于局部最小值处 Hessian 的主特征值。替代差距易于计算,并且在训练期间可以直接最小化。基于上述观察,我们提出了替代差距引导的锐度感知最小化 (GSAM),这是对 SAM 的一种新颖改进,计算开销可忽略不计……”

元学习#

learned_optimization#

  • GitHub 上的代码:learned_optimization

  • Colab 笔记本

  • 研究论文

    • 具有持久进化策略的展开计算图中的无偏梯度估计 (Vicol et al., 2021)

      • “我们引入了一种称为持久进化策略 (PES) 的方法,该方法将计算图划分为一系列截断的展开,并在每次展开后执行基于进化策略的更新步骤。PES 通过在整个展开序列中累积校正项来消除这些截断的偏差。PES 允许快速参数更新,内存使用率低,无偏且具有合理的方差特征。”

    • 梯度并非您所需的一切 (Metz et al., 2021)

      • “……在这份简短的报告中,我们讨论了一种常见的基于混沌的失效模式,该模式出现在各种可微的情况下,从循环神经网络和数值物理模拟到训练学习的优化器。我们将这种失效追溯到研究系统雅可比矩阵的谱,并为从业者何时可能期望这种失效破坏他们基于微分的优化算法提供标准。”

模型效率#

高效扩展 Transformer 推理#

  • GitHub 上的代码

  • 研究论文

    • 高效扩展 Transformer 推理 (Pope et al., 2022)

      • “我们开发了一个简单的推理效率分析模型,以根据应用程序需求选择针对 TPU v4 切片优化的最佳多维分区技术。我们将这些与一套低级优化相结合,以在 500B+ 参数模型的延迟和模型 FLOP 利用率 (MFU) 折衷方面实现新的帕累托前沿,其性能优于 FasterTransformer 基准测试套件。我们进一步表明,通过适当的分区,多查询注意力(即多个查询头共享单个键/值头)的较低内存要求可以将上下文长度扩展到 32 倍。”

神经渲染 / NeRF#

可泛化的基于补丁的神经渲染#

  • GitHub 上的代码

  • 研究论文

    • 可泛化的基于补丁的神经渲染 (Suhail et al., 2022)

      • “……我们提出了一种不同的范例,其中不需要深度特征,也不需要类似 NeRF 的体积渲染。我们的方法能够直接从场景中采样的补丁集合中预测新场景中目标光线的颜色。”

JAX 和 Flax 中基于体素的辐射场#

  • Colab 笔记本 (Velez and Dellaert, 2022)

    • “在此笔记本中,我们展示了使用 JAX/Flax,相对容易快速启动并运行基于体素的 NeRF 变体。具体来说,我们将开发一个简化的 DVGO 版本,该版本直接回归颜色而不是使用小型 MLP。它的效果非常好。”

优化#

Amos 优化器 JEstimator#

  • GitHub 上的代码

    • Amos 和 JEstimator

      • “……实现了 Amos,一个与 optax 库兼容的优化器,以及 JEstimator,一个轻量级库,它具有类似 tf.Estimator 的接口,用于管理 JAX 中机器学习程序的 T5X 兼容检查点,我们用它来运行论文中的实验。”

  • 研究论文

    • Amos:一种具有自适应权重衰减的 Adam 式优化器,面向模型规模 (Tian and Parikh, 2022)

      • 提出“Amos,一个与 optax 库兼容的优化器,以及 JEstimator,一个轻量级库,它具有类似 tf.Estimator 的接口,用于管理 JAX 中机器学习程序的 T5X 兼容检查点。” “当用于预训练 BERT 变体和 T5 时,Amos 始终比 AdamW 的最先进设置更快地收敛,在 <=70% 的训练步骤和时间内获得更好的验证损失,同时槽变量需要的内存 <=51%。”

量化#

帕累托最优量化 ResNet 大多是 4 位AQT:准确量化训练#

  • GitHub 上的代码

  • 研究论文

    • 帕累托最优量化 ResNet 大多是 4 位 (Abdolrashidi et al., 2021)

      • “在这项工作中,我们使用 ResNet 作为案例研究,以系统地研究量化对推理计算成本-质量权衡曲线的影响。我们的结果表明,对于每个 bfloat16 ResNet 模型,都存在成本更低且准确率更高的量化模型;换句话说,bfloat16 计算成本-质量权衡曲线由 4 位和 8 位曲线帕累托主导,其中主要量化为 4 位的模型产生最佳帕累托曲线……我们使用的量化方法针对实用性进行了优化:它几乎不需要调整,并且在设计时考虑了硬件功能……作为这项工作的一部分,我们贡献了一个用 JAX 编写的量化库……”

强化学习#

通过演示进行动作量化的连续控制 (AQuaDem)#

  • GitHub 上的代码

  • 研究论文

    • 通过演示进行动作量化的连续控制 (Dadashi et al., 2021)

      • 提出了一种新的强化学习 (RL) 框架,用于解决具有连续动作空间的问题:基于演示的动作量化 (AQuaDem)。所提出的方法包括从人类演示中学习连续动作空间的离散化。这种离散化为每个输入状态返回一组合理的动作(根据演示),从而捕获演示者的先验知识及其多模态行为。通过离散化动作空间,任何离散动作深度 RL 技术都可以很容易地应用于连续控制问题。实验表明,所提出的方法在 RL 设置中优于 SAC 等最先进的方法,在模仿学习设置中优于 GAIL。

序列模型 / 模型并行#

T5X:使用 t5xseqio 扩展模型和数据#

  • GitHub 上的代码

    • “T5X 是一个模块化、可组合、研究友好的框架,用于以多种规模对序列模型(从语言开始)进行高性能、可配置、自助式的训练、评估和推理。”

  • 研究论文

    • T5X:使用 t5x 和 seqio 扩展模型和数据 (Roberts 等人,2022)

      • “最近基于神经网络的语言模型从扩大训练数据集的大小和模型本身的参数数量中获益匪浅。扩展可能很复杂,原因包括需要在超级计算机集群(例如,TPU)上分配计算、防止数据输入时的瓶颈以及确保结果的可重复性。在这项工作中,我们提出了两个简化这些问题的软件库:t5x 简化了大规模构建和训练大型语言模型的过程,同时保持了易用性,而 seqio 提供了一个基于任务的 API,用于简单地创建快速且可重现的训练数据和评估管道。这些开源库已用于在具有多个 TB 训练数据的数据集上训练具有数千亿个参数的模型。与这些库一起,我们发布了 T5 类编码器-解码器模型以及 GPT 类仅解码器架构的配置和说明。”

模拟#

Brax - 用于大规模刚体模拟的可微分物理引擎#

  • GitHub 上的代码

  • Colab 笔记本

  • 研究论文

    • Brax - 用于大规模刚体模拟的可微分物理引擎 (Freeman 等人,2021)

      • “我们介绍了 Brax,这是一个开源库,用于刚体模拟,重点是加速器上的性能和并行性,用 JAX 编写。我们展示了一套受现有强化学习文献启发的任务的结果,但在我们的引擎中重新制作。此外,我们提供了在 JAX 中重新实现的 PPO、SAC、ES 和直接策略优化,它们与我们的环境一起编译,允许学习算法和环境处理在同一设备上进行,并在加速器上无缝扩展。”