Flax Linen#
使用 JAX 的神经网络
Flax Linen 为使用 JAX 进行神经网络研究的研究人员提供了端到端且灵活的用户体验。Flax 公开了 JAX 的全部功能。它由松散耦合的库组成,这些库通过端到端集成的指南和示例展示。
Flax Linen 被 数百个项目(并且还在增长) 使用,包括开源社区(如 Hugging Face)和 Google(如 Gemini、Imagen、Scenic 和 Big Vision)。
功能特性#
安全性
Flax 的设计考虑了正确性和安全性。由于其不可变的模块和函数式 API,Flax 有助于减轻在 JAX 中处理状态时出现的错误。
控制
Flax 通过其变量集合、RNG 集合和可变性条件,比大多数神经网络框架提供了更细粒度的控制和表达能力。
函数式 API
Flax 的函数式 API 通过 vmap、scan 等提升的变换,从根本上重新定义了模块的功能,同时也实现了与其他 JAX 库(如 Optax 和 Chex)的无缝集成。
简洁的代码
Flax 的 compact
模块允许直接在其调用点定义子模块,从而使代码更易于阅读并避免重复。
安装#
pip install flax
# or to install the latest version of Flax:
pip install --upgrade git+https://github.com/google/flax.git
Flax 安装 JAX 的 vanilla CPU 版本,如果您需要自定义版本,请查看 JAX 的安装页面。
基本用法#
class MLP(nn.Module): # create a Flax Module dataclass
out_dims: int
@nn.compact
def __call__(self, x):
x = x.reshape((x.shape[0], -1))
x = nn.Dense(128)(x) # create inline Flax Module submodules
x = nn.relu(x)
x = nn.Dense(self.out_dims)(x) # shape inference
return x
model = MLP(out_dims=10) # instantiate the MLP model
x = jnp.empty((4, 28, 28, 1)) # generate random data
variables = model.init(random.key(42), x)# initialize the weights
y = model.apply(variables, x) # make forward pass
了解更多#
快速入门
指南
示例
词汇表
开发者笔记
Flax 理念
API 参考
生态系统#
Flax 中的著名示例包括
自然语言处理和计算机视觉模型
用于文本到图像生成的模型
用于文本生成的 5400 亿参数模型
文本到图像扩散模型
用于大规模计算机视觉的库
大规模计算机视觉模型
开源高性能 LLM
大型语言模型
设备上的可微强化学习环境