Analyzing Customer Churn using PySpark

Sunny Anand
11 min readNov 16, 2020

Project Overview

Online Services need to understand how customers are using and feeling about the product. As a data scientist one has to understand if the customers will continue to use the product with the same enthusiasm, will downgrade a paid service to free, or will stop using the service. This is called customer churn.

Predicting churn rates is a challenging and common problem that data scientists and analysts regularly encounter in any customer-facing business.

In this project, we will be looking at Sparkify, a fictional music service provider similar to Spotify which needs a model to help predict customer churn. The data set consists of customer events with timestamps and is provided by udacity.

Problem Definition

The dataset from Udacity for this research which included user logs for the past few months containing all the user activity. A user can have more than one sample as they can take multiple actions. The log contains some basic information about the user as well as information about a single action.

A user churn can be identified whenever there is an action for account cancellation page events for this product. This project is a part of Udacity’s Data Scientist Nanodegree program analyzing the behavior of users for an app called Sparkify to predict user churn.

As part of this project, we will be creating a model using machine learning and the provided dataset and use that model to predict if a given user may leave or not. The model aims to predict which users are likely to cancel their account and based on that we will in the future be able to take actions like provide discounts or offers to retain such users.

In order to create a model, we will be experimenting with a few machine learning models of the type which are used for classification problem. As we have a good amount of data and using it we can learn customer behavior, it will be a good idea to choose classification techniques from supervised learning approaches like Random Forest Classifier, Logistic Regression Classifier, SVM Classifier, and GBT Classifier as we have a binary classification problem. and building the model using aggregated data in order to predict churn users and with this information take decisions in order to retain these users. Then using a suitable metric we will be proposing a model from the above list after comparing the models based on the suitable metric.

Metrics

In this project, we are trying to solve a classification problem of if a customer will churn or not and Accuracy will be used as a metric for evaluating the model performance while trying to solve this problem, but since the provided dataset is imbalanced as there are more active users than churned users we are also using F1 score as it helps to provide a better measure for these imbalanced classes. So the choice of metric for this project will be F1 score along with Accuracy as we are using stratified sampling to balance the active users and churned users in the dataset.

The project is structured as follows:

  1. Data Preprocessing
  2. Exploratory Data Analysis
  3. Feature Engineering
  4. Modeling
  5. Model Comparison and Evaluation
  6. Conclusion & Remarks
  7. Model Justification
  8. Reflection & Remarks
  9. Solution Improvement
  10. Tools & Software
  11. References

Data Set Details

The dataset contains 18 columns and 286500 rows, the log file stored in JSON format. Each row represents user interactions with the service and information of the user’s session. Below is the schema for the data set.

root
|-- artist: string (nullable = true)
|-- auth: string (nullable = true)
|-- firstName: string (nullable = true)
|-- gender: string (nullable = true)
|-- itemInSession: long (nullable = true)
|-- lastName: string (nullable = true)
|-- length: double (nullable = true)
|-- level: string (nullable = true)
|-- location: string (nullable = true)
|-- method: string (nullable = true)
|-- page: string (nullable = true)
|-- registration: long (nullable = true)
|-- sessionId: long (nullable = true)
|-- song: string (nullable = true)
|-- status: long (nullable = true)
|-- ts: long (nullable = true)
|-- userAgent: string (nullable = true)
|-- userId: string (nullable = true)

1. Data Pre-Processing

To understand the customer churn, I did some pre-processing with respect to the user logs to ensure that the data being fed to the model is not suffering from low data quality.

  • Checking for NA values in the dataset
  • Checking userId and sessionId columns for null values and if such values are found drop such rows from the dataset.
  • In the case of an empty user id, I dropped such rows.
  • Finding the number of unique pages.
  • No of rows after this processing was 278154.

Once I’ve done some preliminary analysis, create a column Churn to use as the label for our model, I used Cancellation Confirmation events to define the churn, which happens for both paid and free users.

Columns user_churned and user_downgraded representing users who churned and downgraded were added. I also added a new column for ts.

2. Exploratory Data Analysis

In order to perform some EDA analysis, I performed some data aggregation, conversion, and cleaning. Below are the steps I took to ensure the data is in the format in which I can start looking at users' behavior on the definition of the churn I defined earlier.

  • Converting the time stamp columns into data types and adding them as new columns.
  • Created new columns for downgrade events, users who downgraded, and users who churned in order to understand the impact of downgrade events.
  • Created a churn label column for cancellation confirmation events to track that event separately for identifying as the target.
  • Then joined the new columns and dropped any duplicates from this new transformed dataset to ensure that we have the dataset now ready for EDA with our required question information on churn users.

As part of the EDA, I analyzed the data distribution for churned users based on the gender of users and found that male users are more likely to churn than female users.

As part of the EDA, I analyzed the data distribution to check if the free or paid users churned more or less. Paid users are more likely to churn than free users.

As part of the EDA, I analyzed the data distribution to check if gender distribution for the number of songs played per session for churned and normal users. Churned users tend to play less number of songs on average than non-churned users.

As part of the EDA, I analyzed the data distribution to check if gender distribution for total user engagements per session between normal and churned users. Churned users tend to be having less engagement with the website than non-churned users.

As part of the EDA, I analyzed the data distribution to get the number of users active by the hour. Active users tend to be using the application more on a percentage than the canceled users in the morning, however, one can clearly see the trend of the canceled users to be higher % user later in the evening and the night.

3. Feature Engineering

In order to better find the features which can be used for modeling, I tried to find the various derived features by querying the given dataset. I created a list of around 10 features with the churn feature aliased as the label feature for the dataset. The derived features were as follows:

‘number_songs_per_user’, ‘number_thumbs_up_per_user’, ‘number_thumbs_up_per_user’, ‘total_time’, ‘total_songtime_per_user’, ‘gender’, ‘average_played_songs’,’add_to_playlist’,’tot_artist_played’,’tot_friends’

Code example of creating a derived feature like the number_thumbs_up_user

#number of thumbs up a user gives
feature2 = my_df.select([‘userId’,’page’]).where(my_df[‘page’] == ‘Thumbs Up’).groupBy([‘userId’]) \
.count().withColumnRenamed(‘count’,’number_thumbs_up_per_user’)`

As part of the feature engineering, I set the churned column as the target variable which we will aim to predict.

target = my_df.select(‘userId’, col(‘user_churned’).alias(‘target’)).dropDuplicates()
target.show()

Finally, I combined all the features into a clean dataset and set the target variable of churn for prediction.

final_df = feature1.join(feature2, 'userId', 'inner').join(feature3, 'userId', 'inner').join(feature4, 'userId', 'inner')\
.join(feature5, 'userId', 'inner').join(feature6, 'userId', 'inner').join(feature7, 'userId', 'inner')\
.join(feature8, 'userId', 'inner').join(feature9, 'userId', 'inner').join(feature10, 'userId', 'inner')\
.join(target, 'userId', 'inner').drop('userId')

The final dataset schema looked like as below:

root
|-- number_songs_per_user: long (nullable = false)
|-- number_thumbs_up_per_user: long (nullable = false)
|-- number_thumbs_down_per_user: long (nullable = false)
|-- total_time: double (nullable = true)
|-- total_songtime_per_user: double (nullable = true)
|-- gender: integer (nullable = true)
|-- average_played_songs: double (nullable = true)
|-- add_to_playlist: long (nullable = false)
|-- tot_artist_played: long (nullable = false)
|-- tot_friends: long (nullable = false)
|-- target: integer (nullable = true)

4. Modeling

Once the dataset with all the feature engineering was derived, I performed some standard steps like converting the selected features into a vector and performing standard scaling on the features. Finally creating a data set consisting of label and the input features which we then pass to various estimators for model building and prediction.

I split the full dataset into train and validation sets. Tested out 4 of the machine learning methods below and evaluated the accuracy of the various models using the measure of metric of F1 score and accuracy. The 4 models and their F1 score and accuracy is as below:

Logistic Regression

The hyperparameter setting for the Logistic Regression model was as below

  • maxIter=10
  • regParam=0.0

Below we see that the accuracy of the LR model is only 2/3 right on the validation set.

Logistic Regression Classifier metrics are as below :
The F-1 Score is 0.7391304347826086
The accuracy is 0.6618993135011442

Random Forest Classifier

The hyperparameter setting for the Random Forest Classifier was set to below:

  • numTress =20
  • maxDepth =10
  • seed=42
  • featureSubsetStrategy =auto
  • subsamplingRate =1.0

Below we see that the accuracy of the LR model is 3/4 th right on the validation set.

Random Forest Classifier metrics are as below :
The F-1 Score is 0.782608695652174
The accuracy is 0.7622811970638057

Support Vector Machine

The hyperparameter setting for the Support Vector Machine model was set to

  • maxIter=15.

The SVM model seems to be performing even lower than the LR model with a low accuracy of 57% on the validation set. It took a lot of time for SVM training.

SVM Classifier metrics are as below :
F-1 Score is 0.6956521739130435
Accuracy is 0.5707915273132664

Gradient boosted Tree

The hyperparameter setting for the Gradient boosted Tree model was set to

  • maxIter=15
  • maxDepth = 10
  • seed = 42

The GBT model seems to be performing better than LR and SVM model but it is not at the same accuracy level as the Random Forest. It takes the longest to train.

GBT Classifier metrics are as below :
F-1 Score is 0.7391304347826086
Accuracy is 0.7391304347826086

5. Model Comparison and Evaluation

Based on the F1 score of the four models on the validation set it seems that the Random Forest classifier performs the best and so I would now do some model hyperparameter tuning for this model to try and improve the accuracy and the F1 score. Both LR and SVM seem to be more or less performing the same implying they are not able to cover the data variance the most. However, Random Forest and GBT start to show higher F1 scores and accuracy implying higher data variance capture and better at understanding the data.

In order to achieve a better Random Forest Classifier model, I will try a couple of hyper parameters from the Random Forest Classifier and try to set them to values for which the model performs well. The hyperparameters are:

  • numTress=30
  • maxDepth =5,
  • seed=42
  • featureSubsetStrategy=auto
  • subsamplingRate=1.0
  • impurity=gini

The overall model improved due to fine-tuning the number of tree parameters to include more trees in the forest and to reduce the depth of each tree to half of what was originally used. The more trees one includes the better model becomes at making correct predictions for this dataset. There was an improvement of 4 % in the F1 score.

Random Forest Classifier metrics after hyperparamter tuning are as  below :
The F-1 Score is 0.8260869565217391
The accuracy is 0.800966183574879

A closer look at the hyperparameters helps us understand why the model improved after fine-tuning the parameters for numTrees and maxDepth. looking at the dataset we see that for the model we selected a total of 10 features and had rowed in the range of 23k, thus reducing the maxDepth from 10 to 5 help and using the featureSubsetStrategy to auto help us build trees in the random forest which have more information gain for each tree. This model thus performs better on the validation data than the original model.

6. Model Justification

Looking at the final model and comparing it with the question at hand, the final model F1 score and accuracy of over 80% ensure that out of every 5 customers this model is effective in accurately predicting whether 4 customers are going to leave using the app or not. Thus the initial problem which I set out to solve for predicting customer churn for this app is achieved by this solution with good first accuracy and f-1 score. I think in a real-world setup this is a good first baseline where one is able to have an accuracy of 80%, however, one would want to see this number be in the range of late 90%s as it will mean we have a far better understanding of the app and the customer and are able to understand the user-app interaction by capturing all those information via logs and hence build a model with even higher true prediction rate. I think this model may also prove to be one of the best candidates based on the amount of data it is trained on which is small and thus may be very suitable for faster production deployments and usage of this model for the app company.

7. Conclusion & Remarks

I applied data processing to the small dataset and then extracted features that I thought would help predict the churn of the users. For this I used 4 Machine learning algorithms of LR, Random Forest, SVM and Gradient Boosted trees and saw that Random forest did a better job at predicting the model. An attempt to further perform hyperparameter tuning to enhance the performance of the classifier did not provide any further gains in the classifier performance in making the prediction.

8. Solution Improvement

  • It may be possible to try and add more features or choose any other classifier or try different hyperparameters setting to even try and improve the accuracy of the prediction via those trials as seen in the above Random Forest Classifier hyperparameters tuning.
  • A couple of things that I thought of that may prove beneficial is a continuous learning framework for the model training such that with each new batch of dataset about user logs one can re-train the models based on the change of data and change of accuracy. This would require building an MLOPs pipeline and such a framework for this model.
  • I would also like to try other classification algorithms and setup experiments similar to the Random forest classifier to help determine the best hyperparameter values for them.
  • Also having more data will never hurt and also checking on data quality.

9. Reflection & Remarks

This project gave me exposure to Apache Spark which is typically used for big data applications and large-scale machine learning pipeline.

I worked with PySpark and Pandas and used a few other packages to build the end-to-end pipeline for building a model for predicting customer churn. This project exposed the idea of customer churn and how important it is to be able to make a proper prediction for the churn to help retain customers. Using Spark MlLib gave me exposure to the various machine learning classifiers as part of the Spark framework.

Having more domain experience and a better understanding of the data perhaps I could build a more performant model in the future which is even better at making the churn prediction.

10. Tools & Software

  • PySpark
  • Spark MlLib
  • Pandas
  • Matplotlib
  • Seaborn

11. References

Github

--

--