Distributed Machine Learning Fundamentals: MLX

David Crawford / February 04, 2025
This post is part 2 of my Distributed Machine Learning series, you can go back to any of the posts that are published below:
- Introduction
- Preparing a Model with MLX (👈 this post)
- Dataset Preparation
- How to setup MPI
- Distributed Fine Tuning
- Next Steps
In my last post, I introduced the high level concepts for us to consider. Now, we can actually start installing and running things to get some preliminary results.
My goal with this post is to get anyone comfortable with running inference on a model using Apple's open source MLX framework, and how to choose a model that is right for you. We'll be doing the following:
- Installing MLX
- Installing a good model
- Running inference using MLX
- Creating a simple logistic regression model with only MLX
MLX
Verbatim from their own repo, "MLX is an array framework for machine learning on Apple silicon, brought to you by Apple machine learning research." It is an open source competitor to industry staples like TensorFlow and PyTorch, and is specifically designed to run on Apple Silicon natively.
I wanted to use MLX for this series not just because it's newer and more topical, but because from my experience it's a lot easier to run on Apple Silicon, and it has less documentation which means we have to actually figure things out for ourselves, and build up fundamental knowledge. Let's get started.
Requirements
MLX has its own requirements:
- Using an M series chip (Apple silicon)
- Using a native Python >= 3.9
- macOS >= 13.5
Make sure to also have pyenv installed, because being able to manage your python versions for individual projects is a good practice.
Installing MLX
If you follow MLX's installation guide, you might end up frustrated, so try these steps instead:
- Create a folder and cd to it in terminal
- Create a new python virtual environment (otherwise you'll likely get segmentation fault errors):
mkdir mlx-project
cd mlx-project
python3 -m venv mlx-env
source mlx-env/bin/activate
- Tell
pyenv
to use the python version that you want in this virtual environment (I'm using 3.13.1):
pyenv local 3.13.1
- Verify that the right version is being aliased in this directory:
python --version
- If it's wrong, you may have a
pyenv
installation issue. You can check out their docs on your own, or oftentimes the quick temporary fix is to run this in your current terminal session:
eval "$(pyenv init --path)"
- Then install the MLX packages:
pip install -U mlx mlx-lm
- Make sure you're in your parent directory, and just run
mlx_lm.lora
in the terminal. You should get a bunch of errors, but no errors about the command not being found. This means it's installed and ready to go.
Getting a Model
The obvious key to distributed fine tuning, is having a model to fine tune. For this series, we'll be working with something that's not super sophisticated, but is also smart enough that we won't have to do a crazy amount of fine tuning to get some results.
I chose Mistral-7B-Instruct-v0.3-4bit. This is not to be confused with its big brother, Mistral-7B-Instruct-v0.3.
You can use the more sophisticated model if you like, however the smaller one fulfills a couple needs:
- It's small and therefore faster to download, fine-tune, and run inference on
- I have a 32GB RAM mac, and a 16GB RAM mac for my distributed setup. 16 GB simply isn't enough for the larger model, but the 4bit version is perfect. You might be wondering why distributed machine learning doesn't solve this by sharing the load of the larger model across multiple computers. We will tackle this concept later on in the series
- It's a "dumb" enough model that I've consistently gotten the same responses from it over and over, which is actually really good for testing
How to download the model
- You'll need an account, and you have to accept the terms and conditions to allow your account to download it
- Make a user access token here
- Install the Huggingface cli
pip install --upgrade huggingface_hub
huggingface-cli login
- Enter your token, then run this command to verify:
huggingface-cli whoami
- Next, install
git-lfs
because the model is several gigabytes:
git lfs install
- Clone the model (I put it outside of my project folder):
git clone https://huggingface.co/unsloth/mistral-7b-instruct-v0.3-bnb-4bit
Running inference
To run inference on the model, run this command (make sure to point to the model you downloaded):
mlx_lm.generate --model ../Mistral-7B-Instruct-v0.3-4bit --prompt "Tell me about greek and roman history like a pirate"
You should get a response like this:
==========
Ahoy matey! Let's set sail through the annals of Greek and Roman history, like a ship navigating the vast sea of time!
First, we'll anchor at the shores of ancient Greece, in the cradle of Western civilization. The Greeks, they were a clever bunch, with city-states like Athens and Sparta leading the charge. Athens, known for its wisdom, was home to philosophers like Socrates, Plato
==========
Prompt: 17 tokens, 88.261 tokens-per-sec
Generation: 100 tokens, 36.407 tokens-per-sec
Peak memory: 4.119 GB
If you get a response like this, then everything is working well. If you want to know how to do this programmatically, you can accomplish the same thing with python:
from mlx_lm import load, generate
model, tokenizer = load("../Mistral-7B-Instruct-v0.3-4bit")
prompt = "Tell me about greek and roman history like a pirate"
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
messages, add_generation_prompt=True
)
text = generate(model, tokenizer, prompt=prompt, verbose=True)
What's going on
This is a really basic example of working with a model in MLX. You'll find a lot of tutorials going over this out there, so I wanted to take this a step further into territory that isn't as documented: starting from scratch.
To build up our fundamentals, let's make a minimal logistic regression model built entirely with MLX arrays and plain Python to learn the following:
- How MLX differs from TensorFlow and PyTorch (if you're familiar with them, you'll notice terminology differences)
- To introduce you to gradients which will be important later in this series
- To show you how to build a model from scratch that you can train in a couple seconds
We're going to create a model that can predict the OR
function. This is a simple binary function that returns 1
if either of the inputs are 1
, and 0
if both are 0
.
Let's start by setting up the data using MLX arrays:
# The input data, a 2D matrix
X = mx.array([
[0, 0],
[0, 1],
[1, 0],
[1, 1]
])
# The output data, or "labels"
y = mx.array([0, 1, 1, 1])
Next, we initialize the model's parameters. We're dealing with two input features, which is an individual measurable characteristic of the data being used in the model. Features are the inputs to the model that help it to make a prediction.
In our case, the input features are binary values (0 or 1) in the input array.
# For two input features, we need a weight vector of shape (2,) which is a 1D array with 2 elements
w = mx.zeros(2)
# This is a bias term, an additional parameter that allows the model to fit the data better
# by shifting the decision boundary
b = 0.0
# This determines how much the model's parameters (weights and bias) are adjusted during
# each step of the training process. It determines the size of the steps taken towards
# minimizing the loss function
learning_rate = 0.1
# The number of complete passes the model makes through the entire dataset during training.
# During an epoch, the model processes each training example once and updates its parameters
# (weights and biases) based on the computed gradients
num_epochs = 1000
The learning rate is interesting because if it's too high, the model may take large steps and overshoot the optimal values of the parameters, leading to divergence or oscillation. If it's too low, the model will take very small steps, resulting in a slow convergence and possibly getting stuck.
Next we can define a couple of helper functions:
# Maps any real number to the range [0, 1]
def sigmoid(z):
return 1 / (1 + mx.exp(-z))
# Computes the model prediction.
# We input X as the data
# w as the weights which determine how important each input is
# b for bias to make better guesses
def predict(X, w, b):
b_array = mx.full((X.shape[0],), b)
return sigmoid(mx.addmm(b_array, X, w))
# Measures how good or bad the model's predictions are compared to the actual labels
def binary_cross_entropy(y_true, y_pred, eps=1e-8):
return -mx.mean(
y_true * mx.log(y_pred + eps) + (1 - y_true) * mx.log(1 - y_pred + eps)
)
Now, we create our training loop:
for epoch in range(num_epochs):
# Forward pass which computes predictions and loss
y_pred = predict(X, w, b)
loss = binary_cross_entropy(y, y_pred)
# Backwards pass which computes gradients manually. This essentially helps us teach
# the model how wrong it was in a bad prediction, so that it can learn.
grad = y_pred - y
# We look at the effect of each input on the wrong guesses and averages these effects
grad_w = mx.addmm(mx.zeros_like(X.T @ grad), X.T, grad) / X.shape[0]
# Calculates how much the bias needs to change. It averages the effect of the bias on the wrong guesses
grad_b = mx.mean(grad)
# Update our parameters based on the gradients
w = w - learning_rate * grad_w
b = b - learning_rate * grad_b
# Print progress every 100 epochs.
if epoch % 100 == 0:
print(f"Epoch {epoch:4d} | Loss: {loss}")
Finally, we can output the predictions:
# If the predicted probability is greater than 0.5, it is classified as 1 (true)
# Otherwise, it is classified as 0 (false)
final_preds = predict(X, w, b) > 0.5
print("Final Predictions:", final_preds)
# Calculate the accuracy of the model
accuracy = mx.mean(final_preds == y)
print(f"Accuracy: {accuracy * 100:.2f}%")
Running this script should yield similar results to the following:
python script.py
Epoch 0 | Loss: 0.6931471824645996
Epoch 100 | Loss: 0.342264860868454
Epoch 200 | Loss: 0.2668907940387726
Epoch 300 | Loss: 0.21739481389522552
Epoch 400 | Loss: 0.18253955245018005
Epoch 500 | Loss: 0.1567975878715515
Epoch 600 | Loss: 0.13707442581653595
Epoch 700 | Loss: 0.1215219795703888
Epoch 800 | Loss: 0.10897234827280045
Epoch 900 | Loss: 0.0986524447798729
Final Predictions: array([False, True, True, True], dtype=bool)
Accuracy: 100.00%
Our model has predicted the OR
function with 100% accuracy, using purely MLX and python.
Pay close attention to our use of "gradients" in this example, because I mentioned "gradient averaging" in the last post as a foundational element to distributed machine learning:
- The model makes a guess at whether our input is an
OR
or not, and moves forward in progress - The model then compares its guess to the actual output, and calculates how far off it was from the right answer
- The model is then told to move in the opposite direction of the error, so that it can learn from its mistakes
- This process is repeated until the epochs are finished
You'll notice that this training script doesn't include outputting a model file, and I didn't really feel like getting into that because I wanted to keep this as relevant to our overarching distributed topic as much as possible. Gradients are very important to understand for future posts, and this minimal example helps shed a little light into what's going on.
What's Next
Armed with MLX and some basic models, we can now move on to acquiring and preparing a dataset for fine tuning. The Mistral-7B-Instruct-v0.3-4bit
model is simply not smart enough to talk like a pirate consistently throughout its entire response, and we need to fix that with a great pirate lingo dataset. But you can't just take any piece of data and feed it into a model. It requires formatting, curation, and validation sets to ensure quality results.
Part 3 - Dataset Preparation
Further Reading
Check out this repo that goes over benchmarking MLX vs Pytorch.
Here is the full python script for the logistic regression model:
import mlx.core as mx
# The input data, a 2D matrix
X = mx.array([
[0, 0],
[0, 1],
[1, 0],
[1, 1]
])
# The output data, or "labels"
y = mx.array([0, 1, 1, 1])
# For two input features, we need a weight vector of shape (2,) which is a 1D array with 2 elements
w = mx.zeros(2)
# This is a bias term, an additional parameter that allows the model to fit the data better
# by shifting the decision boundary
b = 0.0
# This determines how much the model's parameters (weights and bias) are adjusted during
# each step of the training process. It determines the size of the steps taken towards
# minimizing the loss function
learning_rate = 0.1
# The number of complete passes the model makes through the entire dataset during training.
# During an epoch, the model processes each training example once and updates its parameters
# (weights and biases) based on the computed gradients
num_epochs = 1000
# Maps any real number to the range [0, 1]
def sigmoid(z):
return 1 / (1 + mx.exp(-z))
# Computes the model prediction.
# We input X as the data
# w as the weights which determine how important each input is
# b for bias to make better guesses
def predict(X, w, b):
b_array = mx.full((X.shape[0],), b)
return sigmoid(mx.addmm(b_array, X, w))
# Measures how good or bad the model's predictions are compared to the actual labels
def binary_cross_entropy(y_true, y_pred, eps=1e-8):
return -mx.mean(
y_true * mx.log(y_pred + eps) + (1 - y_true) * mx.log(1 - y_pred + eps)
)
for epoch in range(num_epochs):
# Forward pass which computes predictions and loss
y_pred = predict(X, w, b)
loss = binary_cross_entropy(y, y_pred)
# Backwards pass which computes gradients manually. This essentially helps us teach
# the model how wrong it was in a bad prediction, so that it can learn.
grad = y_pred - y
# We look at the effect of each input on the wrong guesses and averages these effects
grad_w = mx.addmm(mx.zeros_like(X.T @ grad), X.T, grad) / X.shape[0]
# Calculates how much the bias needs to change. It averages the effect of the bias on the wrong guesses
grad_b = mx.mean(grad)
# Update our parameters based on the gradients
w = w - learning_rate * grad_w
b = b - learning_rate * grad_b
# Print progress every 100 epochs.
if epoch % 100 == 0:
print(f"Epoch {epoch:4d} | Loss: {loss}")
# If the predicted probability is greater than 0.5, it is classified as 1 (true)
# Otherwise, it is classified as 0 (false)
final_preds = predict(X, w, b) > 0.5
print("Final Predictions:", final_preds)
# Calculate the accuracy of the model
accuracy = mx.mean(final_preds == y)
print(f"Accuracy: {accuracy * 100:.2f}%")