PySpark Estimator and Transformer
PySpark's pipeline is a powerful tool that encapsulates machine learning processes. We can build rather complicated pipelines to our needs using the existing estimators/transformers come with the PySpark's library, until we can't. In this article, I will show how we can build custom estimators and transformers to make the pipeline even more powerful.
Imagine that we want to build a model with some high cardinality categorical features. Upon inspection, we find that only some most frequent values are useful and we decide to keep those frequent values and mask other values "OTHERS". We will implement CardinalityReducer
that will keep only most frequent N values in a categorical column (or a column of string type). We will implement it in a way so that it can fit training sets together with other components in a pipeline.
import pandas as pd
import numpy as np
from pyspark import keyword_only
from pyspark.sql import SparkSession
from pyspark.sql import functions as f
from pyspark.ml import Estimator, Transformer
from pyspark.ml.pipeline import Pipeline, PipelineModel
from pyspark.ml.param.shared import (
HasInputCol,
HasOutputCol,
Param,
Params,
TypeConverters
)
from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable
spark = SparkSession.builder.master("local").getOrCreate()
print(f"PySpark Version: {spark.version}")
PySpark Version: 3.4.1
Before the implementation, we need to understand is the differences between Estimator and Transformer. Let's take our example CardinalityReducer
. It is going to be an Estimator that will know:
- input column: what column it is going to perform the cardinality reduction;
- output column: where should the results be stored;
- top n: the number of most frequent values to keep.
However, the estimator CardinalityReducer
doesn't know the most frequent values util it sees that training data and learn the statistics. So it cannot transform any datasets.
Once CardinalityReducer
sees (i.e., is fitted on) the training data, it will return a transformer CardinalityReducerModel
that knows input column, output column and the top n values to keep. This transformer has all needed information to actual transform datasets. We will start from CardinalityReducerModel
.
class CardinalityReducerModelParams(
HasInputCol, # param: inputCol
HasOutputCol, # param: outputCol
):
# param: keepValues
keepValues = Param(
Params._dummy(), # parent of the param, it is required to set to this placeholder
"keepValues", # name of the param
"A list of values to keep for the categorical column", # desc of the param
TypeConverters.toListString # try to convert param to a list of strings
)
def __init__(self):
super().__init__()
# good practice: set default values for params
self._setDefault(inputCol=None)
self._setDefault(outputCol=None)
self._setDefault(keepValues=None)
# HasInputCol has the getter but not the setter
# we add the setter
# same for HasOutputCol
def setInputCol(self, inputCol):
return self._set(inputCol=inputCol)
def setOutputCol(self, outputCol):
return self._set(outputCol=outputCol)
# getter and setter for keepValues
def getKeepValues(self):
return self.getOrDefault(self.keepValues)
def setKeepValues(self, vals):
return self._set(keepValues=vals)
I personally prefer to put the parameters in their own class for Tramsformer or Estimator. But we can put them directly in Transformer/Estimator if we want. Going a bit more on the details on CardinalityReducerModelParams
, we see 3 parameters: inputCol (provided by HasInputCol), outputCol (provided by HasOutputCol) and keepValues (custom). You might notice that the keepValues
is a class level attribute. Won't that be a problem since all instances of the class share the same class level attributes?
The answer is no and it is handled by some internal tricks. HasInputCol is a subclass of Params
and in its __init__
function:
class _Params: # modified, original: Params(Identifiable, metaclass=ABCMeta)
def __init__(self) -> None:
# other codes
...
# Copy the params from the class to the object
self._copy_params()
The class level params will be copied to instances by Params._copy_params
so that they have individual params. We need to make sure our custom params inherit from Params
or subclass(es) of it. In our case, CardinalityReducerModelParams
inherits from HasInputCol
and HasOutputCol
, which are subclasses of Params
. A deeper dig into Params._copy_params
also shows that we need to use Params._dummy()
in our custom param keepValues
. I will leave it as an exercise for you to trace the code. I will give you a hint: you need to trace back to Param._copy_new_parent
and Params._dummy()
.
class CardinalityReducerModel(
Transformer, # must inherit Transformer
CardinalityReducerModelParams, # params for the tramsformer
# the two classes below make the transformer serializable
DefaultParamsReadable,
DefaultParamsWritable
):
@keyword_only
def __init__(self, *, inputCol=None, outputCol=None, keepValues=None):
super().__init__()
kwargs = self._input_kwargs # need to pair with keyword_only decorator
self.setParams(**kwargs)
@keyword_only
def setParams(self, *, inputCol=None, outputCol=None, keepValues=None):
kwargs = self._input_kwargs
return self._set(**kwargs)
def _transform(self, dataset):
# get params for the transformer, provided by CardinalityReducerModelParams
# it is a good time to validate our params here
# for example, the dataset must contain inputCol
# I leave the validation out for simplicity
inputCol = self.getInputCol()
outputCol = self.getOutputCol()
keepVals = set(self.getKeepValues())
# transform: keep top values and mask the remainings "OTHERS"
dataset = dataset.withColumn(
outputCol,
f.when(f.col(inputCol).isin(keepVals), f.col(inputCol)).otherwise(f.lit("OTHERS"))
)
return dataset
Excellent, we got our CardinalityReducerModelParams
. At minimum, a custom transformer must inherit the base class Transformer
and implement _transform
.
Let's make some fake data and try it out!
data={
"a": ["a1"] * 4 + ["a2"] * 3 + ["a3"] * 2 + ["a4"] * 1,
"b": ["b1"] * 7 + ["b2"] * 3
}
for val in data.values():
np.random.shuffle(val)
df = spark.createDataFrame(
pd.DataFrame(
data=data
)
)
df.show()
+---+---+ | a| b| +---+---+ | a2| b1| | a3| b2| | a1| b1| | a4| b1| | a2| b2| | a3| b1| | a1| b1| | a1| b1| | a2| b2| | a1| b1| +---+---+
xformer = CardinalityReducerModel(
inputCol="a",
outputCol="a_reduced",
keepValues=["a1", "a2"]
)
xformer.transform(df).show()
+---+---+---------+ | a| b|a_reduced| +---+---+---------+ | a2| b1| a2| | a3| b2| OTHERS| | a1| b1| a1| | a4| b1| OTHERS| | a2| b2| a2| | a3| b1| OTHERS| | a1| b1| a1| | a1| b1| a1| | a2| b2| a2| | a1| b1| a1| +---+---+---------+
Now we are ready to implement CardinalityReducer
. We will again separate out the params into a class CardinalityReducerParams
. Since it is very simlar to CardinalityReducerModelParams
, I will skip the explanation.
class CardinalityReducerParams(
HasInputCol,
HasOutputCol,
):
topN = Param(
Params._dummy(),
"topN",
"Keep top N number of values in the categorical column",
TypeConverters.toInt
)
def __init__(self):
super().__init__()
# set default values
self._setDefault(inputCol=None)
self._setDefault(outputCol=None)
self._setDefault(topN=1)
def setInputCol(self, inputCol):
return self._set(inputCol=inputCol)
def setOutputCol(self, outputCol):
return self._set(outputCol=outputCol)
def getTopN(self):
return self.getOrDefault(self.topN)
def setTopN(self, val):
return self._set(topN=val)
For a custom estimator, it must inherit Estimator
and implement _fit
that returns a corresponding transformer. For our estimator CardinalityReducer
, it learns the most frequent values and returns a CardinalityReducerModel
in its CardinalityReducer._fit
.
class CardinalityReducer(
Estimator, # a must for custom estimator
CardinalityReducerParams, # params
# make it serializable by subclassing the following two classes
DefaultParamsReadable,
DefaultParamsWritable
):
@keyword_only
def __init__(self, *, inputCol=None, outputCol=None, topN=None, maskValue=None):
super().__init__()
kwargs = self._input_kwargs
self.setParams(**kwargs)
@keyword_only
def setParams(self, *, inputCol=None, outputCol=None, topN=None, maskValue=None):
kwargs = self._input_kwargs
return self._set(**kwargs)
def _fit(self, dataset):
# get params
inputCol = self.getInputCol()
outputCol = self.getOutputCol()
topN = self.getTopN()
# compute the top N values in the inputCol
keepVals = (
dataset
.groupby(inputCol)
.count()
.sort(f.desc("count"))
.select(inputCol)
.toPandas()
[inputCol]
.tolist()
[:topN]
)
# return a CardinalityReducerModel with learned top values to keep
model = CardinalityReducerModel(
inputCol=inputCol,
outputCol=outputCol,
keepValues=keepVals
)
return model
It is not hard, right? Let's use it on our fake data. We will put two CardinalityReducer
in a pipeline, one for column a
and one for b
!
# make our pipeline with our custom estimator CardinalityReducer
pipeline = Pipeline(stages=[
CardinalityReducer(inputCol="a", outputCol="a_reduced", topN=2),
CardinalityReducer(inputCol="b", outputCol="b_reduced", topN=1)
])
# fit the pipeline on the dataset to get a transformer
model = pipeline.fit(df)
# transform datasets
model.transform(df).show()
+---+---+---------+---------+ | a| b|a_reduced|b_reduced| +---+---+---------+---------+ | a2| b1| a2| b1| | a3| b2| OTHERS| OTHERS| | a1| b1| a1| b1| | a4| b1| OTHERS| b1| | a2| b2| a2| OTHERS| | a3| b1| OTHERS| b1| | a1| b1| a1| b1| | a1| b1| a1| b1| | a2| b2| a2| OTHERS| | a1| b1| a1| b1| +---+---+---------+---------+
We have successfully implemented an estimator CardinalityReducer
and its corresponding transformer CardinalityReducerModel
. Lastly, we save and load back the pipeline model to test the cabability of persisting the model.
# save the model
model.write().overwrite().save("saved-models/cardinality-reducer")
# load back the model
loaded_model = PipelineModel.load("saved-models/cardinality-reducer")
# use the loaded model
loaded_model.transform(df).show()
+---+---+---------+---------+ | a| b|a_reduced|b_reduced| +---+---+---------+---------+ | a2| b1| a2| b1| | a3| b2| OTHERS| OTHERS| | a1| b1| a1| b1| | a4| b1| OTHERS| b1| | a2| b2| a2| OTHERS| | a3| b1| OTHERS| b1| | a1| b1| a1| b1| | a1| b1| a1| b1| | a2| b2| a2| OTHERS| | a1| b1| a1| b1| +---+---+---------+---------+