AI Dev 25 x NYC | Robert Crowe: JAX Made Simple: An Intuitive Guide to Building Fast Neural Networks

AI Dev 25 x NYC | Robert Crowe: JAX Made Simple: An Intuitive Guide to Building Fast Neural Networks

More

Descriptions:

Robert Crowe, a product manager at Google with extensive AI experience, delivered a fast-paced technical overview of JAX at AI Dev 25 NYC, aimed at ML practitioners who already train models and want to understand where the framework fits in the modern deep learning stack. Crowe traces JAX’s origins to Google’s need for something more flexible and modular than TensorFlow—a system built for both massive-scale production training and rapid research iteration, now also used across scientific computing domains including bioinformatics and genomics.

The technical core of the talk covers JAX’s composable function transformations: `jit` for just-in-time compilation, `grad` for automatic differentiation, `vmap` for vectorization, and `shard_map` for fine-grained sharding, all compiled by the XLA layer into optimized machine code for GPUs and TPUs. Crowe then walks through the three main distributed training strategies—data parallelism (DDP) for models that fit in a single accelerator’s memory, fully sharded data parallelism (FSDP) for models that don’t, and tensor parallelism for splitting individual layers when memory constraints are extreme. JAX’s SPMD paradigm abstracts the hardware topology, making multi-device programs look like single-device programs.

A November 2023 scaling study showing near-ideal throughput efficiency to over 50,000 TPUs provides concrete evidence of JAX’s production readiness. Crowe closes with an overview of the broader ecosystem—Flax for neural network building, Optax for optimization, and other libraries—along with resource links for viewers who want to go deeper on any covered topic.


📺 Source: DeepLearningAI · Published December 05, 2025
🏷️ Format: Deep Dive

1 Item

Channels

1 Item

Companies