mirror of
https://github.com/LouisShark/chatgpt_system_prompt.git
synced 2025-07-06 14:50:31 -04:00
MLX Guru - files
This commit is contained in:
parent
bd0099ad1f
commit
f96f19d404
8 changed files with 1579 additions and 0 deletions
23
prompts/gpts/knowledge/MLX Guru/functions.txt
Normal file
23
prompts/gpts/knowledge/MLX Guru/functions.txt
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
.. _nn_functions:
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.nn
|
||||||
|
|
||||||
|
Functions
|
||||||
|
---------
|
||||||
|
|
||||||
|
Layers without parameters (e.g. activation functions) are also provided as
|
||||||
|
simple functions.
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary_functions
|
||||||
|
:template: nn-module-template.rst
|
||||||
|
|
||||||
|
gelu
|
||||||
|
gelu_approx
|
||||||
|
gelu_fast_approx
|
||||||
|
mish
|
||||||
|
prelu
|
||||||
|
relu
|
||||||
|
selu
|
||||||
|
silu
|
||||||
|
step
|
45
prompts/gpts/knowledge/MLX Guru/init.txt
Normal file
45
prompts/gpts/knowledge/MLX Guru/init.txt
Normal file
|
@ -0,0 +1,45 @@
|
||||||
|
.. _init:
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.nn.init
|
||||||
|
|
||||||
|
Initializers
|
||||||
|
------------
|
||||||
|
|
||||||
|
The ``mlx.nn.init`` package contains commonly used initializers for neural
|
||||||
|
network parameters. Initializers return a function which can be applied to any
|
||||||
|
input :obj:`mlx.core.array` to produce an initialized output.
|
||||||
|
|
||||||
|
For example:
|
||||||
|
|
||||||
|
.. code:: python
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
init_fn = nn.init.uniform()
|
||||||
|
|
||||||
|
# Produces a [2, 2] uniform matrix
|
||||||
|
param = init_fn(mx.zeros((2, 2)))
|
||||||
|
|
||||||
|
To re-initialize all the parameter in an :obj:`mlx.nn.Module` from say a uniform
|
||||||
|
distribution, you can do:
|
||||||
|
|
||||||
|
.. code:: python
|
||||||
|
|
||||||
|
import mlx.nn as nn
|
||||||
|
model = nn.Sequential(nn.Linear(5, 10), nn.ReLU(), nn.Linear(10, 5))
|
||||||
|
init_fn = nn.init.uniform(low=-0.1, high=0.1)
|
||||||
|
model.apply(init_fn)
|
||||||
|
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
constant
|
||||||
|
normal
|
||||||
|
uniform
|
||||||
|
identity
|
||||||
|
glorot_normal
|
||||||
|
glorot_uniform
|
||||||
|
he_normal
|
||||||
|
he_uniform
|
37
prompts/gpts/knowledge/MLX Guru/layers.txt
Normal file
37
prompts/gpts/knowledge/MLX Guru/layers.txt
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
.. _layers:
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.nn
|
||||||
|
|
||||||
|
Layers
|
||||||
|
------
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
:template: nn-module-template.rst
|
||||||
|
|
||||||
|
ALiBi
|
||||||
|
BatchNorm
|
||||||
|
Conv1d
|
||||||
|
Conv2d
|
||||||
|
Dropout
|
||||||
|
Dropout2d
|
||||||
|
Dropout3d
|
||||||
|
Embedding
|
||||||
|
GELU
|
||||||
|
GroupNorm
|
||||||
|
InstanceNorm
|
||||||
|
LayerNorm
|
||||||
|
Linear
|
||||||
|
Mish
|
||||||
|
MultiHeadAttention
|
||||||
|
PReLU
|
||||||
|
QuantizedLinear
|
||||||
|
RMSNorm
|
||||||
|
ReLU
|
||||||
|
RoPE
|
||||||
|
SELU
|
||||||
|
Sequential
|
||||||
|
SiLU
|
||||||
|
SinusoidalPositionalEncoding
|
||||||
|
Step
|
||||||
|
Transformer
|
24
prompts/gpts/knowledge/MLX Guru/losses.txt
Normal file
24
prompts/gpts/knowledge/MLX Guru/losses.txt
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
.. _losses:
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.nn.losses
|
||||||
|
|
||||||
|
Loss Functions
|
||||||
|
--------------
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary_functions
|
||||||
|
:template: nn-module-template.rst
|
||||||
|
|
||||||
|
binary_cross_entropy
|
||||||
|
cosine_similarity_loss
|
||||||
|
cross_entropy
|
||||||
|
gaussian_nll_loss
|
||||||
|
hinge_loss
|
||||||
|
huber_loss
|
||||||
|
kl_div_loss
|
||||||
|
l1_loss
|
||||||
|
log_cosh_loss
|
||||||
|
mse_loss
|
||||||
|
nll_loss
|
||||||
|
smooth_l1_loss
|
||||||
|
triplet_loss
|
36
prompts/gpts/knowledge/MLX Guru/module.txt
Normal file
36
prompts/gpts/knowledge/MLX Guru/module.txt
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
Module
|
||||||
|
======
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.nn
|
||||||
|
|
||||||
|
.. autoclass:: Module
|
||||||
|
|
||||||
|
.. rubric:: Attributes
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
Module.training
|
||||||
|
|
||||||
|
.. rubric:: Methods
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
Module.apply
|
||||||
|
Module.apply_to_modules
|
||||||
|
Module.children
|
||||||
|
Module.eval
|
||||||
|
Module.filter_and_map
|
||||||
|
Module.freeze
|
||||||
|
Module.leaf_modules
|
||||||
|
Module.load_weights
|
||||||
|
Module.modules
|
||||||
|
Module.named_modules
|
||||||
|
Module.parameters
|
||||||
|
Module.save_weights
|
||||||
|
Module.train
|
||||||
|
Module.trainable_parameters
|
||||||
|
Module.unfreeze
|
||||||
|
Module.update
|
||||||
|
Module.update_modules
|
183
prompts/gpts/knowledge/MLX Guru/nn.txt
Normal file
183
prompts/gpts/knowledge/MLX Guru/nn.txt
Normal file
|
@ -0,0 +1,183 @@
|
||||||
|
.. _nn:
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.nn
|
||||||
|
|
||||||
|
Neural Networks
|
||||||
|
===============
|
||||||
|
|
||||||
|
Writing arbitrarily complex neural networks in MLX can be done using only
|
||||||
|
:class:`mlx.core.array` and :meth:`mlx.core.value_and_grad`. However, this requires the
|
||||||
|
user to write again and again the same simple neural network operations as well
|
||||||
|
as handle all the parameter state and initialization manually and explicitly.
|
||||||
|
|
||||||
|
The module :mod:`mlx.nn` solves this problem by providing an intuitive way of
|
||||||
|
composing neural network layers, initializing their parameters, freezing them
|
||||||
|
for finetuning and more.
|
||||||
|
|
||||||
|
Quick Start with Neural Networks
|
||||||
|
---------------------------------
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, in_dims: int, out_dims: int):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.layers = [
|
||||||
|
nn.Linear(in_dims, 128),
|
||||||
|
nn.Linear(128, 128),
|
||||||
|
nn.Linear(128, out_dims),
|
||||||
|
]
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
for i, l in enumerate(self.layers):
|
||||||
|
x = mx.maximum(x, 0) if i > 0 else x
|
||||||
|
x = l(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
# The model is created with all its parameters but nothing is initialized
|
||||||
|
# yet because MLX is lazily evaluated
|
||||||
|
mlp = MLP(2, 10)
|
||||||
|
|
||||||
|
# We can access its parameters by calling mlp.parameters()
|
||||||
|
params = mlp.parameters()
|
||||||
|
print(params["layers"][0]["weight"].shape)
|
||||||
|
|
||||||
|
# Printing a parameter will cause it to be evaluated and thus initialized
|
||||||
|
print(params["layers"][0])
|
||||||
|
|
||||||
|
# We can also force evaluate all parameters to initialize the model
|
||||||
|
mx.eval(mlp.parameters())
|
||||||
|
|
||||||
|
# A simple loss function.
|
||||||
|
# NOTE: It doesn't matter how it uses the mlp model. It currently captures
|
||||||
|
# it from the local scope. It could be a positional argument or a
|
||||||
|
# keyword argument.
|
||||||
|
def l2_loss(x, y):
|
||||||
|
y_hat = mlp(x)
|
||||||
|
return (y_hat - y).square().mean()
|
||||||
|
|
||||||
|
# Calling `nn.value_and_grad` instead of `mx.value_and_grad` returns the
|
||||||
|
# gradient with respect to `mlp.trainable_parameters()`
|
||||||
|
loss_and_grad = nn.value_and_grad(mlp, l2_loss)
|
||||||
|
|
||||||
|
.. _module_class:
|
||||||
|
|
||||||
|
The Module Class
|
||||||
|
----------------
|
||||||
|
|
||||||
|
The workhorse of any neural network library is the :class:`Module` class. In
|
||||||
|
MLX the :class:`Module` class is a container of :class:`mlx.core.array` or
|
||||||
|
:class:`Module` instances. Its main function is to provide a way to
|
||||||
|
recursively **access** and **update** its parameters and those of its
|
||||||
|
submodules.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
^^^^^^^^^^
|
||||||
|
|
||||||
|
A parameter of a module is any public member of type :class:`mlx.core.array` (its
|
||||||
|
name should not start with ``_``). It can be arbitrarily nested in other
|
||||||
|
:class:`Module` instances or lists and dictionaries.
|
||||||
|
|
||||||
|
:meth:`Module.parameters` can be used to extract a nested dictionary with all
|
||||||
|
the parameters of a module and its submodules.
|
||||||
|
|
||||||
|
A :class:`Module` can also keep track of "frozen" parameters. See the
|
||||||
|
:meth:`Module.freeze` method for more details. :meth:`mlx.nn.value_and_grad`
|
||||||
|
the gradients returned will be with respect to these trainable parameters.
|
||||||
|
|
||||||
|
|
||||||
|
Updating the Parameters
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
MLX modules allow accessing and updating individual parameters. However, most
|
||||||
|
times we need to update large subsets of a module's parameters. This action is
|
||||||
|
performed by :meth:`Module.update`.
|
||||||
|
|
||||||
|
|
||||||
|
Inspecting Modules
|
||||||
|
^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
The simplest way to see the model architecture is to print it. Following along with
|
||||||
|
the above example, you can print the ``MLP`` with:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
print(mlp)
|
||||||
|
|
||||||
|
This will display:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
MLP(
|
||||||
|
(layers.0): Linear(input_dims=2, output_dims=128, bias=True)
|
||||||
|
(layers.1): Linear(input_dims=128, output_dims=128, bias=True)
|
||||||
|
(layers.2): Linear(input_dims=128, output_dims=10, bias=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
To get more detailed information on the arrays in a :class:`Module` you can use
|
||||||
|
:func:`mlx.utils.tree_map` on the parameters. For example, to see the shapes of
|
||||||
|
all the parameters in a :class:`Module` do:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from mlx.utils import tree_map
|
||||||
|
shapes = tree_map(lambda p: p.shape, mlp.parameters())
|
||||||
|
|
||||||
|
As another example, you can count the number of parameters in a :class:`Module`
|
||||||
|
with:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from mlx.utils import tree_flatten
|
||||||
|
num_params = sum(v.size for _, v in tree_flatten(mlp.parameters()))
|
||||||
|
|
||||||
|
|
||||||
|
Value and Grad
|
||||||
|
--------------
|
||||||
|
|
||||||
|
Using a :class:`Module` does not preclude using MLX's high order function
|
||||||
|
transformations (:meth:`mlx.core.value_and_grad`, :meth:`mlx.core.grad`, etc.). However,
|
||||||
|
these function transformations assume pure functions, namely the parameters
|
||||||
|
should be passed as an argument to the function being transformed.
|
||||||
|
|
||||||
|
There is an easy pattern to achieve that with MLX modules
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
model = ...
|
||||||
|
|
||||||
|
def f(params, other_inputs):
|
||||||
|
model.update(params) # <---- Necessary to make the model use the passed parameters
|
||||||
|
return model(other_inputs)
|
||||||
|
|
||||||
|
f(model.trainable_parameters(), mx.zeros((10,)))
|
||||||
|
|
||||||
|
However, :meth:`mlx.nn.value_and_grad` provides precisely this pattern and only
|
||||||
|
computes the gradients with respect to the trainable parameters of the model.
|
||||||
|
|
||||||
|
In detail:
|
||||||
|
|
||||||
|
- it wraps the passed function with a function that calls :meth:`Module.update`
|
||||||
|
to make sure the model is using the provided parameters.
|
||||||
|
- it calls :meth:`mlx.core.value_and_grad` to transform the function into a function
|
||||||
|
that also computes the gradients with respect to the passed parameters.
|
||||||
|
- it wraps the returned function with a function that passes the trainable
|
||||||
|
parameters as the first argument to the function returned by
|
||||||
|
:meth:`mlx.core.value_and_grad`
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
value_and_grad
|
||||||
|
|
||||||
|
.. toctree::
|
||||||
|
|
||||||
|
nn/module
|
||||||
|
nn/layers
|
||||||
|
nn/functions
|
||||||
|
nn/losses
|
||||||
|
nn/init
|
587
prompts/gpts/knowledge/MLX Guru/python_api.txt
Normal file
587
prompts/gpts/knowledge/MLX Guru/python_api.txt
Normal file
|
@ -0,0 +1,587 @@
|
||||||
|
.. _array:
|
||||||
|
|
||||||
|
Array
|
||||||
|
=====
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
array
|
||||||
|
array.astype
|
||||||
|
array.item
|
||||||
|
array.tolist
|
||||||
|
array.dtype
|
||||||
|
array.ndim
|
||||||
|
array.shape
|
||||||
|
array.size
|
||||||
|
Dtype
|
||||||
|
array.abs
|
||||||
|
array.all
|
||||||
|
array.any
|
||||||
|
array.argmax
|
||||||
|
array.argmin
|
||||||
|
array.cos
|
||||||
|
array.dtype
|
||||||
|
array.exp
|
||||||
|
array.log
|
||||||
|
array.log1p
|
||||||
|
array.logsumexp
|
||||||
|
array.max
|
||||||
|
array.mean
|
||||||
|
array.min
|
||||||
|
array.prod
|
||||||
|
array.reciprocal
|
||||||
|
array.reshape
|
||||||
|
array.round
|
||||||
|
array.rsqrt
|
||||||
|
array.sin
|
||||||
|
array.split
|
||||||
|
array.sqrt
|
||||||
|
array.square
|
||||||
|
array.sum
|
||||||
|
array.transpose
|
||||||
|
array.T
|
||||||
|
array.var
|
||||||
|
.. _data_types:
|
||||||
|
|
||||||
|
:orphan:
|
||||||
|
|
||||||
|
Data Types
|
||||||
|
==========
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core
|
||||||
|
|
||||||
|
The default floating point type is ``float32`` and the default integer type is
|
||||||
|
``int32``. The table below shows supported values for :obj:`Dtype`.
|
||||||
|
|
||||||
|
.. list-table:: Supported Data Types
|
||||||
|
:widths: 5 3 20
|
||||||
|
:header-rows: 1
|
||||||
|
|
||||||
|
* - Type
|
||||||
|
- Bytes
|
||||||
|
- Description
|
||||||
|
* - ``bool_``
|
||||||
|
- 1
|
||||||
|
- Boolean (``True``, ``False``) data type
|
||||||
|
* - ``uint8``
|
||||||
|
- 1
|
||||||
|
- 8-bit unsigned integer
|
||||||
|
* - ``uint16``
|
||||||
|
- 2
|
||||||
|
- 16-bit unsigned integer
|
||||||
|
* - ``uint32``
|
||||||
|
- 4
|
||||||
|
- 32-bit unsigned integer
|
||||||
|
* - ``uint64``
|
||||||
|
- 8
|
||||||
|
- 64-bit unsigned integer
|
||||||
|
* - ``int8``
|
||||||
|
- 1
|
||||||
|
- 8-bit signed integer
|
||||||
|
* - ``int16``
|
||||||
|
- 2
|
||||||
|
- 16-bit signed integer
|
||||||
|
* - ``int32``
|
||||||
|
- 4
|
||||||
|
- 32-bit signed integer
|
||||||
|
* - ``int64``
|
||||||
|
- 8
|
||||||
|
- 64-bit signed integer
|
||||||
|
* - ``float16``
|
||||||
|
- 2
|
||||||
|
- 16-bit float, only available with `ARM C language extensions <https://developer.arm.com/documentation/101028/0012/3--C-language-extensions?lang=en>`_
|
||||||
|
* - ``float32``
|
||||||
|
- 4
|
||||||
|
- 32-bit float
|
||||||
|
.. _devices_and_streams:
|
||||||
|
|
||||||
|
Devices and Streams
|
||||||
|
===================
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
Device
|
||||||
|
default_device
|
||||||
|
set_default_device
|
||||||
|
Stream
|
||||||
|
default_stream
|
||||||
|
new_stream
|
||||||
|
set_default_stream
|
||||||
|
.. _fft:
|
||||||
|
|
||||||
|
FFT
|
||||||
|
===
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core.fft
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
fft
|
||||||
|
ifft
|
||||||
|
fft2
|
||||||
|
ifft2
|
||||||
|
fftn
|
||||||
|
ifftn
|
||||||
|
rfft
|
||||||
|
irfft
|
||||||
|
rfft2
|
||||||
|
irfft2
|
||||||
|
rfftn
|
||||||
|
irfftn
|
||||||
|
.. _linalg:
|
||||||
|
|
||||||
|
Linear Algebra
|
||||||
|
==============
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core.linalg
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
norm
|
||||||
|
.. _nn:
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.nn
|
||||||
|
|
||||||
|
Neural Networks
|
||||||
|
===============
|
||||||
|
|
||||||
|
Writing arbitrarily complex neural networks in MLX can be done using only
|
||||||
|
:class:`mlx.core.array` and :meth:`mlx.core.value_and_grad`. However, this requires the
|
||||||
|
user to write again and again the same simple neural network operations as well
|
||||||
|
as handle all the parameter state and initialization manually and explicitly.
|
||||||
|
|
||||||
|
The module :mod:`mlx.nn` solves this problem by providing an intuitive way of
|
||||||
|
composing neural network layers, initializing their parameters, freezing them
|
||||||
|
for finetuning and more.
|
||||||
|
|
||||||
|
Quick Start with Neural Networks
|
||||||
|
---------------------------------
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, in_dims: int, out_dims: int):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.layers = [
|
||||||
|
nn.Linear(in_dims, 128),
|
||||||
|
nn.Linear(128, 128),
|
||||||
|
nn.Linear(128, out_dims),
|
||||||
|
]
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
for i, l in enumerate(self.layers):
|
||||||
|
x = mx.maximum(x, 0) if i > 0 else x
|
||||||
|
x = l(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
# The model is created with all its parameters but nothing is initialized
|
||||||
|
# yet because MLX is lazily evaluated
|
||||||
|
mlp = MLP(2, 10)
|
||||||
|
|
||||||
|
# We can access its parameters by calling mlp.parameters()
|
||||||
|
params = mlp.parameters()
|
||||||
|
print(params["layers"][0]["weight"].shape)
|
||||||
|
|
||||||
|
# Printing a parameter will cause it to be evaluated and thus initialized
|
||||||
|
print(params["layers"][0])
|
||||||
|
|
||||||
|
# We can also force evaluate all parameters to initialize the model
|
||||||
|
mx.eval(mlp.parameters())
|
||||||
|
|
||||||
|
# A simple loss function.
|
||||||
|
# NOTE: It doesn't matter how it uses the mlp model. It currently captures
|
||||||
|
# it from the local scope. It could be a positional argument or a
|
||||||
|
# keyword argument.
|
||||||
|
def l2_loss(x, y):
|
||||||
|
y_hat = mlp(x)
|
||||||
|
return (y_hat - y).square().mean()
|
||||||
|
|
||||||
|
# Calling `nn.value_and_grad` instead of `mx.value_and_grad` returns the
|
||||||
|
# gradient with respect to `mlp.trainable_parameters()`
|
||||||
|
loss_and_grad = nn.value_and_grad(mlp, l2_loss)
|
||||||
|
|
||||||
|
.. _module_class:
|
||||||
|
|
||||||
|
The Module Class
|
||||||
|
----------------
|
||||||
|
|
||||||
|
The workhorse of any neural network library is the :class:`Module` class. In
|
||||||
|
MLX the :class:`Module` class is a container of :class:`mlx.core.array` or
|
||||||
|
:class:`Module` instances. Its main function is to provide a way to
|
||||||
|
recursively **access** and **update** its parameters and those of its
|
||||||
|
submodules.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
^^^^^^^^^^
|
||||||
|
|
||||||
|
A parameter of a module is any public member of type :class:`mlx.core.array` (its
|
||||||
|
name should not start with ``_``). It can be arbitrarily nested in other
|
||||||
|
:class:`Module` instances or lists and dictionaries.
|
||||||
|
|
||||||
|
:meth:`Module.parameters` can be used to extract a nested dictionary with all
|
||||||
|
the parameters of a module and its submodules.
|
||||||
|
|
||||||
|
A :class:`Module` can also keep track of "frozen" parameters. See the
|
||||||
|
:meth:`Module.freeze` method for more details. :meth:`mlx.nn.value_and_grad`
|
||||||
|
the gradients returned will be with respect to these trainable parameters.
|
||||||
|
|
||||||
|
|
||||||
|
Updating the Parameters
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
MLX modules allow accessing and updating individual parameters. However, most
|
||||||
|
times we need to update large subsets of a module's parameters. This action is
|
||||||
|
performed by :meth:`Module.update`.
|
||||||
|
|
||||||
|
|
||||||
|
Inspecting Modules
|
||||||
|
^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
The simplest way to see the model architecture is to print it. Following along with
|
||||||
|
the above example, you can print the ``MLP`` with:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
print(mlp)
|
||||||
|
|
||||||
|
This will display:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
MLP(
|
||||||
|
(layers.0): Linear(input_dims=2, output_dims=128, bias=True)
|
||||||
|
(layers.1): Linear(input_dims=128, output_dims=128, bias=True)
|
||||||
|
(layers.2): Linear(input_dims=128, output_dims=10, bias=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
To get more detailed information on the arrays in a :class:`Module` you can use
|
||||||
|
:func:`mlx.utils.tree_map` on the parameters. For example, to see the shapes of
|
||||||
|
all the parameters in a :class:`Module` do:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from mlx.utils import tree_map
|
||||||
|
shapes = tree_map(lambda p: p.shape, mlp.parameters())
|
||||||
|
|
||||||
|
As another example, you can count the number of parameters in a :class:`Module`
|
||||||
|
with:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from mlx.utils import tree_flatten
|
||||||
|
num_params = sum(v.size for _, v in tree_flatten(mlp.parameters()))
|
||||||
|
|
||||||
|
|
||||||
|
Value and Grad
|
||||||
|
--------------
|
||||||
|
|
||||||
|
Using a :class:`Module` does not preclude using MLX's high order function
|
||||||
|
transformations (:meth:`mlx.core.value_and_grad`, :meth:`mlx.core.grad`, etc.). However,
|
||||||
|
these function transformations assume pure functions, namely the parameters
|
||||||
|
should be passed as an argument to the function being transformed.
|
||||||
|
|
||||||
|
There is an easy pattern to achieve that with MLX modules
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
model = ...
|
||||||
|
|
||||||
|
def f(params, other_inputs):
|
||||||
|
model.update(params) # <---- Necessary to make the model use the passed parameters
|
||||||
|
return model(other_inputs)
|
||||||
|
|
||||||
|
f(model.trainable_parameters(), mx.zeros((10,)))
|
||||||
|
|
||||||
|
However, :meth:`mlx.nn.value_and_grad` provides precisely this pattern and only
|
||||||
|
computes the gradients with respect to the trainable parameters of the model.
|
||||||
|
|
||||||
|
In detail:
|
||||||
|
|
||||||
|
- it wraps the passed function with a function that calls :meth:`Module.update`
|
||||||
|
to make sure the model is using the provided parameters.
|
||||||
|
- it calls :meth:`mlx.core.value_and_grad` to transform the function into a function
|
||||||
|
that also computes the gradients with respect to the passed parameters.
|
||||||
|
- it wraps the returned function with a function that passes the trainable
|
||||||
|
parameters as the first argument to the function returned by
|
||||||
|
:meth:`mlx.core.value_and_grad`
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
value_and_grad
|
||||||
|
|
||||||
|
.. toctree::
|
||||||
|
|
||||||
|
nn/module
|
||||||
|
nn/layers
|
||||||
|
nn/functions
|
||||||
|
nn/losses
|
||||||
|
nn/init
|
||||||
|
.. _ops:
|
||||||
|
|
||||||
|
Operations
|
||||||
|
==========
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
abs
|
||||||
|
add
|
||||||
|
all
|
||||||
|
allclose
|
||||||
|
any
|
||||||
|
arange
|
||||||
|
arccos
|
||||||
|
arccosh
|
||||||
|
arcsin
|
||||||
|
arcsinh
|
||||||
|
arctan
|
||||||
|
arctanh
|
||||||
|
argmax
|
||||||
|
argmin
|
||||||
|
argpartition
|
||||||
|
argsort
|
||||||
|
array_equal
|
||||||
|
broadcast_to
|
||||||
|
ceil
|
||||||
|
clip
|
||||||
|
concatenate
|
||||||
|
convolve
|
||||||
|
conv1d
|
||||||
|
conv2d
|
||||||
|
cos
|
||||||
|
cosh
|
||||||
|
dequantize
|
||||||
|
divide
|
||||||
|
divmod
|
||||||
|
equal
|
||||||
|
erf
|
||||||
|
erfinv
|
||||||
|
exp
|
||||||
|
expand_dims
|
||||||
|
eye
|
||||||
|
flatten
|
||||||
|
floor
|
||||||
|
floor_divide
|
||||||
|
full
|
||||||
|
greater
|
||||||
|
greater_equal
|
||||||
|
identity
|
||||||
|
inner
|
||||||
|
isnan
|
||||||
|
isposinf
|
||||||
|
isneginf
|
||||||
|
isinf
|
||||||
|
less
|
||||||
|
less_equal
|
||||||
|
linspace
|
||||||
|
load
|
||||||
|
log
|
||||||
|
log2
|
||||||
|
log10
|
||||||
|
log1p
|
||||||
|
logaddexp
|
||||||
|
logical_not
|
||||||
|
logical_and
|
||||||
|
logical_or
|
||||||
|
logsumexp
|
||||||
|
matmul
|
||||||
|
max
|
||||||
|
maximum
|
||||||
|
mean
|
||||||
|
min
|
||||||
|
minimum
|
||||||
|
moveaxis
|
||||||
|
multiply
|
||||||
|
negative
|
||||||
|
ones
|
||||||
|
ones_like
|
||||||
|
outer
|
||||||
|
partition
|
||||||
|
pad
|
||||||
|
prod
|
||||||
|
quantize
|
||||||
|
quantized_matmul
|
||||||
|
reciprocal
|
||||||
|
repeat
|
||||||
|
reshape
|
||||||
|
round
|
||||||
|
rsqrt
|
||||||
|
save
|
||||||
|
savez
|
||||||
|
savez_compressed
|
||||||
|
save_gguf
|
||||||
|
save_safetensors
|
||||||
|
sigmoid
|
||||||
|
sign
|
||||||
|
sin
|
||||||
|
sinh
|
||||||
|
softmax
|
||||||
|
sort
|
||||||
|
split
|
||||||
|
sqrt
|
||||||
|
square
|
||||||
|
squeeze
|
||||||
|
stack
|
||||||
|
stop_gradient
|
||||||
|
subtract
|
||||||
|
sum
|
||||||
|
swapaxes
|
||||||
|
take
|
||||||
|
take_along_axis
|
||||||
|
tan
|
||||||
|
tanh
|
||||||
|
tensordot
|
||||||
|
transpose
|
||||||
|
tri
|
||||||
|
tril
|
||||||
|
triu
|
||||||
|
var
|
||||||
|
where
|
||||||
|
zeros
|
||||||
|
zeros_like
|
||||||
|
.. _optimizers:
|
||||||
|
|
||||||
|
Optimizers
|
||||||
|
==========
|
||||||
|
|
||||||
|
The optimizers in MLX can be used both with :mod:`mlx.nn` but also with pure
|
||||||
|
:mod:`mlx.core` functions. A typical example involves calling
|
||||||
|
:meth:`Optimizer.update` to update a model's parameters based on the loss
|
||||||
|
gradients and subsequently calling :func:`mlx.core.eval` to evaluate both the
|
||||||
|
model's parameters and the **optimizer state**.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
# Create a model
|
||||||
|
model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes)
|
||||||
|
mx.eval(model.parameters())
|
||||||
|
|
||||||
|
# Create the gradient function and the optimizer
|
||||||
|
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||||
|
optimizer = optim.SGD(learning_rate=learning_rate)
|
||||||
|
|
||||||
|
for e in range(num_epochs):
|
||||||
|
for X, y in batch_iterate(batch_size, train_images, train_labels):
|
||||||
|
loss, grads = loss_and_grad_fn(model, X, y)
|
||||||
|
|
||||||
|
# Update the model with the gradients. So far no computation has happened.
|
||||||
|
optimizer.update(model, grads)
|
||||||
|
|
||||||
|
# Compute the new parameters but also the optimizer state.
|
||||||
|
mx.eval(model.parameters(), optimizer.state)
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.optimizers
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
:template: optimizers-template.rst
|
||||||
|
|
||||||
|
OptimizerState
|
||||||
|
Optimizer
|
||||||
|
SGD
|
||||||
|
RMSprop
|
||||||
|
Adagrad
|
||||||
|
Adafactor
|
||||||
|
AdaDelta
|
||||||
|
Adam
|
||||||
|
AdamW
|
||||||
|
Adamax
|
||||||
|
Lion
|
||||||
|
.. _random:
|
||||||
|
|
||||||
|
Random
|
||||||
|
======
|
||||||
|
|
||||||
|
Random sampling functions in MLX use an implicit global PRNG state by default.
|
||||||
|
However, all function take an optional ``key`` keyword argument for when more
|
||||||
|
fine-grained control or explicit state management is needed.
|
||||||
|
|
||||||
|
For example, you can generate random numbers with:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
for _ in range(3):
|
||||||
|
print(mx.random.uniform())
|
||||||
|
|
||||||
|
which will print a sequence of unique pseudo random numbers. Alternatively you
|
||||||
|
can explicitly set the key:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
key = mx.random.key(0)
|
||||||
|
for _ in range(3):
|
||||||
|
print(mx.random.uniform(key=key))
|
||||||
|
|
||||||
|
which will yield the same pseudo random number at each iteration.
|
||||||
|
|
||||||
|
Following `JAX's PRNG design <https://jax.readthedocs.io/en/latest/jep/263-prng.html>`_
|
||||||
|
we use a splittable version of Threefry, which is a counter-based PRNG.
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core.random
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
bernoulli
|
||||||
|
categorical
|
||||||
|
gumbel
|
||||||
|
key
|
||||||
|
normal
|
||||||
|
randint
|
||||||
|
seed
|
||||||
|
split
|
||||||
|
truncated_normal
|
||||||
|
uniform
|
||||||
|
.. _transforms:
|
||||||
|
|
||||||
|
Transforms
|
||||||
|
==========
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
eval
|
||||||
|
grad
|
||||||
|
value_and_grad
|
||||||
|
jvp
|
||||||
|
vjp
|
||||||
|
vmap
|
||||||
|
simplify
|
||||||
|
.. _utils:
|
||||||
|
|
||||||
|
Tree Utils
|
||||||
|
==========
|
||||||
|
|
||||||
|
In MLX we consider a python tree to be an arbitrarily nested collection of
|
||||||
|
dictionaries, lists and tuples without cycles. Functions in this module that
|
||||||
|
return python trees will be using the default python ``dict``, ``list`` and
|
||||||
|
``tuple`` but they can usually process objects that inherit from any of these.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
Dictionaries should have keys that are valid python identifiers.
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.utils
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
tree_flatten
|
||||||
|
tree_unflatten
|
||||||
|
tree_map
|
644
prompts/gpts/knowledge/MLX Guru/usage.txt
Normal file
644
prompts/gpts/knowledge/MLX Guru/usage.txt
Normal file
|
@ -0,0 +1,644 @@
|
||||||
|
.. _function_transforms:
|
||||||
|
|
||||||
|
Function Transforms
|
||||||
|
===================
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core
|
||||||
|
|
||||||
|
MLX uses composable function transformations for automatic differentiation and
|
||||||
|
vectorization. The key idea behind composable function transformations is that
|
||||||
|
every transformation returns a function which can be further transformed.
|
||||||
|
|
||||||
|
Here is a simple example:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> dfdx = mx.grad(mx.sin)
|
||||||
|
>>> dfdx(mx.array(mx.pi))
|
||||||
|
array(-1, dtype=float32)
|
||||||
|
>>> mx.cos(mx.array(mx.pi))
|
||||||
|
array(-1, dtype=float32)
|
||||||
|
|
||||||
|
|
||||||
|
The output of :func:`grad` on :func:`sin` is simply another function. In this
|
||||||
|
case it is the gradient of the sine function which is exactly the cosine
|
||||||
|
function. To get the second derivative you can do:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> d2fdx2 = mx.grad(mx.grad(mx.sin))
|
||||||
|
>>> d2fdx2(mx.array(mx.pi / 2))
|
||||||
|
array(-1, dtype=float32)
|
||||||
|
>>> mx.sin(mx.array(mx.pi / 2))
|
||||||
|
array(1, dtype=float32)
|
||||||
|
|
||||||
|
Using :func:`grad` on the output of :func:`grad` is always ok. You keep
|
||||||
|
getting higher order derivatives.
|
||||||
|
|
||||||
|
Any of the MLX function transformations can be composed in any order to any
|
||||||
|
depth. To see the complete list of function transformations check-out the
|
||||||
|
:ref:`API documentation <transforms>`. See the following sections for more
|
||||||
|
information on :ref:`automatic differentiaion <auto diff>` and
|
||||||
|
:ref:`automatic vectorization <vmap>`.
|
||||||
|
|
||||||
|
Automatic Differentiation
|
||||||
|
-------------------------
|
||||||
|
|
||||||
|
.. _auto diff:
|
||||||
|
|
||||||
|
Automatic differentiation in MLX works on functions rather than on implicit
|
||||||
|
graphs.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
If you are coming to MLX from PyTorch, you no longer need functions like
|
||||||
|
``backward``, ``zero_grad``, and ``detach``, or properties like
|
||||||
|
``requires_grad``.
|
||||||
|
|
||||||
|
The most basic example is taking the gradient of a scalar-valued function as we
|
||||||
|
saw above. You can use the :func:`grad` and :func:`value_and_grad` function to
|
||||||
|
compute gradients of more complex functions. By default these functions compute
|
||||||
|
the gradient with respect to the first argument:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def loss_fn(w, x, y):
|
||||||
|
return mx.mean(mx.square(w * x - y))
|
||||||
|
|
||||||
|
w = mx.array(1.0)
|
||||||
|
x = mx.array([0.5, -0.5])
|
||||||
|
y = mx.array([1.5, -1.5])
|
||||||
|
|
||||||
|
# Computes the gradient of loss_fn with respect to w:
|
||||||
|
grad_fn = mx.grad(loss_fn)
|
||||||
|
dloss_dw = grad_fn(w, x, y)
|
||||||
|
# Prints array(-1, dtype=float32)
|
||||||
|
print(dloss_dw)
|
||||||
|
|
||||||
|
# To get the gradient with respect to x we can do:
|
||||||
|
grad_fn = mx.grad(loss_fn, argnums=1)
|
||||||
|
dloss_dx = grad_fn(w, x, y)
|
||||||
|
# Prints array([-1, 1], dtype=float32)
|
||||||
|
print(dloss_dx)
|
||||||
|
|
||||||
|
|
||||||
|
One way to get the loss and gradient is to call ``loss_fn`` followed by
|
||||||
|
``grad_fn``, but this can result in a lot of redundant work. Instead, you
|
||||||
|
should use :func:`value_and_grad`. Continuing the above example:
|
||||||
|
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
# Computes the gradient of loss_fn with respect to w:
|
||||||
|
loss_and_grad_fn = mx.value_and_grad(loss_fn)
|
||||||
|
loss, dloss_dw = loss_and_grad_fn(w, x, y)
|
||||||
|
|
||||||
|
# Prints array(1, dtype=float32)
|
||||||
|
print(loss)
|
||||||
|
|
||||||
|
# Prints array(-1, dtype=float32)
|
||||||
|
print(dloss_dw)
|
||||||
|
|
||||||
|
|
||||||
|
You can also take the gradient with respect to arbitrarily nested Python
|
||||||
|
containers of arrays (specifically any of :obj:`list`, :obj:`tuple`, or
|
||||||
|
:obj:`dict`).
|
||||||
|
|
||||||
|
Suppose we wanted a weight and a bias parameter in the above example. A nice
|
||||||
|
way to do that is the following:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def loss_fn(params, x, y):
|
||||||
|
w, b = params["weight"], params["bias"]
|
||||||
|
h = w * x + b
|
||||||
|
return mx.mean(mx.square(h - y))
|
||||||
|
|
||||||
|
params = {"weight": mx.array(1.0), "bias": mx.array(0.0)}
|
||||||
|
x = mx.array([0.5, -0.5])
|
||||||
|
y = mx.array([1.5, -1.5])
|
||||||
|
|
||||||
|
# Computes the gradient of loss_fn with respect to both the
|
||||||
|
# weight and bias:
|
||||||
|
grad_fn = mx.grad(loss_fn)
|
||||||
|
grads = grad_fn(params, x, y)
|
||||||
|
|
||||||
|
# Prints
|
||||||
|
# {'weight': array(-1, dtype=float32), 'bias': array(0, dtype=float32)}
|
||||||
|
print(grads)
|
||||||
|
|
||||||
|
Notice the tree structure of the parameters is preserved in the gradients.
|
||||||
|
|
||||||
|
In some cases you may want to stop gradients from propagating through a
|
||||||
|
part of the function. You can use the :func:`stop_gradient` for that.
|
||||||
|
|
||||||
|
|
||||||
|
Automatic Vectorization
|
||||||
|
-----------------------
|
||||||
|
|
||||||
|
.. _vmap:
|
||||||
|
|
||||||
|
Use :func:`vmap` to automate vectorizing complex functions. Here we'll go
|
||||||
|
through a basic and contrived example for the sake of clarity, but :func:`vmap`
|
||||||
|
can be quite powerful for more complex functions which are difficult to optimize
|
||||||
|
by hand.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
|
||||||
|
Some operations are not yet supported with :func:`vmap`. If you encounter an error
|
||||||
|
like: ``ValueError: Primitive's vmap not implemented.`` file an `issue
|
||||||
|
<https://github.com/ml-explore/mlx/issues>`_ and include your function.
|
||||||
|
We will prioritize including it.
|
||||||
|
|
||||||
|
A naive way to add the elements from two sets of vectors is with a loop:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
xs = mx.random.uniform(shape=(4096, 100))
|
||||||
|
ys = mx.random.uniform(shape=(100, 4096))
|
||||||
|
|
||||||
|
def naive_add(xs, ys):
|
||||||
|
return [xs[i] + ys[:, i] for i in range(xs.shape[1])]
|
||||||
|
|
||||||
|
Instead you can use :func:`vmap` to automatically vectorize the addition:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
# Vectorize over the second dimension of x and the
|
||||||
|
# first dimension of y
|
||||||
|
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(1, 0))
|
||||||
|
|
||||||
|
The ``in_axes`` parameter can be used to specify which dimensions of the
|
||||||
|
corresponding input to vectorize over. Similarly, use ``out_axes`` to specify
|
||||||
|
where the vectorized axes should be in the outputs.
|
||||||
|
|
||||||
|
Let's time these two different versions:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import timeit
|
||||||
|
|
||||||
|
print(timeit.timeit(lambda: mx.eval(naive_add(xs, ys)), number=100))
|
||||||
|
print(timeit.timeit(lambda: mx.eval(vmap_add(xs, ys)), number=100))
|
||||||
|
|
||||||
|
On an M1 Max the naive version takes in total ``0.390`` seconds whereas the
|
||||||
|
vectorized version takes only ``0.025`` seconds, more than ten times faster.
|
||||||
|
|
||||||
|
Of course, this operation is quite contrived. A better approach is to simply do
|
||||||
|
``xs + ys.T``, but for more complex functions :func:`vmap` can be quite handy.
|
||||||
|
.. _indexing:
|
||||||
|
|
||||||
|
Indexing Arrays
|
||||||
|
===============
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core
|
||||||
|
|
||||||
|
For the most part, indexing an MLX :obj:`array` works the same as indexing a
|
||||||
|
NumPy :obj:`numpy.ndarray`. See the `NumPy documentation
|
||||||
|
<https://numpy.org/doc/stable/user/basics.indexing.html>`_ for more details on
|
||||||
|
how that works.
|
||||||
|
|
||||||
|
For example, you can use regular integers and slices (:obj:`slice`) to index arrays:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> arr = mx.arange(10)
|
||||||
|
>>> arr[3]
|
||||||
|
array(3, dtype=int32)
|
||||||
|
>>> arr[-2] # negative indexing works
|
||||||
|
array(8, dtype=int32)
|
||||||
|
>>> arr[2:8:2] # start, stop, stride
|
||||||
|
array([2, 4, 6], dtype=int32)
|
||||||
|
|
||||||
|
For multi-dimensional arrays, the ``...`` or :obj:`Ellipsis` syntax works as in NumPy:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> arr = mx.arange(8).reshape(2, 2, 2)
|
||||||
|
>>> arr[:, :, 0]
|
||||||
|
array(3, dtype=int32)
|
||||||
|
array([[0, 2],
|
||||||
|
[4, 6]], dtype=int32
|
||||||
|
>>> arr[..., 0]
|
||||||
|
array([[0, 2],
|
||||||
|
[4, 6]], dtype=int32
|
||||||
|
|
||||||
|
You can index with ``None`` to create a new axis:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> arr = mx.arange(8)
|
||||||
|
>>> arr.shape
|
||||||
|
[8]
|
||||||
|
>>> arr[None].shape
|
||||||
|
[1, 8]
|
||||||
|
|
||||||
|
|
||||||
|
You can also use an :obj:`array` to index another :obj:`array`:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> arr = mx.arange(10)
|
||||||
|
>>> idx = mx.array([5, 7])
|
||||||
|
>>> arr[idx]
|
||||||
|
array([5, 7], dtype=int32)
|
||||||
|
|
||||||
|
Mixing and matching integers, :obj:`slice`, ``...``, and :obj:`array` indices
|
||||||
|
works just as in NumPy.
|
||||||
|
|
||||||
|
Other functions which may be useful for indexing arrays are :func:`take` and
|
||||||
|
:func:`take_along_axis`.
|
||||||
|
|
||||||
|
Differences from NumPy
|
||||||
|
----------------------
|
||||||
|
|
||||||
|
.. Note::
|
||||||
|
|
||||||
|
MLX indexing is different from NumPy indexing in two important ways:
|
||||||
|
|
||||||
|
* Indexing does not perform bounds checking. Indexing out of bounds is
|
||||||
|
undefined behavior.
|
||||||
|
* Boolean mask based indexing is not yet supported.
|
||||||
|
|
||||||
|
The reason for the lack of bounds checking is that exceptions cannot propagate
|
||||||
|
from the GPU. Performing bounds checking for array indices before launching the
|
||||||
|
kernel would be extremely inefficient.
|
||||||
|
|
||||||
|
Indexing with boolean masks is something that MLX may support in the future. In
|
||||||
|
general, MLX has limited support for operations for which outputs
|
||||||
|
*shapes* are dependent on input *data*. Other examples of these types of
|
||||||
|
operations which MLX does not yet support include :func:`numpy.nonzero` and the
|
||||||
|
single input version of :func:`numpy.where`.
|
||||||
|
|
||||||
|
In Place Updates
|
||||||
|
----------------
|
||||||
|
|
||||||
|
In place updates to indexed arrays are possible in MLX. For example:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> a = mx.array([1, 2, 3])
|
||||||
|
>>> a[2] = 0
|
||||||
|
>>> a
|
||||||
|
array([1, 2, 0], dtype=int32)
|
||||||
|
|
||||||
|
Just as in NumPy, in place updates will be reflected in all references to the
|
||||||
|
same array:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> a = mx.array([1, 2, 3])
|
||||||
|
>>> b = a
|
||||||
|
>>> b[2] = 0
|
||||||
|
>>> b
|
||||||
|
array([1, 2, 0], dtype=int32)
|
||||||
|
>>> a
|
||||||
|
array([1, 2, 0], dtype=int32)
|
||||||
|
|
||||||
|
Transformations of functions which use in-place updates are allowed and work as
|
||||||
|
expected. For example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def fun(x, idx):
|
||||||
|
x[idx] = 2.0
|
||||||
|
return x.sum()
|
||||||
|
|
||||||
|
dfdx = mx.grad(fun)(mx.array([1.0, 2.0, 3.0]), mx.array([1]))
|
||||||
|
print(dfdx) # Prints: array([1, 0, 1], dtype=float32)
|
||||||
|
|
||||||
|
In the above ``dfdx`` will have the correct gradient, namely zeros at ``idx``
|
||||||
|
and ones elsewhere.
|
||||||
|
.. _lazy eval:
|
||||||
|
|
||||||
|
Lazy Evaluation
|
||||||
|
===============
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core
|
||||||
|
|
||||||
|
Why Lazy Evaluation
|
||||||
|
-------------------
|
||||||
|
|
||||||
|
When you perform operations in MLX, no computation actually happens. Instead a
|
||||||
|
compute graph is recorded. The actual computation only happens if an
|
||||||
|
:func:`eval` is performed.
|
||||||
|
|
||||||
|
MLX uses lazy evaluation because it has some nice features, some of which we
|
||||||
|
describe below.
|
||||||
|
|
||||||
|
Transforming Compute Graphs
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
Lazy evaluation let's us record a compute graph without actually doing any
|
||||||
|
computations. This is useful for function transformations like :func:`grad` and
|
||||||
|
:func:`vmap` and graph optimizations like :func:`simplify`.
|
||||||
|
|
||||||
|
Currently, MLX does not compile and rerun compute graphs. They are all
|
||||||
|
generated dynamically. However, lazy evaluation makes it much easier to
|
||||||
|
integrate compilation for future performance enhancements.
|
||||||
|
|
||||||
|
Only Compute What You Use
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
In MLX you do not need to worry as much about computing outputs that are never
|
||||||
|
used. For example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def fun(x):
|
||||||
|
a = fun1(x)
|
||||||
|
b = expensive_fun(a)
|
||||||
|
return a, b
|
||||||
|
|
||||||
|
y, _ = fun(x)
|
||||||
|
|
||||||
|
Here, we never actually compute the output of ``expensive_fun``. Use this
|
||||||
|
pattern with care though, as the graph of ``expensive_fun`` is still built, and
|
||||||
|
that has some cost associated to it.
|
||||||
|
|
||||||
|
Similarly, lazy evaluation can be beneficial for saving memory while keeping
|
||||||
|
code simple. Say you have a very large model ``Model`` derived from
|
||||||
|
:obj:`mlx.nn.Module`. You can instantiate this model with ``model = Model()``.
|
||||||
|
Typically, this will initialize all of the weights as ``float32``, but the
|
||||||
|
initialization does not actually compute anything until you perform an
|
||||||
|
:func:`eval`. If you update the model with ``float16`` weights, your maximum
|
||||||
|
consumed memory will be half that required if eager computation was used
|
||||||
|
instead.
|
||||||
|
|
||||||
|
This pattern is simple to do in MLX thanks to lazy computation:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
model = Model() # no memory used yet
|
||||||
|
model.load_weights("weights_fp16.safetensors")
|
||||||
|
|
||||||
|
When to Evaluate
|
||||||
|
----------------
|
||||||
|
|
||||||
|
A common question is when to use :func:`eval`. The trade-off is between
|
||||||
|
letting graphs get too large and not batching enough useful work.
|
||||||
|
|
||||||
|
For example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
for _ in range(100):
|
||||||
|
a = a + b
|
||||||
|
mx.eval(a)
|
||||||
|
b = b * 2
|
||||||
|
mx.eval(b)
|
||||||
|
|
||||||
|
This is a bad idea because there is some fixed overhead with each graph
|
||||||
|
evaluation. On the other hand, there is some slight overhead which grows with
|
||||||
|
the compute graph size, so extremely large graphs (while computationally
|
||||||
|
correct) can be costly.
|
||||||
|
|
||||||
|
Luckily, a wide range of compute graph sizes work pretty well with MLX:
|
||||||
|
anything from a few tens of operations to many thousands of operations per
|
||||||
|
evaluation should be okay.
|
||||||
|
|
||||||
|
Most numerical computations have an iterative outer loop (e.g. the iteration in
|
||||||
|
stochastic gradient descent). A natural and usually efficient place to use
|
||||||
|
:func:`eval` is at each iteration of this outer loop.
|
||||||
|
|
||||||
|
Here is a concrete example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
for batch in dataset:
|
||||||
|
|
||||||
|
# Nothing has been evaluated yet
|
||||||
|
loss, grad = value_and_grad_fn(model, batch)
|
||||||
|
|
||||||
|
# Still nothing has been evaluated
|
||||||
|
optimizer.update(model, grad)
|
||||||
|
|
||||||
|
# Evaluate the loss and the new parameters which will
|
||||||
|
# run the full gradient computation and optimizer update
|
||||||
|
mx.eval(loss, model.parameters())
|
||||||
|
|
||||||
|
|
||||||
|
An important behavior to be aware of is when the graph will be implicitly
|
||||||
|
evaluated. Anytime you ``print`` an array, convert it to an
|
||||||
|
:obj:`numpy.ndarray`, or otherwise access it's memory via :obj:`memoryview`,
|
||||||
|
the graph will be evaluated. Saving arrays via :func:`save` (or any other MLX
|
||||||
|
saving functions) will also evaluate the array.
|
||||||
|
|
||||||
|
|
||||||
|
Calling :func:`array.item` on a scalar array will also evaluate it. In the
|
||||||
|
example above, printing the loss (``print(loss)``) or adding the loss scalar to
|
||||||
|
a list (``losses.append(loss.item())``) would cause a graph evaluation. If
|
||||||
|
these lines are before ``mx.eval(loss, model.parameters())`` then this
|
||||||
|
will be a partial evaluation, computing only the forward pass.
|
||||||
|
|
||||||
|
Also, calling :func:`eval` on an array or set of arrays multiple times is
|
||||||
|
perfectly fine. This is effectively a no-op.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
|
||||||
|
Using scalar arrays for control-flow will cause an evaluation.
|
||||||
|
|
||||||
|
Here is an example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def fun(x):
|
||||||
|
h, y = first_layer(x)
|
||||||
|
if y > 0: # An evaluation is done here!
|
||||||
|
z = second_layer_a(h)
|
||||||
|
else:
|
||||||
|
z = second_layer_b(h)
|
||||||
|
return z
|
||||||
|
|
||||||
|
Using arrays for control flow should be done with care. The above example works
|
||||||
|
and can even be used with gradient transformations. However, this can be very
|
||||||
|
inefficient if evaluations are done too frequently.
|
||||||
|
.. _numpy:
|
||||||
|
|
||||||
|
Conversion to NumPy and Other Frameworks
|
||||||
|
========================================
|
||||||
|
|
||||||
|
MLX array implements the `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
|
||||||
|
Let's convert an array to NumPy and back.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
a = mx.arange(3)
|
||||||
|
b = np.array(a) # copy of a
|
||||||
|
c = mx.array(b) # copy of b
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
Since NumPy does not support ``bfloat16`` arrays, you will need to convert to ``float16`` or ``float32`` first:
|
||||||
|
``np.array(a.astype(mx.float32))``.
|
||||||
|
Otherwise, you will receive an error like: ``Item size 2 for PEP 3118 buffer format string does not match the dtype V item size 0.``
|
||||||
|
|
||||||
|
By default, NumPy copies data to a new array. This can be prevented by creating an array view:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
a = mx.arange(3)
|
||||||
|
a_view = np.array(a, copy=False)
|
||||||
|
print(a_view.flags.owndata) # False
|
||||||
|
a_view[0] = 1
|
||||||
|
print(a[0].item()) # 1
|
||||||
|
|
||||||
|
A NumPy array view is a normal NumPy array, except that it does not own its memory.
|
||||||
|
This means writing to the view is reflected in the original array.
|
||||||
|
|
||||||
|
While this is quite powerful to prevent copying arrays, it should be noted that external changes to the memory of arrays cannot be reflected in gradients.
|
||||||
|
|
||||||
|
Let's demonstrate this in an example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def f(x):
|
||||||
|
x_view = np.array(x, copy=False)
|
||||||
|
x_view[:] *= x_view # modify memory without telling mx
|
||||||
|
return x.sum()
|
||||||
|
|
||||||
|
x = mx.array([3.0])
|
||||||
|
y, df = mx.value_and_grad(f)(x)
|
||||||
|
print("f(x) = x² =", y.item()) # 9.0
|
||||||
|
print("f'(x) = 2x !=", df.item()) # 1.0
|
||||||
|
|
||||||
|
|
||||||
|
The function ``f`` indirectly modifies the array ``x`` through a memory view.
|
||||||
|
However, this modification is not reflected in the gradient, as seen in the last line outputting ``1.0``,
|
||||||
|
representing the gradient of the sum operation alone.
|
||||||
|
The squaring of ``x`` occurs externally to MLX, meaning that no gradient is incorporated.
|
||||||
|
It's important to note that a similar issue arises during array conversion and copying.
|
||||||
|
For instance, a function defined as ``mx.array(np.array(x)**2).sum()`` would also result in an incorrect gradient,
|
||||||
|
even though no in-place operations on MLX memory are executed.
|
||||||
|
|
||||||
|
PyTorch
|
||||||
|
-------
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
|
||||||
|
PyTorch Support for :obj:`memoryview` is experimental and can break for
|
||||||
|
multi-dimensional arrays. Casting to NumPy first is advised for now.
|
||||||
|
|
||||||
|
PyTorch supports the buffer protocol, but it requires an explicit :obj:`memoryview`.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import torch
|
||||||
|
|
||||||
|
a = mx.arange(3)
|
||||||
|
b = torch.tensor(memoryview(a))
|
||||||
|
c = mx.array(b.numpy())
|
||||||
|
|
||||||
|
Conversion from PyTorch tensors back to arrays must be done via intermediate NumPy arrays with ``numpy()``.
|
||||||
|
|
||||||
|
JAX
|
||||||
|
---
|
||||||
|
JAX fully supports the buffer protocol.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
a = mx.arange(3)
|
||||||
|
b = jnp.array(a)
|
||||||
|
c = mx.array(b)
|
||||||
|
|
||||||
|
TensorFlow
|
||||||
|
----------
|
||||||
|
|
||||||
|
TensorFlow supports the buffer protocol, but it requires an explicit :obj:`memoryview`.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
a = mx.arange(3)
|
||||||
|
b = tf.constant(memoryview(a))
|
||||||
|
c = mx.array(b)
|
||||||
|
.. _saving_and_loading:
|
||||||
|
|
||||||
|
Saving and Loading Arrays
|
||||||
|
=========================
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core
|
||||||
|
|
||||||
|
MLX supports multiple array serialization formats.
|
||||||
|
|
||||||
|
.. list-table:: Serialization Formats
|
||||||
|
:widths: 20 8 25 25
|
||||||
|
:header-rows: 1
|
||||||
|
|
||||||
|
* - Format
|
||||||
|
- Extension
|
||||||
|
- Function
|
||||||
|
- Notes
|
||||||
|
* - NumPy
|
||||||
|
- ``.npy``
|
||||||
|
- :func:`save`
|
||||||
|
- Single arrays only
|
||||||
|
* - NumPy archive
|
||||||
|
- ``.npz``
|
||||||
|
- :func:`savez` and :func:`savez_compressed`
|
||||||
|
- Multiple arrays
|
||||||
|
* - Safetensors
|
||||||
|
- ``.safetensors``
|
||||||
|
- :func:`save_safetensors`
|
||||||
|
- Multiple arrays
|
||||||
|
* - GGUF
|
||||||
|
- ``.gguf``
|
||||||
|
- :func:`save_gguf`
|
||||||
|
- Multiple arrays
|
||||||
|
|
||||||
|
The :func:`load` function will load any of the supported serialization
|
||||||
|
formats. It determines the format from the extensions. The output of
|
||||||
|
:func:`load` depends on the format.
|
||||||
|
|
||||||
|
Here's an example of saving a single array to a file:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> a = mx.array([1.0])
|
||||||
|
>>> mx.save("array", a)
|
||||||
|
|
||||||
|
The array ``a`` will be saved in the file ``array.npy`` (notice the extension
|
||||||
|
is automatically added). Including the extension is optional; if it is missing
|
||||||
|
it will be added. You can load the array with:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> mx.load("array.npy", a)
|
||||||
|
array([1], dtype=float32)
|
||||||
|
|
||||||
|
Here's an example of saving several arrays to a single file:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> a = mx.array([1.0])
|
||||||
|
>>> b = mx.array([2.0])
|
||||||
|
>>> mx.savez("arrays", a, b=b)
|
||||||
|
|
||||||
|
For compatibility with :func:`numpy.savez` the MLX :func:`savez` takes arrays
|
||||||
|
as arguments. If the keywords are missing, then default names will be
|
||||||
|
provided. This can be loaded with:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> mx.load("arrays.npz")
|
||||||
|
{'b': array([2], dtype=float32), 'arr_0': array([1], dtype=float32)}
|
||||||
|
|
||||||
|
In this case :func:`load` returns a dictionary of names to arrays.
|
||||||
|
|
||||||
|
The functions :func:`save_safetensors` and :func:`save_gguf` are similar to
|
||||||
|
:func:`savez`, but they take as input a :obj:`dict` of string names to arrays:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> a = mx.array([1.0])
|
||||||
|
>>> b = mx.array([2.0])
|
||||||
|
>>> mx.save_safetensors("arrays", {"a": a, "b": b})
|
Loading…
Add table
Add a link
Reference in a new issue