Files
sales-data-analysis/examples/customer_segmentation.py
Jonathan Pressnell cf0b596449 Initial commit: sales analysis template
Co-authored-by: Cursor <cursoragent@cursor.com>
2026-02-06 09:16:34 -05:00

214 lines
7.3 KiB
Python

"""
Example: Customer Segmentation (RFM) Analysis
Example showing customer segmentation using RFM methodology
This example demonstrates:
- Customer-level aggregation
- RFM segmentation (Recency, Frequency, Monetary)
- Segment analysis and visualization
"""
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
# Import utilities
from data_loader import load_sales_data, validate_data_structure
from validate_revenue import validate_revenue
from analysis_utils import (
get_ltm_period_config, apply_exclusion_filters,
setup_revenue_chart, save_chart, format_currency
)
from config import (
OUTPUT_DIR, MAX_DATE, CHART_SIZES, ensure_directories,
get_data_path, COMPANY_NAME, REVENUE_COLUMN, CUSTOMER_COLUMN,
DATE_COLUMN, MIN_YEAR
)
# ============================================================================
# CONFIGURATION
# ============================================================================
ANALYSIS_NAME = "Customer Segmentation (RFM)"
DESCRIPTION = "Customer segmentation using RFM methodology"
# ============================================================================
# RFM SEGMENTATION FUNCTIONS
# ============================================================================
def calculate_rfm_scores(df, analysis_date=None):
"""
Calculate RFM scores for each customer
Args:
df: DataFrame with customer, date, and revenue columns
analysis_date: Reference date for recency calculation (defaults to max date)
Returns:
DataFrame with RFM scores and segment assignment
"""
if analysis_date is None:
analysis_date = df[DATE_COLUMN].max()
# Calculate customer-level metrics
customer_metrics = df.groupby(CUSTOMER_COLUMN).agg({
DATE_COLUMN: ['max', 'count'],
REVENUE_COLUMN: 'sum'
}).reset_index()
customer_metrics.columns = [CUSTOMER_COLUMN, 'LastPurchaseDate', 'Frequency', 'Monetary']
# Calculate Recency (days since last purchase)
customer_metrics['Recency'] = (analysis_date - customer_metrics['LastPurchaseDate']).dt.days
# Score each dimension (1-5 scale, 5 = best)
customer_metrics['R_Score'] = pd.qcut(
customer_metrics['Recency'].rank(method='first'),
q=5, labels=[5, 4, 3, 2, 1], duplicates='drop'
).astype(int)
customer_metrics['F_Score'] = pd.qcut(
customer_metrics['Frequency'].rank(method='first'),
q=5, labels=[1, 2, 3, 4, 5], duplicates='drop'
).astype(int)
customer_metrics['M_Score'] = pd.qcut(
customer_metrics['Monetary'].rank(method='first'),
q=5, labels=[1, 2, 3, 4, 5], duplicates='drop'
).astype(int)
# Calculate RFM score (sum of R, F, M)
customer_metrics['RFM_Score'] = (
customer_metrics['R_Score'] +
customer_metrics['F_Score'] +
customer_metrics['M_Score']
)
# Assign segments
def assign_segment(row):
r, f, m = row['R_Score'], row['F_Score'], row['M_Score']
if r >= 4 and f >= 4 and m >= 4:
return 'Champions'
elif r >= 3 and f >= 3 and m >= 4:
return 'Loyal Customers'
elif r >= 4 and f <= 2:
return 'At Risk'
elif r <= 2:
return 'Hibernating'
elif r >= 3 and f >= 3 and m <= 2:
return 'Potential Loyalists'
else:
return 'Need Attention'
customer_metrics['Segment'] = customer_metrics.apply(assign_segment, axis=1)
return customer_metrics
# ============================================================================
# MAIN ANALYSIS FUNCTION
# ============================================================================
def main():
"""Main analysis function"""
print(f"\n{'='*60}")
print(f"{ANALYSIS_NAME}")
print(f"{'='*60}\n")
# 1. Load data
print("Loading data...")
try:
df = load_sales_data(get_data_path())
print(f"Loaded {len(df):,} transactions")
except Exception as e:
print(f"ERROR loading data: {e}")
return
# 2. Validate data structure
is_valid, msg = validate_data_structure(df)
if not is_valid:
print(f"ERROR: {msg}")
return
if CUSTOMER_COLUMN not in df.columns:
print(f"ERROR: Customer column '{CUSTOMER_COLUMN}' not found in data")
return
print("Data validation passed")
# 3. Apply exclusion filters
df = apply_exclusion_filters(df)
# 4. Filter by date range
df = df[df['Year'] >= MIN_YEAR]
if DATE_COLUMN in df.columns:
df = df[df[DATE_COLUMN] <= MAX_DATE]
# 5. Calculate RFM scores
print("\nCalculating RFM scores...")
rfm_df = calculate_rfm_scores(df)
# 6. Segment summary
print("\nCustomer Segmentation Summary:")
print("-" * 60)
segment_summary = rfm_df.groupby('Segment').agg({
CUSTOMER_COLUMN: 'count',
'Monetary': 'sum'
}).reset_index()
segment_summary.columns = ['Segment', 'Customer Count', 'Total Revenue']
segment_summary = segment_summary.sort_values('Total Revenue', ascending=False)
for _, row in segment_summary.iterrows():
pct_customers = (row['Customer Count'] / len(rfm_df)) * 100
pct_revenue = (row['Total Revenue'] / rfm_df['Monetary'].sum()) * 100
print(f"{row['Segment']:20s}: {row['Customer Count']:5d} customers ({pct_customers:5.1f}%), "
f"{format_currency(row['Total Revenue'])} ({pct_revenue:5.1f}% of revenue)")
# 7. Create visualizations
print("\nGenerating charts...")
ensure_directories()
# Chart 1: Revenue by Segment
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=CHART_SIZES['wide'])
segment_summary_sorted = segment_summary.sort_values('Total Revenue', ascending=True)
revenue_millions = segment_summary_sorted['Total Revenue'].values / 1e6
ax1.barh(range(len(segment_summary_sorted)), revenue_millions, color='#2E86AB')
ax1.set_yticks(range(len(segment_summary_sorted)))
ax1.set_yticklabels(segment_summary_sorted['Segment'].values)
ax1.set_xlabel('Revenue (Millions USD)')
ax1.set_title('Revenue by Customer Segment', fontsize=12, fontweight='bold')
setup_revenue_chart(ax1)
ax1.set_ylabel('')
# Chart 2: Customer Count by Segment
customer_counts = segment_summary_sorted['Customer Count'].values
ax2.barh(range(len(segment_summary_sorted)), customer_counts, color='#A23B72')
ax2.set_yticks(range(len(segment_summary_sorted)))
ax2.set_yticklabels(segment_summary_sorted['Segment'].values)
ax2.set_xlabel('Number of Customers')
ax2.set_title('Customer Count by Segment', fontsize=12, fontweight='bold')
ax2.set_ylabel('')
ax2.grid(True, alpha=0.3)
plt.suptitle(f'Customer Segmentation Analysis - {COMPANY_NAME}',
fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
save_chart(fig, 'customer_segmentation.png')
plt.close()
# 8. Validate revenue
print("\nValidating revenue...")
validate_revenue(df, ANALYSIS_NAME)
print(f"\n{ANALYSIS_NAME} complete!")
print(f"Charts saved to: {OUTPUT_DIR}")
# ============================================================================
# RUN ANALYSIS
# ============================================================================
if __name__ == "__main__":
main()