86 lines
2.7 KiB
Python
86 lines
2.7 KiB
Python
"""
|
|
Unit tests for analysis_utils.py
|
|
"""
|
|
import pytest
|
|
import pandas as pd
|
|
import numpy as np
|
|
from pathlib import Path
|
|
import sys
|
|
|
|
# Add parent directory to path
|
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
|
|
from analysis_utils import (
|
|
millions_formatter, thousands_formatter,
|
|
get_millions_formatter, get_thousands_formatter,
|
|
format_currency, calculate_price_per_unit,
|
|
sort_mixed_years, safe_year_labels
|
|
)
|
|
|
|
class TestFormatters:
|
|
"""Test formatting functions"""
|
|
|
|
def test_millions_formatter(self):
|
|
"""Test millions formatter"""
|
|
assert millions_formatter(10.5, None) == '$10.5m'
|
|
assert millions_formatter(0, None) == '$0.0m'
|
|
assert millions_formatter(100.0, None) == '$100.0m'
|
|
|
|
def test_thousands_formatter(self):
|
|
"""Test thousands formatter"""
|
|
assert thousands_formatter(10.5, None) == '$10.5k'
|
|
assert thousands_formatter(0, None) == '$0.0k'
|
|
|
|
def test_format_currency(self):
|
|
"""Test currency formatting"""
|
|
assert format_currency(1000000) == '$1.00m'
|
|
assert format_currency(1000, millions=False) == '$1.00k'
|
|
assert format_currency(np.nan) == 'N/A'
|
|
|
|
class TestPriceCalculation:
|
|
"""Test price calculation functions"""
|
|
|
|
def test_calculate_price_per_unit(self):
|
|
"""Test price per unit calculation"""
|
|
df = pd.DataFrame({
|
|
'Quantity': [10, 20, 30],
|
|
'Revenue': [100, 200, 300]
|
|
})
|
|
|
|
price = calculate_price_per_unit(df, 'Quantity', 'Revenue')
|
|
assert price == 10.0 # (100+200+300) / (10+20+30)
|
|
|
|
def test_calculate_price_per_unit_with_outliers(self):
|
|
"""Test price calculation excludes outliers"""
|
|
df = pd.DataFrame({
|
|
'Quantity': [10, 20, 30, 2000], # 2000 is outlier
|
|
'Revenue': [100, 200, 300, 10000]
|
|
})
|
|
|
|
# Should exclude quantity > 1000 by default
|
|
price = calculate_price_per_unit(df, 'Quantity', 'Revenue')
|
|
assert price == 10.0 # Only first 3 rows
|
|
|
|
class TestYearHandling:
|
|
"""Test year handling functions"""
|
|
|
|
def test_sort_mixed_years(self):
|
|
"""Test sorting mixed int/str years"""
|
|
df = pd.DataFrame({
|
|
'Year': [2023, '2025 (LTM)', 2024, 2022],
|
|
'Value': [100, 150, 120, 90]
|
|
})
|
|
|
|
sorted_df = sort_mixed_years(df, 'Year')
|
|
assert sorted_df['Year'].iloc[0] == 2022
|
|
assert sorted_df['Year'].iloc[-1] == '2025 (LTM)'
|
|
|
|
def test_safe_year_labels(self):
|
|
"""Test year label conversion"""
|
|
years = [2021, 2022, '2025 (LTM)']
|
|
labels = safe_year_labels(years)
|
|
assert labels == ['2021', '2022', '2025 (LTM)']
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, '-v'])
|