.. _array: Array ===== .. currentmodule:: mlx.core .. autosummary:: :toctree: _autosummary array array.astype array.item array.tolist array.dtype array.ndim array.shape array.size Dtype array.abs array.all array.any array.argmax array.argmin array.cos array.dtype array.exp array.log array.log1p array.logsumexp array.max array.mean array.min array.prod array.reciprocal array.reshape array.round array.rsqrt array.sin array.split array.sqrt array.square array.sum array.transpose array.T array.var .. _data_types: :orphan: Data Types ========== .. currentmodule:: mlx.core The default floating point type is ``float32`` and the default integer type is ``int32``. The table below shows supported values for :obj:`Dtype`. .. list-table:: Supported Data Types :widths: 5 3 20 :header-rows: 1 * - Type - Bytes - Description * - ``bool_`` - 1 - Boolean (``True``, ``False``) data type * - ``uint8`` - 1 - 8-bit unsigned integer * - ``uint16`` - 2 - 16-bit unsigned integer * - ``uint32`` - 4 - 32-bit unsigned integer * - ``uint64`` - 8 - 64-bit unsigned integer * - ``int8`` - 1 - 8-bit signed integer * - ``int16`` - 2 - 16-bit signed integer * - ``int32`` - 4 - 32-bit signed integer * - ``int64`` - 8 - 64-bit signed integer * - ``float16`` - 2 - 16-bit float, only available with `ARM C language extensions `_ * - ``float32`` - 4 - 32-bit float .. _devices_and_streams: Devices and Streams =================== .. currentmodule:: mlx.core .. autosummary:: :toctree: _autosummary Device default_device set_default_device Stream default_stream new_stream set_default_stream .. _fft: FFT === .. currentmodule:: mlx.core.fft .. autosummary:: :toctree: _autosummary fft ifft fft2 ifft2 fftn ifftn rfft irfft rfft2 irfft2 rfftn irfftn .. _linalg: Linear Algebra ============== .. currentmodule:: mlx.core.linalg .. autosummary:: :toctree: _autosummary norm .. _nn: .. currentmodule:: mlx.nn Neural Networks =============== Writing arbitrarily complex neural networks in MLX can be done using only :class:`mlx.core.array` and :meth:`mlx.core.value_and_grad`. However, this requires the user to write again and again the same simple neural network operations as well as handle all the parameter state and initialization manually and explicitly. The module :mod:`mlx.nn` solves this problem by providing an intuitive way of composing neural network layers, initializing their parameters, freezing them for finetuning and more. Quick Start with Neural Networks --------------------------------- .. code-block:: python import mlx.core as mx import mlx.nn as nn class MLP(nn.Module): def __init__(self, in_dims: int, out_dims: int): super().__init__() self.layers = [ nn.Linear(in_dims, 128), nn.Linear(128, 128), nn.Linear(128, out_dims), ] def __call__(self, x): for i, l in enumerate(self.layers): x = mx.maximum(x, 0) if i > 0 else x x = l(x) return x # The model is created with all its parameters but nothing is initialized # yet because MLX is lazily evaluated mlp = MLP(2, 10) # We can access its parameters by calling mlp.parameters() params = mlp.parameters() print(params["layers"][0]["weight"].shape) # Printing a parameter will cause it to be evaluated and thus initialized print(params["layers"][0]) # We can also force evaluate all parameters to initialize the model mx.eval(mlp.parameters()) # A simple loss function. # NOTE: It doesn't matter how it uses the mlp model. It currently captures # it from the local scope. It could be a positional argument or a # keyword argument. def l2_loss(x, y): y_hat = mlp(x) return (y_hat - y).square().mean() # Calling `nn.value_and_grad` instead of `mx.value_and_grad` returns the # gradient with respect to `mlp.trainable_parameters()` loss_and_grad = nn.value_and_grad(mlp, l2_loss) .. _module_class: The Module Class ---------------- The workhorse of any neural network library is the :class:`Module` class. In MLX the :class:`Module` class is a container of :class:`mlx.core.array` or :class:`Module` instances. Its main function is to provide a way to recursively **access** and **update** its parameters and those of its submodules. Parameters ^^^^^^^^^^ A parameter of a module is any public member of type :class:`mlx.core.array` (its name should not start with ``_``). It can be arbitrarily nested in other :class:`Module` instances or lists and dictionaries. :meth:`Module.parameters` can be used to extract a nested dictionary with all the parameters of a module and its submodules. A :class:`Module` can also keep track of "frozen" parameters. See the :meth:`Module.freeze` method for more details. :meth:`mlx.nn.value_and_grad` the gradients returned will be with respect to these trainable parameters. Updating the Parameters ^^^^^^^^^^^^^^^^^^^^^^^ MLX modules allow accessing and updating individual parameters. However, most times we need to update large subsets of a module's parameters. This action is performed by :meth:`Module.update`. Inspecting Modules ^^^^^^^^^^^^^^^^^^ The simplest way to see the model architecture is to print it. Following along with the above example, you can print the ``MLP`` with: .. code-block:: python print(mlp) This will display: .. code-block:: shell MLP( (layers.0): Linear(input_dims=2, output_dims=128, bias=True) (layers.1): Linear(input_dims=128, output_dims=128, bias=True) (layers.2): Linear(input_dims=128, output_dims=10, bias=True) ) To get more detailed information on the arrays in a :class:`Module` you can use :func:`mlx.utils.tree_map` on the parameters. For example, to see the shapes of all the parameters in a :class:`Module` do: .. code-block:: python from mlx.utils import tree_map shapes = tree_map(lambda p: p.shape, mlp.parameters()) As another example, you can count the number of parameters in a :class:`Module` with: .. code-block:: python from mlx.utils import tree_flatten num_params = sum(v.size for _, v in tree_flatten(mlp.parameters())) Value and Grad -------------- Using a :class:`Module` does not preclude using MLX's high order function transformations (:meth:`mlx.core.value_and_grad`, :meth:`mlx.core.grad`, etc.). However, these function transformations assume pure functions, namely the parameters should be passed as an argument to the function being transformed. There is an easy pattern to achieve that with MLX modules .. code-block:: python model = ... def f(params, other_inputs): model.update(params) # <---- Necessary to make the model use the passed parameters return model(other_inputs) f(model.trainable_parameters(), mx.zeros((10,))) However, :meth:`mlx.nn.value_and_grad` provides precisely this pattern and only computes the gradients with respect to the trainable parameters of the model. In detail: - it wraps the passed function with a function that calls :meth:`Module.update` to make sure the model is using the provided parameters. - it calls :meth:`mlx.core.value_and_grad` to transform the function into a function that also computes the gradients with respect to the passed parameters. - it wraps the returned function with a function that passes the trainable parameters as the first argument to the function returned by :meth:`mlx.core.value_and_grad` .. autosummary:: :toctree: _autosummary value_and_grad .. toctree:: nn/module nn/layers nn/functions nn/losses nn/init .. _ops: Operations ========== .. currentmodule:: mlx.core .. autosummary:: :toctree: _autosummary abs add all allclose any arange arccos arccosh arcsin arcsinh arctan arctanh argmax argmin argpartition argsort array_equal broadcast_to ceil clip concatenate convolve conv1d conv2d cos cosh dequantize divide divmod equal erf erfinv exp expand_dims eye flatten floor floor_divide full greater greater_equal identity inner isnan isposinf isneginf isinf less less_equal linspace load log log2 log10 log1p logaddexp logical_not logical_and logical_or logsumexp matmul max maximum mean min minimum moveaxis multiply negative ones ones_like outer partition pad prod quantize quantized_matmul reciprocal repeat reshape round rsqrt save savez savez_compressed save_gguf save_safetensors sigmoid sign sin sinh softmax sort split sqrt square squeeze stack stop_gradient subtract sum swapaxes take take_along_axis tan tanh tensordot transpose tri tril triu var where zeros zeros_like .. _optimizers: Optimizers ========== The optimizers in MLX can be used both with :mod:`mlx.nn` but also with pure :mod:`mlx.core` functions. A typical example involves calling :meth:`Optimizer.update` to update a model's parameters based on the loss gradients and subsequently calling :func:`mlx.core.eval` to evaluate both the model's parameters and the **optimizer state**. .. code-block:: python # Create a model model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes) mx.eval(model.parameters()) # Create the gradient function and the optimizer loss_and_grad_fn = nn.value_and_grad(model, loss_fn) optimizer = optim.SGD(learning_rate=learning_rate) for e in range(num_epochs): for X, y in batch_iterate(batch_size, train_images, train_labels): loss, grads = loss_and_grad_fn(model, X, y) # Update the model with the gradients. So far no computation has happened. optimizer.update(model, grads) # Compute the new parameters but also the optimizer state. mx.eval(model.parameters(), optimizer.state) .. currentmodule:: mlx.optimizers .. autosummary:: :toctree: _autosummary :template: optimizers-template.rst OptimizerState Optimizer SGD RMSprop Adagrad Adafactor AdaDelta Adam AdamW Adamax Lion .. _random: Random ====== Random sampling functions in MLX use an implicit global PRNG state by default. However, all function take an optional ``key`` keyword argument for when more fine-grained control or explicit state management is needed. For example, you can generate random numbers with: .. code-block:: python for _ in range(3): print(mx.random.uniform()) which will print a sequence of unique pseudo random numbers. Alternatively you can explicitly set the key: .. code-block:: python key = mx.random.key(0) for _ in range(3): print(mx.random.uniform(key=key)) which will yield the same pseudo random number at each iteration. Following `JAX's PRNG design `_ we use a splittable version of Threefry, which is a counter-based PRNG. .. currentmodule:: mlx.core.random .. autosummary:: :toctree: _autosummary bernoulli categorical gumbel key normal randint seed split truncated_normal uniform .. _transforms: Transforms ========== .. currentmodule:: mlx.core .. autosummary:: :toctree: _autosummary eval grad value_and_grad jvp vjp vmap simplify .. _utils: Tree Utils ========== In MLX we consider a python tree to be an arbitrarily nested collection of dictionaries, lists and tuples without cycles. Functions in this module that return python trees will be using the default python ``dict``, ``list`` and ``tuple`` but they can usually process objects that inherit from any of these. .. note:: Dictionaries should have keys that are valid python identifiers. .. currentmodule:: mlx.utils .. autosummary:: :toctree: _autosummary tree_flatten tree_unflatten tree_map