Creating a Survival Function in PySpark

Traditionally, survival functions have been used in medical research to visualize the proportion of people who remain alive following a treatment. I often use them to understand the length of time between users creating and cancelling their subscription accounts.

Here, I describe how to create a survival function using PySpark. This is not a post about creating a Kaplan-Meier estimator or fitting mathematical functions to survival functions. Instead, I demonstrate how to acquire the data necessary for plotting a survival function.

I begin by creating a SparkContext.

from pyspark.sql import SparkSession
from pyspark import SparkContext
sc = SparkContext("local", "Example")
spark = SparkSession(sc)

Next, I load fake data into a Spark Dataframe. This is the data we will use in this example. Each row is a different user and the Dataframe has columns describing start and end dates for each user. start_date represents when a user created their account and end_date represents when a user canceled their account.

from pyspark.sql import functions as F
from pyspark.sql import types as T

user_table = (sc.parallelize([[1, '2018-11-01', '2018-11-03'],
                              [2, '2018-01-01', '2018-08-17'],
                              [3, '2017-12-31', '2018-01-06'],
                              [4, '2018-11-15', '2018-11-16'],
                              [5, '2018-04-02', '2018-04-12']])
              .toDF(['id', 'start_date', 'end_date'])
             )
user_table.show()

id start_date end_date
1 2018-11-01 2018-11-03
2 2018-01-01 2018-08-17
3 2017-12-31 2018-01-06
4 2018-11-15 2018-11-16
5 2018-04-02 2018-04-12

I use start_date and end_date to determine how many days each user was active following their start_date .

days_till_cancel = (user_table
                    .withColumn('days_till_cancel', F.datediff(F.col('end_date'), F.col('start_date')))
                   )

days_till_cancel.show()

id start_date end_date days_till_cancel
1 2018-11-01 2018-11-03 2
2 2018-01-01 2018-08-17 228
3 2017-12-31 2018-01-06 6
4 2018-11-15 2018-11-16 1
5 2018-04-02 2018-04-12 10

I use a Python UDF to create a vector of the numbers 0 through 13 representing our period of interest . The start date of our period of interest is a user’s start_date . The end date of our period of interest is 13 days following a user’s start_date . I chose 13 days as the period of interest for no particular reason.

I use explode to expand the numbers in each vector (i.e., 0->13) into different rows. Each user now has a row for each day in the period of interest .

I describe one user’s data below.

create_day_list = F.udf(lambda: [i for i in range(0, 14)], T.ArrayType(T.IntegerType()))

relevant_days = (days_till_cancel
                 .withColumn('day_list', create_day_list())
                 .withColumn('day', F.explode(F.col('day_list')))
                 .drop('day_list')
                )

relevant_days.filter(F.col('id') == 1).show()

id start_date end_date days_till_cancel day
1 2018-11-01 2018-11-03 2 0
1 2018-11-01 2018-11-03 2 1
1 2018-11-01 2018-11-03 2 2
1 2018-11-01 2018-11-03 2 3
1 2018-11-01 2018-11-03 2 4
1 2018-11-01 2018-11-03 2 5
1 2018-11-01 2018-11-03 2 6
1 2018-11-01 2018-11-03 2 7
1 2018-11-01 2018-11-03 2 8
1 2018-11-01 2018-11-03 2 9
1 2018-11-01 2018-11-03 2 10
1 2018-11-01 2018-11-03 2 11
1 2018-11-01 2018-11-03 2 12
1 2018-11-01 2018-11-03 2 13

We want the proportion of users who are active X days after start_date . I create a column active which represents whether users are active or not. I initially assign each user a 1 in each row (1 represents active). I then overwrite 1s with 0s after a user is no longer active. I determine that a user is no longer active by comparing the values in day and days_till_cancel . When day is greater than days_till_cancel , the user is no longer active.

I describe one user’s data below.

days_active = (relevant_days
               .withColumn('active', F.lit(1))
               .withColumn('active', F.when(F.col('day') >= F.col('days_till_cancel'), 0).otherwise(F.col('active')))
              )

days_active.filter(F.col('id') == 1).show()

id start_date end_date days_till_cancel day active
1 2018-11-01 2018-11-03 2 0 1
1 2018-11-01 2018-11-03 2 1 1
1 2018-11-01 2018-11-03 2 2 0
1 2018-11-01 2018-11-03 2 3 0
1 2018-11-01 2018-11-03 2 4 0
1 2018-11-01 2018-11-03 2 5 0
1 2018-11-01 2018-11-03 2 6 0
1 2018-11-01 2018-11-03 2 7 0
1 2018-11-01 2018-11-03 2 8 0
1 2018-11-01 2018-11-03 2 9 0
1 2018-11-01 2018-11-03 2 10 0
1 2018-11-01 2018-11-03 2 11 0
1 2018-11-01 2018-11-03 2 12 0
1 2018-11-01 2018-11-03 2 13 0

Finally, to acquire the survival function data, I group by day (days following start_date ) and average the value in active . This provides us with the proportion of users who are active X days after start_date .

survival_curve = (days_active
                  .groupby('day')
                  .agg(
                      F.count('*').alias('user_count'),
                      F.avg('active').alias('percent_active'),
                  )
                  .orderBy('day')
                 )

survival_curve.show()

day user_count percent_active
0 5 1.0
1 5 0.8
2 5 0.6
3 5 0.6
4 5 0.6
5 5 0.6
6 5 0.4
7 5 0.4
8 5 0.4
9 5 0.4
10 5 0.2
11 5 0.2
12 5 0.2
13 5 0.2