Skip to content

elnazparsaei/sms-spam-detection

Repository files navigation

SMS Spam Detection

Python TensorFlow

SMS Spam Detection with BERT

Overview

This project implements a spam detection model for SMS messages using BERT, a state-of-the-art Transformer model for natural language processing. Transformers excel in capturing contextual relationships in text, making them ideal for distinguishing between spam and non-spam messages. With the increasing volume of SMS spam, this project addresses the critical need for automated filtering to enhance user experience and security.

Project Objective

The goal is to develop a robust classifier to identify spam SMS messages using the SMSSpamCollection dataset. By leveraging BERT and addressing class imbalance with techniques like class weighting, this project showcases advanced NLP methods for real-world applications.

Dataset

The SMSSpamCollection dataset contains 5,574 SMS messages, with approximately 13% labeled as spam and 87% as non-spam. The dataset's imbalance necessitated the use of class weighting to improve model performance on the minority class (spam).

Prerequisites

To run this project, install the following Python packages:

  • tensorflow
  • tensorflow-hub
  • tensorflow-text
  • pandas
  • scikit-learn
  • matplotlib
  • seaborn

Project Structure

The project consists of two Jupyter Notebook files:

  • EDA.ipynb: Handles data preprocessing and exploratory data analysis (EDA), including class distribution, message length analysis, and word frequency.
  • SMSSpamDetection.ipynb: Implements data balancing with class weights, model training with BERT, and evaluation with metrics like accuracy, precision, recall, and F1-score.

Note: These are Jupyter Notebook files (.ipynb) and should be run in a Jupyter environment (e.g., JupyterLab or Jupyter Notebook), not as standard Python scripts.

How to Run

  1. Clone the repository:

    git clone https://github.com/elnazparsaei/sms-spam-detection.git
  2. Navigate to the project directory:

    cd sms-spam-detection
  3. Download the dataset (SMSSpamCollection.tsv) from the UCI repository or use the provided file.

  4. Open JupyterLab or Jupyter Notebook:

    jupyter lab
  5. Run the notebooks in the following order:

    • Open and run EDA.ipynb to perform data analysis and visualize distributions.
    • Open and run SMSSpamDetection.ipynb to train the model and evaluate performance.

Exploratory Data Analysis (EDA)

The EDA.ipynb notebook analyzes the dataset to uncover key insights:

  • Class Distribution: Approximately 87% of messages are non-spam, and 13% are spam, highlighting the need for class imbalance handling.

    Class Distribution

  • Message Length: Spam messages are generally shorter than non-spam messages, as shown below.

    Text Length Distribution

Model Implementation

The SMSSpamDetection.ipynb notebook implements the following:

  • Data Balancing: Used sklearn.utils.class_weight.compute_class_weight to assign weights {0: 0.577, 1: 3.726} for non-spam and spam classes, addressing the dataset's imbalance.
  • Model Architecture: Utilizes BERT from TensorFlow Hub for text embedding, followed by a dropout layer (0.3) and a dense layer with sigmoid activation for binary classification.
  • Training: Fine-tuned for 5 epochs with a learning rate of 2e-5 to optimize performance.

Results

The model achieved the following performance on the test set:

  • Accuracy: 92% (improved from 87% after increasing epochs)
  • Precision (Spam): 0.63
  • Recall (Spam): 0.98
  • F1-Score (Spam): 0.77

The confusion matrix visualizes the model's predictions:

Confusion Matrix

Challenges and Solutions

  • Class Imbalance: Addressed using compute_class_weight to assign higher weights to the spam class ({0: 0.577, 1: 3.726}).
  • Model Performance: Improved accuracy from 87% to 92% by increasing epochs and fine-tuning the learning rate to 2e-5.

Future Improvements

  • Optimize classification threshold to further improve recall for spam detection.
  • Enhance text preprocessing (e.g., removing punctuation or stopwords).
  • Experiment with other Transformer models like DistilBERT for faster inference.

License

This project is licensed under the MIT License.

About

A BERT-based SMS spam detection model using the SMSSpamCollection dataset

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors