""" Common utilities for analysis scripts Provides formatters, LTM setup, and helper functions This module is designed to work with any sales data structure by using configuration from config.py """ import pandas as pd import numpy as np from matplotlib.ticker import FuncFormatter from pathlib import Path from config import ( REVENUE_COLUMN, LTM_ENABLED, get_ltm_period, get_ltm_label, OUTPUT_DIR, CHART_DPI, CHART_BBOX ) # ============================================================================ # CHART FORMATTERS # ============================================================================ def millions_formatter(x: float, pos: int) -> str: """ Format numbers in millions for chart display (e.g., $99.9m) This formatter is used with matplotlib FuncFormatter to display revenue values in millions on chart axes. Args: x: Numeric value (already in millions, e.g., 99.9 for $99.9m) pos: Position parameter (required by FuncFormatter, not used) Returns: str: Formatted string like "$99.9m" Example: >>> from matplotlib.ticker import FuncFormatter >>> formatter = FuncFormatter(millions_formatter) >>> ax.yaxis.set_major_formatter(formatter) """ return f'${x:.1f}m' def thousands_formatter(x: float, pos: int) -> str: """ Format numbers in thousands for chart display (e.g., $99.9k) Args: x: Numeric value (already in thousands) pos: Position parameter (required by FuncFormatter, not used) Returns: str: Formatted string like "$99.9k" """ return f'${x:.1f}k' def get_millions_formatter() -> FuncFormatter: """ Get FuncFormatter for millions Returns: FuncFormatter: Configured formatter for millions display """ return FuncFormatter(millions_formatter) def get_thousands_formatter() -> FuncFormatter: """ Get FuncFormatter for thousands Returns: FuncFormatter: Configured formatter for thousands display """ return FuncFormatter(thousands_formatter) # ============================================================================ # LTM (Last Twelve Months) SETUP # ============================================================================ def get_ltm_period_config(): """ Get LTM period boundaries from config Returns: tuple: (ltm_start, ltm_end) as pd.Period objects, or (None, None) if disabled """ if LTM_ENABLED: return get_ltm_period() return None, None def get_annual_data(df, year, ltm_start=None, ltm_end=None): """ Get data for a specific year, using LTM for the most recent partial year Args: df: DataFrame with 'Year' and 'YearMonth' columns year: Year to extract (int) ltm_start: LTM start period (defaults to config if None) ltm_end: LTM end period (defaults to config if None) Returns: tuple: (year_data DataFrame, year_label string) """ from config import LTM_END_YEAR # Get LTM period from config if not provided if ltm_start is None or ltm_end is None: ltm_start, ltm_end = get_ltm_period_config() # Use LTM for the most recent year if enabled if LTM_ENABLED and ltm_start and ltm_end and year == LTM_END_YEAR: if 'YearMonth' in df.columns: year_data = df[(df['YearMonth'] >= ltm_start) & (df['YearMonth'] <= ltm_end)] year_label = get_ltm_label() or str(year) else: # Fallback if YearMonth not available year_data = df[df['Year'] == year] year_label = str(year) else: # Use full calendar year year_data = df[df['Year'] == year] year_label = str(year) return year_data, year_label def calculate_annual_metrics(df, metrics_func, ltm_start=None, ltm_end=None): """ Calculate annual metrics for all years, using LTM for most recent year Args: df: DataFrame with 'Year' and 'YearMonth' columns metrics_func: Function that takes a DataFrame and returns a dict of metrics ltm_start: LTM start period (defaults to config if None) ltm_end: LTM end period (defaults to config if None) Returns: DataFrame with 'Year' index and metric columns """ from config import ANALYSIS_YEARS if ltm_start is None or ltm_end is None: ltm_start, ltm_end = get_ltm_period_config() annual_data = [] for year in sorted(ANALYSIS_YEARS): if year in df['Year'].unique(): year_data, year_label = get_annual_data(df, year, ltm_start, ltm_end) if len(year_data) > 0: metrics = metrics_func(year_data) metrics['Year'] = year_label annual_data.append(metrics) if not annual_data: return pd.DataFrame() return pd.DataFrame(annual_data).set_index('Year') # ============================================================================ # MIXED TYPE HANDLING # ============================================================================ def create_year_sort_column(df, year_col='Year'): """ Create a numeric sort column for mixed int/str year columns Args: df: DataFrame year_col: Name of year column Returns: Series with numeric sort values """ from config import LTM_END_YEAR def sort_value(x): if isinstance(x, str) and str(LTM_END_YEAR) in str(x): return float(LTM_END_YEAR) + 0.5 elif isinstance(x, (int, float)): return float(x) else: return 9999 return df[year_col].apply(sort_value) def sort_mixed_years(df, year_col='Year'): """ Sort DataFrame by year column that may contain mixed int/str types Args: df: DataFrame year_col: Name of year column Returns: Sorted DataFrame """ df = df.copy() df['_Year_Sort'] = create_year_sort_column(df, year_col) df = df.sort_values('_Year_Sort').drop(columns=['_Year_Sort']) return df def safe_year_labels(years): """ Convert year values to safe string labels for chart axes Args: years: Iterable of year values (int or str) Returns: List of string labels """ return [str(year) for year in years] # ============================================================================ # CHART HELPERS # ============================================================================ def setup_revenue_chart(ax, ylabel: str = 'Revenue (Millions USD)') -> None: """ Setup a chart axis for revenue display (millions) CRITICAL: Always use this function for revenue charts. It applies the millions formatter and standard styling. IMPORTANT: Data must be divided by 1e6 BEFORE plotting: ax.plot(revenue / 1e6, ...) # ✅ Correct ax.plot(revenue, ...) # ❌ Wrong - will show scientific notation Args: ax: Matplotlib axis object to configure ylabel: Y-axis label (default: 'Revenue (Millions USD)') Returns: None: Modifies ax in place Example: >>> import matplotlib.pyplot as plt >>> from analysis_utils import setup_revenue_chart >>> fig, ax = plt.subplots() >>> ax.plot(revenue_data / 1e6, marker='o') # Divide by 1e6 first! >>> setup_revenue_chart(ax) >>> plt.show() See Also: - .cursor/rules/chart_formatting.md for detailed patterns - save_chart() for saving charts """ ax.yaxis.set_major_formatter(get_millions_formatter()) ax.set_ylabel(ylabel) ax.grid(True, alpha=0.3) def save_chart(fig, filename, output_dir=None): """ Save chart to file with organized directory structure Args: fig: Matplotlib figure object filename: Output filename (e.g., 'revenue_trend.png') output_dir: Output directory (defaults to config.OUTPUT_DIR) """ if output_dir is None: output_dir = OUTPUT_DIR else: output_dir = Path(output_dir) output_dir.mkdir(exist_ok=True) filepath = output_dir / filename fig.savefig(filepath, dpi=CHART_DPI, bbox_inches=CHART_BBOX, format='png') print(f"Chart saved: {filepath}") # ============================================================================ # DATA VALIDATION # ============================================================================ def validate_dataframe(df, required_columns=None): """ Validate DataFrame has required columns and basic data quality Args: df: DataFrame to validate required_columns: List of required column names (defaults to config) Returns: tuple: (is_valid bool, error_message str) """ if required_columns is None: required_columns = [REVENUE_COLUMN, 'Year'] if 'YearMonth' in df.columns: required_columns.append('YearMonth') missing_cols = [col for col in required_columns if col not in df.columns] if missing_cols: return False, f"Missing required columns: {missing_cols}" if len(df) == 0: return False, "DataFrame is empty" if REVENUE_COLUMN in df.columns: if df[REVENUE_COLUMN].isna().all(): return False, f"All {REVENUE_COLUMN} values are NaN" return True, "OK" # ============================================================================ # PRICE CALCULATION # ============================================================================ def calculate_price_per_unit(df, quantity_col=None, revenue_col=None): """ Calculate average price per unit, excluding invalid quantities Args: df: DataFrame with quantity and revenue columns quantity_col: Name of quantity column (defaults to config) revenue_col: Name of revenue column (defaults to config) Returns: float: Average price per unit """ from config import QUANTITY_COLUMN, REVENUE_COLUMN, MIN_QUANTITY, MAX_QUANTITY if quantity_col is None: quantity_col = QUANTITY_COLUMN if revenue_col is None: revenue_col = REVENUE_COLUMN # Check if quantity column exists if quantity_col not in df.columns: return np.nan # Filter for valid quantity transactions df_valid = df[(df[quantity_col] > MIN_QUANTITY) & (df[quantity_col] <= MAX_QUANTITY)].copy() if len(df_valid) == 0: return np.nan total_revenue = df_valid[revenue_col].sum() total_quantity = df_valid[quantity_col].sum() if total_quantity == 0: return np.nan return total_revenue / total_quantity # ============================================================================ # OUTPUT FORMATTING # ============================================================================ def format_currency(value: float, millions: bool = True) -> str: """ Format currency value for console output Args: value: Numeric value to format millions: If True, format as millions ($X.Xm), else thousands ($X.Xk) Returns: str: Formatted string like "$99.9m" or "$99.9k" or "N/A" if NaN Example: >>> format_currency(1000000) '$1.00m' >>> format_currency(1000, millions=False) '$1.00k' """ if pd.isna(value): return "N/A" if millions: return f"${value / 1e6:.2f}m" else: return f"${value / 1e3:.2f}k" def print_annual_summary(annual_df, metric_col='Revenue', label='Revenue'): """ Print formatted annual summary to console Args: annual_df: DataFrame with annual metrics (indexed by Year) metric_col: Column name to print label: Label for the metric """ print(f"\n{label} by Year:") print("-" * 40) for year in annual_df.index: value = annual_df.loc[year, metric_col] formatted = format_currency(value) print(f" {year}: {formatted}") print() # ============================================================================ # DATA FILTERING HELPERS # ============================================================================ def apply_exclusion_filters(df): """ Apply exclusion filters from config Args: df: DataFrame to filter Returns: Filtered DataFrame """ from config import EXCLUSION_FILTERS if not EXCLUSION_FILTERS.get('enabled', False): return df exclude_col = EXCLUSION_FILTERS.get('exclude_by_column') exclude_values = EXCLUSION_FILTERS.get('exclude_values', []) if exclude_col and exclude_col in df.columns and exclude_values: original_count = len(df) df_filtered = df[~df[exclude_col].isin(exclude_values)] excluded_count = original_count - len(df_filtered) if excluded_count > 0: print(f"Excluded {excluded_count:,} rows based on {exclude_col} filter") return df_filtered return df # ============================================================================ # INTERACTIVE VISUALIZATIONS (OPTIONAL - PLOTLY) # ============================================================================ def create_interactive_chart(data, chart_type='line', title=None, xlabel=None, ylabel=None): """ Create interactive chart using Plotly (optional dependency) Args: data: DataFrame or dict with chart data chart_type: Type of chart ('line', 'bar', 'scatter') title: Chart title xlabel: X-axis label ylabel: Y-axis label Returns: plotly.graph_objects.Figure: Plotly figure object Raises: ImportError: If plotly is not installed Example: fig = create_interactive_chart( {'x': [1, 2, 3], 'y': [10, 20, 30]}, chart_type='line', title='Revenue Trend' ) fig.show() """ try: import plotly.graph_objects as go from plotly.subplots import make_subplots except ImportError: raise ImportError( "plotly is required for interactive charts. Install with: pip install plotly" ) fig = go.Figure() if chart_type == 'line': if isinstance(data, dict) and 'x' in data and 'y' in data: fig.add_trace(go.Scatter( x=data['x'], y=data['y'], mode='lines+markers', name='Data' )) elif chart_type == 'bar': if isinstance(data, dict) and 'x' in data and 'y' in data: fig.add_trace(go.Bar( x=data['x'], y=data['y'], name='Data' )) if title: fig.update_layout(title=title) if xlabel: fig.update_xaxes(title_text=xlabel) if ylabel: fig.update_yaxes(title_text=ylabel) fig.update_layout( template='plotly_white', hovermode='x unified' ) return fig def save_interactive_chart(fig, filename, output_dir=None): """ Save interactive Plotly chart to HTML file Args: fig: Plotly figure object filename: Output filename (e.g., 'chart.html') output_dir: Output directory (defaults to config.OUTPUT_DIR) """ if output_dir is None: output_dir = OUTPUT_DIR else: output_dir = Path(output_dir) output_dir.mkdir(exist_ok=True) filepath = output_dir / filename fig.write_html(str(filepath)) print(f"Interactive chart saved: {filepath}") return filepath