从这张GitHubStar趋势图可以看到,自2019年JAX出现到如今保持着一个向上的抛物线走势。
回归并理清这些历史问题有助于开发者了解机器学习的演变逻辑,并了解JAX是如何吸取之前的教训,帮助开发者更方便地实践深度学习或机器学习应用。
在TF引进了Eager模式之后,它会采用更直观的界面,使用自然的Python代码和数据结构,而且享受更加便携的调试,在Eager模式中可以通过直接调用操作来检查和测试模型,而之前Graph这种模式有点类似于C和C++,它的编程是写好程序之后要先进行编译才能运行。
Eager模式有自然控制的流程,使用Python而不是图控制流,以及支持GPU和TPU的加速。做为开发者,我们希望可以客观地看待不同的框架,而不是比较他们的优劣。值得思考的一个问题是:通过了解TF的Eager模式对于Graph模式的改进,它的改进逻辑和思路在JAX中都有身影。
JAX作为现在越来越流行的库,是一种类似于NumPy(使用Python开源的数值计算扩展库)的轻量级用于阵列的计算。JAX最开始的设计不仅仅是为了深度学习而设计的,深度学习只是它的一小部分,它提供了编写NumPy程序的能力,这些程序可以使用GPU/TPU自动拆分和加速。
JAX用于基于阵列的计算时,开发者无需修改代码就可以在CPU/GPU/ASIC上同时运行,并支持原生Python和NumPy函数的四种可组合函数转换:
我们可以通过下面这个简单的测试对比JAX和NumPy的计算性能。
输入一个100X100的二维数组X,选取ml.g4dn.12xlarge计算实例通过NumPy和JAX分别对矩阵的前三次幂求和:
deffn(x):returnx+x*x+x*x*xx=np.random.randn(10000,10000).astype(dtype='float32')%timeit-n5fn(x)436ms±206μsperloop(mean±std.dev.of7runs,5loopseach)我们发现此计算大约需要436毫秒。接下来,我们使用JAX实现以下计算:
jax_fn=jit(fn)x=jnp.array(x)%timeitjax_fn(x).block_until_ready()3.67ms±10.7μsperloop(mean±std.dev.of7runs,1loopeach)JAX仅在3.67毫秒内执行此计算,比NumPy快118倍以上。可见,JAX有可能比NumPy快几个数量级(注意,JAX使用TPU而NumPy正在使用CPU)。
*以上为个人测试结果,非官方提供的数据,仅供研究参考
对比测试结果可得,NumPy完成计算需要436毫秒,而JAX仅需要3.67毫秒,计算速度相差100多倍。这个测试也说明了为什么很多开发者对它的性能赞不绝口。
我们希望通过回答这个问题来解读JAX的动机:
如何使用Python从头开始实现高性能和可扩展的深度神经网络?
通常,Python程序员会从NumPy之类的东西开始,因为它是一种熟悉的、基于数组的数据处理语言,在Python社区中已经使用了几十年。如果你想在NumPy中创建深度学习系统,你可以从预测方法开始。
这里可以用一个详细的例子说明问题,从NumPy上的深度学习的场景说起:
上述代码展示了订阅一个前馈的神经网络,它执行了一系列的点积和激活函数,然后将输入转化为某种可以学习的输出。一旦定义了这样的一个模型,接下来需要做就是要定义损失函数,这个函数将为你提供正在尝试优化的那些指标,来适应最佳的机器学习模型。例如以上代码的损失函数是以均方误差损失函数MSE为例。
现在我们来分析下:在深度学习场景使用NumPy还缺少什么?
硬件加速(GPU/TPU)
自动微分(autodiff)快速优化
添加编译(Compilation)融合操作
向量化操作批处理(batching)
大型数据集并行化(Parallelization)
1)硬件加速(GPU/TPU):首先深度学习需要大量的计算,我们想在加速的硬件上运行它。所以我们想在GPU和TPU/ASIC上运行这个模型,这对于经典的NumPy来说有点困难;
2)自动微分(autodiff)快速优化:接下来我们想要做自动微分,这样就可以有效地拟合这个损失函数,而不必自己来实现数值微分;
3)然后我们需要添加编译(Compilation):这样你就可以将这些操作融合在一起,使它们更加高效;
4)向量化操作批处理(Batching):另外,当我们编写了某些函数后,可能希望将其应用于多个数据片段,而不再需要重写预测和损失函数来处理这些批量数据;
5)大型数据集并行化(Parallelization):最后,如果我们正在处理大型数据集,会希望能够支持跨多个cores或多台machines做并行化操作。
JAX非常重要的一个动机就是XLA和自动定位。让我们来看看JAX可以做些什么,来填补前面分析的在深度学习场景使用NumPy还缺少的功能。
首先,用jax.numpy替换numpy导入模块。在许多情况下,jax.numpy与经典的NumPy具有相同的API,但jax.numpy可以完成前面分析时发现NumPy缺少,但是在深度学习场景却非常需要的的东西。
JAX可以通过XLA后端,来自动定位CPU、GPU和TPU或者ASIC,以便快速计算模型和算法。
第二个重要动机是Autograd。开发者可以通过下面的代码调用Autograd版本:
通过fromjaximportgrad模块,使用Autograd的更新版本,JAX可以自动微分原生Python和NumPy函数。它可以处理Python功能的大子集,包括循环、Ifs、递归等,甚至可以接受导数的导数。
JAX提供了一组可组合的变换,其中之一是grad变换。
例子中,像mse_loss这样的损失函数,通过grad(mse_loss)将其转换为计算梯度的Python函数。
Autograd的主要预期应用是基于梯度的优化。
在使用梯度函数时,开发者希望将其应用于多个数据片段,而在JAX中,你不再需要重写预测和损失函数来处理这些批量数据。
如图中代码最后一行(perexample_grads…)所诠释的那样,如果你通过vmaptransform传递它,这会自动向量化这个代码,这样就可以在多个批次中使用相同的代码。
JAX还有一个重要的组合函数——jit,开发者可以使用jittransform实现即时编译。
jit结合后台可以使用XLA后端编译器将操作融合在一起,来自动定位CPU、GPU和TPU或者ASIC,加速计算模型和算法。
最后,如果想并行化你的代码,有一个和vmap非常相似得转换叫pmap。
通过代码运行pmap,开发者能够本地定位系统中的多个内核或你有权访问的GPU、TPU或ASIC集群。
这最终成为一个非常强大的系统,可以在没有太多额外代码的情况下构建我们用类似于NumPy的熟悉API,做深度学习的快速计算等工作负载。
JAX的关键设计思想
通过上述对比可以看到,JAX不仅为开发者提供了和NumPy相似的API,上述的五大函数转换组合也让JAX可以在不需要额外代码的情况下,帮助开发者构建深度学习应用进行快速计算。
这里的关键思想是:
1)首先,在JAX中,Python代码被追溯到中间表示,JAX知道如何转换这个中间表示。
3)另外,JAX还有基于NumPy和SciPy的面向用户的API,如果开发者一直使用Python的技术栈,应该会对JAX感觉相当熟悉;
4)最后,JAX提供了功能强大的变换:grad,git,vmap,pmap等,来支持深度学习等计算,因此JAX可以做到之前NumPy代码无法做到的事情。