Compiling machine learning programs via high-level tracing

aminer-ugc

引用 269|浏览273
暂无评分
摘要
We describe JAX, a domain-specific tracing JIT compiler for generating high-performance accelerator code from pure Python and Numpy machine learning programs. JAX uses the XLA compiler infrastructure to generate optimized code for the program subroutines that are most favorable for acceleration, and these optimized subroutines can be called and orchestrated by arbitrary Python.Because the system is fully compatible with Autograd, it allowsforward- and reverse-mode automatic differentiation of Pythonfunctions to arbitrary order. Because JAX supports structured control flow, it can generate code for sophisticated machine learning algorithms while maintaining high performance. We show that by combining JAX with Autograd and Numpy we get an easily programmable and highly performant ML system that targets CPUs,GPUs, and TPUs, capable of scaling to multi-core Cloud TPUs.
更多
查看译文
AI 理解论文
溯源树
样例
生成溯源树,研究论文发展脉络
Chat Paper
正在生成论文摘要