Neural Networks in pure JAX (with automatic differentiation)

preview_player
Показать описание

-------

-------

-------

Timestamps:
00:00 Intro
01:18 Dataset that somehow looks like a sine function
01:56 Forward pass of the Multilayer Perceptron
03:22 Weight initialization due to Xavier Glorot
04:20 Idea of "Learning" as approximate optimization
04:49 Reverse-mode autodiff requires us to only write the forward pass
05:34 Imports
05:52 Constants and Hyperparameters
06:19 Producing the random toy dataset
08:33 Draw initial parameter guesses
12:05 Implementing the forward/primal pass
13:58 Implementing the loss metric
14:57 Transform forward pass to get gradients by autodiff
20:03 Training loop (using plain gradient descent)
23:21 Improving training speed by JIT compilation
24:25 Plotting loss history
24:47 Plotting final network prediction & Discussion
25:44 Summary
26:59 Outro
Рекомендации по теме
Комментарии
Автор

Thanks a bunch.
I can never find enough JAX tuts lol
Even large language models like ChatGPT have outdated info a lot of the time due to their current knowledge cutoff dates.
It's just this and the docs for me hehe...

DefinitelyNotAMachineCultist
Автор

Thanks a lot for the nice video! I have a naive question. The layers of the neural network seems similar to a polynomial. How is a neural network better than a polynomial fit?

minhphan