How to predict churn with Pyspark and AWS Elastic MapReduce?

Krzysztof Wołowiec
9 min readFeb 15, 2021

This article is a summary of the churn prediction project completed as a part of Udacity Nanodegree. The project’s repository is available on Github.

Project Definition

Project Overview

The main goal of the project is to predict if user is likely to churn given the traffic data on the fictive music streaming platform called Sparkify, provided by Udacity.

Successful prediction of churn allows undertaking multiple preventing actions that could minimize the scope of this phenomenon. Except for the prediction itself, the model could be a great source of knowledge about users' problems and their needs. And this is a great start point to create special offers to users who are likely to churn and encourage them to stay. Moreover, the predictive model could be a great source of inspiration for future product improvements.

The most challenging part of the project is related to the volume of data. Udacity shared two datasets: a sample dataset of size 128 MB (about 100k rows) and a large dataset containing more than 30 mln records and about 14 GB of data.

One of the project requirements is to use Spark in data processing and modelling.

Problem Statement

Predict if a user is likely to churn, given its traffic data. The churn is understood as a service cancellation and is marked in the traffic data as the event ‘Cancellation Confirmation’.

In general, the problem described here can be treated as a binary classification problem. In the training set, users who are likely to churn are marked as 1 and others as 0. There are plenty of methods that are designed to approach such challenges, i.e.: Logistic Regression, Random Forest Classifiers, Gradient Boosted Trees Classifiers or Artificial Neural Networks. All of them have advantages and disadvantages. Sometimes it’s a matter of testing and comparing the results of all of them. This is an approach that I decided to undertake.

One more challenge related to churn data is its frequency. The churn frequency is super low (at least in the Sparkify dataset). Hence, it’s important to choose the right model evaluation metrics that are less sensitive to imbalanced data— details in the ”Metrics” section.

In order to train the Machine Learning model, the proper training set is required. Traffic data consists of hundreds of hundreds of rows for each user and it’s impossible to treat each row as model input. That’s why data aggregation is necessary for the cleaning and feature engineering part. The desired data structure is one row-one user.

Metrics

Binary classification is a type of task that has a set of multiple evaluation metrics, e.i. accuracy, precision, recall, ROCAUC, F1-score. All of them are based on the confusion matrix concept. The confusion matrix simple shows the total number of predicted values broken down by the relationship to the original label:

Structure of confusion matrix

Accuracy is probably the worst choice when it comes to an imbalanced dataset like ours. It’s because the way of calculation is simply the sum of TP and TN divided by the total number of labels.

If we have an imbalanced dataset where there are only 10 positive labels and 90 negatives, we can easily achieve the 90% of accuracy by a dummy model that always returns 0.

Precision and recall are better choices, however, there is a natural question, how to decide which one is more important?

And then F1-score appears. It’s probably the best metric since its components are both precision and recall. F-score is a harmonic average of precision and recall. The formula is as follows:

The F1-score is a key to measure the models' performance, however, it doesn’t mean at all that we should be limited to only one metric and shouldn’t take into account other metrics in model evaluation.

Exploratory Data Analysis

The question that accompanied me throughout the analysis was what causes users to churn? I decided to detail that question and ask more:

Why do users leave the service? What may disturb them or discourage them from the platform? How much time do the users spend on a platform and does it influence churn? For how long the users use a service? And so on. Let’s start from what I had — the general structure of a dataset was as follows:

root 
|-- artist: string (nullable = true)
|-- auth: string (nullable = true)
|-- firstName: string (nullable = true)
|-- gender: string (nullable = true)
|-- itemInSession: long (nullable = true)
|-- lastName: string (nullable = true)
|-- length: double (nullable = true)
|-- level: string (nullable = true)
|-- location: string (nullable = true)
|-- method: string (nullable = true)
|-- page: string (nullable = true)
|-- registration: long (nullable = true)
|-- sessionId: long (nullable = true)
|-- song: string (nullable = true)
|-- status: long (nullable = true)
|-- ts: long (nullable = true)
|-- userAgent: string (nullable = true)
|-- userId: string (nullable = true)

The most promising field is the page event containing information about various pages (events) clicked by users. I’ve spotted some relations between churn and daily page views per user during the data exploration (each row stands for a separate event type):

It seems like the churn is mostly correlated with Roll Advert and Thumbs Down events. This actually doesn’t surprise, since users who churned could not be satisfied with the music as well as they could be influenced by experiencing annoying advertisements. The relation is clearly visible in the time series chart:

There are some differences between users who churned and active users in terms of the average number of items in session and the number of days from registration.

Moreover, the gender structure also differentiates by churn. Males dominate in the group of users who churned, while among the active users, the contribution of males is similar to females.

All of the above insights have been taken into account in the feature engineering stage and the modeling.

Methodology

Data preprocessing

Data cleaning and feature engineering have been performed using Pyspark. Following steps have been undertaken:

  • exclude records that have not user_idassigned (include only traffic generated by logged users)
  • flag churn event using the Cancellation Confirmation event
  • mark all events of users having flagged churn
  • convert all UNIX timestamps (fields: ts and registration) to datetime format
  • calculate the number of days between registration and last activity
  • for each user, limit data range to 30 days period from the last activity
  • encode gender to a binary variable
  • calculate the number of sessions and average number of items in the session for each user
  • count the number of each page type and calculate the number of pages per session

Implementation & Results

The training process involved the application of multiple Machine Learning algorithms that are Pyspark built-ins:

  • LogisticRegression,
  • RandomForestClassifier,
  • GBTClassifier,
  • MultiLayerPerceptrionClassifier

The dataset ready for the modeling part includes the following features:

# C - event count
# CPS - event count per session
inputCols = [
‘avgItemsInSession’,
‘daysFromReg’,
‘isMale’,
‘sessCount’,
‘About_CPS’,
‘About_C’,
‘Add Friend_CPS’,
‘Add Friend_C’,
‘Add to Playlist_CPS’,
‘Add to Playlist_C’,
‘Downgrade_CPS’,
‘Downgrade_C’,
‘Error_CPS’,
‘Error_C’,
‘Help_CPS’,
‘Help_C’,
‘Logout_CPS’,
‘Logout_C’,
‘NextSong_CPS’,
‘NextSong_C’,
‘Roll Advert_CPS’,
‘Roll Advert_C’,
‘Save Settings_CPS’,
‘Save Settings_C’,
‘Settings_CPS’,
‘Settings_C’,
‘Submit Downgrade_CPS’,
‘Submit Downgrade_C’,
‘Submit Upgrade_CPS’,
‘Submit Upgrade_C’,
‘Thumbs Down_CPS’,
‘Thumbs Down_C’,
‘Thumbs Up_CPS’,
‘Thumbs Up_C’,
‘Upgrade_CPS’
‘Upgrade_C’
]

There were two stages of the training process:

Stage I — training performed on the small dataset

Specifications:

  • the environment provided by Udacity
  • local instance of Pyspark ver. 2.4.3
  • train-test split rate: [0.7, 0.3]
  • labels: {1: 35, 0:122}

Results (test set):

The results of the 1st stage were disappointing. The best results have been achieved using logistic regression. Even the 2-fold GBT with Param Grid wasn’t effective. However, during the experimenting, I decided to switch from 60 days time range from the last user activity to 30 days what resulted in quite better results. Before stage II, I made further improvements are described in the “Refinements” section and included in the above results.

Stage II — training performed on the full dataset using AWS EMR cluster

Specifications:

  • AWS Elastic Map Reduce cluster environment (cluster components: 1 x mater node + 2 x core node), with support to EMR Notebook
  • emr-5.20.0 with Spark ver. 2.4.0
  • train-test split rate: [0.7, 0.3]
  • labels: {1: 5003, 0: 17263}
EMR cluster components

The software version was quite important since not all EMR versions support Notebook service. It was surprising to me that for more recent versions of EMR (for instance emr-5.32.0) EMR Notebooks weren’t available. The service is available for emr-5.20.0 with Spark 2.4.0 — the version I have used in the project.

EMR software setup

The results achieved on the full dataset are as follows (test set):

In this case, the Logistic Regression model and Artificial Neural Network resulted in fscore of around 0.53. However, in LR the recall was 0.75 and precision 0.41, while in the MLP model the values were opposite.

I managed to identify 75% of users who churned, however, there were even more false positives than correctly identified users who churned.

As You can see above, there isn’t a param grid and cross-validation used in the AWS EMR approach. Unfortunately, it turned out that using Notebook in the training process is unstable solution. Dead kernels caused by unstable connection and other random factors were challenging to me. Moreover, there were some errors related to AWS itself:

AWS Spark Monitor Widget error

It turned out that I’m not the first who have experienced the same issue.

Refinements

The results presented above include multiple improvements, for instance:

  • using weights in logistic regression to deal with imbalanced classes
  • extending and limiting the features set used in modeling
  • using param grid
  • using cross-validation
  • changing the user activity period from 60 to 30 days

Conclusions

Reflection

The churn prediction is a topic standing on the border of product analytics and business analytics and this is one of the reasons I decided to challenge this.

The most profitable for me is exploring Spark and AWS. Spark itself is a great tool, however, it’s not as simple as Pandas and Scikit-learn. Unfortunately, AWS EMR Notebook wasn’t really helpful during the project. It’s a great tool for data processing and exploration, however, when it comes to the training process I would suggest another solution than Notebooks. Moreover, built-in features for hyperoptimization are limited to grid search (at least in emr-20 with Spark 2.4.). Another disadvantage is relying on the stability of Your internet connection. Also, during longer training, sometimes dead session just occurs with really no clear reason. I spent a lot of time struggling with AWS EMR, while I could use this time for further data investigation and some improvements on a small dataset. Instead, I assumed that the amount of data blocks me (what was only a partly truth).

Improvements

I can see tons of things that could be improved and require further exploration. The results have shown that there is still room for improvement. The best results came from the logistic regression model. Moreover, even the Gradient Boosted Trees combined with Cross-Validation and Grid Search didn’t improve the results. This may suggest that there should be more work done with feature engineering. Provided data require further investigation. There is a lot of topics that should be further explored:

  • a deeper investigation of the features and its impact (absolute features vs features “per session”)
  • introduce features based on trends (decrease/increase of page views WoW or MoM)
  • investigate locations/levels/user-agents
  • investigate false-negatives
  • try other combinations of features
  • further investigation of param grid builder and cross validator — the usage of both features didn’t really improve the results during Stage I.

--

--