Machine Learning with JAX - From Hero to HeroPro+ | Tutorial #2

preview_player
Показать описание
❤️ Become The AI Epiphany Patreon ❤️

👨‍👩‍👧‍👦 Join our Discord community 👨‍👩‍👧‍👦

This is the second video in the JAX series of tutorials.

JAX is a powerful and increasingly more popular ML library built by the Google Research team. The 2 most popular deep learning frameworks built on top of JAX are Haiku (DeepMInd) and Flax (Google Research).

In this video, we continue on and learn additional components needed to train complex ML models (such as NNs) on multiple machines!

▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬

▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬

⌚️ Timetable:
00:00:00 My get started with JAX repo
00:01:25 Stateful to stateless conversion
00:11:00 PyTrees in depth
00:17:45 Training an MLP in pure JAX
00:27:30 Custom PyTrees
00:32:55 Parallelism in JAX (TPUs example)
00:40:05 Communication between devices
00:46:05 value_and_grad and has_aux
00:48:45 Training an ML model on multiple machines
00:58:50 stop grad, per example grads
01:06:45 Implementing MAML in 3 lines
01:08:35 Outro

▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬
💰 BECOME A PATREON OF THE AI EPIPHANY ❤️

If these videos, GitHub projects, and blogs help you,
consider helping me out by supporting me on Patreon!

Huge thank you to these AI Epiphany patreons:
Eli Mahler
Petar Veličković
Bartłomiej Danek
Zvonimir Sabljic

▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬

▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬

#jax #machinelearning #framework
Рекомендации по теме
Комментарии
Автор

Next video we go from HeroPro+ to ultraHeroProUltimateMaster+

akashraut
Автор

Hey everyone!
Just wanted to share that @16:44, jax.tree_multimap has been deprecated in their updates.
In their new versions, you can do both single and multi-argument function operations using jax.tree.map()

Therefore, in the below format we can see the approach to perform both single and multiple args functions

Single Args -> jax.tree.map(lambda x: x**2, list_of_trees)
Multi Args -> jax.tree.map(lambda x, y: x + y, list_of_trees, list_of_trees)

Hope it helps!

SamarthSharma-de
Автор

These videos are really great and helpful. Right to the point, no wasting of time. Thanks!!

sprmsv
Автор

Odlican video, vise sam na pocetku da uradim one taskove sto si preporucio za ML pocetnike za MNIST, ali jedan od omiljenih kanal za ML definitivno :)

objectobjectobject
Автор

Thanks for this JAX series ❤️ I am planning to Implement a CV research paper in JAX and FLAX. It will be of great help thanks 👍

tauhidkhan
Автор

[Combination of Gradient Loss Across Devices]

Hello guys,

Firstly, thank you so much for the amazing tutorials ​ @TheAIEpiphany !

Secondly, I'd like to clarify the mathematics behind the combination of gradients of loss across multiple devices @55:34... The question arises: Is it correct to compute gradient as the average of gradients from different devices? I mean, will it give the same gradient as if we were only doing it on one device ?

The answer is YES it is correct, but only if the Loss is defined as a weighted sum across the samples. This is supported by the fact that the gradient of a weighted sum is equivalent to the weighted sum of gradients.

Thus, in this context, the Loss is a mean across samples (or batches), making it a weighted sum. The same principle would also be applicable for the cross-entropy Loss.

Additionally, the batches size across the devices should be the same. Otherwise it would not be a mean, but instead a weighted sum (with the weights of each device equal the normalised batch size allocated to this device).

Hope my comment is clear and will demystify some questions that one would have wondered :)

PS : For the one that would not have understood my comment, the conclusion is : "it is good to do as ​ @TheAIEpiphany is doing" (because we are dealing with MSE/Cross-Entropy and batch size across devices is the same)

jimmyweber
Автор

Both useful knowledge and memes in a single video!

nickkonovalchuk
Автор

Great series of tutorials, congrats ! It would be nice to see a comparative in terms of performance between jax and pytorch for some real-world use case (gpu and tpu) :)

juansensio
Автор

Great videos, Aleksa! I found the name of x and y arguments in the MLP forward function confusing, since they are really batches of xs and ys. You could used vmap there instead of writing in already batched form, but I guess it's a good exercise for your viewers to rewrite it in unbatched form and apply vmap :)

mathmo
Автор

Also, would there be any tutorial on JAXOpt? It would be highly appreciated! thanks for your videos

lidiias
Автор


Just click the colab button and you're ready to play with the code yourself.

TheAIEpiphany
Автор

You are very nice! Thank you for your video :D

jonathansum
Автор

you are an inspiration sir! Thanks for these videos!

eriknorlander
Автор

So the parallelism you demo'd with pmap, that was data parallelism correct? replicating the whole model across all the devices, sending different batches to each device, and then collecting the mean model back on the host device after forward and backwards pass? am i understanding that correctly?

promethesured
Автор

jax.tree_multimap was deprecated in JAX version 0.3.5, and removed in JAX version 0.3.16. What can we use to replace this function in "Training an MLP in pure JAX" part of the video?

lidiias
Автор

In the middle of the notebook, I saw the comment "# notice how we do jit only at the highest level - XLA will have plenty of space to optimize". Do you have a reference on when to jit only at the highest level and when to jit single nested functions and what the advantages/risks of each approach are? I used to jit every single function until now, so I'm curious what I can gain by a single high-level jit.

arturtoshev
Автор

Nice video. Do you think the research community will embrace JAX as they did with PyTorch?

santiagogomezhernandez
Автор

4:29 why it’s called key and subkey, but not subkey and subsubkey? Aren’t the two keys on the same level of the descendant tree?

cheese-power
Автор

8:53 why do we want to return the state if it doesn't change ? (in general I mean) Is it just a good practice and so you don't need to think about the fact that it may or may not change and always return it ?

thomashirtz
Автор

But how can you train model in Jax? If you set up everything from scratch then I think it is not very useful, I am sure pytorch/TF are not so much behind in terms of speed.

jawadmansoor