511 lines
15 KiB
Python
511 lines
15 KiB
Python
"""
|
|
Common utilities for analysis scripts
|
|
Provides formatters, LTM setup, and helper functions
|
|
|
|
This module is designed to work with any sales data structure
|
|
by using configuration from config.py
|
|
"""
|
|
import pandas as pd
|
|
import numpy as np
|
|
from matplotlib.ticker import FuncFormatter
|
|
from pathlib import Path
|
|
from config import (
|
|
REVENUE_COLUMN, LTM_ENABLED, get_ltm_period, get_ltm_label,
|
|
OUTPUT_DIR, CHART_DPI, CHART_BBOX
|
|
)
|
|
|
|
# ============================================================================
|
|
# CHART FORMATTERS
|
|
# ============================================================================
|
|
|
|
def millions_formatter(x: float, pos: int) -> str:
|
|
"""
|
|
Format numbers in millions for chart display (e.g., $99.9m)
|
|
|
|
This formatter is used with matplotlib FuncFormatter to display
|
|
revenue values in millions on chart axes.
|
|
|
|
Args:
|
|
x: Numeric value (already in millions, e.g., 99.9 for $99.9m)
|
|
pos: Position parameter (required by FuncFormatter, not used)
|
|
|
|
Returns:
|
|
str: Formatted string like "$99.9m"
|
|
|
|
Example:
|
|
>>> from matplotlib.ticker import FuncFormatter
|
|
>>> formatter = FuncFormatter(millions_formatter)
|
|
>>> ax.yaxis.set_major_formatter(formatter)
|
|
"""
|
|
return f'${x:.1f}m'
|
|
|
|
def thousands_formatter(x: float, pos: int) -> str:
|
|
"""
|
|
Format numbers in thousands for chart display (e.g., $99.9k)
|
|
|
|
Args:
|
|
x: Numeric value (already in thousands)
|
|
pos: Position parameter (required by FuncFormatter, not used)
|
|
|
|
Returns:
|
|
str: Formatted string like "$99.9k"
|
|
"""
|
|
return f'${x:.1f}k'
|
|
|
|
def get_millions_formatter() -> FuncFormatter:
|
|
"""
|
|
Get FuncFormatter for millions
|
|
|
|
Returns:
|
|
FuncFormatter: Configured formatter for millions display
|
|
"""
|
|
return FuncFormatter(millions_formatter)
|
|
|
|
def get_thousands_formatter() -> FuncFormatter:
|
|
"""
|
|
Get FuncFormatter for thousands
|
|
|
|
Returns:
|
|
FuncFormatter: Configured formatter for thousands display
|
|
"""
|
|
return FuncFormatter(thousands_formatter)
|
|
|
|
# ============================================================================
|
|
# LTM (Last Twelve Months) SETUP
|
|
# ============================================================================
|
|
|
|
def get_ltm_period_config():
|
|
"""
|
|
Get LTM period boundaries from config
|
|
|
|
Returns:
|
|
tuple: (ltm_start, ltm_end) as pd.Period objects, or (None, None) if disabled
|
|
"""
|
|
if LTM_ENABLED:
|
|
return get_ltm_period()
|
|
return None, None
|
|
|
|
def get_annual_data(df, year, ltm_start=None, ltm_end=None):
|
|
"""
|
|
Get data for a specific year, using LTM for the most recent partial year
|
|
|
|
Args:
|
|
df: DataFrame with 'Year' and 'YearMonth' columns
|
|
year: Year to extract (int)
|
|
ltm_start: LTM start period (defaults to config if None)
|
|
ltm_end: LTM end period (defaults to config if None)
|
|
|
|
Returns:
|
|
tuple: (year_data DataFrame, year_label string)
|
|
"""
|
|
from config import LTM_END_YEAR
|
|
|
|
# Get LTM period from config if not provided
|
|
if ltm_start is None or ltm_end is None:
|
|
ltm_start, ltm_end = get_ltm_period_config()
|
|
|
|
# Use LTM for the most recent year if enabled
|
|
if LTM_ENABLED and ltm_start and ltm_end and year == LTM_END_YEAR:
|
|
if 'YearMonth' in df.columns:
|
|
year_data = df[(df['YearMonth'] >= ltm_start) & (df['YearMonth'] <= ltm_end)]
|
|
year_label = get_ltm_label() or str(year)
|
|
else:
|
|
# Fallback if YearMonth not available
|
|
year_data = df[df['Year'] == year]
|
|
year_label = str(year)
|
|
else:
|
|
# Use full calendar year
|
|
year_data = df[df['Year'] == year]
|
|
year_label = str(year)
|
|
|
|
return year_data, year_label
|
|
|
|
def calculate_annual_metrics(df, metrics_func, ltm_start=None, ltm_end=None):
|
|
"""
|
|
Calculate annual metrics for all years, using LTM for most recent year
|
|
|
|
Args:
|
|
df: DataFrame with 'Year' and 'YearMonth' columns
|
|
metrics_func: Function that takes a DataFrame and returns a dict of metrics
|
|
ltm_start: LTM start period (defaults to config if None)
|
|
ltm_end: LTM end period (defaults to config if None)
|
|
|
|
Returns:
|
|
DataFrame with 'Year' index and metric columns
|
|
"""
|
|
from config import ANALYSIS_YEARS
|
|
|
|
if ltm_start is None or ltm_end is None:
|
|
ltm_start, ltm_end = get_ltm_period_config()
|
|
|
|
annual_data = []
|
|
for year in sorted(ANALYSIS_YEARS):
|
|
if year in df['Year'].unique():
|
|
year_data, year_label = get_annual_data(df, year, ltm_start, ltm_end)
|
|
|
|
if len(year_data) > 0:
|
|
metrics = metrics_func(year_data)
|
|
metrics['Year'] = year_label
|
|
annual_data.append(metrics)
|
|
|
|
if not annual_data:
|
|
return pd.DataFrame()
|
|
|
|
return pd.DataFrame(annual_data).set_index('Year')
|
|
|
|
# ============================================================================
|
|
# MIXED TYPE HANDLING
|
|
# ============================================================================
|
|
|
|
def create_year_sort_column(df, year_col='Year'):
|
|
"""
|
|
Create a numeric sort column for mixed int/str year columns
|
|
|
|
Args:
|
|
df: DataFrame
|
|
year_col: Name of year column
|
|
|
|
Returns:
|
|
Series with numeric sort values
|
|
"""
|
|
from config import LTM_END_YEAR
|
|
|
|
def sort_value(x):
|
|
if isinstance(x, str) and str(LTM_END_YEAR) in str(x):
|
|
return float(LTM_END_YEAR) + 0.5
|
|
elif isinstance(x, (int, float)):
|
|
return float(x)
|
|
else:
|
|
return 9999
|
|
|
|
return df[year_col].apply(sort_value)
|
|
|
|
def sort_mixed_years(df, year_col='Year'):
|
|
"""
|
|
Sort DataFrame by year column that may contain mixed int/str types
|
|
|
|
Args:
|
|
df: DataFrame
|
|
year_col: Name of year column
|
|
|
|
Returns:
|
|
Sorted DataFrame
|
|
"""
|
|
df = df.copy()
|
|
df['_Year_Sort'] = create_year_sort_column(df, year_col)
|
|
df = df.sort_values('_Year_Sort').drop(columns=['_Year_Sort'])
|
|
return df
|
|
|
|
def safe_year_labels(years):
|
|
"""
|
|
Convert year values to safe string labels for chart axes
|
|
|
|
Args:
|
|
years: Iterable of year values (int or str)
|
|
|
|
Returns:
|
|
List of string labels
|
|
"""
|
|
return [str(year) for year in years]
|
|
|
|
# ============================================================================
|
|
# CHART HELPERS
|
|
# ============================================================================
|
|
|
|
def setup_revenue_chart(ax, ylabel: str = 'Revenue (Millions USD)') -> None:
|
|
"""
|
|
Setup a chart axis for revenue display (millions)
|
|
|
|
CRITICAL: Always use this function for revenue charts. It applies
|
|
the millions formatter and standard styling.
|
|
|
|
IMPORTANT: Data must be divided by 1e6 BEFORE plotting:
|
|
ax.plot(revenue / 1e6, ...) # ✅ Correct
|
|
ax.plot(revenue, ...) # ❌ Wrong - will show scientific notation
|
|
|
|
Args:
|
|
ax: Matplotlib axis object to configure
|
|
ylabel: Y-axis label (default: 'Revenue (Millions USD)')
|
|
|
|
Returns:
|
|
None: Modifies ax in place
|
|
|
|
Example:
|
|
>>> import matplotlib.pyplot as plt
|
|
>>> from analysis_utils import setup_revenue_chart
|
|
>>> fig, ax = plt.subplots()
|
|
>>> ax.plot(revenue_data / 1e6, marker='o') # Divide by 1e6 first!
|
|
>>> setup_revenue_chart(ax)
|
|
>>> plt.show()
|
|
|
|
See Also:
|
|
- .cursor/rules/chart_formatting.md for detailed patterns
|
|
- save_chart() for saving charts
|
|
"""
|
|
ax.yaxis.set_major_formatter(get_millions_formatter())
|
|
ax.set_ylabel(ylabel)
|
|
ax.grid(True, alpha=0.3)
|
|
|
|
def save_chart(fig, filename, output_dir=None):
|
|
"""
|
|
Save chart to file with organized directory structure
|
|
|
|
Args:
|
|
fig: Matplotlib figure object
|
|
filename: Output filename (e.g., 'revenue_trend.png')
|
|
output_dir: Output directory (defaults to config.OUTPUT_DIR)
|
|
"""
|
|
if output_dir is None:
|
|
output_dir = OUTPUT_DIR
|
|
else:
|
|
output_dir = Path(output_dir)
|
|
|
|
output_dir.mkdir(exist_ok=True)
|
|
|
|
filepath = output_dir / filename
|
|
fig.savefig(filepath, dpi=CHART_DPI, bbox_inches=CHART_BBOX, format='png')
|
|
print(f"Chart saved: {filepath}")
|
|
|
|
# ============================================================================
|
|
# DATA VALIDATION
|
|
# ============================================================================
|
|
|
|
def validate_dataframe(df, required_columns=None):
|
|
"""
|
|
Validate DataFrame has required columns and basic data quality
|
|
|
|
Args:
|
|
df: DataFrame to validate
|
|
required_columns: List of required column names (defaults to config)
|
|
|
|
Returns:
|
|
tuple: (is_valid bool, error_message str)
|
|
"""
|
|
if required_columns is None:
|
|
required_columns = [REVENUE_COLUMN, 'Year']
|
|
if 'YearMonth' in df.columns:
|
|
required_columns.append('YearMonth')
|
|
|
|
missing_cols = [col for col in required_columns if col not in df.columns]
|
|
if missing_cols:
|
|
return False, f"Missing required columns: {missing_cols}"
|
|
|
|
if len(df) == 0:
|
|
return False, "DataFrame is empty"
|
|
|
|
if REVENUE_COLUMN in df.columns:
|
|
if df[REVENUE_COLUMN].isna().all():
|
|
return False, f"All {REVENUE_COLUMN} values are NaN"
|
|
|
|
return True, "OK"
|
|
|
|
# ============================================================================
|
|
# PRICE CALCULATION
|
|
# ============================================================================
|
|
|
|
def calculate_price_per_unit(df, quantity_col=None, revenue_col=None):
|
|
"""
|
|
Calculate average price per unit, excluding invalid quantities
|
|
|
|
Args:
|
|
df: DataFrame with quantity and revenue columns
|
|
quantity_col: Name of quantity column (defaults to config)
|
|
revenue_col: Name of revenue column (defaults to config)
|
|
|
|
Returns:
|
|
float: Average price per unit
|
|
"""
|
|
from config import QUANTITY_COLUMN, REVENUE_COLUMN, MIN_QUANTITY, MAX_QUANTITY
|
|
|
|
if quantity_col is None:
|
|
quantity_col = QUANTITY_COLUMN
|
|
if revenue_col is None:
|
|
revenue_col = REVENUE_COLUMN
|
|
|
|
# Check if quantity column exists
|
|
if quantity_col not in df.columns:
|
|
return np.nan
|
|
|
|
# Filter for valid quantity transactions
|
|
df_valid = df[(df[quantity_col] > MIN_QUANTITY) & (df[quantity_col] <= MAX_QUANTITY)].copy()
|
|
|
|
if len(df_valid) == 0:
|
|
return np.nan
|
|
|
|
total_revenue = df_valid[revenue_col].sum()
|
|
total_quantity = df_valid[quantity_col].sum()
|
|
|
|
if total_quantity == 0:
|
|
return np.nan
|
|
|
|
return total_revenue / total_quantity
|
|
|
|
# ============================================================================
|
|
# OUTPUT FORMATTING
|
|
# ============================================================================
|
|
|
|
def format_currency(value: float, millions: bool = True) -> str:
|
|
"""
|
|
Format currency value for console output
|
|
|
|
Args:
|
|
value: Numeric value to format
|
|
millions: If True, format as millions ($X.Xm), else thousands ($X.Xk)
|
|
|
|
Returns:
|
|
str: Formatted string like "$99.9m" or "$99.9k" or "N/A" if NaN
|
|
|
|
Example:
|
|
>>> format_currency(1000000)
|
|
'$1.00m'
|
|
>>> format_currency(1000, millions=False)
|
|
'$1.00k'
|
|
"""
|
|
if pd.isna(value):
|
|
return "N/A"
|
|
|
|
if millions:
|
|
return f"${value / 1e6:.2f}m"
|
|
else:
|
|
return f"${value / 1e3:.2f}k"
|
|
|
|
def print_annual_summary(annual_df, metric_col='Revenue', label='Revenue'):
|
|
"""
|
|
Print formatted annual summary to console
|
|
|
|
Args:
|
|
annual_df: DataFrame with annual metrics (indexed by Year)
|
|
metric_col: Column name to print
|
|
label: Label for the metric
|
|
"""
|
|
print(f"\n{label} by Year:")
|
|
print("-" * 40)
|
|
for year in annual_df.index:
|
|
value = annual_df.loc[year, metric_col]
|
|
formatted = format_currency(value)
|
|
print(f" {year}: {formatted}")
|
|
print()
|
|
|
|
# ============================================================================
|
|
# DATA FILTERING HELPERS
|
|
# ============================================================================
|
|
|
|
def apply_exclusion_filters(df):
|
|
"""
|
|
Apply exclusion filters from config
|
|
|
|
Args:
|
|
df: DataFrame to filter
|
|
|
|
Returns:
|
|
Filtered DataFrame
|
|
"""
|
|
from config import EXCLUSION_FILTERS
|
|
|
|
if not EXCLUSION_FILTERS.get('enabled', False):
|
|
return df
|
|
|
|
exclude_col = EXCLUSION_FILTERS.get('exclude_by_column')
|
|
exclude_values = EXCLUSION_FILTERS.get('exclude_values', [])
|
|
|
|
if exclude_col and exclude_col in df.columns and exclude_values:
|
|
original_count = len(df)
|
|
df_filtered = df[~df[exclude_col].isin(exclude_values)]
|
|
excluded_count = original_count - len(df_filtered)
|
|
if excluded_count > 0:
|
|
print(f"Excluded {excluded_count:,} rows based on {exclude_col} filter")
|
|
return df_filtered
|
|
|
|
return df
|
|
|
|
# ============================================================================
|
|
# INTERACTIVE VISUALIZATIONS (OPTIONAL - PLOTLY)
|
|
# ============================================================================
|
|
|
|
def create_interactive_chart(data, chart_type='line', title=None, xlabel=None, ylabel=None):
|
|
"""
|
|
Create interactive chart using Plotly (optional dependency)
|
|
|
|
Args:
|
|
data: DataFrame or dict with chart data
|
|
chart_type: Type of chart ('line', 'bar', 'scatter')
|
|
title: Chart title
|
|
xlabel: X-axis label
|
|
ylabel: Y-axis label
|
|
|
|
Returns:
|
|
plotly.graph_objects.Figure: Plotly figure object
|
|
|
|
Raises:
|
|
ImportError: If plotly is not installed
|
|
|
|
Example:
|
|
fig = create_interactive_chart(
|
|
{'x': [1, 2, 3], 'y': [10, 20, 30]},
|
|
chart_type='line',
|
|
title='Revenue Trend'
|
|
)
|
|
fig.show()
|
|
"""
|
|
try:
|
|
import plotly.graph_objects as go
|
|
from plotly.subplots import make_subplots
|
|
except ImportError:
|
|
raise ImportError(
|
|
"plotly is required for interactive charts. Install with: pip install plotly"
|
|
)
|
|
|
|
fig = go.Figure()
|
|
|
|
if chart_type == 'line':
|
|
if isinstance(data, dict) and 'x' in data and 'y' in data:
|
|
fig.add_trace(go.Scatter(
|
|
x=data['x'],
|
|
y=data['y'],
|
|
mode='lines+markers',
|
|
name='Data'
|
|
))
|
|
elif chart_type == 'bar':
|
|
if isinstance(data, dict) and 'x' in data and 'y' in data:
|
|
fig.add_trace(go.Bar(
|
|
x=data['x'],
|
|
y=data['y'],
|
|
name='Data'
|
|
))
|
|
|
|
if title:
|
|
fig.update_layout(title=title)
|
|
if xlabel:
|
|
fig.update_xaxes(title_text=xlabel)
|
|
if ylabel:
|
|
fig.update_yaxes(title_text=ylabel)
|
|
|
|
fig.update_layout(
|
|
template='plotly_white',
|
|
hovermode='x unified'
|
|
)
|
|
|
|
return fig
|
|
|
|
def save_interactive_chart(fig, filename, output_dir=None):
|
|
"""
|
|
Save interactive Plotly chart to HTML file
|
|
|
|
Args:
|
|
fig: Plotly figure object
|
|
filename: Output filename (e.g., 'chart.html')
|
|
output_dir: Output directory (defaults to config.OUTPUT_DIR)
|
|
"""
|
|
if output_dir is None:
|
|
output_dir = OUTPUT_DIR
|
|
else:
|
|
output_dir = Path(output_dir)
|
|
|
|
output_dir.mkdir(exist_ok=True)
|
|
filepath = output_dir / filename
|
|
|
|
fig.write_html(str(filepath))
|
|
print(f"Interactive chart saved: {filepath}")
|
|
|
|
return filepath
|