Sequence Classification with Apple MLX
MLX is an array framework for machine learning on Apple silicon. The biggest advantage of the framework is the compatibility with the unified memory on Apple so that operations on MLX arrays can be performed on any of the supported device types without transferring data. It makes MLX a strong candidate when it comes to inferencing and even training a large model on Apple silicon. There are examples specifically designed for LLM, with a focus on text completion. As of today, there are few examples on other LLM tasks such as sequence classification for MLX since the framework is relatively new. I will provide an example to do classification inference with MLX, replicating what I did in my previous article Zero-Shot Text Classification with pretrained LLM.
Setup
We will need both mlx
and mlx-lm
:
pip install mlx mlx-lm
from dataclasses import dataclass, asdict
# MLX imports
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_flatten, tree_unflatten
import mlx_lm
from mlx_lm.models.qwen2 import ModelArgs, Qwen2Model
from mlx_lm import load
import pandas as pd
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
import torch
print(f"mlx version: {mx.__version__}")
print(f"mlx-lm version: {mlx_lm.__version__}")
mlx version: 0.23.2 mlx-lm version: 0.21.5
Load Qwen2 Model
The package mlx-lm
provides implementation of popular LLM models including Qwen2. We can directly load the needed model from huggingface.
model_path = "Qwen/Qwen2.5-0.5B-Instruct"
lm_model, tokenizer = load(model_path)
Following my previous article Zero-Shot Text Classification with pretrained LLM, I will do the sentiment classification with the financial data.
# https://huggingface.co/datasets/vumichien/financial-sentiment
# use the valid split
data_path = "hf://datasets/vumichien/financial-sentiment/data/valid-00000-of-00001.parquet"
df_data = pd.read_parquet(data_path)
df_data = df_data.rename(columns={"label_experts": "label"})
print(f"Sample size: {df_data.shape[0]}")
print(f"Labels: {df_data['label'].unique().tolist()}")
for label in df_data["label"].unique():
label_size = df_data[df_data["label"] == label].shape[0]
print(f"Label sample size for {label}: {label_size}")
Sample size: 453 Labels: ['negative', 'neutral', 'positive'] Label sample size for negative: 61 Label sample size for neutral: 265 Label sample size for positive: 127
# set up prompt template
user_prompt_template = """What is the sentiment of the following text related to finance?
negative, neutral or positive: {text}
Give your answer in one word."""
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": user_prompt_template}
]
prompt_template = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
print(prompt_template)
<|im_start|>system You are a helpful assistant.<|im_end|> <|im_start|>user What is the sentiment of the following text related to finance? negative, neutral or positive: {text} Give your answer in one word.<|im_end|> <|im_start|>assistant
# apply the prompt template to all samples
df_data["prompt"] = df_data["text"].apply(lambda text: prompt_template.format(text=text))
Let's see how to run Qwen2 directly on a prompt to generate one token.
sample = df_data.iloc[0]
# step 1: convert the prompt into input ids
inputs = mx.array(tokenizer.encode(sample["prompt"])).reshape(1, -1)
# step 2: run the model on the inputs to get logits
logits = lm_model(inputs)
# step 3: get the next predicted token by looking at the logits of last token
token_id = int(logits[:, -1, :].squeeze().argmax())
# step 4: convert the token id back to token
predicted_label = tokenizer.decode([token_id])
print(f"Text: {sample['text']}")
print(f"Predicted Sentiment: {predicted_label}")
print(f"Ground Truth: {sample['label']}")
Text: In Q2 of 2009 , profit before taxes amounted to EUR 13.6 mn , down from EUR 26.8 mn in Q2 of 2008 . Predicted Sentiment: negative Ground Truth: negative
Implement Classification Model with Qwen2
The equivalence of Qwen2ForSequenceClassification
is not implemented in mlx-lm
. We will need to do it ourselves. I find it actually fun as well as educational to implemenet it with mlx
. We will start with the needed arguments for the classification model. Our classification is built on top of Qwen2, with two more crucial arguments
- num_labels: the number of labels in the classification task
- pad_token_id: the pad token id needed for batch inference
@dataclass
class ClassificationModelArgs(ModelArgs):
num_labels: int = 2 # number of labels in the classification task
pad_token_id: int = 151643 # the pad token id of Qwen2 tokenizer
Following Qwen2ForSequenceClassification, we will use the last token to do classification.
class ClassificationModel(nn.Module):
"""Sequence classification model on top of Qwen2Model."""
def __init__(self, args: ClassificationModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = Qwen2Model(args)
self.score = nn.Linear(args.hidden_size, args.num_labels, bias=False)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
# pass the inputs through Qwen2Model to get embedding
out = self.model(inputs, mask, cache)
# get the last non padding token for classification
non_pad_mask = (inputs != self.args.pad_token_id)
token_indices = mx.arange(inputs.shape[-1])
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
out = out[mx.arange(inputs.shape[0]), last_non_pad_token].reshape(inputs.shape[0], -1)
# get the logits by the classification layer
out = self.score(out)
return out
@property
def layers(self):
return self.model.layers
# the following function can be used in `mlx_lm.load_model` to load
# the classification model directly from folder/checkpoint
# for example, the model saved from Qwen2ForSequenceClassification
def get_qwen2_classificaiton_class(config):
"""Get the classification model and arguments for `load_model`."""
return ClassificationModel, ClassificationModelArgs
We can now create an uninitialized classification model for our sentiment analysis.
# keep the original Qwen2Model arguments
args = ClassificationModelArgs.from_dict(asdict(lm_model.args))
# num_labels is 3: positive, neutral and negative
args.num_labels = 3
# set the pad_toekn_id
args.pad_token_id = tokenizer.pad_token_id
# create the classification model
model = ClassificationModel(args)
We will apply the same idea to initialize the score
module in classification model: fill with the corresponding weights in the lm_head
. However, if we inspect the parameters of lm_model
we loaded, there is no lm_head
in it. It is because Qwen2
has tie_word_embeddings
set to True
, meaning that the token embeddings are used as weights in the lm_head
.
model_parameters = tree_flatten(lm_model.parameters())
print(f"There are {len(model_parameters)} parameters.")
print(f"The first parameter is `{model_parameters[0][0]}`")
print(f"The last parameter is `{model_parameters[-1][0]}`")
There are 290 parameters. The first parameter is `model.embed_tokens.weight` The last parameter is `model.norm.weight`
Normally, the last paramter should be lm_head.weight
, which we would use to fill the score
in the classification model. We confirm that it is not in the MLX implementation, as QWen2 uses token embeddings as LM head. We will initialize score
with the first parameter model.embed_tokens.weight
.
labels = ["positive", "negative", "neutral"]
labels_token_ids = [
tokenizer.encode(label)[0]
for label in labels
]
# get the embeddings as lm_head for Qwen2
lm_head = model_parameters[0][1]
score = lm_head[labels_token_ids]
# append to model_parameters so that we have a complete list of parameters
# for the classification model
model_parameters.append(("score.weight", score))
To initialize paramters of a MLX module, we will use update
method.
model = model.update(tree_unflatten(model_parameters))
We run the model on the sample:
logits = model(inputs)
label_idx = int(logits.squeeze().argmax())
predicted_label2 = labels[label_idx]
print(f"Text: {sample['text']}")
print(f"Predicted Sentiment: {predicted_label2}")
print(f"Ground Truth: {sample['label']}")
Text: In Q2 of 2009 , profit before taxes amounted to EUR 13.6 mn , down from EUR 26.8 mn in Q2 of 2008 . Predicted Sentiment: negative Ground Truth: negative
Batch Inference
Now that we have our classificaiton model initialized, we run it on all the samples in batches and find out its performance. The annoying part is that the mlx_lm.models.qwen2.Qwen2Model
doesn't accept the usual 2d attention mask. We need to convert the 2d attention mask from the tokenizer to the 4d mask. Luckily, we can use transformers.modeling_attn_mask_utils.AttentionMaskConverter
for the conversion.
def score_model(
model, tokenizer, prompts, batch_size=32, max_length=2048, mask_dtype=mx.bfloat16
):
"""Score the model on the prompts in batches."""
# set up the mask converter from 2d to 4d
mask_converter = AttentionMaskConverter(
is_causal=True
)
prompts = list(prompts)
count = len(prompts)
results = []
for batch_begin in range(0, count, batch_size):
batch_end = min(count, batch_begin + batch_size)
batch = prompts[batch_begin: batch_end]
# tokenize a batch of texts
inputs_np = tokenizer._tokenizer(
batch,
padding=True,
truncation=True,
max_length=max_length,
return_tensors="np"
)
inputs = mx.array(inputs_np["input_ids"])
# convert 2d mask to 4d mask
token_length = inputs.shape[-1]
mask_2d = torch.tensor(inputs_np["attention_mask"])
mask_pt = mask_converter.to_4d(
mask_2d, token_length,
key_value_length=token_length,
dtype=torch.float32
)
mask = mx.array(mask_pt.numpy(), dtype=mask_dtype)
# run the model for logits
logits = model(inputs=inputs, mask=mask)
# run the softmax for scores
scores = mx.softmax(logits, axis=-1)
results.append(scores)
scores = mx.concat(results)
# get the predicted labels
predictions = []
for idx in scores.argmax(-1):
idx = int(idx)
predictions.append(labels[idx])
# return both the scores and predicted labels
return scores, predictions
We will see that running the model in difference data types will have difference performance.
import time
def run_model(dtype=mx.bfloat16):
model.set_dtype(dtype)
begin = time.time()
_, predictions = score_model(model, tokenizer, df_data["prompt"], mask_dtype=dtype)
elapsed = time.time() - begin
print(f"Finish inference with dtype {dtype}: {elapsed:.2f} seconds")
count = 0
for result, label in zip(predictions, df_data["label"]):
count += result == label
print(f"Correct count = {count}/{len(predictions)} = {count/len(predictions):.2%}")
run_model(dtype=mx.bfloat16)
Finish inference with dtype mlx.core.bfloat16: 12.65 seconds Correct count = 348/453 = 76.82%
run_model(dtype=mx.float32)
Finish inference with dtype mlx.core.float32: 14.01 seconds Correct count = 356/453 = 78.59%
For referece, the pytorch equivalence with bfloat16 (see previous article) took 20 seconds to finish and the accuracy was 351/453 = 77.48%. All ran on my M1 pro Macbook. The discrepancy of the accuracy on the two models with bfloat16 is due to the precision errors from bfloat16. They have similar performance, but MLX is about 40% faster than pytorch! Switching to float32 reduces precision errors thus improving the performance.
That's it for the article. I noticed a bug of mlx_lm.utils.load
when I wrote this article. Though it accepts model_config in arguments, it doesn't use it to load the model. Anyway, I think both mlx
and mlx_lm
are good if you want to run models on Apple silicon.