Files
sales-data-analysis/examples/cohort_analysis.py
Jonathan Pressnell cf0b596449 Initial commit: sales analysis template
Co-authored-by: Cursor <cursoragent@cursor.com>
2026-02-06 09:16:34 -05:00

219 lines
7.5 KiB
Python

"""
Example: Cohort Analysis
Advanced example showing customer cohort retention analysis
This demonstrates:
- Cohort-based analysis
- Retention rate calculations
- Revenue retention metrics
- Advanced visualization
"""
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from operator import attrgetter
# Import utilities
from data_loader import load_sales_data, validate_data_structure
from validate_revenue import validate_revenue
from analysis_utils import (
get_ltm_period_config, apply_exclusion_filters,
setup_revenue_chart, save_chart, format_currency
)
from config import (
OUTPUT_DIR, MAX_DATE, CHART_SIZES, ensure_directories,
get_data_path, COMPANY_NAME, REVENUE_COLUMN, CUSTOMER_COLUMN,
DATE_COLUMN, MIN_YEAR
)
# ============================================================================
# CONFIGURATION
# ============================================================================
ANALYSIS_NAME = "Cohort Analysis"
DESCRIPTION = "Customer cohort retention and revenue analysis"
# ============================================================================
# COHORT ANALYSIS FUNCTIONS
# ============================================================================
def create_cohorts(df):
"""
Create customer cohorts based on first purchase date
Args:
df: DataFrame with customer and date columns
Returns:
DataFrame: Original DataFrame with 'Cohort' and 'CohortPeriod' columns
"""
from config import CUSTOMER_COLUMN, DATE_COLUMN
# Get first purchase date for each customer
first_purchase = df.groupby(CUSTOMER_COLUMN)[DATE_COLUMN].min().reset_index()
first_purchase.columns = [CUSTOMER_COLUMN, 'FirstPurchaseDate']
# Extract cohort year-month
first_purchase['Cohort'] = first_purchase['FirstPurchaseDate'].dt.to_period('M')
# Merge back to original data
df_with_cohort = df.merge(first_purchase[[CUSTOMER_COLUMN, 'Cohort']], on=CUSTOMER_COLUMN)
# Calculate period number (months since first purchase)
df_with_cohort['Period'] = df_with_cohort[DATE_COLUMN].dt.to_period('M')
df_with_cohort['CohortPeriod'] = (df_with_cohort['Period'] - df_with_cohort['Cohort']).apply(attrgetter('n'))
return df_with_cohort
def calculate_cohort_metrics(df_with_cohort):
"""
Calculate cohort retention metrics
Args:
df_with_cohort: DataFrame with Cohort and CohortPeriod columns
Returns:
DataFrame: Cohort metrics by period
"""
from config import REVENUE_COLUMN, CUSTOMER_COLUMN
# Customer count by cohort and period
cohort_size = df_with_cohort.groupby('Cohort')[CUSTOMER_COLUMN].nunique()
# Revenue by cohort and period
cohort_revenue = df_with_cohort.groupby(['Cohort', 'CohortPeriod']).agg({
CUSTOMER_COLUMN: 'nunique',
REVENUE_COLUMN: 'sum'
}).reset_index()
cohort_revenue.columns = ['Cohort', 'Period', 'Customers', 'Revenue']
# Calculate retention rates
cohort_retention = []
for cohort in cohort_revenue['Cohort'].unique():
cohort_data = cohort_revenue[cohort_revenue['Cohort'] == cohort].copy()
initial_customers = cohort_data[cohort_data['Period'] == 0]['Customers'].values[0]
cohort_data['Retention_Rate'] = (cohort_data['Customers'] / initial_customers) * 100
cohort_data['Revenue_Retention'] = cohort_data['Revenue'] / cohort_data[cohort_data['Period'] == 0]['Revenue'].values[0] * 100
cohort_retention.append(cohort_data)
return pd.concat(cohort_retention, ignore_index=True)
# ============================================================================
# MAIN ANALYSIS FUNCTION
# ============================================================================
def main():
"""Main analysis function"""
print(f"\n{'='*60}")
print(f"{ANALYSIS_NAME}")
print(f"{'='*60}\n")
# 1. Load data
print("Loading data...")
try:
df = load_sales_data(get_data_path())
print(f"Loaded {len(df):,} transactions")
except Exception as e:
print(f"ERROR loading data: {e}")
return
# 2. Validate
is_valid, msg = validate_data_structure(df)
if not is_valid:
print(f"ERROR: {msg}")
return
if CUSTOMER_COLUMN not in df.columns:
print(f"ERROR: Customer column '{CUSTOMER_COLUMN}' not found")
return
# 3. Apply filters
df = apply_exclusion_filters(df)
df = df[df['Year'] >= MIN_YEAR]
if DATE_COLUMN in df.columns:
df = df[df[DATE_COLUMN] <= MAX_DATE]
# 4. Create cohorts
print("\nCreating customer cohorts...")
df_cohort = create_cohorts(df)
# 5. Calculate cohort metrics
print("Calculating cohort metrics...")
cohort_metrics = calculate_cohort_metrics(df_cohort)
# 6. Print summary
print("\nCohort Summary:")
print("-" * 60)
for cohort in sorted(cohort_metrics['Cohort'].unique())[:5]: # Show top 5 cohorts
cohort_data = cohort_metrics[cohort_metrics['Cohort'] == cohort]
period_0 = cohort_data[cohort_data['Period'] == 0]
if len(period_0) > 0:
initial_customers = period_0['Customers'].values[0]
initial_revenue = period_0['Revenue'].values[0]
print(f"\n{cohort}:")
print(f" Initial: {initial_customers:,} customers, {format_currency(initial_revenue)}")
# Show retention at period 12
period_12 = cohort_data[cohort_data['Period'] == 12]
if len(period_12) > 0:
retention = period_12['Retention_Rate'].values[0]
revenue_ret = period_12['Revenue_Retention'].values[0]
print(f" Period 12: {retention:.1f}% customer retention, {revenue_ret:.1f}% revenue retention")
# 7. Create visualizations
print("\nGenerating charts...")
ensure_directories()
# Heatmap: Customer retention
pivot_retention = cohort_metrics.pivot_table(
index='Cohort',
columns='Period',
values='Retention_Rate',
aggfunc='mean'
)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=CHART_SIZES['wide'])
# Retention heatmap
sns.heatmap(pivot_retention, annot=True, fmt='.0f', cmap='YlOrRd', ax=ax1, cbar_kws={'label': 'Retention %'})
ax1.set_title('Customer Retention by Cohort\n(Period 0 = 100%)', fontsize=12, fontweight='bold')
ax1.set_xlabel('Months Since First Purchase')
ax1.set_ylabel('Cohort')
# Revenue retention heatmap
pivot_revenue = cohort_metrics.pivot_table(
index='Cohort',
columns='Period',
values='Revenue_Retention',
aggfunc='mean'
)
sns.heatmap(pivot_revenue, annot=True, fmt='.0f', cmap='YlGnBu', ax=ax2, cbar_kws={'label': 'Revenue Retention %'})
ax2.set_title('Revenue Retention by Cohort\n(Period 0 = 100%)', fontsize=12, fontweight='bold')
ax2.set_xlabel('Months Since First Purchase')
ax2.set_ylabel('Cohort')
plt.suptitle(f'Cohort Analysis - {COMPANY_NAME}', fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
save_chart(fig, 'cohort_analysis.png')
plt.close()
# 8. Validate
print("\nValidating revenue...")
validate_revenue(df, ANALYSIS_NAME)
print(f"\n{ANALYSIS_NAME} complete!")
print(f"Charts saved to: {OUTPUT_DIR}")
# ============================================================================
# RUN ANALYSIS
# ============================================================================
if __name__ == "__main__":
main()