219 lines
7.5 KiB
Python
219 lines
7.5 KiB
Python
"""
|
|
Example: Cohort Analysis
|
|
Advanced example showing customer cohort retention analysis
|
|
|
|
This demonstrates:
|
|
- Cohort-based analysis
|
|
- Retention rate calculations
|
|
- Revenue retention metrics
|
|
- Advanced visualization
|
|
"""
|
|
import pandas as pd
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
import seaborn as sns
|
|
from pathlib import Path
|
|
from operator import attrgetter
|
|
|
|
# Import utilities
|
|
from data_loader import load_sales_data, validate_data_structure
|
|
from validate_revenue import validate_revenue
|
|
from analysis_utils import (
|
|
get_ltm_period_config, apply_exclusion_filters,
|
|
setup_revenue_chart, save_chart, format_currency
|
|
)
|
|
from config import (
|
|
OUTPUT_DIR, MAX_DATE, CHART_SIZES, ensure_directories,
|
|
get_data_path, COMPANY_NAME, REVENUE_COLUMN, CUSTOMER_COLUMN,
|
|
DATE_COLUMN, MIN_YEAR
|
|
)
|
|
|
|
# ============================================================================
|
|
# CONFIGURATION
|
|
# ============================================================================
|
|
|
|
ANALYSIS_NAME = "Cohort Analysis"
|
|
DESCRIPTION = "Customer cohort retention and revenue analysis"
|
|
|
|
# ============================================================================
|
|
# COHORT ANALYSIS FUNCTIONS
|
|
# ============================================================================
|
|
|
|
def create_cohorts(df):
|
|
"""
|
|
Create customer cohorts based on first purchase date
|
|
|
|
Args:
|
|
df: DataFrame with customer and date columns
|
|
|
|
Returns:
|
|
DataFrame: Original DataFrame with 'Cohort' and 'CohortPeriod' columns
|
|
"""
|
|
from config import CUSTOMER_COLUMN, DATE_COLUMN
|
|
|
|
# Get first purchase date for each customer
|
|
first_purchase = df.groupby(CUSTOMER_COLUMN)[DATE_COLUMN].min().reset_index()
|
|
first_purchase.columns = [CUSTOMER_COLUMN, 'FirstPurchaseDate']
|
|
|
|
# Extract cohort year-month
|
|
first_purchase['Cohort'] = first_purchase['FirstPurchaseDate'].dt.to_period('M')
|
|
|
|
# Merge back to original data
|
|
df_with_cohort = df.merge(first_purchase[[CUSTOMER_COLUMN, 'Cohort']], on=CUSTOMER_COLUMN)
|
|
|
|
# Calculate period number (months since first purchase)
|
|
df_with_cohort['Period'] = df_with_cohort[DATE_COLUMN].dt.to_period('M')
|
|
df_with_cohort['CohortPeriod'] = (df_with_cohort['Period'] - df_with_cohort['Cohort']).apply(attrgetter('n'))
|
|
|
|
return df_with_cohort
|
|
|
|
def calculate_cohort_metrics(df_with_cohort):
|
|
"""
|
|
Calculate cohort retention metrics
|
|
|
|
Args:
|
|
df_with_cohort: DataFrame with Cohort and CohortPeriod columns
|
|
|
|
Returns:
|
|
DataFrame: Cohort metrics by period
|
|
"""
|
|
from config import REVENUE_COLUMN, CUSTOMER_COLUMN
|
|
|
|
# Customer count by cohort and period
|
|
cohort_size = df_with_cohort.groupby('Cohort')[CUSTOMER_COLUMN].nunique()
|
|
|
|
# Revenue by cohort and period
|
|
cohort_revenue = df_with_cohort.groupby(['Cohort', 'CohortPeriod']).agg({
|
|
CUSTOMER_COLUMN: 'nunique',
|
|
REVENUE_COLUMN: 'sum'
|
|
}).reset_index()
|
|
cohort_revenue.columns = ['Cohort', 'Period', 'Customers', 'Revenue']
|
|
|
|
# Calculate retention rates
|
|
cohort_retention = []
|
|
for cohort in cohort_revenue['Cohort'].unique():
|
|
cohort_data = cohort_revenue[cohort_revenue['Cohort'] == cohort].copy()
|
|
initial_customers = cohort_data[cohort_data['Period'] == 0]['Customers'].values[0]
|
|
|
|
cohort_data['Retention_Rate'] = (cohort_data['Customers'] / initial_customers) * 100
|
|
cohort_data['Revenue_Retention'] = cohort_data['Revenue'] / cohort_data[cohort_data['Period'] == 0]['Revenue'].values[0] * 100
|
|
|
|
cohort_retention.append(cohort_data)
|
|
|
|
return pd.concat(cohort_retention, ignore_index=True)
|
|
|
|
# ============================================================================
|
|
# MAIN ANALYSIS FUNCTION
|
|
# ============================================================================
|
|
|
|
def main():
|
|
"""Main analysis function"""
|
|
|
|
print(f"\n{'='*60}")
|
|
print(f"{ANALYSIS_NAME}")
|
|
print(f"{'='*60}\n")
|
|
|
|
# 1. Load data
|
|
print("Loading data...")
|
|
try:
|
|
df = load_sales_data(get_data_path())
|
|
print(f"Loaded {len(df):,} transactions")
|
|
except Exception as e:
|
|
print(f"ERROR loading data: {e}")
|
|
return
|
|
|
|
# 2. Validate
|
|
is_valid, msg = validate_data_structure(df)
|
|
if not is_valid:
|
|
print(f"ERROR: {msg}")
|
|
return
|
|
|
|
if CUSTOMER_COLUMN not in df.columns:
|
|
print(f"ERROR: Customer column '{CUSTOMER_COLUMN}' not found")
|
|
return
|
|
|
|
# 3. Apply filters
|
|
df = apply_exclusion_filters(df)
|
|
df = df[df['Year'] >= MIN_YEAR]
|
|
if DATE_COLUMN in df.columns:
|
|
df = df[df[DATE_COLUMN] <= MAX_DATE]
|
|
|
|
# 4. Create cohorts
|
|
print("\nCreating customer cohorts...")
|
|
df_cohort = create_cohorts(df)
|
|
|
|
# 5. Calculate cohort metrics
|
|
print("Calculating cohort metrics...")
|
|
cohort_metrics = calculate_cohort_metrics(df_cohort)
|
|
|
|
# 6. Print summary
|
|
print("\nCohort Summary:")
|
|
print("-" * 60)
|
|
for cohort in sorted(cohort_metrics['Cohort'].unique())[:5]: # Show top 5 cohorts
|
|
cohort_data = cohort_metrics[cohort_metrics['Cohort'] == cohort]
|
|
period_0 = cohort_data[cohort_data['Period'] == 0]
|
|
if len(period_0) > 0:
|
|
initial_customers = period_0['Customers'].values[0]
|
|
initial_revenue = period_0['Revenue'].values[0]
|
|
print(f"\n{cohort}:")
|
|
print(f" Initial: {initial_customers:,} customers, {format_currency(initial_revenue)}")
|
|
|
|
# Show retention at period 12
|
|
period_12 = cohort_data[cohort_data['Period'] == 12]
|
|
if len(period_12) > 0:
|
|
retention = period_12['Retention_Rate'].values[0]
|
|
revenue_ret = period_12['Revenue_Retention'].values[0]
|
|
print(f" Period 12: {retention:.1f}% customer retention, {revenue_ret:.1f}% revenue retention")
|
|
|
|
# 7. Create visualizations
|
|
print("\nGenerating charts...")
|
|
ensure_directories()
|
|
|
|
# Heatmap: Customer retention
|
|
pivot_retention = cohort_metrics.pivot_table(
|
|
index='Cohort',
|
|
columns='Period',
|
|
values='Retention_Rate',
|
|
aggfunc='mean'
|
|
)
|
|
|
|
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=CHART_SIZES['wide'])
|
|
|
|
# Retention heatmap
|
|
sns.heatmap(pivot_retention, annot=True, fmt='.0f', cmap='YlOrRd', ax=ax1, cbar_kws={'label': 'Retention %'})
|
|
ax1.set_title('Customer Retention by Cohort\n(Period 0 = 100%)', fontsize=12, fontweight='bold')
|
|
ax1.set_xlabel('Months Since First Purchase')
|
|
ax1.set_ylabel('Cohort')
|
|
|
|
# Revenue retention heatmap
|
|
pivot_revenue = cohort_metrics.pivot_table(
|
|
index='Cohort',
|
|
columns='Period',
|
|
values='Revenue_Retention',
|
|
aggfunc='mean'
|
|
)
|
|
|
|
sns.heatmap(pivot_revenue, annot=True, fmt='.0f', cmap='YlGnBu', ax=ax2, cbar_kws={'label': 'Revenue Retention %'})
|
|
ax2.set_title('Revenue Retention by Cohort\n(Period 0 = 100%)', fontsize=12, fontweight='bold')
|
|
ax2.set_xlabel('Months Since First Purchase')
|
|
ax2.set_ylabel('Cohort')
|
|
|
|
plt.suptitle(f'Cohort Analysis - {COMPANY_NAME}', fontsize=14, fontweight='bold', y=1.02)
|
|
plt.tight_layout()
|
|
save_chart(fig, 'cohort_analysis.png')
|
|
plt.close()
|
|
|
|
# 8. Validate
|
|
print("\nValidating revenue...")
|
|
validate_revenue(df, ANALYSIS_NAME)
|
|
|
|
print(f"\n{ANALYSIS_NAME} complete!")
|
|
print(f"Charts saved to: {OUTPUT_DIR}")
|
|
|
|
# ============================================================================
|
|
# RUN ANALYSIS
|
|
# ============================================================================
|
|
|
|
if __name__ == "__main__":
|
|
main()
|