Recently I came across an issue in saving/loading BERT models with TensorFlow. The BERT models are provided by the Transformers library, and I used Tensorflow backend. When saving with model.save(path) then loading with tf.keras.models.load_model(path), it gave the following TypeError or ValueError:

TypeError/ValueError: The two structures don't have the same nested structure.

The article is to document several ways to solve the issue.

import numpy as np
import pandas as pd

import tensorflow as tf
import transformers
from transformers import TFBertForSequenceClassification

# suppress warning messages as they are irrelevant to the issue in discussion
import logging
tf.get_logger().setLevel(logging.FATAL)
transformers.logging.set_verbosity_error()

print(f"TensorFlow version: {tf.__version__}")
print(f"Transformers version: {transformers.__version__}")
TensorFlow version: 2.9.2
Transformers version: 4.22.0

Reproducing the Issue

Let's reproduce the issue. Imagine that we want to build a binary classification model that will predict whether a given sentence is positive or negative. We will fine tune a BERT model for the classification.

def create_model(bert_model, input_len=50):

    # model input
    input_ids = tf.keras.layers.Input(shape=(input_len, ), dtype=tf.int32)

    # the bert model as clf
    try:
        clf = TFBertForSequenceClassification.from_pretrained(bert_model)
    except:
        clf = TFBertForSequenceClassification.from_pretrained(bert_model, from_pt=True)

    # output score for positive class
    output = clf(input_ids).logits
    score = tf.nn.softmax(output)[:, 1]

    model = tf.keras.Model(inputs=[input_ids], outputs=[score])
    return model

bert_model = "google/bert_uncased_L-2_H-128_A-2" # use the tiny BERT as demo
input_len = 50
model = create_model(bert_model, input_len=input_len)

At this point, we normally need to fine tune the model with our own data. We will skip this fine-tuning part though, since it is irrelevant to the issue we try to solve here. So let's pretend we have fine tuned our model and next give it a try on some input.

test_input = np.array([[0] * input_len])
test_output = model(test_input).numpy()
print(test_output[0])
0.47826624

We are ready to reproduce the error.

# save the model, it will run successfully as expected
model.save("test/sent-clf")

# load back the model, and here comes the TypeError
try:
    loaded_model = tf.keras.models.load_model("test/sent-clf")
except (TypeError, ValueError) as e:
    print(f"A {e.__class__.__name__} is raised. Error message:")
    print(e)
WARNING:absl:Found untraced functions such as embeddings_layer_call_fn, embeddings_layer_call_and_return_conditional_losses, encoder_layer_call_fn, encoder_layer_call_and_return_conditional_losses, pooler_layer_call_fn while saving (showing 5 of 80). These functions will not be directly callable after loading.
A ValueError is raised. Error message:
The two structures don't have the same nested structure.

First structure: type=tuple str=(({'input_ids': TensorSpec(shape=(None, 5), dtype=tf.int32, name=None)}, None, None, None, None, None, None, None, None, None, False), {})

Second structure: type=tuple str=((TensorSpec(shape=(None, 50), dtype=tf.int32, name='input_ids'), None, None, None, None, None, None, None, None, None, False), {})

More specifically: Substructure "type=dict str={'input_ids': TensorSpec(shape=(None, 5), dtype=tf.int32, name=None)}" is a sequence, while substructure "type=TensorSpec str=TensorSpec(shape=(None, 50), dtype=tf.int32, name='input_ids')" is not
Entire first structure:
(({'input_ids': .}, ., ., ., ., ., ., ., ., ., .), {})
Entire second structure:
((., ., ., ., ., ., ., ., ., ., .), {})

A Simple Fix

A very simple fix is to save the weights of the model instead.

model.save_weights("test/sent-clf-weights")

loaded_model_from_weights = create_model(bert_model, input_len=input_len)
loaded_model_from_weights.load_weights("test/sent-clf-weights")

test_output_simple = model(test_input).numpy()
print(test_output_simple[0])
0.47826624
# make sure we get the same result on the same input
assert test_output_simple[0] == test_output[0]

Another Fix

While the simple fix works, we might still want to save the model via model.save. For example, if we want to log the model using mlflow. When we run mlflow.keras.log_model(model, some_model_uri), mlflow will internally call model.save or something equivalent. It will be painful to figure out later that while we can log successfully with mlflow, but we cannot load it back, at least not easily.

So what is the problem with model.save and tf.keras.models.load_model? If we look closer to the error message above, we see something like TensorSpec(shape=(None, 5), dtype=tf.int32, name=None). Wait a minute, didn't we say we want the input length to be 50? Where was this 5 coming from?

A quick google search gives me the answer:

Keras saves the input specs on the first call of the model here. When loading a pretrained model with transformers using the from_pretrained class classmethod of TFPretrainedModel, the networks is first fed dummy inputs here. So the saved models expect their input tensors to be of sequence length 5, because that is the length of the dummy inputs. (source)

A further dig shows that the dummy input is

DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]

Its last dimension is 5, and that is why we see the mysterious 5. Given this, the natural idea to fix the problem is to make the dummy input to have the length we desire. We can modify our create_model to do so.

def create_model_fix(bert_model, input_len=50):

    # model input
    input_ids = tf.keras.layers.Input(shape=(input_len, ), dtype=tf.int32)

    # the bert model as clf
    try:
        clf = TFBertForSequenceClassification.from_pretrained(bert_model)
    except:
        clf = TFBertForSequenceClassification.from_pretrained(bert_model, from_pt=True)

    # change dummy input for bert
    features = tf.constant([[0] * input_len])
    clf._saved_model_inputs_spec = None
    clf._set_save_spec(features)

    # output score for positive class
    output = clf(input_ids).logits
    score = tf.nn.softmax(output)[:, 1]

    model = tf.keras.Model(inputs=[input_ids], outputs=[score])
    return model

bert_model = "google/bert_uncased_L-2_H-128_A-2" # use the tiny BERT as demo
input_len = 50
model_fix = create_model_fix(bert_model, input_len=input_len)
test_output_2 = model_fix(test_input).numpy()
print(test_output_2[0])
0.4662352

A note on why we get different results from model vs model_fix. TFBertForSequenceClassification consists of two main parts, first the BERT encoder and then a dense layer. While they load the same pretrained BERT model into the first part, the dense layer parts are randomly intitialized.

We can now save/load as usual.

# save the model, it will run successfully as expected
model_fix.save("test/sent-clf-fix")

# load back the model
loaded_model_fix = tf.keras.models.load_model("test/sent-clf-fix")

# test the loaded_model
test_output_fix = loaded_model_fix(test_input).numpy()
print(test_output_fix[0])
WARNING:absl:Found untraced functions such as embeddings_layer_call_fn, embeddings_layer_call_and_return_conditional_losses, encoder_layer_call_fn, encoder_layer_call_and_return_conditional_losses, pooler_layer_call_fn while saving (showing 5 of 80). These functions will not be directly callable after loading.
0.46623516
2022-12-29 16:16:22.178045: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz
2022-12-29 16:16:22.178854: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.

I like this fix better because it addresses the problem instead of avoiding it. It is also now compatible with mlflow. Before encountering this issue, I also had another issue with TFBertModel. Here is a reference to this related issue.

Restore the Model

The two fixes are good, but what if I have already saved the model without noticing the issue? Is there a way for me to restore the model? The good news is YES! Let's remind ourselves the error.

try:
    loaded_model = tf.keras.models.load_model("test/sent-clf")
except (TypeError, ValueError) as e:
    print(f"A {e.__class__.__name__} is raised. Error message:")
    print(e)
A ValueError is raised. Error message:
The two structures don't have the same nested structure.

First structure: type=tuple str=(({'input_ids': TensorSpec(shape=(None, 5), dtype=tf.int32, name=None)}, None, None, None, None, None, None, None, None, None, False), {})

Second structure: type=tuple str=((TensorSpec(shape=(None, 50), dtype=tf.int32, name='input_ids'), None, None, None, None, None, None, None, None, None, False), {})

More specifically: Substructure "type=dict str={'input_ids': TensorSpec(shape=(None, 5), dtype=tf.int32, name=None)}" is a sequence, while substructure "type=TensorSpec str=TensorSpec(shape=(None, 50), dtype=tf.int32, name='input_ids')" is not
Entire first structure:
(({'input_ids': .}, ., ., ., ., ., ., ., ., ., .), {})
Entire second structure:
((., ., ., ., ., ., ., ., ., ., .), {})

Let's take a look into what are saved in the folder test/sent-clf.

!tree -n test/sent-clf
test/sent-clf
├── assets
├── keras_metadata.pb
├── saved_model.pb
└── variables
    ├── variables.data-00000-of-00001
    └── variables.index

2 directories, 4 files

Upon inspection, the weights of the model are saved in the variables subfolder. We can restore the model using load_weights.

restored_model = create_model(bert_model, input_len=input_len)
restored_model.load_weights("test/sent-clf/variables/variables")

test_output_restored = restored_model(test_input).numpy()
assert test_output[0] == test_output_restored[0]

print(test_output_restored[0])
0.47826624