Efficient UDFs on Databricks with unpickleable objects

How to avoid PicklingError on custom UDFs on Databricks/Spark, while keeping optimal performance.


I often run into a problem when writing UDFs on Databricks, where I need some to access some object that pickle can’t serialize. Often times this is just something that comes from some external library, and so fixing the code is not a practical solution.

An easy solution to this is to initialize the object inside the UDF itself. This avoids the need for serialization, but it introduces a new problem: the object is initialized for every run of the UDF, hitting performance.

The solution that addresses these 2 problems is to cache the object initialization. Then, each executor initializes the object only once.

The Problem

Here is a simple example:

import time
from lxml.etree import HTMLParser

# `spark` is the spark context, on databricks it is a global variable that's always available
df = spark.createDataFrame([{"n": n} for n in range(10000)])

class Slow:
    def __init__(self):
        self.parser = HTMLParser()

    def double(self, x: int) -> int:
        return 2 * x

slow_global = Slow()

def f_error(n):
    return slow_global.double(n)

When actually executing the UDF

df.select("n", f_error("n")).collect()

we get the error

PicklingError: Could not serialize object: TypeError: can't pickle lxml.etree.HTMLParser objects

Naive Solution

The naive solution is to initialize the object in each run of the UDF:

def f(n):
    slow = Slow()
    return slow.double(n)

This works

df.select("n", f("n")).collect()

but it’s very inefficient.

On a cluster with 2 i3.xlarge workers on AWS, executing this took me around 25 seconds.

Optimized Solution

The solution is then to cache the object initialization. For this, we need the cachetools library. On Databricks, you can install it by running the following cell

%pip install cachetools
We can’t use lru_cache from the standard library, because it requires serialization. Trying it gives us the error: PicklingError: Could not serialize object: AttributeError: 'functools._lru_cache_wrapper' object has no attribute '__bases__'

Usage is very simple:

from cachetools import cached

def get_slow():
    return Slow()

def f_cached(n):
    slow = get_slow()
    return slow.double(n)

Executing it

df.select("n", f_cached("n")).collect()

took around 0.5 seconds, in the same cluster as above.