Files
sales-data-analysis/analysis_utils.py
Jonathan Pressnell cf0b596449 Initial commit: sales analysis template
Co-authored-by: Cursor <cursoragent@cursor.com>
2026-02-06 09:16:34 -05:00

511 lines
15 KiB
Python

"""
Common utilities for analysis scripts
Provides formatters, LTM setup, and helper functions
This module is designed to work with any sales data structure
by using configuration from config.py
"""
import pandas as pd
import numpy as np
from matplotlib.ticker import FuncFormatter
from pathlib import Path
from config import (
REVENUE_COLUMN, LTM_ENABLED, get_ltm_period, get_ltm_label,
OUTPUT_DIR, CHART_DPI, CHART_BBOX
)
# ============================================================================
# CHART FORMATTERS
# ============================================================================
def millions_formatter(x: float, pos: int) -> str:
"""
Format numbers in millions for chart display (e.g., $99.9m)
This formatter is used with matplotlib FuncFormatter to display
revenue values in millions on chart axes.
Args:
x: Numeric value (already in millions, e.g., 99.9 for $99.9m)
pos: Position parameter (required by FuncFormatter, not used)
Returns:
str: Formatted string like "$99.9m"
Example:
>>> from matplotlib.ticker import FuncFormatter
>>> formatter = FuncFormatter(millions_formatter)
>>> ax.yaxis.set_major_formatter(formatter)
"""
return f'${x:.1f}m'
def thousands_formatter(x: float, pos: int) -> str:
"""
Format numbers in thousands for chart display (e.g., $99.9k)
Args:
x: Numeric value (already in thousands)
pos: Position parameter (required by FuncFormatter, not used)
Returns:
str: Formatted string like "$99.9k"
"""
return f'${x:.1f}k'
def get_millions_formatter() -> FuncFormatter:
"""
Get FuncFormatter for millions
Returns:
FuncFormatter: Configured formatter for millions display
"""
return FuncFormatter(millions_formatter)
def get_thousands_formatter() -> FuncFormatter:
"""
Get FuncFormatter for thousands
Returns:
FuncFormatter: Configured formatter for thousands display
"""
return FuncFormatter(thousands_formatter)
# ============================================================================
# LTM (Last Twelve Months) SETUP
# ============================================================================
def get_ltm_period_config():
"""
Get LTM period boundaries from config
Returns:
tuple: (ltm_start, ltm_end) as pd.Period objects, or (None, None) if disabled
"""
if LTM_ENABLED:
return get_ltm_period()
return None, None
def get_annual_data(df, year, ltm_start=None, ltm_end=None):
"""
Get data for a specific year, using LTM for the most recent partial year
Args:
df: DataFrame with 'Year' and 'YearMonth' columns
year: Year to extract (int)
ltm_start: LTM start period (defaults to config if None)
ltm_end: LTM end period (defaults to config if None)
Returns:
tuple: (year_data DataFrame, year_label string)
"""
from config import LTM_END_YEAR
# Get LTM period from config if not provided
if ltm_start is None or ltm_end is None:
ltm_start, ltm_end = get_ltm_period_config()
# Use LTM for the most recent year if enabled
if LTM_ENABLED and ltm_start and ltm_end and year == LTM_END_YEAR:
if 'YearMonth' in df.columns:
year_data = df[(df['YearMonth'] >= ltm_start) & (df['YearMonth'] <= ltm_end)]
year_label = get_ltm_label() or str(year)
else:
# Fallback if YearMonth not available
year_data = df[df['Year'] == year]
year_label = str(year)
else:
# Use full calendar year
year_data = df[df['Year'] == year]
year_label = str(year)
return year_data, year_label
def calculate_annual_metrics(df, metrics_func, ltm_start=None, ltm_end=None):
"""
Calculate annual metrics for all years, using LTM for most recent year
Args:
df: DataFrame with 'Year' and 'YearMonth' columns
metrics_func: Function that takes a DataFrame and returns a dict of metrics
ltm_start: LTM start period (defaults to config if None)
ltm_end: LTM end period (defaults to config if None)
Returns:
DataFrame with 'Year' index and metric columns
"""
from config import ANALYSIS_YEARS
if ltm_start is None or ltm_end is None:
ltm_start, ltm_end = get_ltm_period_config()
annual_data = []
for year in sorted(ANALYSIS_YEARS):
if year in df['Year'].unique():
year_data, year_label = get_annual_data(df, year, ltm_start, ltm_end)
if len(year_data) > 0:
metrics = metrics_func(year_data)
metrics['Year'] = year_label
annual_data.append(metrics)
if not annual_data:
return pd.DataFrame()
return pd.DataFrame(annual_data).set_index('Year')
# ============================================================================
# MIXED TYPE HANDLING
# ============================================================================
def create_year_sort_column(df, year_col='Year'):
"""
Create a numeric sort column for mixed int/str year columns
Args:
df: DataFrame
year_col: Name of year column
Returns:
Series with numeric sort values
"""
from config import LTM_END_YEAR
def sort_value(x):
if isinstance(x, str) and str(LTM_END_YEAR) in str(x):
return float(LTM_END_YEAR) + 0.5
elif isinstance(x, (int, float)):
return float(x)
else:
return 9999
return df[year_col].apply(sort_value)
def sort_mixed_years(df, year_col='Year'):
"""
Sort DataFrame by year column that may contain mixed int/str types
Args:
df: DataFrame
year_col: Name of year column
Returns:
Sorted DataFrame
"""
df = df.copy()
df['_Year_Sort'] = create_year_sort_column(df, year_col)
df = df.sort_values('_Year_Sort').drop(columns=['_Year_Sort'])
return df
def safe_year_labels(years):
"""
Convert year values to safe string labels for chart axes
Args:
years: Iterable of year values (int or str)
Returns:
List of string labels
"""
return [str(year) for year in years]
# ============================================================================
# CHART HELPERS
# ============================================================================
def setup_revenue_chart(ax, ylabel: str = 'Revenue (Millions USD)') -> None:
"""
Setup a chart axis for revenue display (millions)
CRITICAL: Always use this function for revenue charts. It applies
the millions formatter and standard styling.
IMPORTANT: Data must be divided by 1e6 BEFORE plotting:
ax.plot(revenue / 1e6, ...) # ✅ Correct
ax.plot(revenue, ...) # ❌ Wrong - will show scientific notation
Args:
ax: Matplotlib axis object to configure
ylabel: Y-axis label (default: 'Revenue (Millions USD)')
Returns:
None: Modifies ax in place
Example:
>>> import matplotlib.pyplot as plt
>>> from analysis_utils import setup_revenue_chart
>>> fig, ax = plt.subplots()
>>> ax.plot(revenue_data / 1e6, marker='o') # Divide by 1e6 first!
>>> setup_revenue_chart(ax)
>>> plt.show()
See Also:
- .cursor/rules/chart_formatting.md for detailed patterns
- save_chart() for saving charts
"""
ax.yaxis.set_major_formatter(get_millions_formatter())
ax.set_ylabel(ylabel)
ax.grid(True, alpha=0.3)
def save_chart(fig, filename, output_dir=None):
"""
Save chart to file with organized directory structure
Args:
fig: Matplotlib figure object
filename: Output filename (e.g., 'revenue_trend.png')
output_dir: Output directory (defaults to config.OUTPUT_DIR)
"""
if output_dir is None:
output_dir = OUTPUT_DIR
else:
output_dir = Path(output_dir)
output_dir.mkdir(exist_ok=True)
filepath = output_dir / filename
fig.savefig(filepath, dpi=CHART_DPI, bbox_inches=CHART_BBOX, format='png')
print(f"Chart saved: {filepath}")
# ============================================================================
# DATA VALIDATION
# ============================================================================
def validate_dataframe(df, required_columns=None):
"""
Validate DataFrame has required columns and basic data quality
Args:
df: DataFrame to validate
required_columns: List of required column names (defaults to config)
Returns:
tuple: (is_valid bool, error_message str)
"""
if required_columns is None:
required_columns = [REVENUE_COLUMN, 'Year']
if 'YearMonth' in df.columns:
required_columns.append('YearMonth')
missing_cols = [col for col in required_columns if col not in df.columns]
if missing_cols:
return False, f"Missing required columns: {missing_cols}"
if len(df) == 0:
return False, "DataFrame is empty"
if REVENUE_COLUMN in df.columns:
if df[REVENUE_COLUMN].isna().all():
return False, f"All {REVENUE_COLUMN} values are NaN"
return True, "OK"
# ============================================================================
# PRICE CALCULATION
# ============================================================================
def calculate_price_per_unit(df, quantity_col=None, revenue_col=None):
"""
Calculate average price per unit, excluding invalid quantities
Args:
df: DataFrame with quantity and revenue columns
quantity_col: Name of quantity column (defaults to config)
revenue_col: Name of revenue column (defaults to config)
Returns:
float: Average price per unit
"""
from config import QUANTITY_COLUMN, REVENUE_COLUMN, MIN_QUANTITY, MAX_QUANTITY
if quantity_col is None:
quantity_col = QUANTITY_COLUMN
if revenue_col is None:
revenue_col = REVENUE_COLUMN
# Check if quantity column exists
if quantity_col not in df.columns:
return np.nan
# Filter for valid quantity transactions
df_valid = df[(df[quantity_col] > MIN_QUANTITY) & (df[quantity_col] <= MAX_QUANTITY)].copy()
if len(df_valid) == 0:
return np.nan
total_revenue = df_valid[revenue_col].sum()
total_quantity = df_valid[quantity_col].sum()
if total_quantity == 0:
return np.nan
return total_revenue / total_quantity
# ============================================================================
# OUTPUT FORMATTING
# ============================================================================
def format_currency(value: float, millions: bool = True) -> str:
"""
Format currency value for console output
Args:
value: Numeric value to format
millions: If True, format as millions ($X.Xm), else thousands ($X.Xk)
Returns:
str: Formatted string like "$99.9m" or "$99.9k" or "N/A" if NaN
Example:
>>> format_currency(1000000)
'$1.00m'
>>> format_currency(1000, millions=False)
'$1.00k'
"""
if pd.isna(value):
return "N/A"
if millions:
return f"${value / 1e6:.2f}m"
else:
return f"${value / 1e3:.2f}k"
def print_annual_summary(annual_df, metric_col='Revenue', label='Revenue'):
"""
Print formatted annual summary to console
Args:
annual_df: DataFrame with annual metrics (indexed by Year)
metric_col: Column name to print
label: Label for the metric
"""
print(f"\n{label} by Year:")
print("-" * 40)
for year in annual_df.index:
value = annual_df.loc[year, metric_col]
formatted = format_currency(value)
print(f" {year}: {formatted}")
print()
# ============================================================================
# DATA FILTERING HELPERS
# ============================================================================
def apply_exclusion_filters(df):
"""
Apply exclusion filters from config
Args:
df: DataFrame to filter
Returns:
Filtered DataFrame
"""
from config import EXCLUSION_FILTERS
if not EXCLUSION_FILTERS.get('enabled', False):
return df
exclude_col = EXCLUSION_FILTERS.get('exclude_by_column')
exclude_values = EXCLUSION_FILTERS.get('exclude_values', [])
if exclude_col and exclude_col in df.columns and exclude_values:
original_count = len(df)
df_filtered = df[~df[exclude_col].isin(exclude_values)]
excluded_count = original_count - len(df_filtered)
if excluded_count > 0:
print(f"Excluded {excluded_count:,} rows based on {exclude_col} filter")
return df_filtered
return df
# ============================================================================
# INTERACTIVE VISUALIZATIONS (OPTIONAL - PLOTLY)
# ============================================================================
def create_interactive_chart(data, chart_type='line', title=None, xlabel=None, ylabel=None):
"""
Create interactive chart using Plotly (optional dependency)
Args:
data: DataFrame or dict with chart data
chart_type: Type of chart ('line', 'bar', 'scatter')
title: Chart title
xlabel: X-axis label
ylabel: Y-axis label
Returns:
plotly.graph_objects.Figure: Plotly figure object
Raises:
ImportError: If plotly is not installed
Example:
fig = create_interactive_chart(
{'x': [1, 2, 3], 'y': [10, 20, 30]},
chart_type='line',
title='Revenue Trend'
)
fig.show()
"""
try:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
except ImportError:
raise ImportError(
"plotly is required for interactive charts. Install with: pip install plotly"
)
fig = go.Figure()
if chart_type == 'line':
if isinstance(data, dict) and 'x' in data and 'y' in data:
fig.add_trace(go.Scatter(
x=data['x'],
y=data['y'],
mode='lines+markers',
name='Data'
))
elif chart_type == 'bar':
if isinstance(data, dict) and 'x' in data and 'y' in data:
fig.add_trace(go.Bar(
x=data['x'],
y=data['y'],
name='Data'
))
if title:
fig.update_layout(title=title)
if xlabel:
fig.update_xaxes(title_text=xlabel)
if ylabel:
fig.update_yaxes(title_text=ylabel)
fig.update_layout(
template='plotly_white',
hovermode='x unified'
)
return fig
def save_interactive_chart(fig, filename, output_dir=None):
"""
Save interactive Plotly chart to HTML file
Args:
fig: Plotly figure object
filename: Output filename (e.g., 'chart.html')
output_dir: Output directory (defaults to config.OUTPUT_DIR)
"""
if output_dir is None:
output_dir = OUTPUT_DIR
else:
output_dir = Path(output_dir)
output_dir.mkdir(exist_ok=True)
filepath = output_dir / filename
fig.write_html(str(filepath))
print(f"Interactive chart saved: {filepath}")
return filepath