Initial commit: sales analysis template
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
510
analysis_utils.py
Normal file
510
analysis_utils.py
Normal file
@@ -0,0 +1,510 @@
|
||||
"""
|
||||
Common utilities for analysis scripts
|
||||
Provides formatters, LTM setup, and helper functions
|
||||
|
||||
This module is designed to work with any sales data structure
|
||||
by using configuration from config.py
|
||||
"""
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from matplotlib.ticker import FuncFormatter
|
||||
from pathlib import Path
|
||||
from config import (
|
||||
REVENUE_COLUMN, LTM_ENABLED, get_ltm_period, get_ltm_label,
|
||||
OUTPUT_DIR, CHART_DPI, CHART_BBOX
|
||||
)
|
||||
|
||||
# ============================================================================
|
||||
# CHART FORMATTERS
|
||||
# ============================================================================
|
||||
|
||||
def millions_formatter(x: float, pos: int) -> str:
|
||||
"""
|
||||
Format numbers in millions for chart display (e.g., $99.9m)
|
||||
|
||||
This formatter is used with matplotlib FuncFormatter to display
|
||||
revenue values in millions on chart axes.
|
||||
|
||||
Args:
|
||||
x: Numeric value (already in millions, e.g., 99.9 for $99.9m)
|
||||
pos: Position parameter (required by FuncFormatter, not used)
|
||||
|
||||
Returns:
|
||||
str: Formatted string like "$99.9m"
|
||||
|
||||
Example:
|
||||
>>> from matplotlib.ticker import FuncFormatter
|
||||
>>> formatter = FuncFormatter(millions_formatter)
|
||||
>>> ax.yaxis.set_major_formatter(formatter)
|
||||
"""
|
||||
return f'${x:.1f}m'
|
||||
|
||||
def thousands_formatter(x: float, pos: int) -> str:
|
||||
"""
|
||||
Format numbers in thousands for chart display (e.g., $99.9k)
|
||||
|
||||
Args:
|
||||
x: Numeric value (already in thousands)
|
||||
pos: Position parameter (required by FuncFormatter, not used)
|
||||
|
||||
Returns:
|
||||
str: Formatted string like "$99.9k"
|
||||
"""
|
||||
return f'${x:.1f}k'
|
||||
|
||||
def get_millions_formatter() -> FuncFormatter:
|
||||
"""
|
||||
Get FuncFormatter for millions
|
||||
|
||||
Returns:
|
||||
FuncFormatter: Configured formatter for millions display
|
||||
"""
|
||||
return FuncFormatter(millions_formatter)
|
||||
|
||||
def get_thousands_formatter() -> FuncFormatter:
|
||||
"""
|
||||
Get FuncFormatter for thousands
|
||||
|
||||
Returns:
|
||||
FuncFormatter: Configured formatter for thousands display
|
||||
"""
|
||||
return FuncFormatter(thousands_formatter)
|
||||
|
||||
# ============================================================================
|
||||
# LTM (Last Twelve Months) SETUP
|
||||
# ============================================================================
|
||||
|
||||
def get_ltm_period_config():
|
||||
"""
|
||||
Get LTM period boundaries from config
|
||||
|
||||
Returns:
|
||||
tuple: (ltm_start, ltm_end) as pd.Period objects, or (None, None) if disabled
|
||||
"""
|
||||
if LTM_ENABLED:
|
||||
return get_ltm_period()
|
||||
return None, None
|
||||
|
||||
def get_annual_data(df, year, ltm_start=None, ltm_end=None):
|
||||
"""
|
||||
Get data for a specific year, using LTM for the most recent partial year
|
||||
|
||||
Args:
|
||||
df: DataFrame with 'Year' and 'YearMonth' columns
|
||||
year: Year to extract (int)
|
||||
ltm_start: LTM start period (defaults to config if None)
|
||||
ltm_end: LTM end period (defaults to config if None)
|
||||
|
||||
Returns:
|
||||
tuple: (year_data DataFrame, year_label string)
|
||||
"""
|
||||
from config import LTM_END_YEAR
|
||||
|
||||
# Get LTM period from config if not provided
|
||||
if ltm_start is None or ltm_end is None:
|
||||
ltm_start, ltm_end = get_ltm_period_config()
|
||||
|
||||
# Use LTM for the most recent year if enabled
|
||||
if LTM_ENABLED and ltm_start and ltm_end and year == LTM_END_YEAR:
|
||||
if 'YearMonth' in df.columns:
|
||||
year_data = df[(df['YearMonth'] >= ltm_start) & (df['YearMonth'] <= ltm_end)]
|
||||
year_label = get_ltm_label() or str(year)
|
||||
else:
|
||||
# Fallback if YearMonth not available
|
||||
year_data = df[df['Year'] == year]
|
||||
year_label = str(year)
|
||||
else:
|
||||
# Use full calendar year
|
||||
year_data = df[df['Year'] == year]
|
||||
year_label = str(year)
|
||||
|
||||
return year_data, year_label
|
||||
|
||||
def calculate_annual_metrics(df, metrics_func, ltm_start=None, ltm_end=None):
|
||||
"""
|
||||
Calculate annual metrics for all years, using LTM for most recent year
|
||||
|
||||
Args:
|
||||
df: DataFrame with 'Year' and 'YearMonth' columns
|
||||
metrics_func: Function that takes a DataFrame and returns a dict of metrics
|
||||
ltm_start: LTM start period (defaults to config if None)
|
||||
ltm_end: LTM end period (defaults to config if None)
|
||||
|
||||
Returns:
|
||||
DataFrame with 'Year' index and metric columns
|
||||
"""
|
||||
from config import ANALYSIS_YEARS
|
||||
|
||||
if ltm_start is None or ltm_end is None:
|
||||
ltm_start, ltm_end = get_ltm_period_config()
|
||||
|
||||
annual_data = []
|
||||
for year in sorted(ANALYSIS_YEARS):
|
||||
if year in df['Year'].unique():
|
||||
year_data, year_label = get_annual_data(df, year, ltm_start, ltm_end)
|
||||
|
||||
if len(year_data) > 0:
|
||||
metrics = metrics_func(year_data)
|
||||
metrics['Year'] = year_label
|
||||
annual_data.append(metrics)
|
||||
|
||||
if not annual_data:
|
||||
return pd.DataFrame()
|
||||
|
||||
return pd.DataFrame(annual_data).set_index('Year')
|
||||
|
||||
# ============================================================================
|
||||
# MIXED TYPE HANDLING
|
||||
# ============================================================================
|
||||
|
||||
def create_year_sort_column(df, year_col='Year'):
|
||||
"""
|
||||
Create a numeric sort column for mixed int/str year columns
|
||||
|
||||
Args:
|
||||
df: DataFrame
|
||||
year_col: Name of year column
|
||||
|
||||
Returns:
|
||||
Series with numeric sort values
|
||||
"""
|
||||
from config import LTM_END_YEAR
|
||||
|
||||
def sort_value(x):
|
||||
if isinstance(x, str) and str(LTM_END_YEAR) in str(x):
|
||||
return float(LTM_END_YEAR) + 0.5
|
||||
elif isinstance(x, (int, float)):
|
||||
return float(x)
|
||||
else:
|
||||
return 9999
|
||||
|
||||
return df[year_col].apply(sort_value)
|
||||
|
||||
def sort_mixed_years(df, year_col='Year'):
|
||||
"""
|
||||
Sort DataFrame by year column that may contain mixed int/str types
|
||||
|
||||
Args:
|
||||
df: DataFrame
|
||||
year_col: Name of year column
|
||||
|
||||
Returns:
|
||||
Sorted DataFrame
|
||||
"""
|
||||
df = df.copy()
|
||||
df['_Year_Sort'] = create_year_sort_column(df, year_col)
|
||||
df = df.sort_values('_Year_Sort').drop(columns=['_Year_Sort'])
|
||||
return df
|
||||
|
||||
def safe_year_labels(years):
|
||||
"""
|
||||
Convert year values to safe string labels for chart axes
|
||||
|
||||
Args:
|
||||
years: Iterable of year values (int or str)
|
||||
|
||||
Returns:
|
||||
List of string labels
|
||||
"""
|
||||
return [str(year) for year in years]
|
||||
|
||||
# ============================================================================
|
||||
# CHART HELPERS
|
||||
# ============================================================================
|
||||
|
||||
def setup_revenue_chart(ax, ylabel: str = 'Revenue (Millions USD)') -> None:
|
||||
"""
|
||||
Setup a chart axis for revenue display (millions)
|
||||
|
||||
CRITICAL: Always use this function for revenue charts. It applies
|
||||
the millions formatter and standard styling.
|
||||
|
||||
IMPORTANT: Data must be divided by 1e6 BEFORE plotting:
|
||||
ax.plot(revenue / 1e6, ...) # ✅ Correct
|
||||
ax.plot(revenue, ...) # ❌ Wrong - will show scientific notation
|
||||
|
||||
Args:
|
||||
ax: Matplotlib axis object to configure
|
||||
ylabel: Y-axis label (default: 'Revenue (Millions USD)')
|
||||
|
||||
Returns:
|
||||
None: Modifies ax in place
|
||||
|
||||
Example:
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> from analysis_utils import setup_revenue_chart
|
||||
>>> fig, ax = plt.subplots()
|
||||
>>> ax.plot(revenue_data / 1e6, marker='o') # Divide by 1e6 first!
|
||||
>>> setup_revenue_chart(ax)
|
||||
>>> plt.show()
|
||||
|
||||
See Also:
|
||||
- .cursor/rules/chart_formatting.md for detailed patterns
|
||||
- save_chart() for saving charts
|
||||
"""
|
||||
ax.yaxis.set_major_formatter(get_millions_formatter())
|
||||
ax.set_ylabel(ylabel)
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
def save_chart(fig, filename, output_dir=None):
|
||||
"""
|
||||
Save chart to file with organized directory structure
|
||||
|
||||
Args:
|
||||
fig: Matplotlib figure object
|
||||
filename: Output filename (e.g., 'revenue_trend.png')
|
||||
output_dir: Output directory (defaults to config.OUTPUT_DIR)
|
||||
"""
|
||||
if output_dir is None:
|
||||
output_dir = OUTPUT_DIR
|
||||
else:
|
||||
output_dir = Path(output_dir)
|
||||
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
filepath = output_dir / filename
|
||||
fig.savefig(filepath, dpi=CHART_DPI, bbox_inches=CHART_BBOX, format='png')
|
||||
print(f"Chart saved: {filepath}")
|
||||
|
||||
# ============================================================================
|
||||
# DATA VALIDATION
|
||||
# ============================================================================
|
||||
|
||||
def validate_dataframe(df, required_columns=None):
|
||||
"""
|
||||
Validate DataFrame has required columns and basic data quality
|
||||
|
||||
Args:
|
||||
df: DataFrame to validate
|
||||
required_columns: List of required column names (defaults to config)
|
||||
|
||||
Returns:
|
||||
tuple: (is_valid bool, error_message str)
|
||||
"""
|
||||
if required_columns is None:
|
||||
required_columns = [REVENUE_COLUMN, 'Year']
|
||||
if 'YearMonth' in df.columns:
|
||||
required_columns.append('YearMonth')
|
||||
|
||||
missing_cols = [col for col in required_columns if col not in df.columns]
|
||||
if missing_cols:
|
||||
return False, f"Missing required columns: {missing_cols}"
|
||||
|
||||
if len(df) == 0:
|
||||
return False, "DataFrame is empty"
|
||||
|
||||
if REVENUE_COLUMN in df.columns:
|
||||
if df[REVENUE_COLUMN].isna().all():
|
||||
return False, f"All {REVENUE_COLUMN} values are NaN"
|
||||
|
||||
return True, "OK"
|
||||
|
||||
# ============================================================================
|
||||
# PRICE CALCULATION
|
||||
# ============================================================================
|
||||
|
||||
def calculate_price_per_unit(df, quantity_col=None, revenue_col=None):
|
||||
"""
|
||||
Calculate average price per unit, excluding invalid quantities
|
||||
|
||||
Args:
|
||||
df: DataFrame with quantity and revenue columns
|
||||
quantity_col: Name of quantity column (defaults to config)
|
||||
revenue_col: Name of revenue column (defaults to config)
|
||||
|
||||
Returns:
|
||||
float: Average price per unit
|
||||
"""
|
||||
from config import QUANTITY_COLUMN, REVENUE_COLUMN, MIN_QUANTITY, MAX_QUANTITY
|
||||
|
||||
if quantity_col is None:
|
||||
quantity_col = QUANTITY_COLUMN
|
||||
if revenue_col is None:
|
||||
revenue_col = REVENUE_COLUMN
|
||||
|
||||
# Check if quantity column exists
|
||||
if quantity_col not in df.columns:
|
||||
return np.nan
|
||||
|
||||
# Filter for valid quantity transactions
|
||||
df_valid = df[(df[quantity_col] > MIN_QUANTITY) & (df[quantity_col] <= MAX_QUANTITY)].copy()
|
||||
|
||||
if len(df_valid) == 0:
|
||||
return np.nan
|
||||
|
||||
total_revenue = df_valid[revenue_col].sum()
|
||||
total_quantity = df_valid[quantity_col].sum()
|
||||
|
||||
if total_quantity == 0:
|
||||
return np.nan
|
||||
|
||||
return total_revenue / total_quantity
|
||||
|
||||
# ============================================================================
|
||||
# OUTPUT FORMATTING
|
||||
# ============================================================================
|
||||
|
||||
def format_currency(value: float, millions: bool = True) -> str:
|
||||
"""
|
||||
Format currency value for console output
|
||||
|
||||
Args:
|
||||
value: Numeric value to format
|
||||
millions: If True, format as millions ($X.Xm), else thousands ($X.Xk)
|
||||
|
||||
Returns:
|
||||
str: Formatted string like "$99.9m" or "$99.9k" or "N/A" if NaN
|
||||
|
||||
Example:
|
||||
>>> format_currency(1000000)
|
||||
'$1.00m'
|
||||
>>> format_currency(1000, millions=False)
|
||||
'$1.00k'
|
||||
"""
|
||||
if pd.isna(value):
|
||||
return "N/A"
|
||||
|
||||
if millions:
|
||||
return f"${value / 1e6:.2f}m"
|
||||
else:
|
||||
return f"${value / 1e3:.2f}k"
|
||||
|
||||
def print_annual_summary(annual_df, metric_col='Revenue', label='Revenue'):
|
||||
"""
|
||||
Print formatted annual summary to console
|
||||
|
||||
Args:
|
||||
annual_df: DataFrame with annual metrics (indexed by Year)
|
||||
metric_col: Column name to print
|
||||
label: Label for the metric
|
||||
"""
|
||||
print(f"\n{label} by Year:")
|
||||
print("-" * 40)
|
||||
for year in annual_df.index:
|
||||
value = annual_df.loc[year, metric_col]
|
||||
formatted = format_currency(value)
|
||||
print(f" {year}: {formatted}")
|
||||
print()
|
||||
|
||||
# ============================================================================
|
||||
# DATA FILTERING HELPERS
|
||||
# ============================================================================
|
||||
|
||||
def apply_exclusion_filters(df):
|
||||
"""
|
||||
Apply exclusion filters from config
|
||||
|
||||
Args:
|
||||
df: DataFrame to filter
|
||||
|
||||
Returns:
|
||||
Filtered DataFrame
|
||||
"""
|
||||
from config import EXCLUSION_FILTERS
|
||||
|
||||
if not EXCLUSION_FILTERS.get('enabled', False):
|
||||
return df
|
||||
|
||||
exclude_col = EXCLUSION_FILTERS.get('exclude_by_column')
|
||||
exclude_values = EXCLUSION_FILTERS.get('exclude_values', [])
|
||||
|
||||
if exclude_col and exclude_col in df.columns and exclude_values:
|
||||
original_count = len(df)
|
||||
df_filtered = df[~df[exclude_col].isin(exclude_values)]
|
||||
excluded_count = original_count - len(df_filtered)
|
||||
if excluded_count > 0:
|
||||
print(f"Excluded {excluded_count:,} rows based on {exclude_col} filter")
|
||||
return df_filtered
|
||||
|
||||
return df
|
||||
|
||||
# ============================================================================
|
||||
# INTERACTIVE VISUALIZATIONS (OPTIONAL - PLOTLY)
|
||||
# ============================================================================
|
||||
|
||||
def create_interactive_chart(data, chart_type='line', title=None, xlabel=None, ylabel=None):
|
||||
"""
|
||||
Create interactive chart using Plotly (optional dependency)
|
||||
|
||||
Args:
|
||||
data: DataFrame or dict with chart data
|
||||
chart_type: Type of chart ('line', 'bar', 'scatter')
|
||||
title: Chart title
|
||||
xlabel: X-axis label
|
||||
ylabel: Y-axis label
|
||||
|
||||
Returns:
|
||||
plotly.graph_objects.Figure: Plotly figure object
|
||||
|
||||
Raises:
|
||||
ImportError: If plotly is not installed
|
||||
|
||||
Example:
|
||||
fig = create_interactive_chart(
|
||||
{'x': [1, 2, 3], 'y': [10, 20, 30]},
|
||||
chart_type='line',
|
||||
title='Revenue Trend'
|
||||
)
|
||||
fig.show()
|
||||
"""
|
||||
try:
|
||||
import plotly.graph_objects as go
|
||||
from plotly.subplots import make_subplots
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"plotly is required for interactive charts. Install with: pip install plotly"
|
||||
)
|
||||
|
||||
fig = go.Figure()
|
||||
|
||||
if chart_type == 'line':
|
||||
if isinstance(data, dict) and 'x' in data and 'y' in data:
|
||||
fig.add_trace(go.Scatter(
|
||||
x=data['x'],
|
||||
y=data['y'],
|
||||
mode='lines+markers',
|
||||
name='Data'
|
||||
))
|
||||
elif chart_type == 'bar':
|
||||
if isinstance(data, dict) and 'x' in data and 'y' in data:
|
||||
fig.add_trace(go.Bar(
|
||||
x=data['x'],
|
||||
y=data['y'],
|
||||
name='Data'
|
||||
))
|
||||
|
||||
if title:
|
||||
fig.update_layout(title=title)
|
||||
if xlabel:
|
||||
fig.update_xaxes(title_text=xlabel)
|
||||
if ylabel:
|
||||
fig.update_yaxes(title_text=ylabel)
|
||||
|
||||
fig.update_layout(
|
||||
template='plotly_white',
|
||||
hovermode='x unified'
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
def save_interactive_chart(fig, filename, output_dir=None):
|
||||
"""
|
||||
Save interactive Plotly chart to HTML file
|
||||
|
||||
Args:
|
||||
fig: Plotly figure object
|
||||
filename: Output filename (e.g., 'chart.html')
|
||||
output_dir: Output directory (defaults to config.OUTPUT_DIR)
|
||||
"""
|
||||
if output_dir is None:
|
||||
output_dir = OUTPUT_DIR
|
||||
else:
|
||||
output_dir = Path(output_dir)
|
||||
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
filepath = output_dir / filename
|
||||
|
||||
fig.write_html(str(filepath))
|
||||
print(f"Interactive chart saved: {filepath}")
|
||||
|
||||
return filepath
|
||||
Reference in New Issue
Block a user