MLX Guru - files

This commit is contained in:
Elias Bachaalany 2024-01-26 10:03:11 -08:00
parent bd0099ad1f
commit f96f19d404
8 changed files with 1579 additions and 0 deletions

View file

@ -0,0 +1,587 @@
.. _array:
Array
=====
.. currentmodule:: mlx.core
.. autosummary::
:toctree: _autosummary
array
array.astype
array.item
array.tolist
array.dtype
array.ndim
array.shape
array.size
Dtype
array.abs
array.all
array.any
array.argmax
array.argmin
array.cos
array.dtype
array.exp
array.log
array.log1p
array.logsumexp
array.max
array.mean
array.min
array.prod
array.reciprocal
array.reshape
array.round
array.rsqrt
array.sin
array.split
array.sqrt
array.square
array.sum
array.transpose
array.T
array.var
.. _data_types:
:orphan:
Data Types
==========
.. currentmodule:: mlx.core
The default floating point type is ``float32`` and the default integer type is
``int32``. The table below shows supported values for :obj:`Dtype`.
.. list-table:: Supported Data Types
:widths: 5 3 20
:header-rows: 1
* - Type
- Bytes
- Description
* - ``bool_``
- 1
- Boolean (``True``, ``False``) data type
* - ``uint8``
- 1
- 8-bit unsigned integer
* - ``uint16``
- 2
- 16-bit unsigned integer
* - ``uint32``
- 4
- 32-bit unsigned integer
* - ``uint64``
- 8
- 64-bit unsigned integer
* - ``int8``
- 1
- 8-bit signed integer
* - ``int16``
- 2
- 16-bit signed integer
* - ``int32``
- 4
- 32-bit signed integer
* - ``int64``
- 8
- 64-bit signed integer
* - ``float16``
- 2
- 16-bit float, only available with `ARM C language extensions <https://developer.arm.com/documentation/101028/0012/3--C-language-extensions?lang=en>`_
* - ``float32``
- 4
- 32-bit float
.. _devices_and_streams:
Devices and Streams
===================
.. currentmodule:: mlx.core
.. autosummary::
:toctree: _autosummary
Device
default_device
set_default_device
Stream
default_stream
new_stream
set_default_stream
.. _fft:
FFT
===
.. currentmodule:: mlx.core.fft
.. autosummary::
:toctree: _autosummary
fft
ifft
fft2
ifft2
fftn
ifftn
rfft
irfft
rfft2
irfft2
rfftn
irfftn
.. _linalg:
Linear Algebra
==============
.. currentmodule:: mlx.core.linalg
.. autosummary::
:toctree: _autosummary
norm
.. _nn:
.. currentmodule:: mlx.nn
Neural Networks
===============
Writing arbitrarily complex neural networks in MLX can be done using only
:class:`mlx.core.array` and :meth:`mlx.core.value_and_grad`. However, this requires the
user to write again and again the same simple neural network operations as well
as handle all the parameter state and initialization manually and explicitly.
The module :mod:`mlx.nn` solves this problem by providing an intuitive way of
composing neural network layers, initializing their parameters, freezing them
for finetuning and more.
Quick Start with Neural Networks
---------------------------------
.. code-block:: python
import mlx.core as mx
import mlx.nn as nn
class MLP(nn.Module):
def __init__(self, in_dims: int, out_dims: int):
super().__init__()
self.layers = [
nn.Linear(in_dims, 128),
nn.Linear(128, 128),
nn.Linear(128, out_dims),
]
def __call__(self, x):
for i, l in enumerate(self.layers):
x = mx.maximum(x, 0) if i > 0 else x
x = l(x)
return x
# The model is created with all its parameters but nothing is initialized
# yet because MLX is lazily evaluated
mlp = MLP(2, 10)
# We can access its parameters by calling mlp.parameters()
params = mlp.parameters()
print(params["layers"][0]["weight"].shape)
# Printing a parameter will cause it to be evaluated and thus initialized
print(params["layers"][0])
# We can also force evaluate all parameters to initialize the model
mx.eval(mlp.parameters())
# A simple loss function.
# NOTE: It doesn't matter how it uses the mlp model. It currently captures
# it from the local scope. It could be a positional argument or a
# keyword argument.
def l2_loss(x, y):
y_hat = mlp(x)
return (y_hat - y).square().mean()
# Calling `nn.value_and_grad` instead of `mx.value_and_grad` returns the
# gradient with respect to `mlp.trainable_parameters()`
loss_and_grad = nn.value_and_grad(mlp, l2_loss)
.. _module_class:
The Module Class
----------------
The workhorse of any neural network library is the :class:`Module` class. In
MLX the :class:`Module` class is a container of :class:`mlx.core.array` or
:class:`Module` instances. Its main function is to provide a way to
recursively **access** and **update** its parameters and those of its
submodules.
Parameters
^^^^^^^^^^
A parameter of a module is any public member of type :class:`mlx.core.array` (its
name should not start with ``_``). It can be arbitrarily nested in other
:class:`Module` instances or lists and dictionaries.
:meth:`Module.parameters` can be used to extract a nested dictionary with all
the parameters of a module and its submodules.
A :class:`Module` can also keep track of "frozen" parameters. See the
:meth:`Module.freeze` method for more details. :meth:`mlx.nn.value_and_grad`
the gradients returned will be with respect to these trainable parameters.
Updating the Parameters
^^^^^^^^^^^^^^^^^^^^^^^
MLX modules allow accessing and updating individual parameters. However, most
times we need to update large subsets of a module's parameters. This action is
performed by :meth:`Module.update`.
Inspecting Modules
^^^^^^^^^^^^^^^^^^
The simplest way to see the model architecture is to print it. Following along with
the above example, you can print the ``MLP`` with:
.. code-block:: python
print(mlp)
This will display:
.. code-block:: shell
MLP(
(layers.0): Linear(input_dims=2, output_dims=128, bias=True)
(layers.1): Linear(input_dims=128, output_dims=128, bias=True)
(layers.2): Linear(input_dims=128, output_dims=10, bias=True)
)
To get more detailed information on the arrays in a :class:`Module` you can use
:func:`mlx.utils.tree_map` on the parameters. For example, to see the shapes of
all the parameters in a :class:`Module` do:
.. code-block:: python
from mlx.utils import tree_map
shapes = tree_map(lambda p: p.shape, mlp.parameters())
As another example, you can count the number of parameters in a :class:`Module`
with:
.. code-block:: python
from mlx.utils import tree_flatten
num_params = sum(v.size for _, v in tree_flatten(mlp.parameters()))
Value and Grad
--------------
Using a :class:`Module` does not preclude using MLX's high order function
transformations (:meth:`mlx.core.value_and_grad`, :meth:`mlx.core.grad`, etc.). However,
these function transformations assume pure functions, namely the parameters
should be passed as an argument to the function being transformed.
There is an easy pattern to achieve that with MLX modules
.. code-block:: python
model = ...
def f(params, other_inputs):
model.update(params) # <---- Necessary to make the model use the passed parameters
return model(other_inputs)
f(model.trainable_parameters(), mx.zeros((10,)))
However, :meth:`mlx.nn.value_and_grad` provides precisely this pattern and only
computes the gradients with respect to the trainable parameters of the model.
In detail:
- it wraps the passed function with a function that calls :meth:`Module.update`
to make sure the model is using the provided parameters.
- it calls :meth:`mlx.core.value_and_grad` to transform the function into a function
that also computes the gradients with respect to the passed parameters.
- it wraps the returned function with a function that passes the trainable
parameters as the first argument to the function returned by
:meth:`mlx.core.value_and_grad`
.. autosummary::
:toctree: _autosummary
value_and_grad
.. toctree::
nn/module
nn/layers
nn/functions
nn/losses
nn/init
.. _ops:
Operations
==========
.. currentmodule:: mlx.core
.. autosummary::
:toctree: _autosummary
abs
add
all
allclose
any
arange
arccos
arccosh
arcsin
arcsinh
arctan
arctanh
argmax
argmin
argpartition
argsort
array_equal
broadcast_to
ceil
clip
concatenate
convolve
conv1d
conv2d
cos
cosh
dequantize
divide
divmod
equal
erf
erfinv
exp
expand_dims
eye
flatten
floor
floor_divide
full
greater
greater_equal
identity
inner
isnan
isposinf
isneginf
isinf
less
less_equal
linspace
load
log
log2
log10
log1p
logaddexp
logical_not
logical_and
logical_or
logsumexp
matmul
max
maximum
mean
min
minimum
moveaxis
multiply
negative
ones
ones_like
outer
partition
pad
prod
quantize
quantized_matmul
reciprocal
repeat
reshape
round
rsqrt
save
savez
savez_compressed
save_gguf
save_safetensors
sigmoid
sign
sin
sinh
softmax
sort
split
sqrt
square
squeeze
stack
stop_gradient
subtract
sum
swapaxes
take
take_along_axis
tan
tanh
tensordot
transpose
tri
tril
triu
var
where
zeros
zeros_like
.. _optimizers:
Optimizers
==========
The optimizers in MLX can be used both with :mod:`mlx.nn` but also with pure
:mod:`mlx.core` functions. A typical example involves calling
:meth:`Optimizer.update` to update a model's parameters based on the loss
gradients and subsequently calling :func:`mlx.core.eval` to evaluate both the
model's parameters and the **optimizer state**.
.. code-block:: python
# Create a model
model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes)
mx.eval(model.parameters())
# Create the gradient function and the optimizer
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
optimizer = optim.SGD(learning_rate=learning_rate)
for e in range(num_epochs):
for X, y in batch_iterate(batch_size, train_images, train_labels):
loss, grads = loss_and_grad_fn(model, X, y)
# Update the model with the gradients. So far no computation has happened.
optimizer.update(model, grads)
# Compute the new parameters but also the optimizer state.
mx.eval(model.parameters(), optimizer.state)
.. currentmodule:: mlx.optimizers
.. autosummary::
:toctree: _autosummary
:template: optimizers-template.rst
OptimizerState
Optimizer
SGD
RMSprop
Adagrad
Adafactor
AdaDelta
Adam
AdamW
Adamax
Lion
.. _random:
Random
======
Random sampling functions in MLX use an implicit global PRNG state by default.
However, all function take an optional ``key`` keyword argument for when more
fine-grained control or explicit state management is needed.
For example, you can generate random numbers with:
.. code-block:: python
for _ in range(3):
print(mx.random.uniform())
which will print a sequence of unique pseudo random numbers. Alternatively you
can explicitly set the key:
.. code-block:: python
key = mx.random.key(0)
for _ in range(3):
print(mx.random.uniform(key=key))
which will yield the same pseudo random number at each iteration.
Following `JAX's PRNG design <https://jax.readthedocs.io/en/latest/jep/263-prng.html>`_
we use a splittable version of Threefry, which is a counter-based PRNG.
.. currentmodule:: mlx.core.random
.. autosummary::
:toctree: _autosummary
bernoulli
categorical
gumbel
key
normal
randint
seed
split
truncated_normal
uniform
.. _transforms:
Transforms
==========
.. currentmodule:: mlx.core
.. autosummary::
:toctree: _autosummary
eval
grad
value_and_grad
jvp
vjp
vmap
simplify
.. _utils:
Tree Utils
==========
In MLX we consider a python tree to be an arbitrarily nested collection of
dictionaries, lists and tuples without cycles. Functions in this module that
return python trees will be using the default python ``dict``, ``list`` and
``tuple`` but they can usually process objects that inherit from any of these.
.. note::
Dictionaries should have keys that are valid python identifiers.
.. currentmodule:: mlx.utils
.. autosummary::
:toctree: _autosummary
tree_flatten
tree_unflatten
tree_map