MLX Guru - files

This commit is contained in:
Elias Bachaalany 2024-01-26 10:03:11 -08:00
parent bd0099ad1f
commit f96f19d404
8 changed files with 1579 additions and 0 deletions

View file

@ -0,0 +1,23 @@
.. _nn_functions:
.. currentmodule:: mlx.nn
Functions
---------
Layers without parameters (e.g. activation functions) are also provided as
simple functions.
.. autosummary::
:toctree: _autosummary_functions
:template: nn-module-template.rst
gelu
gelu_approx
gelu_fast_approx
mish
prelu
relu
selu
silu
step

View file

@ -0,0 +1,45 @@
.. _init:
.. currentmodule:: mlx.nn.init
Initializers
------------
The ``mlx.nn.init`` package contains commonly used initializers for neural
network parameters. Initializers return a function which can be applied to any
input :obj:`mlx.core.array` to produce an initialized output.
For example:
.. code:: python
import mlx.core as mx
import mlx.nn as nn
init_fn = nn.init.uniform()
# Produces a [2, 2] uniform matrix
param = init_fn(mx.zeros((2, 2)))
To re-initialize all the parameter in an :obj:`mlx.nn.Module` from say a uniform
distribution, you can do:
.. code:: python
import mlx.nn as nn
model = nn.Sequential(nn.Linear(5, 10), nn.ReLU(), nn.Linear(10, 5))
init_fn = nn.init.uniform(low=-0.1, high=0.1)
model.apply(init_fn)
.. autosummary::
:toctree: _autosummary
constant
normal
uniform
identity
glorot_normal
glorot_uniform
he_normal
he_uniform

View file

@ -0,0 +1,37 @@
.. _layers:
.. currentmodule:: mlx.nn
Layers
------
.. autosummary::
:toctree: _autosummary
:template: nn-module-template.rst
ALiBi
BatchNorm
Conv1d
Conv2d
Dropout
Dropout2d
Dropout3d
Embedding
GELU
GroupNorm
InstanceNorm
LayerNorm
Linear
Mish
MultiHeadAttention
PReLU
QuantizedLinear
RMSNorm
ReLU
RoPE
SELU
Sequential
SiLU
SinusoidalPositionalEncoding
Step
Transformer

View file

@ -0,0 +1,24 @@
.. _losses:
.. currentmodule:: mlx.nn.losses
Loss Functions
--------------
.. autosummary::
:toctree: _autosummary_functions
:template: nn-module-template.rst
binary_cross_entropy
cosine_similarity_loss
cross_entropy
gaussian_nll_loss
hinge_loss
huber_loss
kl_div_loss
l1_loss
log_cosh_loss
mse_loss
nll_loss
smooth_l1_loss
triplet_loss

View file

@ -0,0 +1,36 @@
Module
======
.. currentmodule:: mlx.nn
.. autoclass:: Module
.. rubric:: Attributes
.. autosummary::
:toctree: _autosummary
Module.training
.. rubric:: Methods
.. autosummary::
:toctree: _autosummary
Module.apply
Module.apply_to_modules
Module.children
Module.eval
Module.filter_and_map
Module.freeze
Module.leaf_modules
Module.load_weights
Module.modules
Module.named_modules
Module.parameters
Module.save_weights
Module.train
Module.trainable_parameters
Module.unfreeze
Module.update
Module.update_modules

View file

@ -0,0 +1,183 @@
.. _nn:
.. currentmodule:: mlx.nn
Neural Networks
===============
Writing arbitrarily complex neural networks in MLX can be done using only
:class:`mlx.core.array` and :meth:`mlx.core.value_and_grad`. However, this requires the
user to write again and again the same simple neural network operations as well
as handle all the parameter state and initialization manually and explicitly.
The module :mod:`mlx.nn` solves this problem by providing an intuitive way of
composing neural network layers, initializing their parameters, freezing them
for finetuning and more.
Quick Start with Neural Networks
---------------------------------
.. code-block:: python
import mlx.core as mx
import mlx.nn as nn
class MLP(nn.Module):
def __init__(self, in_dims: int, out_dims: int):
super().__init__()
self.layers = [
nn.Linear(in_dims, 128),
nn.Linear(128, 128),
nn.Linear(128, out_dims),
]
def __call__(self, x):
for i, l in enumerate(self.layers):
x = mx.maximum(x, 0) if i > 0 else x
x = l(x)
return x
# The model is created with all its parameters but nothing is initialized
# yet because MLX is lazily evaluated
mlp = MLP(2, 10)
# We can access its parameters by calling mlp.parameters()
params = mlp.parameters()
print(params["layers"][0]["weight"].shape)
# Printing a parameter will cause it to be evaluated and thus initialized
print(params["layers"][0])
# We can also force evaluate all parameters to initialize the model
mx.eval(mlp.parameters())
# A simple loss function.
# NOTE: It doesn't matter how it uses the mlp model. It currently captures
# it from the local scope. It could be a positional argument or a
# keyword argument.
def l2_loss(x, y):
y_hat = mlp(x)
return (y_hat - y).square().mean()
# Calling `nn.value_and_grad` instead of `mx.value_and_grad` returns the
# gradient with respect to `mlp.trainable_parameters()`
loss_and_grad = nn.value_and_grad(mlp, l2_loss)
.. _module_class:
The Module Class
----------------
The workhorse of any neural network library is the :class:`Module` class. In
MLX the :class:`Module` class is a container of :class:`mlx.core.array` or
:class:`Module` instances. Its main function is to provide a way to
recursively **access** and **update** its parameters and those of its
submodules.
Parameters
^^^^^^^^^^
A parameter of a module is any public member of type :class:`mlx.core.array` (its
name should not start with ``_``). It can be arbitrarily nested in other
:class:`Module` instances or lists and dictionaries.
:meth:`Module.parameters` can be used to extract a nested dictionary with all
the parameters of a module and its submodules.
A :class:`Module` can also keep track of "frozen" parameters. See the
:meth:`Module.freeze` method for more details. :meth:`mlx.nn.value_and_grad`
the gradients returned will be with respect to these trainable parameters.
Updating the Parameters
^^^^^^^^^^^^^^^^^^^^^^^
MLX modules allow accessing and updating individual parameters. However, most
times we need to update large subsets of a module's parameters. This action is
performed by :meth:`Module.update`.
Inspecting Modules
^^^^^^^^^^^^^^^^^^
The simplest way to see the model architecture is to print it. Following along with
the above example, you can print the ``MLP`` with:
.. code-block:: python
print(mlp)
This will display:
.. code-block:: shell
MLP(
(layers.0): Linear(input_dims=2, output_dims=128, bias=True)
(layers.1): Linear(input_dims=128, output_dims=128, bias=True)
(layers.2): Linear(input_dims=128, output_dims=10, bias=True)
)
To get more detailed information on the arrays in a :class:`Module` you can use
:func:`mlx.utils.tree_map` on the parameters. For example, to see the shapes of
all the parameters in a :class:`Module` do:
.. code-block:: python
from mlx.utils import tree_map
shapes = tree_map(lambda p: p.shape, mlp.parameters())
As another example, you can count the number of parameters in a :class:`Module`
with:
.. code-block:: python
from mlx.utils import tree_flatten
num_params = sum(v.size for _, v in tree_flatten(mlp.parameters()))
Value and Grad
--------------
Using a :class:`Module` does not preclude using MLX's high order function
transformations (:meth:`mlx.core.value_and_grad`, :meth:`mlx.core.grad`, etc.). However,
these function transformations assume pure functions, namely the parameters
should be passed as an argument to the function being transformed.
There is an easy pattern to achieve that with MLX modules
.. code-block:: python
model = ...
def f(params, other_inputs):
model.update(params) # <---- Necessary to make the model use the passed parameters
return model(other_inputs)
f(model.trainable_parameters(), mx.zeros((10,)))
However, :meth:`mlx.nn.value_and_grad` provides precisely this pattern and only
computes the gradients with respect to the trainable parameters of the model.
In detail:
- it wraps the passed function with a function that calls :meth:`Module.update`
to make sure the model is using the provided parameters.
- it calls :meth:`mlx.core.value_and_grad` to transform the function into a function
that also computes the gradients with respect to the passed parameters.
- it wraps the returned function with a function that passes the trainable
parameters as the first argument to the function returned by
:meth:`mlx.core.value_and_grad`
.. autosummary::
:toctree: _autosummary
value_and_grad
.. toctree::
nn/module
nn/layers
nn/functions
nn/losses
nn/init

View file

@ -0,0 +1,587 @@
.. _array:
Array
=====
.. currentmodule:: mlx.core
.. autosummary::
:toctree: _autosummary
array
array.astype
array.item
array.tolist
array.dtype
array.ndim
array.shape
array.size
Dtype
array.abs
array.all
array.any
array.argmax
array.argmin
array.cos
array.dtype
array.exp
array.log
array.log1p
array.logsumexp
array.max
array.mean
array.min
array.prod
array.reciprocal
array.reshape
array.round
array.rsqrt
array.sin
array.split
array.sqrt
array.square
array.sum
array.transpose
array.T
array.var
.. _data_types:
:orphan:
Data Types
==========
.. currentmodule:: mlx.core
The default floating point type is ``float32`` and the default integer type is
``int32``. The table below shows supported values for :obj:`Dtype`.
.. list-table:: Supported Data Types
:widths: 5 3 20
:header-rows: 1
* - Type
- Bytes
- Description
* - ``bool_``
- 1
- Boolean (``True``, ``False``) data type
* - ``uint8``
- 1
- 8-bit unsigned integer
* - ``uint16``
- 2
- 16-bit unsigned integer
* - ``uint32``
- 4
- 32-bit unsigned integer
* - ``uint64``
- 8
- 64-bit unsigned integer
* - ``int8``
- 1
- 8-bit signed integer
* - ``int16``
- 2
- 16-bit signed integer
* - ``int32``
- 4
- 32-bit signed integer
* - ``int64``
- 8
- 64-bit signed integer
* - ``float16``
- 2
- 16-bit float, only available with `ARM C language extensions <https://developer.arm.com/documentation/101028/0012/3--C-language-extensions?lang=en>`_
* - ``float32``
- 4
- 32-bit float
.. _devices_and_streams:
Devices and Streams
===================
.. currentmodule:: mlx.core
.. autosummary::
:toctree: _autosummary
Device
default_device
set_default_device
Stream
default_stream
new_stream
set_default_stream
.. _fft:
FFT
===
.. currentmodule:: mlx.core.fft
.. autosummary::
:toctree: _autosummary
fft
ifft
fft2
ifft2
fftn
ifftn
rfft
irfft
rfft2
irfft2
rfftn
irfftn
.. _linalg:
Linear Algebra
==============
.. currentmodule:: mlx.core.linalg
.. autosummary::
:toctree: _autosummary
norm
.. _nn:
.. currentmodule:: mlx.nn
Neural Networks
===============
Writing arbitrarily complex neural networks in MLX can be done using only
:class:`mlx.core.array` and :meth:`mlx.core.value_and_grad`. However, this requires the
user to write again and again the same simple neural network operations as well
as handle all the parameter state and initialization manually and explicitly.
The module :mod:`mlx.nn` solves this problem by providing an intuitive way of
composing neural network layers, initializing their parameters, freezing them
for finetuning and more.
Quick Start with Neural Networks
---------------------------------
.. code-block:: python
import mlx.core as mx
import mlx.nn as nn
class MLP(nn.Module):
def __init__(self, in_dims: int, out_dims: int):
super().__init__()
self.layers = [
nn.Linear(in_dims, 128),
nn.Linear(128, 128),
nn.Linear(128, out_dims),
]
def __call__(self, x):
for i, l in enumerate(self.layers):
x = mx.maximum(x, 0) if i > 0 else x
x = l(x)
return x
# The model is created with all its parameters but nothing is initialized
# yet because MLX is lazily evaluated
mlp = MLP(2, 10)
# We can access its parameters by calling mlp.parameters()
params = mlp.parameters()
print(params["layers"][0]["weight"].shape)
# Printing a parameter will cause it to be evaluated and thus initialized
print(params["layers"][0])
# We can also force evaluate all parameters to initialize the model
mx.eval(mlp.parameters())
# A simple loss function.
# NOTE: It doesn't matter how it uses the mlp model. It currently captures
# it from the local scope. It could be a positional argument or a
# keyword argument.
def l2_loss(x, y):
y_hat = mlp(x)
return (y_hat - y).square().mean()
# Calling `nn.value_and_grad` instead of `mx.value_and_grad` returns the
# gradient with respect to `mlp.trainable_parameters()`
loss_and_grad = nn.value_and_grad(mlp, l2_loss)
.. _module_class:
The Module Class
----------------
The workhorse of any neural network library is the :class:`Module` class. In
MLX the :class:`Module` class is a container of :class:`mlx.core.array` or
:class:`Module` instances. Its main function is to provide a way to
recursively **access** and **update** its parameters and those of its
submodules.
Parameters
^^^^^^^^^^
A parameter of a module is any public member of type :class:`mlx.core.array` (its
name should not start with ``_``). It can be arbitrarily nested in other
:class:`Module` instances or lists and dictionaries.
:meth:`Module.parameters` can be used to extract a nested dictionary with all
the parameters of a module and its submodules.
A :class:`Module` can also keep track of "frozen" parameters. See the
:meth:`Module.freeze` method for more details. :meth:`mlx.nn.value_and_grad`
the gradients returned will be with respect to these trainable parameters.
Updating the Parameters
^^^^^^^^^^^^^^^^^^^^^^^
MLX modules allow accessing and updating individual parameters. However, most
times we need to update large subsets of a module's parameters. This action is
performed by :meth:`Module.update`.
Inspecting Modules
^^^^^^^^^^^^^^^^^^
The simplest way to see the model architecture is to print it. Following along with
the above example, you can print the ``MLP`` with:
.. code-block:: python
print(mlp)
This will display:
.. code-block:: shell
MLP(
(layers.0): Linear(input_dims=2, output_dims=128, bias=True)
(layers.1): Linear(input_dims=128, output_dims=128, bias=True)
(layers.2): Linear(input_dims=128, output_dims=10, bias=True)
)
To get more detailed information on the arrays in a :class:`Module` you can use
:func:`mlx.utils.tree_map` on the parameters. For example, to see the shapes of
all the parameters in a :class:`Module` do:
.. code-block:: python
from mlx.utils import tree_map
shapes = tree_map(lambda p: p.shape, mlp.parameters())
As another example, you can count the number of parameters in a :class:`Module`
with:
.. code-block:: python
from mlx.utils import tree_flatten
num_params = sum(v.size for _, v in tree_flatten(mlp.parameters()))
Value and Grad
--------------
Using a :class:`Module` does not preclude using MLX's high order function
transformations (:meth:`mlx.core.value_and_grad`, :meth:`mlx.core.grad`, etc.). However,
these function transformations assume pure functions, namely the parameters
should be passed as an argument to the function being transformed.
There is an easy pattern to achieve that with MLX modules
.. code-block:: python
model = ...
def f(params, other_inputs):
model.update(params) # <---- Necessary to make the model use the passed parameters
return model(other_inputs)
f(model.trainable_parameters(), mx.zeros((10,)))
However, :meth:`mlx.nn.value_and_grad` provides precisely this pattern and only
computes the gradients with respect to the trainable parameters of the model.
In detail:
- it wraps the passed function with a function that calls :meth:`Module.update`
to make sure the model is using the provided parameters.
- it calls :meth:`mlx.core.value_and_grad` to transform the function into a function
that also computes the gradients with respect to the passed parameters.
- it wraps the returned function with a function that passes the trainable
parameters as the first argument to the function returned by
:meth:`mlx.core.value_and_grad`
.. autosummary::
:toctree: _autosummary
value_and_grad
.. toctree::
nn/module
nn/layers
nn/functions
nn/losses
nn/init
.. _ops:
Operations
==========
.. currentmodule:: mlx.core
.. autosummary::
:toctree: _autosummary
abs
add
all
allclose
any
arange
arccos
arccosh
arcsin
arcsinh
arctan
arctanh
argmax
argmin
argpartition
argsort
array_equal
broadcast_to
ceil
clip
concatenate
convolve
conv1d
conv2d
cos
cosh
dequantize
divide
divmod
equal
erf
erfinv
exp
expand_dims
eye
flatten
floor
floor_divide
full
greater
greater_equal
identity
inner
isnan
isposinf
isneginf
isinf
less
less_equal
linspace
load
log
log2
log10
log1p
logaddexp
logical_not
logical_and
logical_or
logsumexp
matmul
max
maximum
mean
min
minimum
moveaxis
multiply
negative
ones
ones_like
outer
partition
pad
prod
quantize
quantized_matmul
reciprocal
repeat
reshape
round
rsqrt
save
savez
savez_compressed
save_gguf
save_safetensors
sigmoid
sign
sin
sinh
softmax
sort
split
sqrt
square
squeeze
stack
stop_gradient
subtract
sum
swapaxes
take
take_along_axis
tan
tanh
tensordot
transpose
tri
tril
triu
var
where
zeros
zeros_like
.. _optimizers:
Optimizers
==========
The optimizers in MLX can be used both with :mod:`mlx.nn` but also with pure
:mod:`mlx.core` functions. A typical example involves calling
:meth:`Optimizer.update` to update a model's parameters based on the loss
gradients and subsequently calling :func:`mlx.core.eval` to evaluate both the
model's parameters and the **optimizer state**.
.. code-block:: python
# Create a model
model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes)
mx.eval(model.parameters())
# Create the gradient function and the optimizer
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
optimizer = optim.SGD(learning_rate=learning_rate)
for e in range(num_epochs):
for X, y in batch_iterate(batch_size, train_images, train_labels):
loss, grads = loss_and_grad_fn(model, X, y)
# Update the model with the gradients. So far no computation has happened.
optimizer.update(model, grads)
# Compute the new parameters but also the optimizer state.
mx.eval(model.parameters(), optimizer.state)
.. currentmodule:: mlx.optimizers
.. autosummary::
:toctree: _autosummary
:template: optimizers-template.rst
OptimizerState
Optimizer
SGD
RMSprop
Adagrad
Adafactor
AdaDelta
Adam
AdamW
Adamax
Lion
.. _random:
Random
======
Random sampling functions in MLX use an implicit global PRNG state by default.
However, all function take an optional ``key`` keyword argument for when more
fine-grained control or explicit state management is needed.
For example, you can generate random numbers with:
.. code-block:: python
for _ in range(3):
print(mx.random.uniform())
which will print a sequence of unique pseudo random numbers. Alternatively you
can explicitly set the key:
.. code-block:: python
key = mx.random.key(0)
for _ in range(3):
print(mx.random.uniform(key=key))
which will yield the same pseudo random number at each iteration.
Following `JAX's PRNG design <https://jax.readthedocs.io/en/latest/jep/263-prng.html>`_
we use a splittable version of Threefry, which is a counter-based PRNG.
.. currentmodule:: mlx.core.random
.. autosummary::
:toctree: _autosummary
bernoulli
categorical
gumbel
key
normal
randint
seed
split
truncated_normal
uniform
.. _transforms:
Transforms
==========
.. currentmodule:: mlx.core
.. autosummary::
:toctree: _autosummary
eval
grad
value_and_grad
jvp
vjp
vmap
simplify
.. _utils:
Tree Utils
==========
In MLX we consider a python tree to be an arbitrarily nested collection of
dictionaries, lists and tuples without cycles. Functions in this module that
return python trees will be using the default python ``dict``, ``list`` and
``tuple`` but they can usually process objects that inherit from any of these.
.. note::
Dictionaries should have keys that are valid python identifiers.
.. currentmodule:: mlx.utils
.. autosummary::
:toctree: _autosummary
tree_flatten
tree_unflatten
tree_map

View file

@ -0,0 +1,644 @@
.. _function_transforms:
Function Transforms
===================
.. currentmodule:: mlx.core
MLX uses composable function transformations for automatic differentiation and
vectorization. The key idea behind composable function transformations is that
every transformation returns a function which can be further transformed.
Here is a simple example:
.. code-block:: shell
>>> dfdx = mx.grad(mx.sin)
>>> dfdx(mx.array(mx.pi))
array(-1, dtype=float32)
>>> mx.cos(mx.array(mx.pi))
array(-1, dtype=float32)
The output of :func:`grad` on :func:`sin` is simply another function. In this
case it is the gradient of the sine function which is exactly the cosine
function. To get the second derivative you can do:
.. code-block:: shell
>>> d2fdx2 = mx.grad(mx.grad(mx.sin))
>>> d2fdx2(mx.array(mx.pi / 2))
array(-1, dtype=float32)
>>> mx.sin(mx.array(mx.pi / 2))
array(1, dtype=float32)
Using :func:`grad` on the output of :func:`grad` is always ok. You keep
getting higher order derivatives.
Any of the MLX function transformations can be composed in any order to any
depth. To see the complete list of function transformations check-out the
:ref:`API documentation <transforms>`. See the following sections for more
information on :ref:`automatic differentiaion <auto diff>` and
:ref:`automatic vectorization <vmap>`.
Automatic Differentiation
-------------------------
.. _auto diff:
Automatic differentiation in MLX works on functions rather than on implicit
graphs.
.. note::
If you are coming to MLX from PyTorch, you no longer need functions like
``backward``, ``zero_grad``, and ``detach``, or properties like
``requires_grad``.
The most basic example is taking the gradient of a scalar-valued function as we
saw above. You can use the :func:`grad` and :func:`value_and_grad` function to
compute gradients of more complex functions. By default these functions compute
the gradient with respect to the first argument:
.. code-block:: python
def loss_fn(w, x, y):
return mx.mean(mx.square(w * x - y))
w = mx.array(1.0)
x = mx.array([0.5, -0.5])
y = mx.array([1.5, -1.5])
# Computes the gradient of loss_fn with respect to w:
grad_fn = mx.grad(loss_fn)
dloss_dw = grad_fn(w, x, y)
# Prints array(-1, dtype=float32)
print(dloss_dw)
# To get the gradient with respect to x we can do:
grad_fn = mx.grad(loss_fn, argnums=1)
dloss_dx = grad_fn(w, x, y)
# Prints array([-1, 1], dtype=float32)
print(dloss_dx)
One way to get the loss and gradient is to call ``loss_fn`` followed by
``grad_fn``, but this can result in a lot of redundant work. Instead, you
should use :func:`value_and_grad`. Continuing the above example:
.. code-block:: python
# Computes the gradient of loss_fn with respect to w:
loss_and_grad_fn = mx.value_and_grad(loss_fn)
loss, dloss_dw = loss_and_grad_fn(w, x, y)
# Prints array(1, dtype=float32)
print(loss)
# Prints array(-1, dtype=float32)
print(dloss_dw)
You can also take the gradient with respect to arbitrarily nested Python
containers of arrays (specifically any of :obj:`list`, :obj:`tuple`, or
:obj:`dict`).
Suppose we wanted a weight and a bias parameter in the above example. A nice
way to do that is the following:
.. code-block:: python
def loss_fn(params, x, y):
w, b = params["weight"], params["bias"]
h = w * x + b
return mx.mean(mx.square(h - y))
params = {"weight": mx.array(1.0), "bias": mx.array(0.0)}
x = mx.array([0.5, -0.5])
y = mx.array([1.5, -1.5])
# Computes the gradient of loss_fn with respect to both the
# weight and bias:
grad_fn = mx.grad(loss_fn)
grads = grad_fn(params, x, y)
# Prints
# {'weight': array(-1, dtype=float32), 'bias': array(0, dtype=float32)}
print(grads)
Notice the tree structure of the parameters is preserved in the gradients.
In some cases you may want to stop gradients from propagating through a
part of the function. You can use the :func:`stop_gradient` for that.
Automatic Vectorization
-----------------------
.. _vmap:
Use :func:`vmap` to automate vectorizing complex functions. Here we'll go
through a basic and contrived example for the sake of clarity, but :func:`vmap`
can be quite powerful for more complex functions which are difficult to optimize
by hand.
.. warning::
Some operations are not yet supported with :func:`vmap`. If you encounter an error
like: ``ValueError: Primitive's vmap not implemented.`` file an `issue
<https://github.com/ml-explore/mlx/issues>`_ and include your function.
We will prioritize including it.
A naive way to add the elements from two sets of vectors is with a loop:
.. code-block:: python
xs = mx.random.uniform(shape=(4096, 100))
ys = mx.random.uniform(shape=(100, 4096))
def naive_add(xs, ys):
return [xs[i] + ys[:, i] for i in range(xs.shape[1])]
Instead you can use :func:`vmap` to automatically vectorize the addition:
.. code-block:: python
# Vectorize over the second dimension of x and the
# first dimension of y
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(1, 0))
The ``in_axes`` parameter can be used to specify which dimensions of the
corresponding input to vectorize over. Similarly, use ``out_axes`` to specify
where the vectorized axes should be in the outputs.
Let's time these two different versions:
.. code-block:: python
import timeit
print(timeit.timeit(lambda: mx.eval(naive_add(xs, ys)), number=100))
print(timeit.timeit(lambda: mx.eval(vmap_add(xs, ys)), number=100))
On an M1 Max the naive version takes in total ``0.390`` seconds whereas the
vectorized version takes only ``0.025`` seconds, more than ten times faster.
Of course, this operation is quite contrived. A better approach is to simply do
``xs + ys.T``, but for more complex functions :func:`vmap` can be quite handy.
.. _indexing:
Indexing Arrays
===============
.. currentmodule:: mlx.core
For the most part, indexing an MLX :obj:`array` works the same as indexing a
NumPy :obj:`numpy.ndarray`. See the `NumPy documentation
<https://numpy.org/doc/stable/user/basics.indexing.html>`_ for more details on
how that works.
For example, you can use regular integers and slices (:obj:`slice`) to index arrays:
.. code-block:: shell
>>> arr = mx.arange(10)
>>> arr[3]
array(3, dtype=int32)
>>> arr[-2] # negative indexing works
array(8, dtype=int32)
>>> arr[2:8:2] # start, stop, stride
array([2, 4, 6], dtype=int32)
For multi-dimensional arrays, the ``...`` or :obj:`Ellipsis` syntax works as in NumPy:
.. code-block:: shell
>>> arr = mx.arange(8).reshape(2, 2, 2)
>>> arr[:, :, 0]
array(3, dtype=int32)
array([[0, 2],
[4, 6]], dtype=int32
>>> arr[..., 0]
array([[0, 2],
[4, 6]], dtype=int32
You can index with ``None`` to create a new axis:
.. code-block:: shell
>>> arr = mx.arange(8)
>>> arr.shape
[8]
>>> arr[None].shape
[1, 8]
You can also use an :obj:`array` to index another :obj:`array`:
.. code-block:: shell
>>> arr = mx.arange(10)
>>> idx = mx.array([5, 7])
>>> arr[idx]
array([5, 7], dtype=int32)
Mixing and matching integers, :obj:`slice`, ``...``, and :obj:`array` indices
works just as in NumPy.
Other functions which may be useful for indexing arrays are :func:`take` and
:func:`take_along_axis`.
Differences from NumPy
----------------------
.. Note::
MLX indexing is different from NumPy indexing in two important ways:
* Indexing does not perform bounds checking. Indexing out of bounds is
undefined behavior.
* Boolean mask based indexing is not yet supported.
The reason for the lack of bounds checking is that exceptions cannot propagate
from the GPU. Performing bounds checking for array indices before launching the
kernel would be extremely inefficient.
Indexing with boolean masks is something that MLX may support in the future. In
general, MLX has limited support for operations for which outputs
*shapes* are dependent on input *data*. Other examples of these types of
operations which MLX does not yet support include :func:`numpy.nonzero` and the
single input version of :func:`numpy.where`.
In Place Updates
----------------
In place updates to indexed arrays are possible in MLX. For example:
.. code-block:: shell
>>> a = mx.array([1, 2, 3])
>>> a[2] = 0
>>> a
array([1, 2, 0], dtype=int32)
Just as in NumPy, in place updates will be reflected in all references to the
same array:
.. code-block:: shell
>>> a = mx.array([1, 2, 3])
>>> b = a
>>> b[2] = 0
>>> b
array([1, 2, 0], dtype=int32)
>>> a
array([1, 2, 0], dtype=int32)
Transformations of functions which use in-place updates are allowed and work as
expected. For example:
.. code-block:: python
def fun(x, idx):
x[idx] = 2.0
return x.sum()
dfdx = mx.grad(fun)(mx.array([1.0, 2.0, 3.0]), mx.array([1]))
print(dfdx) # Prints: array([1, 0, 1], dtype=float32)
In the above ``dfdx`` will have the correct gradient, namely zeros at ``idx``
and ones elsewhere.
.. _lazy eval:
Lazy Evaluation
===============
.. currentmodule:: mlx.core
Why Lazy Evaluation
-------------------
When you perform operations in MLX, no computation actually happens. Instead a
compute graph is recorded. The actual computation only happens if an
:func:`eval` is performed.
MLX uses lazy evaluation because it has some nice features, some of which we
describe below.
Transforming Compute Graphs
^^^^^^^^^^^^^^^^^^^^^^^^^^^
Lazy evaluation let's us record a compute graph without actually doing any
computations. This is useful for function transformations like :func:`grad` and
:func:`vmap` and graph optimizations like :func:`simplify`.
Currently, MLX does not compile and rerun compute graphs. They are all
generated dynamically. However, lazy evaluation makes it much easier to
integrate compilation for future performance enhancements.
Only Compute What You Use
^^^^^^^^^^^^^^^^^^^^^^^^^
In MLX you do not need to worry as much about computing outputs that are never
used. For example:
.. code-block:: python
def fun(x):
a = fun1(x)
b = expensive_fun(a)
return a, b
y, _ = fun(x)
Here, we never actually compute the output of ``expensive_fun``. Use this
pattern with care though, as the graph of ``expensive_fun`` is still built, and
that has some cost associated to it.
Similarly, lazy evaluation can be beneficial for saving memory while keeping
code simple. Say you have a very large model ``Model`` derived from
:obj:`mlx.nn.Module`. You can instantiate this model with ``model = Model()``.
Typically, this will initialize all of the weights as ``float32``, but the
initialization does not actually compute anything until you perform an
:func:`eval`. If you update the model with ``float16`` weights, your maximum
consumed memory will be half that required if eager computation was used
instead.
This pattern is simple to do in MLX thanks to lazy computation:
.. code-block:: python
model = Model() # no memory used yet
model.load_weights("weights_fp16.safetensors")
When to Evaluate
----------------
A common question is when to use :func:`eval`. The trade-off is between
letting graphs get too large and not batching enough useful work.
For example:
.. code-block:: python
for _ in range(100):
a = a + b
mx.eval(a)
b = b * 2
mx.eval(b)
This is a bad idea because there is some fixed overhead with each graph
evaluation. On the other hand, there is some slight overhead which grows with
the compute graph size, so extremely large graphs (while computationally
correct) can be costly.
Luckily, a wide range of compute graph sizes work pretty well with MLX:
anything from a few tens of operations to many thousands of operations per
evaluation should be okay.
Most numerical computations have an iterative outer loop (e.g. the iteration in
stochastic gradient descent). A natural and usually efficient place to use
:func:`eval` is at each iteration of this outer loop.
Here is a concrete example:
.. code-block:: python
for batch in dataset:
# Nothing has been evaluated yet
loss, grad = value_and_grad_fn(model, batch)
# Still nothing has been evaluated
optimizer.update(model, grad)
# Evaluate the loss and the new parameters which will
# run the full gradient computation and optimizer update
mx.eval(loss, model.parameters())
An important behavior to be aware of is when the graph will be implicitly
evaluated. Anytime you ``print`` an array, convert it to an
:obj:`numpy.ndarray`, or otherwise access it's memory via :obj:`memoryview`,
the graph will be evaluated. Saving arrays via :func:`save` (or any other MLX
saving functions) will also evaluate the array.
Calling :func:`array.item` on a scalar array will also evaluate it. In the
example above, printing the loss (``print(loss)``) or adding the loss scalar to
a list (``losses.append(loss.item())``) would cause a graph evaluation. If
these lines are before ``mx.eval(loss, model.parameters())`` then this
will be a partial evaluation, computing only the forward pass.
Also, calling :func:`eval` on an array or set of arrays multiple times is
perfectly fine. This is effectively a no-op.
.. warning::
Using scalar arrays for control-flow will cause an evaluation.
Here is an example:
.. code-block:: python
def fun(x):
h, y = first_layer(x)
if y > 0: # An evaluation is done here!
z = second_layer_a(h)
else:
z = second_layer_b(h)
return z
Using arrays for control flow should be done with care. The above example works
and can even be used with gradient transformations. However, this can be very
inefficient if evaluations are done too frequently.
.. _numpy:
Conversion to NumPy and Other Frameworks
========================================
MLX array implements the `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
Let's convert an array to NumPy and back.
.. code-block:: python
import mlx.core as mx
import numpy as np
a = mx.arange(3)
b = np.array(a) # copy of a
c = mx.array(b) # copy of b
.. note::
Since NumPy does not support ``bfloat16`` arrays, you will need to convert to ``float16`` or ``float32`` first:
``np.array(a.astype(mx.float32))``.
Otherwise, you will receive an error like: ``Item size 2 for PEP 3118 buffer format string does not match the dtype V item size 0.``
By default, NumPy copies data to a new array. This can be prevented by creating an array view:
.. code-block:: python
a = mx.arange(3)
a_view = np.array(a, copy=False)
print(a_view.flags.owndata) # False
a_view[0] = 1
print(a[0].item()) # 1
A NumPy array view is a normal NumPy array, except that it does not own its memory.
This means writing to the view is reflected in the original array.
While this is quite powerful to prevent copying arrays, it should be noted that external changes to the memory of arrays cannot be reflected in gradients.
Let's demonstrate this in an example:
.. code-block:: python
def f(x):
x_view = np.array(x, copy=False)
x_view[:] *= x_view # modify memory without telling mx
return x.sum()
x = mx.array([3.0])
y, df = mx.value_and_grad(f)(x)
print("f(x) = x² =", y.item()) # 9.0
print("f'(x) = 2x !=", df.item()) # 1.0
The function ``f`` indirectly modifies the array ``x`` through a memory view.
However, this modification is not reflected in the gradient, as seen in the last line outputting ``1.0``,
representing the gradient of the sum operation alone.
The squaring of ``x`` occurs externally to MLX, meaning that no gradient is incorporated.
It's important to note that a similar issue arises during array conversion and copying.
For instance, a function defined as ``mx.array(np.array(x)**2).sum()`` would also result in an incorrect gradient,
even though no in-place operations on MLX memory are executed.
PyTorch
-------
.. warning::
PyTorch Support for :obj:`memoryview` is experimental and can break for
multi-dimensional arrays. Casting to NumPy first is advised for now.
PyTorch supports the buffer protocol, but it requires an explicit :obj:`memoryview`.
.. code-block:: python
import mlx.core as mx
import torch
a = mx.arange(3)
b = torch.tensor(memoryview(a))
c = mx.array(b.numpy())
Conversion from PyTorch tensors back to arrays must be done via intermediate NumPy arrays with ``numpy()``.
JAX
---
JAX fully supports the buffer protocol.
.. code-block:: python
import mlx.core as mx
import jax.numpy as jnp
a = mx.arange(3)
b = jnp.array(a)
c = mx.array(b)
TensorFlow
----------
TensorFlow supports the buffer protocol, but it requires an explicit :obj:`memoryview`.
.. code-block:: python
import mlx.core as mx
import tensorflow as tf
a = mx.arange(3)
b = tf.constant(memoryview(a))
c = mx.array(b)
.. _saving_and_loading:
Saving and Loading Arrays
=========================
.. currentmodule:: mlx.core
MLX supports multiple array serialization formats.
.. list-table:: Serialization Formats
:widths: 20 8 25 25
:header-rows: 1
* - Format
- Extension
- Function
- Notes
* - NumPy
- ``.npy``
- :func:`save`
- Single arrays only
* - NumPy archive
- ``.npz``
- :func:`savez` and :func:`savez_compressed`
- Multiple arrays
* - Safetensors
- ``.safetensors``
- :func:`save_safetensors`
- Multiple arrays
* - GGUF
- ``.gguf``
- :func:`save_gguf`
- Multiple arrays
The :func:`load` function will load any of the supported serialization
formats. It determines the format from the extensions. The output of
:func:`load` depends on the format.
Here's an example of saving a single array to a file:
.. code-block:: shell
>>> a = mx.array([1.0])
>>> mx.save("array", a)
The array ``a`` will be saved in the file ``array.npy`` (notice the extension
is automatically added). Including the extension is optional; if it is missing
it will be added. You can load the array with:
.. code-block:: shell
>>> mx.load("array.npy", a)
array([1], dtype=float32)
Here's an example of saving several arrays to a single file:
.. code-block:: shell
>>> a = mx.array([1.0])
>>> b = mx.array([2.0])
>>> mx.savez("arrays", a, b=b)
For compatibility with :func:`numpy.savez` the MLX :func:`savez` takes arrays
as arguments. If the keywords are missing, then default names will be
provided. This can be loaded with:
.. code-block:: shell
>>> mx.load("arrays.npz")
{'b': array([2], dtype=float32), 'arr_0': array([1], dtype=float32)}
In this case :func:`load` returns a dictionary of names to arrays.
The functions :func:`save_safetensors` and :func:`save_gguf` are similar to
:func:`savez`, but they take as input a :obj:`dict` of string names to arrays:
.. code-block:: shell
>>> a = mx.array([1.0])
>>> b = mx.array([2.0])
>>> mx.save_safetensors("arrays", {"a": a, "b": b})