MLX is an array framework for machine learning on Apple silicon. The biggest advantage of the framework is the compatibility with the unified memory on Apple so that operations on MLX arrays can be performed on any of the supported device types without transferring data. It makes MLX a strong candidate when it comes to inferencing and even training a large model on Apple silicon. There are examples specifically designed for LLM, with a focus on text completion. As of today, there are few examples on other LLM tasks such as sequence classification for MLX since the framework is relatively new. I will provide an example to do classification inference with MLX, replicating what I did in my previous article Zero-Shot Text Classification with pretrained LLM.

Setup

We will need both mlx and mlx-lm:

pip install mlx mlx-lm
from dataclasses import dataclass, asdict

# MLX imports
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_flatten, tree_unflatten
import mlx_lm
from mlx_lm.models.qwen2 import ModelArgs, Qwen2Model
from mlx_lm import load

import pandas as pd
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
import torch

print(f"mlx version: {mx.__version__}")
print(f"mlx-lm version: {mlx_lm.__version__}")
mlx version: 0.23.2
mlx-lm version: 0.21.5

Load Qwen2 Model

The package mlx-lm provides implementation of popular LLM models including Qwen2. We can directly load the needed model from huggingface.

model_path = "Qwen/Qwen2.5-0.5B-Instruct"
lm_model, tokenizer = load(model_path)

Following my previous article Zero-Shot Text Classification with pretrained LLM, I will do the sentiment classification with the financial data.

# https://huggingface.co/datasets/vumichien/financial-sentiment
# use the valid split
data_path = "hf://datasets/vumichien/financial-sentiment/data/valid-00000-of-00001.parquet"
df_data = pd.read_parquet(data_path)
df_data = df_data.rename(columns={"label_experts": "label"})
print(f"Sample size: {df_data.shape[0]}")
print(f"Labels: {df_data['label'].unique().tolist()}")
for label in df_data["label"].unique():
    label_size = df_data[df_data["label"] == label].shape[0]
    print(f"Label sample size for {label}: {label_size}")
Sample size: 453
Labels: ['negative', 'neutral', 'positive']
Label sample size for negative: 61
Label sample size for neutral: 265
Label sample size for positive: 127
# set up prompt template
user_prompt_template = """What is the sentiment of the following text related to finance?
negative, neutral or positive: {text}
Give your answer in one word."""
messages = [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": user_prompt_template}
]
prompt_template = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)
print(prompt_template)
<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
What is the sentiment of the following text related to finance?
negative, neutral or positive: {text}
Give your answer in one word.<|im_end|>
<|im_start|>assistant

# apply the prompt template to all samples
df_data["prompt"] = df_data["text"].apply(lambda text: prompt_template.format(text=text))

Let's see how to run Qwen2 directly on a prompt to generate one token.

sample = df_data.iloc[0]
# step 1: convert the prompt into input ids
inputs = mx.array(tokenizer.encode(sample["prompt"])).reshape(1, -1)
# step 2: run the model on the inputs to get logits
logits = lm_model(inputs)
# step 3: get the next predicted token by looking at the logits of last token
token_id = int(logits[:, -1, :].squeeze().argmax())
# step 4: convert the token id back to token
predicted_label = tokenizer.decode([token_id])

print(f"Text: {sample['text']}")
print(f"Predicted Sentiment: {predicted_label}")
print(f"Ground Truth: {sample['label']}")
Text: In Q2 of 2009 , profit before taxes amounted to EUR 13.6 mn , down from EUR 26.8 mn in Q2 of 2008 .
Predicted Sentiment: negative
Ground Truth: negative

Implement Classification Model with Qwen2

The equivalence of Qwen2ForSequenceClassification is not implemented in mlx-lm. We will need to do it ourselves. I find it actually fun as well as educational to implemenet it with mlx. We will start with the needed arguments for the classification model. Our classification is built on top of Qwen2, with two more crucial arguments

  • num_labels: the number of labels in the classification task
  • pad_token_id: the pad token id needed for batch inference
@dataclass
class ClassificationModelArgs(ModelArgs):
    num_labels: int = 2 # number of labels in the classification task
    pad_token_id: int = 151643 # the pad token id of Qwen2 tokenizer

Following Qwen2ForSequenceClassification, we will use the last token to do classification.

class ClassificationModel(nn.Module):
    """Sequence classification model on top of Qwen2Model."""

    def __init__(self, args: ClassificationModelArgs):
        super().__init__()
        self.args = args
        self.model_type = args.model_type
        self.model = Qwen2Model(args)
        self.score = nn.Linear(args.hidden_size, args.num_labels, bias=False)

    def __call__(
        self,
        inputs: mx.array,
        mask: mx.array = None,
        cache=None,
    ):
        # pass the inputs through Qwen2Model to get embedding
        out = self.model(inputs, mask, cache)

        # get the last non padding token for classification
        non_pad_mask = (inputs != self.args.pad_token_id)
        token_indices = mx.arange(inputs.shape[-1])
        last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
        out = out[mx.arange(inputs.shape[0]), last_non_pad_token].reshape(inputs.shape[0], -1)

        # get the logits by the classification layer
        out = self.score(out)

        return out

    @property
    def layers(self):
        return self.model.layers

# the following function can be used in `mlx_lm.load_model` to load
# the classification model directly from folder/checkpoint
# for example, the model saved from Qwen2ForSequenceClassification
def get_qwen2_classificaiton_class(config):
    """Get the classification model and arguments for `load_model`."""
    return ClassificationModel, ClassificationModelArgs

We can now create an uninitialized classification model for our sentiment analysis.

# keep the original Qwen2Model arguments
args = ClassificationModelArgs.from_dict(asdict(lm_model.args))

# num_labels is 3: positive, neutral and negative
args.num_labels = 3

# set the pad_toekn_id
args.pad_token_id = tokenizer.pad_token_id

# create the classification model
model = ClassificationModel(args)

We will apply the same idea to initialize the score module in classification model: fill with the corresponding weights in the lm_head. However, if we inspect the parameters of lm_model we loaded, there is no lm_head in it. It is because Qwen2 has tie_word_embeddings set to True, meaning that the token embeddings are used as weights in the lm_head.

model_parameters = tree_flatten(lm_model.parameters())
print(f"There are {len(model_parameters)} parameters.")
print(f"The first parameter is `{model_parameters[0][0]}`")
print(f"The last parameter is `{model_parameters[-1][0]}`")
There are 290 parameters.
The first parameter is `model.embed_tokens.weight`
The last parameter is `model.norm.weight`

Normally, the last paramter should be lm_head.weight, which we would use to fill the score in the classification model. We confirm that it is not in the MLX implementation, as QWen2 uses token embeddings as LM head. We will initialize score with the first parameter model.embed_tokens.weight.

labels = ["positive", "negative", "neutral"]
labels_token_ids = [
    tokenizer.encode(label)[0]
    for label in labels
]

# get the embeddings as lm_head for Qwen2
lm_head = model_parameters[0][1]
score = lm_head[labels_token_ids]

# append to model_parameters so that we have a complete list of parameters
# for the classification model
model_parameters.append(("score.weight", score))

To initialize paramters of a MLX module, we will use update method.

model = model.update(tree_unflatten(model_parameters))

We run the model on the sample:

logits = model(inputs)
label_idx = int(logits.squeeze().argmax())
predicted_label2 = labels[label_idx]
print(f"Text: {sample['text']}")
print(f"Predicted Sentiment: {predicted_label2}")
print(f"Ground Truth: {sample['label']}")
Text: In Q2 of 2009 , profit before taxes amounted to EUR 13.6 mn , down from EUR 26.8 mn in Q2 of 2008 .
Predicted Sentiment: negative
Ground Truth: negative

Batch Inference

Now that we have our classificaiton model initialized, we run it on all the samples in batches and find out its performance. The annoying part is that the mlx_lm.models.qwen2.Qwen2Model doesn't accept the usual 2d attention mask. We need to convert the 2d attention mask from the tokenizer to the 4d mask. Luckily, we can use transformers.modeling_attn_mask_utils.AttentionMaskConverter for the conversion.

def score_model(
    model, tokenizer, prompts, batch_size=32, max_length=2048, mask_dtype=mx.bfloat16
):
    """Score the model on the prompts in batches."""
    # set up the mask converter from 2d to 4d
    mask_converter = AttentionMaskConverter(
        is_causal=True
    )
    prompts = list(prompts)
    count = len(prompts)
    results = []
    for batch_begin in range(0, count, batch_size):
        batch_end = min(count, batch_begin + batch_size)
        batch = prompts[batch_begin: batch_end]
        # tokenize a batch of texts
        inputs_np = tokenizer._tokenizer(
            batch,
            padding=True,
            truncation=True,
            max_length=max_length,
            return_tensors="np"
        )
        inputs = mx.array(inputs_np["input_ids"])
        # convert 2d mask to 4d mask
        token_length = inputs.shape[-1]
        mask_2d = torch.tensor(inputs_np["attention_mask"])
        mask_pt = mask_converter.to_4d(
            mask_2d, token_length,
            key_value_length=token_length,
            dtype=torch.float32
        )
        mask = mx.array(mask_pt.numpy(), dtype=mask_dtype)
        # run the model for logits
        logits = model(inputs=inputs, mask=mask)
        # run the softmax for scores
        scores = mx.softmax(logits, axis=-1)
        results.append(scores)
    scores = mx.concat(results)
    # get the predicted labels
    predictions = []
    for idx in scores.argmax(-1):
        idx = int(idx)
        predictions.append(labels[idx])

    # return both the scores and predicted labels
    return scores, predictions

We will see that running the model in difference data types will have difference performance.

import time

def run_model(dtype=mx.bfloat16):
    model.set_dtype(dtype)
    begin = time.time()
    _, predictions = score_model(model, tokenizer, df_data["prompt"], mask_dtype=dtype)
    elapsed = time.time() - begin
    print(f"Finish inference with dtype {dtype}: {elapsed:.2f} seconds")
    count = 0
    for result, label in zip(predictions, df_data["label"]):
        count += result == label
    print(f"Correct count = {count}/{len(predictions)} = {count/len(predictions):.2%}")
run_model(dtype=mx.bfloat16)
Finish inference with dtype mlx.core.bfloat16: 12.65 seconds
Correct count = 348/453 = 76.82%
run_model(dtype=mx.float32)
Finish inference with dtype mlx.core.float32: 14.01 seconds
Correct count = 356/453 = 78.59%

For referece, the pytorch equivalence with bfloat16 (see previous article) took 20 seconds to finish and the accuracy was 351/453 = 77.48%. All ran on my M1 pro Macbook. The discrepancy of the accuracy on the two models with bfloat16 is due to the precision errors from bfloat16. They have similar performance, but MLX is about 40% faster than pytorch! Switching to float32 reduces precision errors thus improving the performance.

That's it for the article. I noticed a bug of mlx_lm.utils.load when I wrote this article. Though it accepts model_config in arguments, it doesn't use it to load the model. Anyway, I think both mlx and mlx_lm are good if you want to run models on Apple silicon.