Python Aggregate UDFs in Pyspark

This post was originally published here

Pyspark has a great set of aggregate functions (e.g., count, countDistinct, min, max, avg, sum ), but these are not enough for all cases (particularly if you’re trying to avoid costly Shuffle operations).

Pyspark currently has pandas_udfs , which can create custom aggregators, but you can only “apply” one pandas_udf at a time. If you want to use more than one, you’ll have to preform multiple groupBys…and there goes avoiding those shuffles.

In this post I describe a little hack which enables you to create simple python UDFs which act on aggregated data (this functionality is only supposed to exist in Scala!).

from pyspark.sql import functions as F
from pyspark.sql import types as T

a = sc.parallelize([[1, 'a'],
                    [1, 'b'],
                    [1, 'b'],
                    [2, 'c']]).toDF(['id', 'value'])
id value
1 ‘a’
1 ‘b’
1 ‘b’
2 ‘c’

I use collect_list to bring all data from a given group into a single row. I print the output of this operation below.

id value_list
1 [‘a’, ‘b’, ‘b’]
2 [‘c’]

I then create a UDF which will count all the occurences of the letter ‘a’ in these lists (this can be easily done without a UDF but you get the point). This UDF wraps around collect_list, so it acts on the output of collect_list.

def find_a(x):
  """Count 'a's in list."""
  output_count = 0
  for i in x:
    if i == 'a':
      output_count += 1
  return output_count

find_a_udf = F.udf(find_a, T.IntegerType())

id a_count
1 1
2 0

There we go! A UDF that acts on aggregated data! Next, I show the power of this approach when combined with when which let’s us control which data enters F.collect_list.

First, let’s create a dataframe with an extra column.

from pyspark.sql import functions as F
from pyspark.sql import types as T

a = sc.parallelize([[1, 1, 'a'],
                    [1, 2, 'a'],
                    [1, 1, 'b'],
                    [1, 2, 'b'],
                    [2, 1, 'c']]).toDF(['id', 'value1', 'value2'])
id value1 value2
1 1 ‘a’
1 2 ‘a’
1 1 ‘b’
1 2 ‘b’
2 1 ‘c’

Notice, how I included a when in the collect_list. Note that the UDF still wraps around collect_list.

a.groupBy('id').agg(find_a_udf( F.collect_list(F.when(F.col('value1') == 1, F.col('value2')))).alias('a_count')).show()
id a_count
1 1
2 0

There we go! Hope you find this info helpful!

Related Posts

Conditional Statements in Python From the previous tutorials in this series, you now have quite a bit of Python code under your belt. Everything you have seen so far has consisted of...

Python Virtual Environments in Five Minutes In Python, virtual environments are used to isolate projects from each other (if they require different versions of the same library, for example). Th...

Structuring Python Programs You have now covered Python variables, operators, and data types in depth, and you’ve seen quite a bit of example code. Up to now, the code has ...

Python Pandas: Tricks & Features You May Not Know Pandas is a foundational library for analytics, data processing, and data science. It’s a huge project with tons of optionality and depth. This...