""" Example: Cohort Analysis Advanced example showing customer cohort retention analysis This demonstrates: - Cohort-based analysis - Retention rate calculations - Revenue retention metrics - Advanced visualization """ import pandas as pd import numpy as np import matplotlib.pyplot as plt import seaborn as sns from pathlib import Path from operator import attrgetter # Import utilities from data_loader import load_sales_data, validate_data_structure from validate_revenue import validate_revenue from analysis_utils import ( get_ltm_period_config, apply_exclusion_filters, setup_revenue_chart, save_chart, format_currency ) from config import ( OUTPUT_DIR, MAX_DATE, CHART_SIZES, ensure_directories, get_data_path, COMPANY_NAME, REVENUE_COLUMN, CUSTOMER_COLUMN, DATE_COLUMN, MIN_YEAR ) # ============================================================================ # CONFIGURATION # ============================================================================ ANALYSIS_NAME = "Cohort Analysis" DESCRIPTION = "Customer cohort retention and revenue analysis" # ============================================================================ # COHORT ANALYSIS FUNCTIONS # ============================================================================ def create_cohorts(df): """ Create customer cohorts based on first purchase date Args: df: DataFrame with customer and date columns Returns: DataFrame: Original DataFrame with 'Cohort' and 'CohortPeriod' columns """ from config import CUSTOMER_COLUMN, DATE_COLUMN # Get first purchase date for each customer first_purchase = df.groupby(CUSTOMER_COLUMN)[DATE_COLUMN].min().reset_index() first_purchase.columns = [CUSTOMER_COLUMN, 'FirstPurchaseDate'] # Extract cohort year-month first_purchase['Cohort'] = first_purchase['FirstPurchaseDate'].dt.to_period('M') # Merge back to original data df_with_cohort = df.merge(first_purchase[[CUSTOMER_COLUMN, 'Cohort']], on=CUSTOMER_COLUMN) # Calculate period number (months since first purchase) df_with_cohort['Period'] = df_with_cohort[DATE_COLUMN].dt.to_period('M') df_with_cohort['CohortPeriod'] = (df_with_cohort['Period'] - df_with_cohort['Cohort']).apply(attrgetter('n')) return df_with_cohort def calculate_cohort_metrics(df_with_cohort): """ Calculate cohort retention metrics Args: df_with_cohort: DataFrame with Cohort and CohortPeriod columns Returns: DataFrame: Cohort metrics by period """ from config import REVENUE_COLUMN, CUSTOMER_COLUMN # Customer count by cohort and period cohort_size = df_with_cohort.groupby('Cohort')[CUSTOMER_COLUMN].nunique() # Revenue by cohort and period cohort_revenue = df_with_cohort.groupby(['Cohort', 'CohortPeriod']).agg({ CUSTOMER_COLUMN: 'nunique', REVENUE_COLUMN: 'sum' }).reset_index() cohort_revenue.columns = ['Cohort', 'Period', 'Customers', 'Revenue'] # Calculate retention rates cohort_retention = [] for cohort in cohort_revenue['Cohort'].unique(): cohort_data = cohort_revenue[cohort_revenue['Cohort'] == cohort].copy() initial_customers = cohort_data[cohort_data['Period'] == 0]['Customers'].values[0] cohort_data['Retention_Rate'] = (cohort_data['Customers'] / initial_customers) * 100 cohort_data['Revenue_Retention'] = cohort_data['Revenue'] / cohort_data[cohort_data['Period'] == 0]['Revenue'].values[0] * 100 cohort_retention.append(cohort_data) return pd.concat(cohort_retention, ignore_index=True) # ============================================================================ # MAIN ANALYSIS FUNCTION # ============================================================================ def main(): """Main analysis function""" print(f"\n{'='*60}") print(f"{ANALYSIS_NAME}") print(f"{'='*60}\n") # 1. Load data print("Loading data...") try: df = load_sales_data(get_data_path()) print(f"Loaded {len(df):,} transactions") except Exception as e: print(f"ERROR loading data: {e}") return # 2. Validate is_valid, msg = validate_data_structure(df) if not is_valid: print(f"ERROR: {msg}") return if CUSTOMER_COLUMN not in df.columns: print(f"ERROR: Customer column '{CUSTOMER_COLUMN}' not found") return # 3. Apply filters df = apply_exclusion_filters(df) df = df[df['Year'] >= MIN_YEAR] if DATE_COLUMN in df.columns: df = df[df[DATE_COLUMN] <= MAX_DATE] # 4. Create cohorts print("\nCreating customer cohorts...") df_cohort = create_cohorts(df) # 5. Calculate cohort metrics print("Calculating cohort metrics...") cohort_metrics = calculate_cohort_metrics(df_cohort) # 6. Print summary print("\nCohort Summary:") print("-" * 60) for cohort in sorted(cohort_metrics['Cohort'].unique())[:5]: # Show top 5 cohorts cohort_data = cohort_metrics[cohort_metrics['Cohort'] == cohort] period_0 = cohort_data[cohort_data['Period'] == 0] if len(period_0) > 0: initial_customers = period_0['Customers'].values[0] initial_revenue = period_0['Revenue'].values[0] print(f"\n{cohort}:") print(f" Initial: {initial_customers:,} customers, {format_currency(initial_revenue)}") # Show retention at period 12 period_12 = cohort_data[cohort_data['Period'] == 12] if len(period_12) > 0: retention = period_12['Retention_Rate'].values[0] revenue_ret = period_12['Revenue_Retention'].values[0] print(f" Period 12: {retention:.1f}% customer retention, {revenue_ret:.1f}% revenue retention") # 7. Create visualizations print("\nGenerating charts...") ensure_directories() # Heatmap: Customer retention pivot_retention = cohort_metrics.pivot_table( index='Cohort', columns='Period', values='Retention_Rate', aggfunc='mean' ) fig, (ax1, ax2) = plt.subplots(1, 2, figsize=CHART_SIZES['wide']) # Retention heatmap sns.heatmap(pivot_retention, annot=True, fmt='.0f', cmap='YlOrRd', ax=ax1, cbar_kws={'label': 'Retention %'}) ax1.set_title('Customer Retention by Cohort\n(Period 0 = 100%)', fontsize=12, fontweight='bold') ax1.set_xlabel('Months Since First Purchase') ax1.set_ylabel('Cohort') # Revenue retention heatmap pivot_revenue = cohort_metrics.pivot_table( index='Cohort', columns='Period', values='Revenue_Retention', aggfunc='mean' ) sns.heatmap(pivot_revenue, annot=True, fmt='.0f', cmap='YlGnBu', ax=ax2, cbar_kws={'label': 'Revenue Retention %'}) ax2.set_title('Revenue Retention by Cohort\n(Period 0 = 100%)', fontsize=12, fontweight='bold') ax2.set_xlabel('Months Since First Purchase') ax2.set_ylabel('Cohort') plt.suptitle(f'Cohort Analysis - {COMPANY_NAME}', fontsize=14, fontweight='bold', y=1.02) plt.tight_layout() save_chart(fig, 'cohort_analysis.png') plt.close() # 8. Validate print("\nValidating revenue...") validate_revenue(df, ANALYSIS_NAME) print(f"\n{ANALYSIS_NAME} complete!") print(f"Charts saved to: {OUTPUT_DIR}") # ============================================================================ # RUN ANALYSIS # ============================================================================ if __name__ == "__main__": main()