.. _function_transforms: Function Transforms =================== .. currentmodule:: mlx.core MLX uses composable function transformations for automatic differentiation and vectorization. The key idea behind composable function transformations is that every transformation returns a function which can be further transformed. Here is a simple example: .. code-block:: shell >>> dfdx = mx.grad(mx.sin) >>> dfdx(mx.array(mx.pi)) array(-1, dtype=float32) >>> mx.cos(mx.array(mx.pi)) array(-1, dtype=float32) The output of :func:`grad` on :func:`sin` is simply another function. In this case it is the gradient of the sine function which is exactly the cosine function. To get the second derivative you can do: .. code-block:: shell >>> d2fdx2 = mx.grad(mx.grad(mx.sin)) >>> d2fdx2(mx.array(mx.pi / 2)) array(-1, dtype=float32) >>> mx.sin(mx.array(mx.pi / 2)) array(1, dtype=float32) Using :func:`grad` on the output of :func:`grad` is always ok. You keep getting higher order derivatives. Any of the MLX function transformations can be composed in any order to any depth. To see the complete list of function transformations check-out the :ref:`API documentation `. See the following sections for more information on :ref:`automatic differentiaion ` and :ref:`automatic vectorization `. Automatic Differentiation ------------------------- .. _auto diff: Automatic differentiation in MLX works on functions rather than on implicit graphs. .. note:: If you are coming to MLX from PyTorch, you no longer need functions like ``backward``, ``zero_grad``, and ``detach``, or properties like ``requires_grad``. The most basic example is taking the gradient of a scalar-valued function as we saw above. You can use the :func:`grad` and :func:`value_and_grad` function to compute gradients of more complex functions. By default these functions compute the gradient with respect to the first argument: .. code-block:: python def loss_fn(w, x, y): return mx.mean(mx.square(w * x - y)) w = mx.array(1.0) x = mx.array([0.5, -0.5]) y = mx.array([1.5, -1.5]) # Computes the gradient of loss_fn with respect to w: grad_fn = mx.grad(loss_fn) dloss_dw = grad_fn(w, x, y) # Prints array(-1, dtype=float32) print(dloss_dw) # To get the gradient with respect to x we can do: grad_fn = mx.grad(loss_fn, argnums=1) dloss_dx = grad_fn(w, x, y) # Prints array([-1, 1], dtype=float32) print(dloss_dx) One way to get the loss and gradient is to call ``loss_fn`` followed by ``grad_fn``, but this can result in a lot of redundant work. Instead, you should use :func:`value_and_grad`. Continuing the above example: .. code-block:: python # Computes the gradient of loss_fn with respect to w: loss_and_grad_fn = mx.value_and_grad(loss_fn) loss, dloss_dw = loss_and_grad_fn(w, x, y) # Prints array(1, dtype=float32) print(loss) # Prints array(-1, dtype=float32) print(dloss_dw) You can also take the gradient with respect to arbitrarily nested Python containers of arrays (specifically any of :obj:`list`, :obj:`tuple`, or :obj:`dict`). Suppose we wanted a weight and a bias parameter in the above example. A nice way to do that is the following: .. code-block:: python def loss_fn(params, x, y): w, b = params["weight"], params["bias"] h = w * x + b return mx.mean(mx.square(h - y)) params = {"weight": mx.array(1.0), "bias": mx.array(0.0)} x = mx.array([0.5, -0.5]) y = mx.array([1.5, -1.5]) # Computes the gradient of loss_fn with respect to both the # weight and bias: grad_fn = mx.grad(loss_fn) grads = grad_fn(params, x, y) # Prints # {'weight': array(-1, dtype=float32), 'bias': array(0, dtype=float32)} print(grads) Notice the tree structure of the parameters is preserved in the gradients. In some cases you may want to stop gradients from propagating through a part of the function. You can use the :func:`stop_gradient` for that. Automatic Vectorization ----------------------- .. _vmap: Use :func:`vmap` to automate vectorizing complex functions. Here we'll go through a basic and contrived example for the sake of clarity, but :func:`vmap` can be quite powerful for more complex functions which are difficult to optimize by hand. .. warning:: Some operations are not yet supported with :func:`vmap`. If you encounter an error like: ``ValueError: Primitive's vmap not implemented.`` file an `issue `_ and include your function. We will prioritize including it. A naive way to add the elements from two sets of vectors is with a loop: .. code-block:: python xs = mx.random.uniform(shape=(4096, 100)) ys = mx.random.uniform(shape=(100, 4096)) def naive_add(xs, ys): return [xs[i] + ys[:, i] for i in range(xs.shape[1])] Instead you can use :func:`vmap` to automatically vectorize the addition: .. code-block:: python # Vectorize over the second dimension of x and the # first dimension of y vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(1, 0)) The ``in_axes`` parameter can be used to specify which dimensions of the corresponding input to vectorize over. Similarly, use ``out_axes`` to specify where the vectorized axes should be in the outputs. Let's time these two different versions: .. code-block:: python import timeit print(timeit.timeit(lambda: mx.eval(naive_add(xs, ys)), number=100)) print(timeit.timeit(lambda: mx.eval(vmap_add(xs, ys)), number=100)) On an M1 Max the naive version takes in total ``0.390`` seconds whereas the vectorized version takes only ``0.025`` seconds, more than ten times faster. Of course, this operation is quite contrived. A better approach is to simply do ``xs + ys.T``, but for more complex functions :func:`vmap` can be quite handy. .. _indexing: Indexing Arrays =============== .. currentmodule:: mlx.core For the most part, indexing an MLX :obj:`array` works the same as indexing a NumPy :obj:`numpy.ndarray`. See the `NumPy documentation `_ for more details on how that works. For example, you can use regular integers and slices (:obj:`slice`) to index arrays: .. code-block:: shell >>> arr = mx.arange(10) >>> arr[3] array(3, dtype=int32) >>> arr[-2] # negative indexing works array(8, dtype=int32) >>> arr[2:8:2] # start, stop, stride array([2, 4, 6], dtype=int32) For multi-dimensional arrays, the ``...`` or :obj:`Ellipsis` syntax works as in NumPy: .. code-block:: shell >>> arr = mx.arange(8).reshape(2, 2, 2) >>> arr[:, :, 0] array(3, dtype=int32) array([[0, 2], [4, 6]], dtype=int32 >>> arr[..., 0] array([[0, 2], [4, 6]], dtype=int32 You can index with ``None`` to create a new axis: .. code-block:: shell >>> arr = mx.arange(8) >>> arr.shape [8] >>> arr[None].shape [1, 8] You can also use an :obj:`array` to index another :obj:`array`: .. code-block:: shell >>> arr = mx.arange(10) >>> idx = mx.array([5, 7]) >>> arr[idx] array([5, 7], dtype=int32) Mixing and matching integers, :obj:`slice`, ``...``, and :obj:`array` indices works just as in NumPy. Other functions which may be useful for indexing arrays are :func:`take` and :func:`take_along_axis`. Differences from NumPy ---------------------- .. Note:: MLX indexing is different from NumPy indexing in two important ways: * Indexing does not perform bounds checking. Indexing out of bounds is undefined behavior. * Boolean mask based indexing is not yet supported. The reason for the lack of bounds checking is that exceptions cannot propagate from the GPU. Performing bounds checking for array indices before launching the kernel would be extremely inefficient. Indexing with boolean masks is something that MLX may support in the future. In general, MLX has limited support for operations for which outputs *shapes* are dependent on input *data*. Other examples of these types of operations which MLX does not yet support include :func:`numpy.nonzero` and the single input version of :func:`numpy.where`. In Place Updates ---------------- In place updates to indexed arrays are possible in MLX. For example: .. code-block:: shell >>> a = mx.array([1, 2, 3]) >>> a[2] = 0 >>> a array([1, 2, 0], dtype=int32) Just as in NumPy, in place updates will be reflected in all references to the same array: .. code-block:: shell >>> a = mx.array([1, 2, 3]) >>> b = a >>> b[2] = 0 >>> b array([1, 2, 0], dtype=int32) >>> a array([1, 2, 0], dtype=int32) Transformations of functions which use in-place updates are allowed and work as expected. For example: .. code-block:: python def fun(x, idx): x[idx] = 2.0 return x.sum() dfdx = mx.grad(fun)(mx.array([1.0, 2.0, 3.0]), mx.array([1])) print(dfdx) # Prints: array([1, 0, 1], dtype=float32) In the above ``dfdx`` will have the correct gradient, namely zeros at ``idx`` and ones elsewhere. .. _lazy eval: Lazy Evaluation =============== .. currentmodule:: mlx.core Why Lazy Evaluation ------------------- When you perform operations in MLX, no computation actually happens. Instead a compute graph is recorded. The actual computation only happens if an :func:`eval` is performed. MLX uses lazy evaluation because it has some nice features, some of which we describe below. Transforming Compute Graphs ^^^^^^^^^^^^^^^^^^^^^^^^^^^ Lazy evaluation let's us record a compute graph without actually doing any computations. This is useful for function transformations like :func:`grad` and :func:`vmap` and graph optimizations like :func:`simplify`. Currently, MLX does not compile and rerun compute graphs. They are all generated dynamically. However, lazy evaluation makes it much easier to integrate compilation for future performance enhancements. Only Compute What You Use ^^^^^^^^^^^^^^^^^^^^^^^^^ In MLX you do not need to worry as much about computing outputs that are never used. For example: .. code-block:: python def fun(x): a = fun1(x) b = expensive_fun(a) return a, b y, _ = fun(x) Here, we never actually compute the output of ``expensive_fun``. Use this pattern with care though, as the graph of ``expensive_fun`` is still built, and that has some cost associated to it. Similarly, lazy evaluation can be beneficial for saving memory while keeping code simple. Say you have a very large model ``Model`` derived from :obj:`mlx.nn.Module`. You can instantiate this model with ``model = Model()``. Typically, this will initialize all of the weights as ``float32``, but the initialization does not actually compute anything until you perform an :func:`eval`. If you update the model with ``float16`` weights, your maximum consumed memory will be half that required if eager computation was used instead. This pattern is simple to do in MLX thanks to lazy computation: .. code-block:: python model = Model() # no memory used yet model.load_weights("weights_fp16.safetensors") When to Evaluate ---------------- A common question is when to use :func:`eval`. The trade-off is between letting graphs get too large and not batching enough useful work. For example: .. code-block:: python for _ in range(100): a = a + b mx.eval(a) b = b * 2 mx.eval(b) This is a bad idea because there is some fixed overhead with each graph evaluation. On the other hand, there is some slight overhead which grows with the compute graph size, so extremely large graphs (while computationally correct) can be costly. Luckily, a wide range of compute graph sizes work pretty well with MLX: anything from a few tens of operations to many thousands of operations per evaluation should be okay. Most numerical computations have an iterative outer loop (e.g. the iteration in stochastic gradient descent). A natural and usually efficient place to use :func:`eval` is at each iteration of this outer loop. Here is a concrete example: .. code-block:: python for batch in dataset: # Nothing has been evaluated yet loss, grad = value_and_grad_fn(model, batch) # Still nothing has been evaluated optimizer.update(model, grad) # Evaluate the loss and the new parameters which will # run the full gradient computation and optimizer update mx.eval(loss, model.parameters()) An important behavior to be aware of is when the graph will be implicitly evaluated. Anytime you ``print`` an array, convert it to an :obj:`numpy.ndarray`, or otherwise access it's memory via :obj:`memoryview`, the graph will be evaluated. Saving arrays via :func:`save` (or any other MLX saving functions) will also evaluate the array. Calling :func:`array.item` on a scalar array will also evaluate it. In the example above, printing the loss (``print(loss)``) or adding the loss scalar to a list (``losses.append(loss.item())``) would cause a graph evaluation. If these lines are before ``mx.eval(loss, model.parameters())`` then this will be a partial evaluation, computing only the forward pass. Also, calling :func:`eval` on an array or set of arrays multiple times is perfectly fine. This is effectively a no-op. .. warning:: Using scalar arrays for control-flow will cause an evaluation. Here is an example: .. code-block:: python def fun(x): h, y = first_layer(x) if y > 0: # An evaluation is done here! z = second_layer_a(h) else: z = second_layer_b(h) return z Using arrays for control flow should be done with care. The above example works and can even be used with gradient transformations. However, this can be very inefficient if evaluations are done too frequently. .. _numpy: Conversion to NumPy and Other Frameworks ======================================== MLX array implements the `Python Buffer Protocol `_. Let's convert an array to NumPy and back. .. code-block:: python import mlx.core as mx import numpy as np a = mx.arange(3) b = np.array(a) # copy of a c = mx.array(b) # copy of b .. note:: Since NumPy does not support ``bfloat16`` arrays, you will need to convert to ``float16`` or ``float32`` first: ``np.array(a.astype(mx.float32))``. Otherwise, you will receive an error like: ``Item size 2 for PEP 3118 buffer format string does not match the dtype V item size 0.`` By default, NumPy copies data to a new array. This can be prevented by creating an array view: .. code-block:: python a = mx.arange(3) a_view = np.array(a, copy=False) print(a_view.flags.owndata) # False a_view[0] = 1 print(a[0].item()) # 1 A NumPy array view is a normal NumPy array, except that it does not own its memory. This means writing to the view is reflected in the original array. While this is quite powerful to prevent copying arrays, it should be noted that external changes to the memory of arrays cannot be reflected in gradients. Let's demonstrate this in an example: .. code-block:: python def f(x): x_view = np.array(x, copy=False) x_view[:] *= x_view # modify memory without telling mx return x.sum() x = mx.array([3.0]) y, df = mx.value_and_grad(f)(x) print("f(x) = x² =", y.item()) # 9.0 print("f'(x) = 2x !=", df.item()) # 1.0 The function ``f`` indirectly modifies the array ``x`` through a memory view. However, this modification is not reflected in the gradient, as seen in the last line outputting ``1.0``, representing the gradient of the sum operation alone. The squaring of ``x`` occurs externally to MLX, meaning that no gradient is incorporated. It's important to note that a similar issue arises during array conversion and copying. For instance, a function defined as ``mx.array(np.array(x)**2).sum()`` would also result in an incorrect gradient, even though no in-place operations on MLX memory are executed. PyTorch ------- .. warning:: PyTorch Support for :obj:`memoryview` is experimental and can break for multi-dimensional arrays. Casting to NumPy first is advised for now. PyTorch supports the buffer protocol, but it requires an explicit :obj:`memoryview`. .. code-block:: python import mlx.core as mx import torch a = mx.arange(3) b = torch.tensor(memoryview(a)) c = mx.array(b.numpy()) Conversion from PyTorch tensors back to arrays must be done via intermediate NumPy arrays with ``numpy()``. JAX --- JAX fully supports the buffer protocol. .. code-block:: python import mlx.core as mx import jax.numpy as jnp a = mx.arange(3) b = jnp.array(a) c = mx.array(b) TensorFlow ---------- TensorFlow supports the buffer protocol, but it requires an explicit :obj:`memoryview`. .. code-block:: python import mlx.core as mx import tensorflow as tf a = mx.arange(3) b = tf.constant(memoryview(a)) c = mx.array(b) .. _saving_and_loading: Saving and Loading Arrays ========================= .. currentmodule:: mlx.core MLX supports multiple array serialization formats. .. list-table:: Serialization Formats :widths: 20 8 25 25 :header-rows: 1 * - Format - Extension - Function - Notes * - NumPy - ``.npy`` - :func:`save` - Single arrays only * - NumPy archive - ``.npz`` - :func:`savez` and :func:`savez_compressed` - Multiple arrays * - Safetensors - ``.safetensors`` - :func:`save_safetensors` - Multiple arrays * - GGUF - ``.gguf`` - :func:`save_gguf` - Multiple arrays The :func:`load` function will load any of the supported serialization formats. It determines the format from the extensions. The output of :func:`load` depends on the format. Here's an example of saving a single array to a file: .. code-block:: shell >>> a = mx.array([1.0]) >>> mx.save("array", a) The array ``a`` will be saved in the file ``array.npy`` (notice the extension is automatically added). Including the extension is optional; if it is missing it will be added. You can load the array with: .. code-block:: shell >>> mx.load("array.npy", a) array([1], dtype=float32) Here's an example of saving several arrays to a single file: .. code-block:: shell >>> a = mx.array([1.0]) >>> b = mx.array([2.0]) >>> mx.savez("arrays", a, b=b) For compatibility with :func:`numpy.savez` the MLX :func:`savez` takes arrays as arguments. If the keywords are missing, then default names will be provided. This can be loaded with: .. code-block:: shell >>> mx.load("arrays.npz") {'b': array([2], dtype=float32), 'arr_0': array([1], dtype=float32)} In this case :func:`load` returns a dictionary of names to arrays. The functions :func:`save_safetensors` and :func:`save_gguf` are similar to :func:`savez`, but they take as input a :obj:`dict` of string names to arrays: .. code-block:: shell >>> a = mx.array([1.0]) >>> b = mx.array([2.0]) >>> mx.save_safetensors("arrays", {"a": a, "b": b})