diff --git a/prompts/gpts/knowledge/MLX Guru/functions.txt b/prompts/gpts/knowledge/MLX Guru/functions.txt new file mode 100644 index 0000000..fd4302e --- /dev/null +++ b/prompts/gpts/knowledge/MLX Guru/functions.txt @@ -0,0 +1,23 @@ +.. _nn_functions: + +.. currentmodule:: mlx.nn + +Functions +--------- + +Layers without parameters (e.g. activation functions) are also provided as +simple functions. + +.. autosummary:: + :toctree: _autosummary_functions + :template: nn-module-template.rst + + gelu + gelu_approx + gelu_fast_approx + mish + prelu + relu + selu + silu + step diff --git a/prompts/gpts/knowledge/MLX Guru/init.txt b/prompts/gpts/knowledge/MLX Guru/init.txt new file mode 100644 index 0000000..610d767 --- /dev/null +++ b/prompts/gpts/knowledge/MLX Guru/init.txt @@ -0,0 +1,45 @@ +.. _init: + +.. currentmodule:: mlx.nn.init + +Initializers +------------ + +The ``mlx.nn.init`` package contains commonly used initializers for neural +network parameters. Initializers return a function which can be applied to any +input :obj:`mlx.core.array` to produce an initialized output. + +For example: + +.. code:: python + + import mlx.core as mx + import mlx.nn as nn + + init_fn = nn.init.uniform() + + # Produces a [2, 2] uniform matrix + param = init_fn(mx.zeros((2, 2))) + +To re-initialize all the parameter in an :obj:`mlx.nn.Module` from say a uniform +distribution, you can do: + +.. code:: python + + import mlx.nn as nn + model = nn.Sequential(nn.Linear(5, 10), nn.ReLU(), nn.Linear(10, 5)) + init_fn = nn.init.uniform(low=-0.1, high=0.1) + model.apply(init_fn) + + +.. autosummary:: + :toctree: _autosummary + + constant + normal + uniform + identity + glorot_normal + glorot_uniform + he_normal + he_uniform diff --git a/prompts/gpts/knowledge/MLX Guru/layers.txt b/prompts/gpts/knowledge/MLX Guru/layers.txt new file mode 100644 index 0000000..fc8848c --- /dev/null +++ b/prompts/gpts/knowledge/MLX Guru/layers.txt @@ -0,0 +1,37 @@ +.. _layers: + +.. currentmodule:: mlx.nn + +Layers +------ + +.. autosummary:: + :toctree: _autosummary + :template: nn-module-template.rst + + ALiBi + BatchNorm + Conv1d + Conv2d + Dropout + Dropout2d + Dropout3d + Embedding + GELU + GroupNorm + InstanceNorm + LayerNorm + Linear + Mish + MultiHeadAttention + PReLU + QuantizedLinear + RMSNorm + ReLU + RoPE + SELU + Sequential + SiLU + SinusoidalPositionalEncoding + Step + Transformer diff --git a/prompts/gpts/knowledge/MLX Guru/losses.txt b/prompts/gpts/knowledge/MLX Guru/losses.txt new file mode 100644 index 0000000..6c4327e --- /dev/null +++ b/prompts/gpts/knowledge/MLX Guru/losses.txt @@ -0,0 +1,24 @@ +.. _losses: + +.. currentmodule:: mlx.nn.losses + +Loss Functions +-------------- + +.. autosummary:: + :toctree: _autosummary_functions + :template: nn-module-template.rst + + binary_cross_entropy + cosine_similarity_loss + cross_entropy + gaussian_nll_loss + hinge_loss + huber_loss + kl_div_loss + l1_loss + log_cosh_loss + mse_loss + nll_loss + smooth_l1_loss + triplet_loss \ No newline at end of file diff --git a/prompts/gpts/knowledge/MLX Guru/module.txt b/prompts/gpts/knowledge/MLX Guru/module.txt new file mode 100644 index 0000000..042a880 --- /dev/null +++ b/prompts/gpts/knowledge/MLX Guru/module.txt @@ -0,0 +1,36 @@ +Module +====== + +.. currentmodule:: mlx.nn + +.. autoclass:: Module + + .. rubric:: Attributes + + .. autosummary:: + :toctree: _autosummary + + Module.training + + .. rubric:: Methods + + .. autosummary:: + :toctree: _autosummary + + Module.apply + Module.apply_to_modules + Module.children + Module.eval + Module.filter_and_map + Module.freeze + Module.leaf_modules + Module.load_weights + Module.modules + Module.named_modules + Module.parameters + Module.save_weights + Module.train + Module.trainable_parameters + Module.unfreeze + Module.update + Module.update_modules diff --git a/prompts/gpts/knowledge/MLX Guru/nn.txt b/prompts/gpts/knowledge/MLX Guru/nn.txt new file mode 100644 index 0000000..2a253ab --- /dev/null +++ b/prompts/gpts/knowledge/MLX Guru/nn.txt @@ -0,0 +1,183 @@ +.. _nn: + +.. currentmodule:: mlx.nn + +Neural Networks +=============== + +Writing arbitrarily complex neural networks in MLX can be done using only +:class:`mlx.core.array` and :meth:`mlx.core.value_and_grad`. However, this requires the +user to write again and again the same simple neural network operations as well +as handle all the parameter state and initialization manually and explicitly. + +The module :mod:`mlx.nn` solves this problem by providing an intuitive way of +composing neural network layers, initializing their parameters, freezing them +for finetuning and more. + +Quick Start with Neural Networks +--------------------------------- + +.. code-block:: python + + import mlx.core as mx + import mlx.nn as nn + + class MLP(nn.Module): + def __init__(self, in_dims: int, out_dims: int): + super().__init__() + + self.layers = [ + nn.Linear(in_dims, 128), + nn.Linear(128, 128), + nn.Linear(128, out_dims), + ] + + def __call__(self, x): + for i, l in enumerate(self.layers): + x = mx.maximum(x, 0) if i > 0 else x + x = l(x) + return x + + # The model is created with all its parameters but nothing is initialized + # yet because MLX is lazily evaluated + mlp = MLP(2, 10) + + # We can access its parameters by calling mlp.parameters() + params = mlp.parameters() + print(params["layers"][0]["weight"].shape) + + # Printing a parameter will cause it to be evaluated and thus initialized + print(params["layers"][0]) + + # We can also force evaluate all parameters to initialize the model + mx.eval(mlp.parameters()) + + # A simple loss function. + # NOTE: It doesn't matter how it uses the mlp model. It currently captures + # it from the local scope. It could be a positional argument or a + # keyword argument. + def l2_loss(x, y): + y_hat = mlp(x) + return (y_hat - y).square().mean() + + # Calling `nn.value_and_grad` instead of `mx.value_and_grad` returns the + # gradient with respect to `mlp.trainable_parameters()` + loss_and_grad = nn.value_and_grad(mlp, l2_loss) + +.. _module_class: + +The Module Class +---------------- + +The workhorse of any neural network library is the :class:`Module` class. In +MLX the :class:`Module` class is a container of :class:`mlx.core.array` or +:class:`Module` instances. Its main function is to provide a way to +recursively **access** and **update** its parameters and those of its +submodules. + +Parameters +^^^^^^^^^^ + +A parameter of a module is any public member of type :class:`mlx.core.array` (its +name should not start with ``_``). It can be arbitrarily nested in other +:class:`Module` instances or lists and dictionaries. + +:meth:`Module.parameters` can be used to extract a nested dictionary with all +the parameters of a module and its submodules. + +A :class:`Module` can also keep track of "frozen" parameters. See the +:meth:`Module.freeze` method for more details. :meth:`mlx.nn.value_and_grad` +the gradients returned will be with respect to these trainable parameters. + + +Updating the Parameters +^^^^^^^^^^^^^^^^^^^^^^^ + +MLX modules allow accessing and updating individual parameters. However, most +times we need to update large subsets of a module's parameters. This action is +performed by :meth:`Module.update`. + + +Inspecting Modules +^^^^^^^^^^^^^^^^^^ + +The simplest way to see the model architecture is to print it. Following along with +the above example, you can print the ``MLP`` with: + +.. code-block:: python + + print(mlp) + +This will display: + +.. code-block:: shell + + MLP( + (layers.0): Linear(input_dims=2, output_dims=128, bias=True) + (layers.1): Linear(input_dims=128, output_dims=128, bias=True) + (layers.2): Linear(input_dims=128, output_dims=10, bias=True) + ) + +To get more detailed information on the arrays in a :class:`Module` you can use +:func:`mlx.utils.tree_map` on the parameters. For example, to see the shapes of +all the parameters in a :class:`Module` do: + +.. code-block:: python + + from mlx.utils import tree_map + shapes = tree_map(lambda p: p.shape, mlp.parameters()) + +As another example, you can count the number of parameters in a :class:`Module` +with: + +.. code-block:: python + + from mlx.utils import tree_flatten + num_params = sum(v.size for _, v in tree_flatten(mlp.parameters())) + + +Value and Grad +-------------- + +Using a :class:`Module` does not preclude using MLX's high order function +transformations (:meth:`mlx.core.value_and_grad`, :meth:`mlx.core.grad`, etc.). However, +these function transformations assume pure functions, namely the parameters +should be passed as an argument to the function being transformed. + +There is an easy pattern to achieve that with MLX modules + +.. code-block:: python + + model = ... + + def f(params, other_inputs): + model.update(params) # <---- Necessary to make the model use the passed parameters + return model(other_inputs) + + f(model.trainable_parameters(), mx.zeros((10,))) + +However, :meth:`mlx.nn.value_and_grad` provides precisely this pattern and only +computes the gradients with respect to the trainable parameters of the model. + +In detail: + +- it wraps the passed function with a function that calls :meth:`Module.update` + to make sure the model is using the provided parameters. +- it calls :meth:`mlx.core.value_and_grad` to transform the function into a function + that also computes the gradients with respect to the passed parameters. +- it wraps the returned function with a function that passes the trainable + parameters as the first argument to the function returned by + :meth:`mlx.core.value_and_grad` + +.. autosummary:: + :toctree: _autosummary + + value_and_grad + +.. toctree:: + + nn/module + nn/layers + nn/functions + nn/losses + nn/init diff --git a/prompts/gpts/knowledge/MLX Guru/python_api.txt b/prompts/gpts/knowledge/MLX Guru/python_api.txt new file mode 100644 index 0000000..8aff6d2 --- /dev/null +++ b/prompts/gpts/knowledge/MLX Guru/python_api.txt @@ -0,0 +1,587 @@ +.. _array: + +Array +===== + +.. currentmodule:: mlx.core + +.. autosummary:: + :toctree: _autosummary + + array + array.astype + array.item + array.tolist + array.dtype + array.ndim + array.shape + array.size + Dtype + array.abs + array.all + array.any + array.argmax + array.argmin + array.cos + array.dtype + array.exp + array.log + array.log1p + array.logsumexp + array.max + array.mean + array.min + array.prod + array.reciprocal + array.reshape + array.round + array.rsqrt + array.sin + array.split + array.sqrt + array.square + array.sum + array.transpose + array.T + array.var +.. _data_types: + +:orphan: + +Data Types +========== + +.. currentmodule:: mlx.core + +The default floating point type is ``float32`` and the default integer type is +``int32``. The table below shows supported values for :obj:`Dtype`. + +.. list-table:: Supported Data Types + :widths: 5 3 20 + :header-rows: 1 + + * - Type + - Bytes + - Description + * - ``bool_`` + - 1 + - Boolean (``True``, ``False``) data type + * - ``uint8`` + - 1 + - 8-bit unsigned integer + * - ``uint16`` + - 2 + - 16-bit unsigned integer + * - ``uint32`` + - 4 + - 32-bit unsigned integer + * - ``uint64`` + - 8 + - 64-bit unsigned integer + * - ``int8`` + - 1 + - 8-bit signed integer + * - ``int16`` + - 2 + - 16-bit signed integer + * - ``int32`` + - 4 + - 32-bit signed integer + * - ``int64`` + - 8 + - 64-bit signed integer + * - ``float16`` + - 2 + - 16-bit float, only available with `ARM C language extensions `_ + * - ``float32`` + - 4 + - 32-bit float +.. _devices_and_streams: + +Devices and Streams +=================== + +.. currentmodule:: mlx.core + +.. autosummary:: + :toctree: _autosummary + + Device + default_device + set_default_device + Stream + default_stream + new_stream + set_default_stream +.. _fft: + +FFT +=== + +.. currentmodule:: mlx.core.fft + +.. autosummary:: + :toctree: _autosummary + + fft + ifft + fft2 + ifft2 + fftn + ifftn + rfft + irfft + rfft2 + irfft2 + rfftn + irfftn +.. _linalg: + +Linear Algebra +============== + +.. currentmodule:: mlx.core.linalg + +.. autosummary:: + :toctree: _autosummary + + norm +.. _nn: + +.. currentmodule:: mlx.nn + +Neural Networks +=============== + +Writing arbitrarily complex neural networks in MLX can be done using only +:class:`mlx.core.array` and :meth:`mlx.core.value_and_grad`. However, this requires the +user to write again and again the same simple neural network operations as well +as handle all the parameter state and initialization manually and explicitly. + +The module :mod:`mlx.nn` solves this problem by providing an intuitive way of +composing neural network layers, initializing their parameters, freezing them +for finetuning and more. + +Quick Start with Neural Networks +--------------------------------- + +.. code-block:: python + + import mlx.core as mx + import mlx.nn as nn + + class MLP(nn.Module): + def __init__(self, in_dims: int, out_dims: int): + super().__init__() + + self.layers = [ + nn.Linear(in_dims, 128), + nn.Linear(128, 128), + nn.Linear(128, out_dims), + ] + + def __call__(self, x): + for i, l in enumerate(self.layers): + x = mx.maximum(x, 0) if i > 0 else x + x = l(x) + return x + + # The model is created with all its parameters but nothing is initialized + # yet because MLX is lazily evaluated + mlp = MLP(2, 10) + + # We can access its parameters by calling mlp.parameters() + params = mlp.parameters() + print(params["layers"][0]["weight"].shape) + + # Printing a parameter will cause it to be evaluated and thus initialized + print(params["layers"][0]) + + # We can also force evaluate all parameters to initialize the model + mx.eval(mlp.parameters()) + + # A simple loss function. + # NOTE: It doesn't matter how it uses the mlp model. It currently captures + # it from the local scope. It could be a positional argument or a + # keyword argument. + def l2_loss(x, y): + y_hat = mlp(x) + return (y_hat - y).square().mean() + + # Calling `nn.value_and_grad` instead of `mx.value_and_grad` returns the + # gradient with respect to `mlp.trainable_parameters()` + loss_and_grad = nn.value_and_grad(mlp, l2_loss) + +.. _module_class: + +The Module Class +---------------- + +The workhorse of any neural network library is the :class:`Module` class. In +MLX the :class:`Module` class is a container of :class:`mlx.core.array` or +:class:`Module` instances. Its main function is to provide a way to +recursively **access** and **update** its parameters and those of its +submodules. + +Parameters +^^^^^^^^^^ + +A parameter of a module is any public member of type :class:`mlx.core.array` (its +name should not start with ``_``). It can be arbitrarily nested in other +:class:`Module` instances or lists and dictionaries. + +:meth:`Module.parameters` can be used to extract a nested dictionary with all +the parameters of a module and its submodules. + +A :class:`Module` can also keep track of "frozen" parameters. See the +:meth:`Module.freeze` method for more details. :meth:`mlx.nn.value_and_grad` +the gradients returned will be with respect to these trainable parameters. + + +Updating the Parameters +^^^^^^^^^^^^^^^^^^^^^^^ + +MLX modules allow accessing and updating individual parameters. However, most +times we need to update large subsets of a module's parameters. This action is +performed by :meth:`Module.update`. + + +Inspecting Modules +^^^^^^^^^^^^^^^^^^ + +The simplest way to see the model architecture is to print it. Following along with +the above example, you can print the ``MLP`` with: + +.. code-block:: python + + print(mlp) + +This will display: + +.. code-block:: shell + + MLP( + (layers.0): Linear(input_dims=2, output_dims=128, bias=True) + (layers.1): Linear(input_dims=128, output_dims=128, bias=True) + (layers.2): Linear(input_dims=128, output_dims=10, bias=True) + ) + +To get more detailed information on the arrays in a :class:`Module` you can use +:func:`mlx.utils.tree_map` on the parameters. For example, to see the shapes of +all the parameters in a :class:`Module` do: + +.. code-block:: python + + from mlx.utils import tree_map + shapes = tree_map(lambda p: p.shape, mlp.parameters()) + +As another example, you can count the number of parameters in a :class:`Module` +with: + +.. code-block:: python + + from mlx.utils import tree_flatten + num_params = sum(v.size for _, v in tree_flatten(mlp.parameters())) + + +Value and Grad +-------------- + +Using a :class:`Module` does not preclude using MLX's high order function +transformations (:meth:`mlx.core.value_and_grad`, :meth:`mlx.core.grad`, etc.). However, +these function transformations assume pure functions, namely the parameters +should be passed as an argument to the function being transformed. + +There is an easy pattern to achieve that with MLX modules + +.. code-block:: python + + model = ... + + def f(params, other_inputs): + model.update(params) # <---- Necessary to make the model use the passed parameters + return model(other_inputs) + + f(model.trainable_parameters(), mx.zeros((10,))) + +However, :meth:`mlx.nn.value_and_grad` provides precisely this pattern and only +computes the gradients with respect to the trainable parameters of the model. + +In detail: + +- it wraps the passed function with a function that calls :meth:`Module.update` + to make sure the model is using the provided parameters. +- it calls :meth:`mlx.core.value_and_grad` to transform the function into a function + that also computes the gradients with respect to the passed parameters. +- it wraps the returned function with a function that passes the trainable + parameters as the first argument to the function returned by + :meth:`mlx.core.value_and_grad` + +.. autosummary:: + :toctree: _autosummary + + value_and_grad + +.. toctree:: + + nn/module + nn/layers + nn/functions + nn/losses + nn/init +.. _ops: + +Operations +========== + +.. currentmodule:: mlx.core + +.. autosummary:: + :toctree: _autosummary + + abs + add + all + allclose + any + arange + arccos + arccosh + arcsin + arcsinh + arctan + arctanh + argmax + argmin + argpartition + argsort + array_equal + broadcast_to + ceil + clip + concatenate + convolve + conv1d + conv2d + cos + cosh + dequantize + divide + divmod + equal + erf + erfinv + exp + expand_dims + eye + flatten + floor + floor_divide + full + greater + greater_equal + identity + inner + isnan + isposinf + isneginf + isinf + less + less_equal + linspace + load + log + log2 + log10 + log1p + logaddexp + logical_not + logical_and + logical_or + logsumexp + matmul + max + maximum + mean + min + minimum + moveaxis + multiply + negative + ones + ones_like + outer + partition + pad + prod + quantize + quantized_matmul + reciprocal + repeat + reshape + round + rsqrt + save + savez + savez_compressed + save_gguf + save_safetensors + sigmoid + sign + sin + sinh + softmax + sort + split + sqrt + square + squeeze + stack + stop_gradient + subtract + sum + swapaxes + take + take_along_axis + tan + tanh + tensordot + transpose + tri + tril + triu + var + where + zeros + zeros_like +.. _optimizers: + +Optimizers +========== + +The optimizers in MLX can be used both with :mod:`mlx.nn` but also with pure +:mod:`mlx.core` functions. A typical example involves calling +:meth:`Optimizer.update` to update a model's parameters based on the loss +gradients and subsequently calling :func:`mlx.core.eval` to evaluate both the +model's parameters and the **optimizer state**. + +.. code-block:: python + + # Create a model + model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes) + mx.eval(model.parameters()) + + # Create the gradient function and the optimizer + loss_and_grad_fn = nn.value_and_grad(model, loss_fn) + optimizer = optim.SGD(learning_rate=learning_rate) + + for e in range(num_epochs): + for X, y in batch_iterate(batch_size, train_images, train_labels): + loss, grads = loss_and_grad_fn(model, X, y) + + # Update the model with the gradients. So far no computation has happened. + optimizer.update(model, grads) + + # Compute the new parameters but also the optimizer state. + mx.eval(model.parameters(), optimizer.state) + +.. currentmodule:: mlx.optimizers + +.. autosummary:: + :toctree: _autosummary + :template: optimizers-template.rst + + OptimizerState + Optimizer + SGD + RMSprop + Adagrad + Adafactor + AdaDelta + Adam + AdamW + Adamax + Lion +.. _random: + +Random +====== + +Random sampling functions in MLX use an implicit global PRNG state by default. +However, all function take an optional ``key`` keyword argument for when more +fine-grained control or explicit state management is needed. + +For example, you can generate random numbers with: + +.. code-block:: python + + for _ in range(3): + print(mx.random.uniform()) + +which will print a sequence of unique pseudo random numbers. Alternatively you +can explicitly set the key: + +.. code-block:: python + + key = mx.random.key(0) + for _ in range(3): + print(mx.random.uniform(key=key)) + +which will yield the same pseudo random number at each iteration. + +Following `JAX's PRNG design `_ +we use a splittable version of Threefry, which is a counter-based PRNG. + +.. currentmodule:: mlx.core.random + +.. autosummary:: + :toctree: _autosummary + + bernoulli + categorical + gumbel + key + normal + randint + seed + split + truncated_normal + uniform +.. _transforms: + +Transforms +========== + +.. currentmodule:: mlx.core + +.. autosummary:: + :toctree: _autosummary + + eval + grad + value_and_grad + jvp + vjp + vmap + simplify +.. _utils: + +Tree Utils +========== + +In MLX we consider a python tree to be an arbitrarily nested collection of +dictionaries, lists and tuples without cycles. Functions in this module that +return python trees will be using the default python ``dict``, ``list`` and +``tuple`` but they can usually process objects that inherit from any of these. + +.. note:: + Dictionaries should have keys that are valid python identifiers. + +.. currentmodule:: mlx.utils + +.. autosummary:: + :toctree: _autosummary + + tree_flatten + tree_unflatten + tree_map diff --git a/prompts/gpts/knowledge/MLX Guru/usage.txt b/prompts/gpts/knowledge/MLX Guru/usage.txt new file mode 100644 index 0000000..7cce765 --- /dev/null +++ b/prompts/gpts/knowledge/MLX Guru/usage.txt @@ -0,0 +1,644 @@ +.. _function_transforms: + +Function Transforms +=================== + +.. currentmodule:: mlx.core + +MLX uses composable function transformations for automatic differentiation and +vectorization. The key idea behind composable function transformations is that +every transformation returns a function which can be further transformed. + +Here is a simple example: + +.. code-block:: shell + + >>> dfdx = mx.grad(mx.sin) + >>> dfdx(mx.array(mx.pi)) + array(-1, dtype=float32) + >>> mx.cos(mx.array(mx.pi)) + array(-1, dtype=float32) + + +The output of :func:`grad` on :func:`sin` is simply another function. In this +case it is the gradient of the sine function which is exactly the cosine +function. To get the second derivative you can do: + +.. code-block:: shell + + >>> d2fdx2 = mx.grad(mx.grad(mx.sin)) + >>> d2fdx2(mx.array(mx.pi / 2)) + array(-1, dtype=float32) + >>> mx.sin(mx.array(mx.pi / 2)) + array(1, dtype=float32) + +Using :func:`grad` on the output of :func:`grad` is always ok. You keep +getting higher order derivatives. + +Any of the MLX function transformations can be composed in any order to any +depth. To see the complete list of function transformations check-out the +:ref:`API documentation `. See the following sections for more +information on :ref:`automatic differentiaion ` and +:ref:`automatic vectorization `. + +Automatic Differentiation +------------------------- + +.. _auto diff: + +Automatic differentiation in MLX works on functions rather than on implicit +graphs. + +.. note:: + + If you are coming to MLX from PyTorch, you no longer need functions like + ``backward``, ``zero_grad``, and ``detach``, or properties like + ``requires_grad``. + +The most basic example is taking the gradient of a scalar-valued function as we +saw above. You can use the :func:`grad` and :func:`value_and_grad` function to +compute gradients of more complex functions. By default these functions compute +the gradient with respect to the first argument: + +.. code-block:: python + + def loss_fn(w, x, y): + return mx.mean(mx.square(w * x - y)) + + w = mx.array(1.0) + x = mx.array([0.5, -0.5]) + y = mx.array([1.5, -1.5]) + + # Computes the gradient of loss_fn with respect to w: + grad_fn = mx.grad(loss_fn) + dloss_dw = grad_fn(w, x, y) + # Prints array(-1, dtype=float32) + print(dloss_dw) + + # To get the gradient with respect to x we can do: + grad_fn = mx.grad(loss_fn, argnums=1) + dloss_dx = grad_fn(w, x, y) + # Prints array([-1, 1], dtype=float32) + print(dloss_dx) + + +One way to get the loss and gradient is to call ``loss_fn`` followed by +``grad_fn``, but this can result in a lot of redundant work. Instead, you +should use :func:`value_and_grad`. Continuing the above example: + + +.. code-block:: python + + # Computes the gradient of loss_fn with respect to w: + loss_and_grad_fn = mx.value_and_grad(loss_fn) + loss, dloss_dw = loss_and_grad_fn(w, x, y) + + # Prints array(1, dtype=float32) + print(loss) + + # Prints array(-1, dtype=float32) + print(dloss_dw) + + +You can also take the gradient with respect to arbitrarily nested Python +containers of arrays (specifically any of :obj:`list`, :obj:`tuple`, or +:obj:`dict`). + +Suppose we wanted a weight and a bias parameter in the above example. A nice +way to do that is the following: + +.. code-block:: python + + def loss_fn(params, x, y): + w, b = params["weight"], params["bias"] + h = w * x + b + return mx.mean(mx.square(h - y)) + + params = {"weight": mx.array(1.0), "bias": mx.array(0.0)} + x = mx.array([0.5, -0.5]) + y = mx.array([1.5, -1.5]) + + # Computes the gradient of loss_fn with respect to both the + # weight and bias: + grad_fn = mx.grad(loss_fn) + grads = grad_fn(params, x, y) + + # Prints + # {'weight': array(-1, dtype=float32), 'bias': array(0, dtype=float32)} + print(grads) + +Notice the tree structure of the parameters is preserved in the gradients. + +In some cases you may want to stop gradients from propagating through a +part of the function. You can use the :func:`stop_gradient` for that. + + +Automatic Vectorization +----------------------- + +.. _vmap: + +Use :func:`vmap` to automate vectorizing complex functions. Here we'll go +through a basic and contrived example for the sake of clarity, but :func:`vmap` +can be quite powerful for more complex functions which are difficult to optimize +by hand. + +.. warning:: + + Some operations are not yet supported with :func:`vmap`. If you encounter an error + like: ``ValueError: Primitive's vmap not implemented.`` file an `issue + `_ and include your function. + We will prioritize including it. + +A naive way to add the elements from two sets of vectors is with a loop: + +.. code-block:: python + + xs = mx.random.uniform(shape=(4096, 100)) + ys = mx.random.uniform(shape=(100, 4096)) + + def naive_add(xs, ys): + return [xs[i] + ys[:, i] for i in range(xs.shape[1])] + +Instead you can use :func:`vmap` to automatically vectorize the addition: + +.. code-block:: python + + # Vectorize over the second dimension of x and the + # first dimension of y + vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(1, 0)) + +The ``in_axes`` parameter can be used to specify which dimensions of the +corresponding input to vectorize over. Similarly, use ``out_axes`` to specify +where the vectorized axes should be in the outputs. + +Let's time these two different versions: + +.. code-block:: python + + import timeit + + print(timeit.timeit(lambda: mx.eval(naive_add(xs, ys)), number=100)) + print(timeit.timeit(lambda: mx.eval(vmap_add(xs, ys)), number=100)) + +On an M1 Max the naive version takes in total ``0.390`` seconds whereas the +vectorized version takes only ``0.025`` seconds, more than ten times faster. + +Of course, this operation is quite contrived. A better approach is to simply do +``xs + ys.T``, but for more complex functions :func:`vmap` can be quite handy. +.. _indexing: + +Indexing Arrays +=============== + +.. currentmodule:: mlx.core + +For the most part, indexing an MLX :obj:`array` works the same as indexing a +NumPy :obj:`numpy.ndarray`. See the `NumPy documentation +`_ for more details on +how that works. + +For example, you can use regular integers and slices (:obj:`slice`) to index arrays: + +.. code-block:: shell + + >>> arr = mx.arange(10) + >>> arr[3] + array(3, dtype=int32) + >>> arr[-2] # negative indexing works + array(8, dtype=int32) + >>> arr[2:8:2] # start, stop, stride + array([2, 4, 6], dtype=int32) + +For multi-dimensional arrays, the ``...`` or :obj:`Ellipsis` syntax works as in NumPy: + +.. code-block:: shell + + >>> arr = mx.arange(8).reshape(2, 2, 2) + >>> arr[:, :, 0] + array(3, dtype=int32) + array([[0, 2], + [4, 6]], dtype=int32 + >>> arr[..., 0] + array([[0, 2], + [4, 6]], dtype=int32 + +You can index with ``None`` to create a new axis: + +.. code-block:: shell + + >>> arr = mx.arange(8) + >>> arr.shape + [8] + >>> arr[None].shape + [1, 8] + + +You can also use an :obj:`array` to index another :obj:`array`: + +.. code-block:: shell + + >>> arr = mx.arange(10) + >>> idx = mx.array([5, 7]) + >>> arr[idx] + array([5, 7], dtype=int32) + +Mixing and matching integers, :obj:`slice`, ``...``, and :obj:`array` indices +works just as in NumPy. + +Other functions which may be useful for indexing arrays are :func:`take` and +:func:`take_along_axis`. + +Differences from NumPy +---------------------- + +.. Note:: + + MLX indexing is different from NumPy indexing in two important ways: + + * Indexing does not perform bounds checking. Indexing out of bounds is + undefined behavior. + * Boolean mask based indexing is not yet supported. + +The reason for the lack of bounds checking is that exceptions cannot propagate +from the GPU. Performing bounds checking for array indices before launching the +kernel would be extremely inefficient. + +Indexing with boolean masks is something that MLX may support in the future. In +general, MLX has limited support for operations for which outputs +*shapes* are dependent on input *data*. Other examples of these types of +operations which MLX does not yet support include :func:`numpy.nonzero` and the +single input version of :func:`numpy.where`. + +In Place Updates +---------------- + +In place updates to indexed arrays are possible in MLX. For example: + +.. code-block:: shell + + >>> a = mx.array([1, 2, 3]) + >>> a[2] = 0 + >>> a + array([1, 2, 0], dtype=int32) + +Just as in NumPy, in place updates will be reflected in all references to the +same array: + +.. code-block:: shell + + >>> a = mx.array([1, 2, 3]) + >>> b = a + >>> b[2] = 0 + >>> b + array([1, 2, 0], dtype=int32) + >>> a + array([1, 2, 0], dtype=int32) + +Transformations of functions which use in-place updates are allowed and work as +expected. For example: + +.. code-block:: python + + def fun(x, idx): + x[idx] = 2.0 + return x.sum() + + dfdx = mx.grad(fun)(mx.array([1.0, 2.0, 3.0]), mx.array([1])) + print(dfdx) # Prints: array([1, 0, 1], dtype=float32) + +In the above ``dfdx`` will have the correct gradient, namely zeros at ``idx`` +and ones elsewhere. +.. _lazy eval: + +Lazy Evaluation +=============== + +.. currentmodule:: mlx.core + +Why Lazy Evaluation +------------------- + +When you perform operations in MLX, no computation actually happens. Instead a +compute graph is recorded. The actual computation only happens if an +:func:`eval` is performed. + +MLX uses lazy evaluation because it has some nice features, some of which we +describe below. + +Transforming Compute Graphs +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Lazy evaluation let's us record a compute graph without actually doing any +computations. This is useful for function transformations like :func:`grad` and +:func:`vmap` and graph optimizations like :func:`simplify`. + +Currently, MLX does not compile and rerun compute graphs. They are all +generated dynamically. However, lazy evaluation makes it much easier to +integrate compilation for future performance enhancements. + +Only Compute What You Use +^^^^^^^^^^^^^^^^^^^^^^^^^ + +In MLX you do not need to worry as much about computing outputs that are never +used. For example: + +.. code-block:: python + + def fun(x): + a = fun1(x) + b = expensive_fun(a) + return a, b + + y, _ = fun(x) + +Here, we never actually compute the output of ``expensive_fun``. Use this +pattern with care though, as the graph of ``expensive_fun`` is still built, and +that has some cost associated to it. + +Similarly, lazy evaluation can be beneficial for saving memory while keeping +code simple. Say you have a very large model ``Model`` derived from +:obj:`mlx.nn.Module`. You can instantiate this model with ``model = Model()``. +Typically, this will initialize all of the weights as ``float32``, but the +initialization does not actually compute anything until you perform an +:func:`eval`. If you update the model with ``float16`` weights, your maximum +consumed memory will be half that required if eager computation was used +instead. + +This pattern is simple to do in MLX thanks to lazy computation: + +.. code-block:: python + + model = Model() # no memory used yet + model.load_weights("weights_fp16.safetensors") + +When to Evaluate +---------------- + +A common question is when to use :func:`eval`. The trade-off is between +letting graphs get too large and not batching enough useful work. + +For example: + +.. code-block:: python + + for _ in range(100): + a = a + b + mx.eval(a) + b = b * 2 + mx.eval(b) + +This is a bad idea because there is some fixed overhead with each graph +evaluation. On the other hand, there is some slight overhead which grows with +the compute graph size, so extremely large graphs (while computationally +correct) can be costly. + +Luckily, a wide range of compute graph sizes work pretty well with MLX: +anything from a few tens of operations to many thousands of operations per +evaluation should be okay. + +Most numerical computations have an iterative outer loop (e.g. the iteration in +stochastic gradient descent). A natural and usually efficient place to use +:func:`eval` is at each iteration of this outer loop. + +Here is a concrete example: + +.. code-block:: python + + for batch in dataset: + + # Nothing has been evaluated yet + loss, grad = value_and_grad_fn(model, batch) + + # Still nothing has been evaluated + optimizer.update(model, grad) + + # Evaluate the loss and the new parameters which will + # run the full gradient computation and optimizer update + mx.eval(loss, model.parameters()) + + +An important behavior to be aware of is when the graph will be implicitly +evaluated. Anytime you ``print`` an array, convert it to an +:obj:`numpy.ndarray`, or otherwise access it's memory via :obj:`memoryview`, +the graph will be evaluated. Saving arrays via :func:`save` (or any other MLX +saving functions) will also evaluate the array. + + +Calling :func:`array.item` on a scalar array will also evaluate it. In the +example above, printing the loss (``print(loss)``) or adding the loss scalar to +a list (``losses.append(loss.item())``) would cause a graph evaluation. If +these lines are before ``mx.eval(loss, model.parameters())`` then this +will be a partial evaluation, computing only the forward pass. + +Also, calling :func:`eval` on an array or set of arrays multiple times is +perfectly fine. This is effectively a no-op. + +.. warning:: + + Using scalar arrays for control-flow will cause an evaluation. + +Here is an example: + +.. code-block:: python + + def fun(x): + h, y = first_layer(x) + if y > 0: # An evaluation is done here! + z = second_layer_a(h) + else: + z = second_layer_b(h) + return z + +Using arrays for control flow should be done with care. The above example works +and can even be used with gradient transformations. However, this can be very +inefficient if evaluations are done too frequently. +.. _numpy: + +Conversion to NumPy and Other Frameworks +======================================== + +MLX array implements the `Python Buffer Protocol `_. +Let's convert an array to NumPy and back. + +.. code-block:: python + + import mlx.core as mx + import numpy as np + + a = mx.arange(3) + b = np.array(a) # copy of a + c = mx.array(b) # copy of b + +.. note:: + + Since NumPy does not support ``bfloat16`` arrays, you will need to convert to ``float16`` or ``float32`` first: + ``np.array(a.astype(mx.float32))``. + Otherwise, you will receive an error like: ``Item size 2 for PEP 3118 buffer format string does not match the dtype V item size 0.`` + +By default, NumPy copies data to a new array. This can be prevented by creating an array view: + +.. code-block:: python + + a = mx.arange(3) + a_view = np.array(a, copy=False) + print(a_view.flags.owndata) # False + a_view[0] = 1 + print(a[0].item()) # 1 + +A NumPy array view is a normal NumPy array, except that it does not own its memory. +This means writing to the view is reflected in the original array. + +While this is quite powerful to prevent copying arrays, it should be noted that external changes to the memory of arrays cannot be reflected in gradients. + +Let's demonstrate this in an example: + +.. code-block:: python + + def f(x): + x_view = np.array(x, copy=False) + x_view[:] *= x_view # modify memory without telling mx + return x.sum() + + x = mx.array([3.0]) + y, df = mx.value_and_grad(f)(x) + print("f(x) = x² =", y.item()) # 9.0 + print("f'(x) = 2x !=", df.item()) # 1.0 + + +The function ``f`` indirectly modifies the array ``x`` through a memory view. +However, this modification is not reflected in the gradient, as seen in the last line outputting ``1.0``, +representing the gradient of the sum operation alone. +The squaring of ``x`` occurs externally to MLX, meaning that no gradient is incorporated. +It's important to note that a similar issue arises during array conversion and copying. +For instance, a function defined as ``mx.array(np.array(x)**2).sum()`` would also result in an incorrect gradient, +even though no in-place operations on MLX memory are executed. + +PyTorch +------- + +.. warning:: + + PyTorch Support for :obj:`memoryview` is experimental and can break for + multi-dimensional arrays. Casting to NumPy first is advised for now. + +PyTorch supports the buffer protocol, but it requires an explicit :obj:`memoryview`. + +.. code-block:: python + + import mlx.core as mx + import torch + + a = mx.arange(3) + b = torch.tensor(memoryview(a)) + c = mx.array(b.numpy()) + +Conversion from PyTorch tensors back to arrays must be done via intermediate NumPy arrays with ``numpy()``. + +JAX +--- +JAX fully supports the buffer protocol. + +.. code-block:: python + + import mlx.core as mx + import jax.numpy as jnp + + a = mx.arange(3) + b = jnp.array(a) + c = mx.array(b) + +TensorFlow +---------- + +TensorFlow supports the buffer protocol, but it requires an explicit :obj:`memoryview`. + +.. code-block:: python + + import mlx.core as mx + import tensorflow as tf + + a = mx.arange(3) + b = tf.constant(memoryview(a)) + c = mx.array(b) +.. _saving_and_loading: + +Saving and Loading Arrays +========================= + +.. currentmodule:: mlx.core + +MLX supports multiple array serialization formats. + +.. list-table:: Serialization Formats + :widths: 20 8 25 25 + :header-rows: 1 + + * - Format + - Extension + - Function + - Notes + * - NumPy + - ``.npy`` + - :func:`save` + - Single arrays only + * - NumPy archive + - ``.npz`` + - :func:`savez` and :func:`savez_compressed` + - Multiple arrays + * - Safetensors + - ``.safetensors`` + - :func:`save_safetensors` + - Multiple arrays + * - GGUF + - ``.gguf`` + - :func:`save_gguf` + - Multiple arrays + +The :func:`load` function will load any of the supported serialization +formats. It determines the format from the extensions. The output of +:func:`load` depends on the format. + +Here's an example of saving a single array to a file: + +.. code-block:: shell + + >>> a = mx.array([1.0]) + >>> mx.save("array", a) + +The array ``a`` will be saved in the file ``array.npy`` (notice the extension +is automatically added). Including the extension is optional; if it is missing +it will be added. You can load the array with: + +.. code-block:: shell + + >>> mx.load("array.npy", a) + array([1], dtype=float32) + +Here's an example of saving several arrays to a single file: + +.. code-block:: shell + + >>> a = mx.array([1.0]) + >>> b = mx.array([2.0]) + >>> mx.savez("arrays", a, b=b) + +For compatibility with :func:`numpy.savez` the MLX :func:`savez` takes arrays +as arguments. If the keywords are missing, then default names will be +provided. This can be loaded with: + +.. code-block:: shell + + >>> mx.load("arrays.npz") + {'b': array([2], dtype=float32), 'arr_0': array([1], dtype=float32)} + +In this case :func:`load` returns a dictionary of names to arrays. + +The functions :func:`save_safetensors` and :func:`save_gguf` are similar to +:func:`savez`, but they take as input a :obj:`dict` of string names to arrays: + +.. code-block:: shell + + >>> a = mx.array([1.0]) + >>> b = mx.array([2.0]) + >>> mx.save_safetensors("arrays", {"a": a, "b": b})