#!/usr/bin/env python3
# Copyright (c) 2025 Edulilo Pty Ltd
#
# This software is provided "as is", without warranty of any kind, either express or implied,
# including but not limited to the warranties of merchantability and fitness for a particular purpose.
# In no event shall the author be liable for any claim, damages or other liability,
# whether in an action of contract, tort or otherwise, arising from, out of or in connection
# with the software or the use or other dealings in the software.
#
# Requires the IAM permissions defined in spend_lens_policy.json
# to run successfully. Attach that policy to the user or role running this script.

import json
from collections import defaultdict
from pathlib import Path
from datetime import datetime, timedelta
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from typing import List

import boto3

DATA_DIR = Path("aws-cost-data")
EC2_DIR = DATA_DIR / "ec2"
RDS_DIR = DATA_DIR / "rds"
EBS_DIR = DATA_DIR / "ebs"

# look back window
END_TIME = datetime.utcnow()
START_TIME = END_TIME - timedelta(days=30)
# AWS ISO format
TP = {
    "Start": START_TIME.strftime("%Y-%m-%dT%H:%M:%SZ"),
    "End":   END_TIME.strftime("%Y-%m-%dT%H:%M:%SZ"),
}
# For Cost Explorer (uses date-only)
time_period_for_ce = {
    "Start": START_TIME.strftime("%Y-%m-%d"),
    "End":   END_TIME.strftime("%Y-%m-%d"),
}


@dataclass
class ChartItem:
    """Represents a single item in a bar chart."""
    label: str
    amount: float


def generate_divided_bar_with_legend(
    items: List[ChartItem],
    width: int = 600,
    margin: int = 20,
    unit_prefix: str = '$',
    unit_suffix: str = ''
) -> str:
    bar_h = 20
    legend_fs = 12
    square = 10
    line_sp = 4
    colors = ["steelblue", "orange", "green", "purple", "red", "pink"]

    total = sum(item.amount for item in items)
    chart_w = width - 2 * margin

    legend_line_h = max(square, legend_fs) + line_sp
    legend_h = len(items) * legend_line_h
    height = margin*2 + bar_h + margin/2 + legend_h

    parts = [
        f'<svg xmlns="http://www.w3.org/2000/svg"',
        f' viewBox="0 0 {width} {height}"',
        ' preserveAspectRatio="xMinYMin meet"',
        ' style="width:100%; height:auto; font-family:\'IBM Plex Sans\', sans-serif">',
    ]
    # Generate the bar segments
    x = margin
    for i, item in enumerate(items):
        seg = (item.amount / total) * chart_w if total > 0 else 0
        parts.append(f'<rect x="{x:.1f}" y="{margin:.1f}" width="{seg:.1f}" height="{bar_h}" fill="{colors[i % len(colors)]}"/>')
        x += seg

    # Generate the legend
    ys = margin + bar_h + (margin / 2)
    for i, item in enumerate(items):
        y = ys + i * legend_line_h
        
        # Format the amount based on the suffix
        if unit_suffix == " GB":
            amount_str = f"{item.amount:.3f}"
        else:
            amount_str = f"{item.amount:.2f}"
            
        label = f"{item.label}: {unit_prefix}{amount_str}{unit_suffix}"
        
        parts.append(f'<rect x="{margin}" y="{y}" width="{square}" height="{square}" fill="{colors[i % len(colors)]}"/>')
        parts.append(f'<text x="{margin + square + 5}" y="{y + square}" font-size="{legend_fs}">{label}</text>')

    parts.append("</svg>")
    return "\n".join(parts)

def max_cpu_for(resource, region, id_key, cw_namespace):
    """Return the peak CPU utilisation for a resource.

    The function queries CloudWatch to find the maximum CPU utilisation
    recorded in any 1‑hour period during the lookback window defined at the
    top of this script.  This value represents the highest sustained load on
    the instance and can be used to identify resources that may be
    over‑provisioned.
    """
    cw = boto3.client("cloudwatch", region_name=region)
    resp = cw.get_metric_statistics(
        Namespace=cw_namespace,
        MetricName="CPUUtilization",
        Dimensions=[{"Name": id_key, "Value": resource}],
        StartTime=START_TIME,
        EndTime=END_TIME,
        Period=3600,
        Statistics=["Maximum"],
    )
    points = [dp["Maximum"] for dp in resp.get("Datapoints", [])]
    return max(points) if points else 0.0

def get_instance_name(tags):
    for t in tags or []:
        if t.get("Key") == "Name":
            return t.get("Value", "(no name)")
    return "(no name)"

if __name__ == "__main__":
    EC2_DIR.mkdir(parents=True, exist_ok=True)
    RDS_DIR.mkdir(parents=True, exist_ok=True)
    EBS_DIR.mkdir(parents=True, exist_ok=True)

    # 1) Savings Plans
    sp = boto3.client("savingsplans")
    sp_resp = sp.describe_savings_plans()
    has_sp = bool(sp_resp.get("SavingsPlans"))

    ce = boto3.client("ce")

    def cached_ce(key: str, **params) -> dict:
        cache_file = DATA_DIR / f"ce_{key}.json"
        if cache_file.exists():
            mtime_utc = datetime.utcfromtimestamp(cache_file.stat().st_mtime)
            age = datetime.utcnow() - mtime_utc
            if age < timedelta(days=1):
                age_str = str(age).split('.')[0]
                print(f"🗄️ Loading cached CE data '{key}' (age {age_str})")
                return json.loads(cache_file.read_text())
        print(f"🔍 Fetching fresh CE data '{key}'")
        resp = ce.get_cost_and_usage(**params)
        cache_file.write_text(json.dumps(resp, default=str))
        return resp

    def get_grouped_costs(
        key: str,
        group_by: list,
    ) -> dict:
        """
        Returns { dimension_value_string: total_cost }.
        """
        resp = cached_ce(key,
            TimePeriod=time_period_for_ce,
            Granularity="MONTHLY",
            Metrics=["UnblendedCost"],
            GroupBy=group_by,
        )

        totals = defaultdict(float)
        for period in resp["ResultsByTime"]:
            for g in period["Groups"]:
                # join multi-key dims with '/'
                name = "/".join(g["Keys"])
                totals[name] += float(g["Metrics"]["UnblendedCost"]["Amount"])
        return totals

    region_costs = get_grouped_costs(
        key="region",
        group_by=[{"Type":"DIMENSION","Key":"REGION"}],
    )
    region_data = sorted(
        [
            ChartItem(label=region, amount=cost)
            for region, cost in region_costs.items() if cost > 0.01
        ],
        key=lambda x: x.amount,
        reverse=True
    )
    region_costs_svg = generate_divided_bar_with_legend(region_data)
    regions_with_cost_over_one_dollar = [r for r, total in region_costs.items()
            if r not in ("global", "NoRegion") and total > 1.0]

    service_costs = get_grouped_costs(
        key="service",
        group_by=[{"Type":"DIMENSION","Key":"SERVICE"}],
    )
    services = sorted(
        [ChartItem(label=service, amount=amount) for service, amount in service_costs.items()],
        key=lambda x: x.amount, reverse=True
    )

    usage_costs = get_grouped_costs(
        key="usage",
        group_by=[
            {"Type":"DIMENSION","Key":"SERVICE"},
            {"Type":"DIMENSION","Key":"USAGE_TYPE"},
        ],
    )
    usages = sorted(
        [ChartItem(label=usage, amount=cost) for usage, cost in usage_costs.items()],
        key=lambda x: x.amount, reverse=True
    )

    ec2_list, rds_list, ebs_list = [], [], []
    with ThreadPoolExecutor(max_workers=10) as pool:
        future_to_ec2 = {}
        future_to_rds = {}

        for region in regions_with_cost_over_one_dollar:
            regional_ec2 = boto3.client("ec2", region_name=region)
            paginator = regional_ec2.get_paginator("describe_instances")
            pages = paginator.paginate(PaginationConfig={"PageSize": 100})

            all_reservations = []
            for page in pages:
                all_reservations.extend(page.get("Reservations", []))

            ec2_dump = {"Reservations": all_reservations}
            Path(EC2_DIR / f"{region}.json").write_text(json.dumps(ec2_dump, default=str))

            # fetch EBS volumes
            vol_pages = regional_ec2.get_paginator("describe_volumes").paginate(PaginationConfig={"PageSize": 100})
            all_volumes = []
            for vpage in vol_pages:
                all_volumes.extend(vpage.get("Volumes", []))
            ebs_dump = {"Volumes": all_volumes}
            Path(EBS_DIR / f"{region}.json").write_text(json.dumps(ebs_dump, default=str))

            for reservation in all_reservations:
                for instance in reservation.get("Instances", []):
                    future_to_ec2[
                        pool.submit(max_cpu_for, instance["InstanceId"], region, "InstanceId", "AWS/EC2")
                    ] = (region, instance)

        for region in regions_with_cost_over_one_dollar:
            regional_rds = boto3.client("rds", region_name=region)
            databases = regional_rds.describe_db_instances().get("DBInstances", [])

            rds_dump = {"DBInstances": databases}
            Path(RDS_DIR / f"{region}.json").write_text(json.dumps(rds_dump, default=str))

            for database in databases:
                dataBaseId = database["DBInstanceIdentifier"]
                future_to_rds[
                    pool.submit(max_cpu_for, dataBaseId, region, "DBInstanceIdentifier", "AWS/RDS")
                ] = (region, database)

        ec2_cpu = {}
        for fut in as_completed(future_to_ec2):
            region, instance = future_to_ec2[fut]
            ec2_cpu[(region, instance["InstanceId"])] = fut.result()

        rds_cpu = {}
        for fut in as_completed(future_to_rds):
            region, database = future_to_rds[fut]
            rds_cpu[(region, database["DBInstanceIdentifier"])] = fut.result()

    # build EC2 list
    for jsonf in EC2_DIR.glob("*.json"):
        region = jsonf.stem
        data = json.loads(jsonf.read_text())
        for reservation in data.get("Reservations", []):
            for instance in reservation.get("Instances", []):
                instance_id = instance["InstanceId"]
                ec2_list.append({
                    "Name": get_instance_name(instance.get("Tags")),
                    "Region": region,
                    "InstanceId": instance_id,
                    "InstanceType": instance.get("InstanceType"),
                    "State": instance.get("State", {}).get("Name"),
                    "MaxCPU": ec2_cpu.get((region, instance_id), 0.0),
                })

    # build RDS list
    for jsonf in RDS_DIR.glob("*.json"):
        region = jsonf.stem
        data = json.loads(jsonf.read_text())
        for database in data.get("DBInstances", []):
            dataBaseId = database["DBInstanceIdentifier"]
            rds_list.append({
                "Identifier": dataBaseId,
                "Region": region,
                "DBInstanceClass": database.get("DBInstanceClass"),
                "Status": database.get("DBInstanceStatus"),
                "MaxCPU": rds_cpu.get((region, dataBaseId), 0.0),
            })

    # build EBS volume list
    for jsonf in EBS_DIR.glob("*.json"):
        region = jsonf.stem
        data = json.loads(jsonf.read_text())
        for volume in data.get("Volumes", []):
            attachments = volume.get("Attachments", [])
            attached_to = attachments[0]["InstanceId"] if attachments else None
            ebs_list.append({
                "VolumeId": volume["VolumeId"],
                "Region": region,
                "Size": volume.get("Size"),
                "Type": volume.get("VolumeType"),
                "State": volume.get("State"),
                "AttachedTo": attached_to,
            })

    ebs_volume_count = len(ebs_list)
    unattached_ebs_count = sum(1 for v in ebs_list if v["AttachedTo"] is None)
    total_ebs_size = sum(v.get("Size", 0) for v in ebs_list)


    s3 = boto3.client("s3")
    buckets = s3.list_buckets().get("Buckets", [])
    s3_bucket_count = len(buckets)

    cw = boto3.client("cloudwatch")
    s3_data = []
    for bucket in buckets:
        name = bucket["Name"]
        # get latest StandardStorage size (bytes)
        resp_size = cw.get_metric_statistics(
            Namespace="AWS/S3",
            MetricName="BucketSizeBytes",
            Dimensions=[
                {"Name": "BucketName", "Value": name},
                {"Name": "StorageType", "Value": "StandardStorage"},
            ],
            StartTime=START_TIME - timedelta(days=1),
            EndTime=END_TIME,
            Period=86400,
            Statistics=["Average"],
        )
        size_bytes = max((dp["Average"] for dp in resp_size.get("Datapoints", [])), default=0)
        size_gb = size_bytes / (1024 ** 3)

        s3_data.append({
            "Bucket": name,
            "SizeGB": size_gb,
        })

    total_storage = sum(d["SizeGB"] for d in s3_data)

    min_size_gb = 0.01

    # keep only buckets ≥ min_size_gb
    large = [d for d in s3_data if d["SizeGB"] >= min_size_gb]
    # sort desc
    large.sort(key=lambda d: d["SizeGB"], reverse=True)

    # pick top 5
    top = large[:5]

    # sum *all* buckets not in top (small ones + any beyond 5)
    top_names = {b["Bucket"] for b in top}
    other_storage = sum(
        d["SizeGB"] for d in s3_data
        if d["Bucket"] not in top_names
    )

    # assemble chart entries
    chart_data = [
        ChartItem(label=b["Bucket"], amount=b["SizeGB"])
        for b in top
    ]
    if other_storage > 0:
        chart_data.append(ChartItem(label="All other buckets", amount=other_storage))

    # render PNG with GB units
    bucket_size_svg = generate_divided_bar_with_legend(
        chart_data,
        unit_prefix='',
        unit_suffix=' GB'
    )
    # Report generation

    other_services = sum(s.amount for s in services[3:])
    services = services[:3] + ([ChartItem(label="Other Services", amount=other_services)] if other_services > 0 else [])
    top_services_svg = generate_divided_bar_with_legend(services)

    other_usage_types = sum(u.amount for u in usages[5:])
    usage_types = usages[:5]  + ([ChartItem(label="Other Usage Types", amount=other_usage_types)] if other_usage_types > 0 else [])
    top_usage_types_svg = generate_divided_bar_with_legend(usage_types)

    styles = """
    body {
    margin: 0;
    font-family: 'IBM Plex Sans', sans-serif;
    background-color: #F5F7FA;
    color: #1C2D4A;
    line-height: 1.5;
    }

    header {
    margin-top: 0;
    background-color: white;
    box-shadow: 0 1px 3px rgba(0,0,0,0.1);
    }

    h1 {
    margin-top: 0;
    }

    main {
    max-width: 900px;
    margin: 0 auto;
    }

    h1 {
    font-size: 2.5rem;
    margin-bottom: 0.5rem;
    text-align: center;
    }
    h2 {
    font-size: 1.75rem;
    margin: 2rem 0 1rem;
    padding-bottom: 0.5rem;
    }

    p {
    font-size: 1.125rem;
    margin-bottom: 1.5rem;
    }


    svg {
    display: block;
    background-color: white;
    border-radius: 0.5rem;
    box-shadow: 0 1px 3px rgba(0,0,0,0.1);
    }

    ul {
    list-style: none;
    padding: 0;
    margin: 0 0 1.5rem;
    }
    ul li {
    background-color: white;
    padding: 1rem;
    border-radius: 0.5rem;
    margin-bottom: 0.75rem;
    box-shadow: 0 1px 3px rgba(0,0,0,0.1);
    }

    table {
    width: 100%;
    border-collapse: collapse;
    margin: 0 0 1.5rem;
    background-color: white;
    border-radius: 0.5rem;
    box-shadow: 0 1px 3px rgba(0,0,0,0.1);
    overflow: visible;
    }
    th, td {
    padding: 0.5rem 0.75rem;
    text-align: left;
    border-bottom: 1px solid #ddd;
    }
    th {
    background-color: #f9fafb;
    font-weight: 600;
    }
    .warning {
    background-color: #fff3cd;
    }
    .tooltip {
    position: relative;
    }
    .tooltip:hover::after {
    content: attr(data-tip);
    position: absolute;
    left: 50%;
    transform: translateX(-50%);
    bottom: 125%;
    background-color: #333;
    color: #fff;
    padding: 0.25rem 0.5rem;
    border-radius: 0.25rem;
    white-space: nowrap;
    font-size: 0.75rem;
    }
    .info-icon {
    font-style: normal;
    margin-left: 0.25rem;
    cursor: help;
    }
    """

    html = [
        "<!DOCTYPE html>",
        "<html lang=\"en\">",
        "<head>",
        "  <meta charset=\"utf-8\">",
        "  <title>AWS Cost Report</title>",
        "  <link href='https://fonts.googleapis.com/css2?family=IBM+Plex+Sans:wght@400;600&display=swap' rel='stylesheet'>"
        f" <style>{styles}</style>",
        "</head>",
        "<body>",
        "  <header><h1>AWS Cost Report (Last 30 Days)</h1></header>",
        "  <main>"
        f"  <p><strong>Savings Plans:</strong> {'✅ Yes' if has_sp else '❌ No'}</p>",
        "  <h2>Top 3 Services by Cost</h2>",
        f"  {top_services_svg}",
        "  <h2>Top 5 Usage Types by Cost</h2>",
        f"  {top_usage_types_svg}",
        "  <h2>Cost by Region</h2>",
        f"  {region_costs_svg}",
        "  <h2>S3 Summary</h2>",
        f"  <p>Buckets: {s3_bucket_count}, Total storage: {total_storage:.2f} GB</p>",
        "  <h3>Top 5 Buckets by Storage</h3>",
        f"  {bucket_size_svg}"

        "  <h2>EC2 Instances</h2>",
    ]
    if ec2_list:
        html.append("  <table>")
        html.append(
            "    <tr><th>Name</th><th>Instance ID</th><th>Type</th><th>Region</th><th>State</th>"
            "<th>Peak CPU (1h) <span class='info-icon tooltip'"
            " data-tip='Maximum CPU utilisation recorded in any one-hour period during the last 30 days'>\u24D8</span></th></tr>"
        )
        for i in sorted(ec2_list, key=lambda x: (x["Region"], x["Name"])):
            cpu = i['MaxCPU']
            tip = (
                "Below 50% – potential for right-sizing"
                if cpu < 50
                else "Peak CPU utilisation over the last 30 days"
            )
            classes = "warning tooltip" if cpu < 50 else "tooltip"
            cpu_cell = (
                f"<td class='{classes}' data-tip='{tip}'>{cpu:.1f}% <span class='info-icon'>\u24D8</span></td>"
            )
            html.append(
                f"    <tr><td>{i['Name']}</td><td>{i['InstanceId']}</td><td>{i['InstanceType']}</td><td>{i['Region']}</td><td>{i['State']}</td>{cpu_cell}</tr>"
            )
        html.append("  </table>")
    else:
        html.append("  <p>No EC2 instances found.</p>")
    html.extend([
        "  <h2>RDS Instances</h2>",
        "  <ul>"
    ])
    if rds_list:
        for database in sorted(rds_list, key=lambda x: (x["Region"], x["Identifier"])):
            cpu = database['MaxCPU']
            tip = (
                "Below 50% – potential for right-sizing"
                if cpu < 50
                else "Peak CPU utilisation over the last 30 days"
            )
            classes = "warning tooltip" if cpu < 50 else "tooltip"
            cpu_str = (
                f"<span class='{classes}' data-tip='{tip}'>{cpu:.1f}% <span class='info-icon'>\u24D8</span></span>"
            )
            html.append(
                f"    <li><strong>{database['Identifier']}</strong> "
                f"(Class: {database['DBInstanceClass']}, Region: {database['Region']}, "
                f"Status: {database['Status']}, Peak CPU (1h): {cpu_str})</li>"
            )
    else:
        html.append("    <li>No RDS instances found.</li>")
    html.extend([
        "  </ul>",
        "  <h2>EBS Volumes</h2>",
        f"  <p>Volumes: {ebs_volume_count}, Unattached: {unattached_ebs_count}, Total size: {total_ebs_size} GB</p>"
    ])
    if ebs_list:
        html.append("  <table>")
        html.append("    <tr><th>Volume ID</th><th>Size (GB)</th><th>Type</th><th>Region</th><th>State</th><th>Attached To</th></tr>")
        for vol in sorted(ebs_list, key=lambda x: (x['Region'], x['VolumeId'])):
            if vol['AttachedTo']:
                attached = vol['AttachedTo']
                attached_cell = f"<td>{attached}</td>"
            else:
                attached = 'unattached'
                tooltip = 'This volume is not attached to any instance and may be a candidate for deletion if no longer needed.'
                attached_cell = f"<td class='warning tooltip' data-tip='{tooltip}'>{attached} <span class='info-icon'>ⓘ</span></td>"
            
            html.append(
                f"    <tr>"
                f"<td>{vol['VolumeId']}</td>"
                f"<td>{vol['Size']}</td>"
                f"<td>{vol['Type']}</td>"
                f"<td>{vol['Region']}</td>"
                f"<td>{vol['State']}</td>"
                f"{attached_cell}"
                "</tr>"
            )
        html.append("  </table>")
    else:
        html.append("  <p>No EBS volumes found.</p>")
    html.extend([
        "  </main>",
        "</body>",
        "</html>"
    ])

    Path("report.html").write_text("\n".join(html))
    print(f"✅ HTML report (report.html) generated")
