mirror of
https://github.com/LouisShark/chatgpt_system_prompt.git
synced 2025-07-06 23:00:34 -04:00
587 lines
12 KiB
Text
587 lines
12 KiB
Text
.. _array:
|
|
|
|
Array
|
|
=====
|
|
|
|
.. currentmodule:: mlx.core
|
|
|
|
.. autosummary::
|
|
:toctree: _autosummary
|
|
|
|
array
|
|
array.astype
|
|
array.item
|
|
array.tolist
|
|
array.dtype
|
|
array.ndim
|
|
array.shape
|
|
array.size
|
|
Dtype
|
|
array.abs
|
|
array.all
|
|
array.any
|
|
array.argmax
|
|
array.argmin
|
|
array.cos
|
|
array.dtype
|
|
array.exp
|
|
array.log
|
|
array.log1p
|
|
array.logsumexp
|
|
array.max
|
|
array.mean
|
|
array.min
|
|
array.prod
|
|
array.reciprocal
|
|
array.reshape
|
|
array.round
|
|
array.rsqrt
|
|
array.sin
|
|
array.split
|
|
array.sqrt
|
|
array.square
|
|
array.sum
|
|
array.transpose
|
|
array.T
|
|
array.var
|
|
.. _data_types:
|
|
|
|
:orphan:
|
|
|
|
Data Types
|
|
==========
|
|
|
|
.. currentmodule:: mlx.core
|
|
|
|
The default floating point type is ``float32`` and the default integer type is
|
|
``int32``. The table below shows supported values for :obj:`Dtype`.
|
|
|
|
.. list-table:: Supported Data Types
|
|
:widths: 5 3 20
|
|
:header-rows: 1
|
|
|
|
* - Type
|
|
- Bytes
|
|
- Description
|
|
* - ``bool_``
|
|
- 1
|
|
- Boolean (``True``, ``False``) data type
|
|
* - ``uint8``
|
|
- 1
|
|
- 8-bit unsigned integer
|
|
* - ``uint16``
|
|
- 2
|
|
- 16-bit unsigned integer
|
|
* - ``uint32``
|
|
- 4
|
|
- 32-bit unsigned integer
|
|
* - ``uint64``
|
|
- 8
|
|
- 64-bit unsigned integer
|
|
* - ``int8``
|
|
- 1
|
|
- 8-bit signed integer
|
|
* - ``int16``
|
|
- 2
|
|
- 16-bit signed integer
|
|
* - ``int32``
|
|
- 4
|
|
- 32-bit signed integer
|
|
* - ``int64``
|
|
- 8
|
|
- 64-bit signed integer
|
|
* - ``float16``
|
|
- 2
|
|
- 16-bit float, only available with `ARM C language extensions <https://developer.arm.com/documentation/101028/0012/3--C-language-extensions?lang=en>`_
|
|
* - ``float32``
|
|
- 4
|
|
- 32-bit float
|
|
.. _devices_and_streams:
|
|
|
|
Devices and Streams
|
|
===================
|
|
|
|
.. currentmodule:: mlx.core
|
|
|
|
.. autosummary::
|
|
:toctree: _autosummary
|
|
|
|
Device
|
|
default_device
|
|
set_default_device
|
|
Stream
|
|
default_stream
|
|
new_stream
|
|
set_default_stream
|
|
.. _fft:
|
|
|
|
FFT
|
|
===
|
|
|
|
.. currentmodule:: mlx.core.fft
|
|
|
|
.. autosummary::
|
|
:toctree: _autosummary
|
|
|
|
fft
|
|
ifft
|
|
fft2
|
|
ifft2
|
|
fftn
|
|
ifftn
|
|
rfft
|
|
irfft
|
|
rfft2
|
|
irfft2
|
|
rfftn
|
|
irfftn
|
|
.. _linalg:
|
|
|
|
Linear Algebra
|
|
==============
|
|
|
|
.. currentmodule:: mlx.core.linalg
|
|
|
|
.. autosummary::
|
|
:toctree: _autosummary
|
|
|
|
norm
|
|
.. _nn:
|
|
|
|
.. currentmodule:: mlx.nn
|
|
|
|
Neural Networks
|
|
===============
|
|
|
|
Writing arbitrarily complex neural networks in MLX can be done using only
|
|
:class:`mlx.core.array` and :meth:`mlx.core.value_and_grad`. However, this requires the
|
|
user to write again and again the same simple neural network operations as well
|
|
as handle all the parameter state and initialization manually and explicitly.
|
|
|
|
The module :mod:`mlx.nn` solves this problem by providing an intuitive way of
|
|
composing neural network layers, initializing their parameters, freezing them
|
|
for finetuning and more.
|
|
|
|
Quick Start with Neural Networks
|
|
---------------------------------
|
|
|
|
.. code-block:: python
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
|
|
class MLP(nn.Module):
|
|
def __init__(self, in_dims: int, out_dims: int):
|
|
super().__init__()
|
|
|
|
self.layers = [
|
|
nn.Linear(in_dims, 128),
|
|
nn.Linear(128, 128),
|
|
nn.Linear(128, out_dims),
|
|
]
|
|
|
|
def __call__(self, x):
|
|
for i, l in enumerate(self.layers):
|
|
x = mx.maximum(x, 0) if i > 0 else x
|
|
x = l(x)
|
|
return x
|
|
|
|
# The model is created with all its parameters but nothing is initialized
|
|
# yet because MLX is lazily evaluated
|
|
mlp = MLP(2, 10)
|
|
|
|
# We can access its parameters by calling mlp.parameters()
|
|
params = mlp.parameters()
|
|
print(params["layers"][0]["weight"].shape)
|
|
|
|
# Printing a parameter will cause it to be evaluated and thus initialized
|
|
print(params["layers"][0])
|
|
|
|
# We can also force evaluate all parameters to initialize the model
|
|
mx.eval(mlp.parameters())
|
|
|
|
# A simple loss function.
|
|
# NOTE: It doesn't matter how it uses the mlp model. It currently captures
|
|
# it from the local scope. It could be a positional argument or a
|
|
# keyword argument.
|
|
def l2_loss(x, y):
|
|
y_hat = mlp(x)
|
|
return (y_hat - y).square().mean()
|
|
|
|
# Calling `nn.value_and_grad` instead of `mx.value_and_grad` returns the
|
|
# gradient with respect to `mlp.trainable_parameters()`
|
|
loss_and_grad = nn.value_and_grad(mlp, l2_loss)
|
|
|
|
.. _module_class:
|
|
|
|
The Module Class
|
|
----------------
|
|
|
|
The workhorse of any neural network library is the :class:`Module` class. In
|
|
MLX the :class:`Module` class is a container of :class:`mlx.core.array` or
|
|
:class:`Module` instances. Its main function is to provide a way to
|
|
recursively **access** and **update** its parameters and those of its
|
|
submodules.
|
|
|
|
Parameters
|
|
^^^^^^^^^^
|
|
|
|
A parameter of a module is any public member of type :class:`mlx.core.array` (its
|
|
name should not start with ``_``). It can be arbitrarily nested in other
|
|
:class:`Module` instances or lists and dictionaries.
|
|
|
|
:meth:`Module.parameters` can be used to extract a nested dictionary with all
|
|
the parameters of a module and its submodules.
|
|
|
|
A :class:`Module` can also keep track of "frozen" parameters. See the
|
|
:meth:`Module.freeze` method for more details. :meth:`mlx.nn.value_and_grad`
|
|
the gradients returned will be with respect to these trainable parameters.
|
|
|
|
|
|
Updating the Parameters
|
|
^^^^^^^^^^^^^^^^^^^^^^^
|
|
|
|
MLX modules allow accessing and updating individual parameters. However, most
|
|
times we need to update large subsets of a module's parameters. This action is
|
|
performed by :meth:`Module.update`.
|
|
|
|
|
|
Inspecting Modules
|
|
^^^^^^^^^^^^^^^^^^
|
|
|
|
The simplest way to see the model architecture is to print it. Following along with
|
|
the above example, you can print the ``MLP`` with:
|
|
|
|
.. code-block:: python
|
|
|
|
print(mlp)
|
|
|
|
This will display:
|
|
|
|
.. code-block:: shell
|
|
|
|
MLP(
|
|
(layers.0): Linear(input_dims=2, output_dims=128, bias=True)
|
|
(layers.1): Linear(input_dims=128, output_dims=128, bias=True)
|
|
(layers.2): Linear(input_dims=128, output_dims=10, bias=True)
|
|
)
|
|
|
|
To get more detailed information on the arrays in a :class:`Module` you can use
|
|
:func:`mlx.utils.tree_map` on the parameters. For example, to see the shapes of
|
|
all the parameters in a :class:`Module` do:
|
|
|
|
.. code-block:: python
|
|
|
|
from mlx.utils import tree_map
|
|
shapes = tree_map(lambda p: p.shape, mlp.parameters())
|
|
|
|
As another example, you can count the number of parameters in a :class:`Module`
|
|
with:
|
|
|
|
.. code-block:: python
|
|
|
|
from mlx.utils import tree_flatten
|
|
num_params = sum(v.size for _, v in tree_flatten(mlp.parameters()))
|
|
|
|
|
|
Value and Grad
|
|
--------------
|
|
|
|
Using a :class:`Module` does not preclude using MLX's high order function
|
|
transformations (:meth:`mlx.core.value_and_grad`, :meth:`mlx.core.grad`, etc.). However,
|
|
these function transformations assume pure functions, namely the parameters
|
|
should be passed as an argument to the function being transformed.
|
|
|
|
There is an easy pattern to achieve that with MLX modules
|
|
|
|
.. code-block:: python
|
|
|
|
model = ...
|
|
|
|
def f(params, other_inputs):
|
|
model.update(params) # <---- Necessary to make the model use the passed parameters
|
|
return model(other_inputs)
|
|
|
|
f(model.trainable_parameters(), mx.zeros((10,)))
|
|
|
|
However, :meth:`mlx.nn.value_and_grad` provides precisely this pattern and only
|
|
computes the gradients with respect to the trainable parameters of the model.
|
|
|
|
In detail:
|
|
|
|
- it wraps the passed function with a function that calls :meth:`Module.update`
|
|
to make sure the model is using the provided parameters.
|
|
- it calls :meth:`mlx.core.value_and_grad` to transform the function into a function
|
|
that also computes the gradients with respect to the passed parameters.
|
|
- it wraps the returned function with a function that passes the trainable
|
|
parameters as the first argument to the function returned by
|
|
:meth:`mlx.core.value_and_grad`
|
|
|
|
.. autosummary::
|
|
:toctree: _autosummary
|
|
|
|
value_and_grad
|
|
|
|
.. toctree::
|
|
|
|
nn/module
|
|
nn/layers
|
|
nn/functions
|
|
nn/losses
|
|
nn/init
|
|
.. _ops:
|
|
|
|
Operations
|
|
==========
|
|
|
|
.. currentmodule:: mlx.core
|
|
|
|
.. autosummary::
|
|
:toctree: _autosummary
|
|
|
|
abs
|
|
add
|
|
all
|
|
allclose
|
|
any
|
|
arange
|
|
arccos
|
|
arccosh
|
|
arcsin
|
|
arcsinh
|
|
arctan
|
|
arctanh
|
|
argmax
|
|
argmin
|
|
argpartition
|
|
argsort
|
|
array_equal
|
|
broadcast_to
|
|
ceil
|
|
clip
|
|
concatenate
|
|
convolve
|
|
conv1d
|
|
conv2d
|
|
cos
|
|
cosh
|
|
dequantize
|
|
divide
|
|
divmod
|
|
equal
|
|
erf
|
|
erfinv
|
|
exp
|
|
expand_dims
|
|
eye
|
|
flatten
|
|
floor
|
|
floor_divide
|
|
full
|
|
greater
|
|
greater_equal
|
|
identity
|
|
inner
|
|
isnan
|
|
isposinf
|
|
isneginf
|
|
isinf
|
|
less
|
|
less_equal
|
|
linspace
|
|
load
|
|
log
|
|
log2
|
|
log10
|
|
log1p
|
|
logaddexp
|
|
logical_not
|
|
logical_and
|
|
logical_or
|
|
logsumexp
|
|
matmul
|
|
max
|
|
maximum
|
|
mean
|
|
min
|
|
minimum
|
|
moveaxis
|
|
multiply
|
|
negative
|
|
ones
|
|
ones_like
|
|
outer
|
|
partition
|
|
pad
|
|
prod
|
|
quantize
|
|
quantized_matmul
|
|
reciprocal
|
|
repeat
|
|
reshape
|
|
round
|
|
rsqrt
|
|
save
|
|
savez
|
|
savez_compressed
|
|
save_gguf
|
|
save_safetensors
|
|
sigmoid
|
|
sign
|
|
sin
|
|
sinh
|
|
softmax
|
|
sort
|
|
split
|
|
sqrt
|
|
square
|
|
squeeze
|
|
stack
|
|
stop_gradient
|
|
subtract
|
|
sum
|
|
swapaxes
|
|
take
|
|
take_along_axis
|
|
tan
|
|
tanh
|
|
tensordot
|
|
transpose
|
|
tri
|
|
tril
|
|
triu
|
|
var
|
|
where
|
|
zeros
|
|
zeros_like
|
|
.. _optimizers:
|
|
|
|
Optimizers
|
|
==========
|
|
|
|
The optimizers in MLX can be used both with :mod:`mlx.nn` but also with pure
|
|
:mod:`mlx.core` functions. A typical example involves calling
|
|
:meth:`Optimizer.update` to update a model's parameters based on the loss
|
|
gradients and subsequently calling :func:`mlx.core.eval` to evaluate both the
|
|
model's parameters and the **optimizer state**.
|
|
|
|
.. code-block:: python
|
|
|
|
# Create a model
|
|
model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes)
|
|
mx.eval(model.parameters())
|
|
|
|
# Create the gradient function and the optimizer
|
|
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
|
optimizer = optim.SGD(learning_rate=learning_rate)
|
|
|
|
for e in range(num_epochs):
|
|
for X, y in batch_iterate(batch_size, train_images, train_labels):
|
|
loss, grads = loss_and_grad_fn(model, X, y)
|
|
|
|
# Update the model with the gradients. So far no computation has happened.
|
|
optimizer.update(model, grads)
|
|
|
|
# Compute the new parameters but also the optimizer state.
|
|
mx.eval(model.parameters(), optimizer.state)
|
|
|
|
.. currentmodule:: mlx.optimizers
|
|
|
|
.. autosummary::
|
|
:toctree: _autosummary
|
|
:template: optimizers-template.rst
|
|
|
|
OptimizerState
|
|
Optimizer
|
|
SGD
|
|
RMSprop
|
|
Adagrad
|
|
Adafactor
|
|
AdaDelta
|
|
Adam
|
|
AdamW
|
|
Adamax
|
|
Lion
|
|
.. _random:
|
|
|
|
Random
|
|
======
|
|
|
|
Random sampling functions in MLX use an implicit global PRNG state by default.
|
|
However, all function take an optional ``key`` keyword argument for when more
|
|
fine-grained control or explicit state management is needed.
|
|
|
|
For example, you can generate random numbers with:
|
|
|
|
.. code-block:: python
|
|
|
|
for _ in range(3):
|
|
print(mx.random.uniform())
|
|
|
|
which will print a sequence of unique pseudo random numbers. Alternatively you
|
|
can explicitly set the key:
|
|
|
|
.. code-block:: python
|
|
|
|
key = mx.random.key(0)
|
|
for _ in range(3):
|
|
print(mx.random.uniform(key=key))
|
|
|
|
which will yield the same pseudo random number at each iteration.
|
|
|
|
Following `JAX's PRNG design <https://jax.readthedocs.io/en/latest/jep/263-prng.html>`_
|
|
we use a splittable version of Threefry, which is a counter-based PRNG.
|
|
|
|
.. currentmodule:: mlx.core.random
|
|
|
|
.. autosummary::
|
|
:toctree: _autosummary
|
|
|
|
bernoulli
|
|
categorical
|
|
gumbel
|
|
key
|
|
normal
|
|
randint
|
|
seed
|
|
split
|
|
truncated_normal
|
|
uniform
|
|
.. _transforms:
|
|
|
|
Transforms
|
|
==========
|
|
|
|
.. currentmodule:: mlx.core
|
|
|
|
.. autosummary::
|
|
:toctree: _autosummary
|
|
|
|
eval
|
|
grad
|
|
value_and_grad
|
|
jvp
|
|
vjp
|
|
vmap
|
|
simplify
|
|
.. _utils:
|
|
|
|
Tree Utils
|
|
==========
|
|
|
|
In MLX we consider a python tree to be an arbitrarily nested collection of
|
|
dictionaries, lists and tuples without cycles. Functions in this module that
|
|
return python trees will be using the default python ``dict``, ``list`` and
|
|
``tuple`` but they can usually process objects that inherit from any of these.
|
|
|
|
.. note::
|
|
Dictionaries should have keys that are valid python identifiers.
|
|
|
|
.. currentmodule:: mlx.utils
|
|
|
|
.. autosummary::
|
|
:toctree: _autosummary
|
|
|
|
tree_flatten
|
|
tree_unflatten
|
|
tree_map
|