#!/usr/bin/env python3
# Copyright (c) 2025 Edulilo Pty Ltd
#
# This software is provided "as is", without warranty of any kind, either express or implied,
# including but not limited to the warranties of merchantability and fitness for a particular purpose.
# In no event shall the author be liable for any claim, damages or other liability,
# whether in an action of contract, tort or otherwise, arising from, out of or in connection
# with the software or the use or other dealings in the software.
#
# Requires the IAM permissions defined in cloud_snapshot_policy.json
# to run successfully. Attach that policy to the user or role running this script.

import json
from collections import defaultdict
from pathlib import Path
from datetime import datetime, timedelta, timezone
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from typing import List, Dict, Any, Optional

import boto3

session = boto3.session.Session()
DEFAULT_REGION = session.region_name or "us-east-1"

# Track start time for generating the report
start_time = datetime.now(timezone.utc)


@dataclass
class ChartItem:
    """Represents a single item in a bar chart."""
    label: str
    amount: float


def generate_divided_bar_with_legend(
    items: List[ChartItem],
    unit_prefix: str = '$',
    unit_suffix: str = ''
) -> str:
    width = 600
    bar_h = 30
    legend_font_size = 12
    legend_text_margin = 4
    margin = 16
    colors = ["steelblue", "orange", "green", "purple", "red", "pink"]

    if not items:
        return "<svg></svg>"

    height = margin * 2 + bar_h + len(items) * (legend_font_size + legend_text_margin)

    parts = [
        f'<svg xmlns="http://www.w3.org/2000/svg"',
        f' viewBox="0 0 {width} {height}"',
        ' preserveAspectRatio="xMinYMin meet"',
        ' style="width:100%; height:auto; font-family:\'IBM Plex Sans\', sans-serif">',
    ]

    # Generate the bar segments
    segment_start_x = margin
    total = sum(item.amount for item in items) or 1
    for i, item in enumerate(items):
        segment_width = (item.amount / total) * (width - 2 * margin)
        parts.append(f'<rect x="{segment_start_x:.1f}" y="{margin:.1f}" width="{segment_width:.1f}" height="{bar_h}" fill="{colors[i % len(colors)]}"/>')
        segment_start_x += segment_width

    # Generate the legend
    top_of_legend_y = 1.5 * margin + bar_h
    for i, item in enumerate(items):
        legend_start_y = top_of_legend_y + i * (legend_font_size + legend_text_margin)
        
        # Format the amount based on the suffix
        if unit_suffix == " GB":
            amount_str = f"{item.amount:.3f}"
        else:
            amount_str = f"{item.amount:.2f}"

        parts.append(f'<rect x="{margin}" y="{legend_start_y}" width="{legend_font_size - 2}" height="{legend_font_size -2}" fill="{colors[i % len(colors)]}"/>')
        parts.append(f'<text x="{margin + legend_font_size}" y="{legend_start_y + legend_font_size - 2}" font-size="{legend_font_size}">{f"{item.label}: {unit_prefix}{amount_str}{unit_suffix}"}</text>')

    parts.append("</svg>")
    return "\n".join(parts)

def get_iam_user_details():
    iam = boto3.client("iam", region_name=DEFAULT_REGION)
    users = []
    
    # Get all users
    paginator = iam.get_paginator('list_users')
    for page in paginator.paginate():
        for user in page['Users']:
            user_name = user['UserName']
            create_date = user['CreateDate'].strftime('%Y-%m-%d %H:%M:%S%z')
            
            # Check for console access
            try:
                login_profile = iam.get_login_profile(UserName=user_name)
                has_console_access = 'Yes'
                password_last_used = user.get('PasswordLastUsed', 'Never')
            except iam.exceptions.NoSuchEntityException:
                has_console_access = 'No'
                password_last_used = 'No console access'
            
            # Check for access keys
            access_keys = iam.list_access_keys(UserName=user_name)['AccessKeyMetadata']
            has_access_key = 'Yes' if access_keys else 'No'
            
            # Get last used date for access keys
            access_key_last_used = 'Never used'
            if has_access_key == 'Yes':
                for key in access_keys:
                    key_last_used = iam.get_access_key_last_used(AccessKeyId=key['AccessKeyId'])
                    if 'LastUsedDate' in key_last_used['AccessKeyLastUsed']:
                        last_used = key_last_used['AccessKeyLastUsed']['LastUsedDate'].strftime('%Y-%m-%d %H:%M:%S%z')
                        if access_key_last_used == 'Never used' or last_used > access_key_last_used:
                            access_key_last_used = last_used
            
            users.append({
                'UserName': user_name,
                'CreateDate': create_date,
                'HasConsoleAccess': has_console_access,
                'PasswordLastUsed': password_last_used,
                'HasAccessKey': has_access_key,
                'AccessKeyLastUsed': access_key_last_used
            })
    
    return users


def get_iam_group_details():
    iam = boto3.client("iam", region_name=DEFAULT_REGION)
    groups = []
    
    # Get all groups
    paginator = iam.get_paginator('list_groups')
    for page in paginator.paginate():
        for group in page['Groups']:
            group_name = group['GroupName']
            create_date = group['CreateDate'].strftime('%Y-%m-%d %H:%M:%S%z')
            
            # Get number of users in the group
            user_count = 0
            user_paginator = iam.get_paginator('get_group')
            for user_page in user_paginator.paginate(GroupName=group_name):
                user_count += len(user_page.get('Users', []))
            
            groups.append({
                'GroupName': group_name,
                'CreateDate': create_date,
                'UserCount': user_count
            })
    
    return groups


def get_iam_role_details():
    iam = boto3.client("iam", region_name=DEFAULT_REGION)
    roles = []
    service_role_count = 0

    paginator = iam.get_paginator('list_roles')
    for page in paginator.paginate():
        for role in page['Roles']:
            role_name = role['RoleName']
            create_date = role['CreateDate'].strftime('%Y-%m-%d %H:%M:%S%z')

            # Skip AWS service roles to reduce clutter
            if role_name.startswith('AWSServiceRoleFor') or role.get('Path', '').startswith('/aws-service-role/'):
                service_role_count += 1
                continue

            roles.append({
                'RoleName': role_name,
                'CreateDate': create_date,
            })

    return roles, service_role_count


def get_cloudwatch_alarms(regions: List[str]) -> List[Dict[str, str]]:
    """Retrieve CloudWatch alarms across multiple regions.

    Args:
        regions: List of AWS regions to query.

    Returns:
        List of dictionaries with alarm details.
    """
    alarms: List[Dict[str, str]] = []
    comparison_map = {
        "GreaterThanThreshold": ">",
        "GreaterThanOrEqualToThreshold": ">=",
        "LessThanThreshold": "<",
        "LessThanOrEqualToThreshold": "<=",
    }
    for region in regions:
        cw = boto3.client("cloudwatch", region_name=region)
        paginator = cw.get_paginator("describe_alarms")
        for page in paginator.paginate():
            for alarm in page.get("MetricAlarms", []):
                operator = comparison_map.get(
                    alarm.get("ComparisonOperator", ""),
                    alarm.get("ComparisonOperator", ""),
                )
                condition = (
                    f"{alarm.get('MetricName', '')} {operator} {alarm.get('Threshold')}"
                )
                alarms.append(
                    {
                        "Name": alarm.get("AlarmName", ""),
                        "State": alarm.get("StateValue", ""),
                        "Region": region,
                        "Condition": condition,
                    }
                )
    return alarms


def extract_snapshot_details(snapshot: Dict[str, Any], region: str) -> Dict[str, Any]:
    """Normalize snapshot information for reporting.

    Args:
        snapshot: Raw snapshot dictionary returned from AWS.
        region: AWS region where the snapshot resides.

    Returns:
        Dictionary containing the fields needed for the report.
    """
    tags = {tag["Key"]: tag["Value"] for tag in snapshot.get("Tags", [])}

    return {
        "Name": tags.get("Name", ""),
        "SnapshotId": snapshot.get("SnapshotId"),
        "FullSnapshotSize": snapshot.get("FullSnapshotSizeInBytes", 0) / (1024 ** 3),
        "VolumeSize": snapshot.get("VolumeSize"),
        "Description": snapshot.get("Description", ""),
        "StorageTier": snapshot.get("StorageTier", ""),
        "StartTime": snapshot.get("StartTime"),
        "Region": region,
    }

def format_days_since(dt, access_type=None) -> tuple:
    """Format a datetime to show days since that time and check if it's inactive.
    
    Args:
        dt: Either a datetime object or a string that can be:
            - 'Never'
            - 'Never used'
            - 'No console access'
            - 'No access key'
            - An ISO format datetime string
        access_type: Either 'console' or 'key' to specify the access type for warnings
        
    Returns:
        A tuple of (formatted_date, warning_message)
    """
    warning = None
    
    # Handle special string cases
    if isinstance(dt, str):
        if dt in ['Never used', 'Never'] and access_type:
            warning = f"{access_type.capitalize()} access is enabled but has never been used"
            return dt, warning
        elif dt in ['No console access', 'No access key', 'Never used', 'Never']:
            return dt, None
        try:
            dt = datetime.fromisoformat(dt.replace('Z', '+00:00'))
        except (ValueError, AttributeError):
            return str(dt), None
    
    # Handle datetime objects
    if isinstance(dt, datetime):
        if dt.tzinfo is None:
            dt = dt.replace(tzinfo=timezone.utc)
        delta = datetime.now(timezone.utc) - dt
        days = delta.days
        
        if days > 90 and access_type:
            warning = f"{access_type.capitalize()} access is enabled but hasn't been used in {days} days"
            
        return f"{days} days ago", warning
    
    # Fallback for any other type
    return str(dt), None


def max_cpu_for(resource, region, id_key, cw_namespace) -> int:
    """Return the peak CPU utilisation for a resource.

    The function queries CloudWatch to find the maximum CPU utilisation
    recorded in any 1‑hour period during the lookback window defined at the
    top of this script.  This value represents the highest sustained load on
    the instance and can be used to identify resources that may be
    over‑provisioned.
    """
    cw = boto3.client("cloudwatch", region_name=region)
    now = datetime.now(timezone.utc)
    resp = cw.get_metric_statistics(
        Namespace=cw_namespace,
        MetricName="CPUUtilization",
        Dimensions=[{"Name": id_key, "Value": resource}],
        StartTime=now - timedelta(days=30),
        EndTime=now,
        Period=3600,
        Statistics=["Maximum"],
    )
    points = [dp["Maximum"] for dp in resp.get("Datapoints", [])]
    return max(points) if points else 0.0


if __name__ == "__main__":
    Path("aws-cost-data/ec2").mkdir(parents=True, exist_ok=True)
    Path("aws-cost-data/rds").mkdir(parents=True, exist_ok=True)
    Path("aws-cost-data/ebs").mkdir(parents=True, exist_ok=True)

    ce = boto3.client("ce", region_name=DEFAULT_REGION)
    def get_ce_costs(
        key: str,
        group_by: list,
    ) -> list:
        cache_file = Path("aws-cost-data") / f"ce_{key}.json"
        ce_json = None
        if cache_file.exists():
            age = datetime.now(timezone.utc) - datetime.fromtimestamp(
                cache_file.stat().st_mtime, tz=timezone.utc
            )
            if age < timedelta(days=1):
                age_str = str(age).split('.')[0]
                print(f"🗄️ Loading cached CE data '{key}' (age {age_str})")
                ce_json = json.loads(cache_file.read_text())
        if not ce_json:
            print(f"🔍 Fetching fresh CE data '{key}'")
            try:
                ce_json = ce.get_cost_and_usage(
                    TimePeriod={
                        "Start": (datetime.now(timezone.utc) - timedelta(days=30)).strftime("%Y-%m-%d"),
                        "End": datetime.now(timezone.utc).strftime("%Y-%m-%d"),
                    },
                    Granularity="MONTHLY",
                    GroupBy=group_by,
                    Metrics=['UnblendedCost']
                )
                cache_file.write_text(json.dumps(ce_json, default=str))
            except Exception as e:
                print(f"⚠️ Error fetching CE data '{key}': {e}")
                return []

        combined_monthly_totals = defaultdict(float)
        for period in ce_json["ResultsByTime"]:
            for group in period["Groups"]:
                name = "/".join(group["Keys"])
                combined_monthly_totals[name] += float(group["Metrics"]["UnblendedCost"]["Amount"])

        return sorted(
            [
                ChartItem(label=dimension, amount=cost)
                for dimension, cost in combined_monthly_totals.items() if cost > 0.01
            ],
            key=lambda x: x.amount,
            reverse=True
        )

    region_costs = get_ce_costs(
        key="region",
        group_by=[{"Type":"DIMENSION","Key":"REGION"}],
    )
    region_costs_svg = generate_divided_bar_with_legend(region_costs)
    regions_with_cost_over_one_dollar = [region.label for region in region_costs if region.label not in ("global", "NoRegion") and region.amount > 1.0]

    service_costs = get_ce_costs(
        key="service",
        group_by=[{"Type":"DIMENSION","Key":"SERVICE"}],
    )
    usage_costs = get_ce_costs(
        key="usage",
        group_by=[
            {"Type":"DIMENSION","Key":"SERVICE"},
            {"Type":"DIMENSION","Key":"USAGE_TYPE"},
        ],
    )

    ec2_list, rds_list, ebs_volume_list, ebs_snapshot_list = [], [], [], []
    with ThreadPoolExecutor(max_workers=10) as pool:
        future_to_ec2 = {}
        future_to_rds = {}

        for region in regions_with_cost_over_one_dollar:
            regional_ec2 = boto3.client("ec2", region_name=region)
            paginator = regional_ec2.get_paginator("describe_instances")
            pages = paginator.paginate(PaginationConfig={"PageSize": 100})

            all_reservations = [res for page in pages for res in page.get("Reservations", [])]

            ec2_dump = {"Reservations": all_reservations}
            Path(Path("aws-cost-data/ec2") / f"{region}.json").write_text(json.dumps(ec2_dump, default=str))

            # fetch EBS volumes and snapshots
            vol_pages = regional_ec2.get_paginator("describe_volumes").paginate(PaginationConfig={"PageSize": 100})
            all_volumes = [vol for vpage in vol_pages for vol in vpage.get("Volumes", [])]

            snap_pages = regional_ec2.get_paginator("describe_snapshots").paginate(
                OwnerIds=["self"], PaginationConfig={"PageSize": 100}
            )
            all_snapshots = [snap for spage in snap_pages for snap in spage.get("Snapshots", [])]

            ebs_dump = {"Volumes": all_volumes, "Snapshots": all_snapshots}
            Path(Path("aws-cost-data/ebs") / f"{region}.json").write_text(json.dumps(ebs_dump, default=str))

            for reservation in all_reservations:
                for instance in reservation.get("Instances", []):
                    future_to_ec2[
                        pool.submit(max_cpu_for, instance["InstanceId"], region, "InstanceId", "AWS/EC2")
                    ] = (region, instance)

            regional_rds = boto3.client("rds", region_name=region)
            databases = regional_rds.describe_db_instances().get("DBInstances", [])

            rds_dump = {"DBInstances": databases}
            Path(Path("aws-cost-data/rds") / f"{region}.json").write_text(json.dumps(rds_dump, default=str))

            for database in databases:
                dataBaseId = database["DBInstanceIdentifier"]
                future_to_rds[
                    pool.submit(max_cpu_for, dataBaseId, region, "DBInstanceIdentifier", "AWS/RDS")
                ] = (region, database)

        ec2_cpu = {}
        for fut in as_completed(future_to_ec2):
            region, instance = future_to_ec2[fut]
            ec2_cpu[(region, instance["InstanceId"])] = fut.result()

        rds_cpu = {}
        for fut in as_completed(future_to_rds):
            region, database = future_to_rds[fut]
            rds_cpu[(region, database["DBInstanceIdentifier"])] = fut.result()

    # build EC2 list
    for jsonf in Path("aws-cost-data/ec2").glob("*.json"):
        region = jsonf.stem
        data = json.loads(jsonf.read_text())
        for reservation in data.get("Reservations", []):
            for instance in reservation.get("Instances", []):
                instance_id = instance["InstanceId"]
                ec2_list.append({
                    "Name": next((t.get("Value") for t in instance.get("Tags", []) if t.get("Key") == "Name"), ""),
                    "Region": region,
                    "InstanceId": instance_id,
                    "InstanceType": instance.get("InstanceType"),
                    "State": instance.get("State", {}).get("Name"),
                    "MaxCPU": ec2_cpu.get((region, instance_id), 0.0),
                })

    # build RDS list
    for jsonf in Path("aws-cost-data/rds").glob("*.json"):
        region = jsonf.stem
        data = json.loads(jsonf.read_text())
        for database in data.get("DBInstances", []):
            dataBaseId = database["DBInstanceIdentifier"]
            rds_list.append({
                "Identifier": dataBaseId,
                "Region": region,
                "DBInstanceClass": database.get("DBInstanceClass"),
                "Status": database.get("DBInstanceStatus"),
                "MaxCPU": rds_cpu.get((region, dataBaseId), 0.0),
            })

    # build EBS volume and snapshot lists
    for jsonf in Path("aws-cost-data/ebs").glob("*.json"):
        region = jsonf.stem
        data = json.loads(jsonf.read_text())
        for volume in data.get("Volumes", []):
            attachments = volume.get("Attachments", [])
            attached_to = attachments[0]["InstanceId"] if attachments else None
            ebs_volume_list.append({
                "VolumeId": volume["VolumeId"],
                "Region": region,
                "Size": volume.get("Size"),
                "Type": volume.get("VolumeType"),
                "State": volume.get("State"),
                "AttachedTo": attached_to,
            })

        for snapshot in data.get("Snapshots", []):
            ebs_snapshot_list.append(extract_snapshot_details(snapshot, region))

    s3 = boto3.client("s3", region_name=DEFAULT_REGION)
    buckets = s3.list_buckets().get("Buckets", [])

    s3_data = []
    cw_clients: Dict[str, Any] = {}
    for bucket in buckets:
        name = bucket["Name"]
        try:
            loc = s3.get_bucket_location(Bucket=name)
            bucket_region = loc.get("LocationConstraint") or "us-east-1"
        except Exception:
            bucket_region = DEFAULT_REGION
        cw = cw_clients.setdefault(
            bucket_region, boto3.client("cloudwatch", region_name=bucket_region)
        )
        resp_size = cw.get_metric_statistics(
            Namespace="AWS/S3",
            MetricName="BucketSizeBytes",
            Dimensions=[
                {"Name": "BucketName", "Value": name},
                {"Name": "StorageType", "Value": "StandardStorage"},
            ],
            StartTime=datetime.now(timezone.utc) - timedelta(days=31),
            EndTime=datetime.now(timezone.utc),
            Period=86400,
            Statistics=["Average"],
        )
        size_bytes = max((dp["Average"] for dp in resp_size.get("Datapoints", [])), default=0)
        size_gb = size_bytes / (1024 ** 3)

        s3_data.append({
            "Bucket": name,
            "SizeGB": size_gb,
        })

    cloudwatch_alarm_list = get_cloudwatch_alarms(regions_with_cost_over_one_dollar)
    total_storage = sum(d["SizeGB"] for d in s3_data)

    bucket_sizes = [bucket for bucket in s3_data if bucket["SizeGB"] >= 0.01]
    bucket_sizes.sort(key=lambda bucket: bucket["SizeGB"], reverse=True)
    five_largest_buckets = bucket_sizes[:5]

    # assemble chart entries
    bucket_sizes = [
        ChartItem(label=b["Bucket"], amount=b["SizeGB"])
        for b in five_largest_buckets
    ]

    top_names = {b["Bucket"] for b in five_largest_buckets}
    other_storage = sum(
        d["SizeGB"] for d in s3_data
        if d["Bucket"] not in top_names
    )
    if other_storage > 0:
        bucket_sizes.append(ChartItem(label="All other buckets", amount=other_storage))

    # Report generation
    other_services = sum(s.amount for s in service_costs[3:])
    service_costs = service_costs[:3] + ([ChartItem(label="Other Services", amount=other_services)] if other_services > 0 else [])

    other_usage_types = sum(u.amount for u in usage_costs[5:])
    usage_types = usage_costs[:5]  + ([ChartItem(label="Other Usage Types", amount=other_usage_types)] if other_usage_types > 0 else [])

    styles = """
    body {
    margin: 0;
    font-family: 'IBM Plex Sans', sans-serif;
    background-color: #F5F7FA;
    color: #1C2D4A;
    line-height: 1.5;
    }

    header {
    text-align: center;
    margin-top: 0;
    background-color: white;
    box-shadow: 0 1px 3px rgba(0,0,0,0.1);
    }

    h1 {
    margin-top: 0;
    }

    main {
    max-width: 900px;
    margin: 0 auto;
    }

    h1 {
    font-size: 2.5rem;
    margin-bottom: 0.5rem;
    }

    h2 {
    font-size: 1.75rem;
    margin: 2rem 0 1rem;
    padding-bottom: 0.5rem;
    }

    p {
    font-size: 1.125rem;
    margin-bottom: 1.5rem;
    }


    svg {
    display: block;
    background-color: white;
    border-radius: 0.5rem;
    box-shadow: 0 1px 3px rgba(0,0,0,0.1);
    }

    ul {
    list-style: none;
    padding: 0;
    margin: 0 0 1.5rem;
    }
    ul li {
    background-color: white;
    padding: 1rem;
    border-radius: 0.5rem;
    margin-bottom: 0.75rem;
    box-shadow: 0 1px 3px rgba(0,0,0,0.1);
    }

    table {
    width: 100%;
    border-collapse: collapse;
    margin: 0 0 1.5rem;
    background-color: white;
    border-radius: 0.5rem;
    box-shadow: 0 1px 3px rgba(0,0,0,0.1);
    overflow: visible;
    }
    th, td {
    padding: 0.5rem 0.75rem;
    text-align: left;
    border-bottom: 1px solid #ddd;
    }
    th {
    background-color: #f9fafb;
    font-weight: 600;
    }
    .warning {
    background-color: #fff3cd;
    }
    .tooltip {
    position: relative;
    }
    .tooltip:hover::after {
    content: attr(data-tip);
    position: absolute;
    left: 50%;
    transform: translateX(-50%);
    bottom: 125%;
    background-color: #333;
    color: #fff;
    padding: 0.25rem 0.5rem;
    border-radius: 0.25rem;
    white-space: nowrap;
    font-size: 0.75rem;
    }
    .info-icon {
    font-style: normal;
    margin-left: 0.25rem;
    cursor: help;
    }
    """

    try:
        sp_client = boto3.client("savingsplans", region_name=DEFAULT_REGION)
        has_savings_plans = bool(sp_client.describe_savings_plans().get("SavingsPlans"))
    except Exception:
        has_savings_plans = False

    html = [
        "<!DOCTYPE html>",
        "<html lang=\"en\">",
        "<head>",
        "  <meta charset=\"utf-8\">",
        "  <title>AWS Cost Report</title>",
        "  <link href='https://fonts.googleapis.com/css2?family=IBM+Plex+Sans:wght@400;600&display=swap' rel='stylesheet'>"
        f" <style>{styles}</style>",
        "</head>",
        "<body>",
        "  <header>",
        "    <h1>AWS Cost Report (Last 30 Days)</h1>",
        "  </header>",
        "  <main>"
        f"  <p><strong>Savings Plans:</strong> {'✅ Yes' if has_savings_plans else '❌ No'}</p>",
        "  <h2>Top 3 Services by Cost</h2>",
        f"  {generate_divided_bar_with_legend(service_costs)}",
        "  <h2>Top 5 Usage Types by Cost</h2>",
        f"  {generate_divided_bar_with_legend(usage_types)}",
        "  <h2>Cost by Region</h2>",
        f"  {generate_divided_bar_with_legend(region_costs)}",
        "  <h2>S3 Summary</h2>",
        f"  <p>Buckets: {len(buckets)}, Total storage: {total_storage:.2f} GB</p>",
        "  <h3>Top 5 Buckets by Storage</h3>",
        f"  {generate_divided_bar_with_legend(bucket_sizes, unit_prefix='', unit_suffix=' GB')}",
        "  <h2>EC2 Instances</h2>",
    ]
    if ec2_list:
        html.append("  <table>")
        html.append(
            "    <tr><th>Name</th><th>Instance ID</th><th>Type</th><th>Region</th><th>State</th>"
            "<th>Peak CPU (1h) <span class='info-icon tooltip'"
            " data-tip='Maximum CPU utilisation recorded in any one-hour period during the last 30 days'>\u24D8</span></th></tr>"
        )
        for instance in sorted(ec2_list, key=lambda x: (x["Region"], x["Name"])):
            tip = (
                "Below 50% – potential for right-sizing"
                if instance['MaxCPU'] < 50
                else "Peak CPU utilisation over the last 30 days"
            )
            classes = "warning" if instance['MaxCPU'] < 50 else ""
            cpu_cell = (
                f"<td class='{classes} tooltip' data-tip='{tip}'>{instance['MaxCPU']:.1f}% <span class='info-icon'>\u24D8</span></td>"
            )
            html.append(
                f"    <tr><td>{instance['Name']}</td><td>{instance['InstanceId']}</td><td>{instance['InstanceType']}</td><td>{instance['Region']}</td><td>{instance['State']}</td>{cpu_cell}</tr>"
            )
        html.append("  </table>")
    else:
        html.append("  <p>No EC2 instances found.</p>")
    html.extend([
        "  <h2>RDS Instances</h2>",
        "  <ul>"
    ])
    if rds_list:
        for database in sorted(rds_list, key=lambda x: (x["Region"], x["Identifier"])):
            cpu = database['MaxCPU']
            tip = (
                "Below 50% – potential for right-sizing"
                if cpu < 50
                else "Peak CPU utilisation over the last 30 days"
            )
            classes = "warning tooltip" if cpu < 50 else "tooltip"
            cpu_str = (
                f"<span class='{classes}' data-tip='{tip}'>{cpu:.1f}% <span class='info-icon'>\u24D8</span></span>"
            )
            html.append(
                f"    <li><strong>{database['Identifier']}</strong> "
                f"(Class: {database['DBInstanceClass']}, Region: {database['Region']}, "
                f"Status: {database['Status']}, Peak CPU (1h): {cpu_str})</li>"
            )
    else:
        html.append("    <li>No RDS instances found.</li>")
    html.extend([
        "  </ul>",
        "  <h2>EBS Volumes</h2>",
        f"  <p>Volumes: {len(ebs_volume_list)}, Unattached: {sum(1 for v in ebs_volume_list if v['AttachedTo'] is None)}, Total size: {sum(v.get('Size', 0) for v in ebs_volume_list)} GB</p>"])
    if ebs_volume_list:
        html.append("  <table>")
        html.append("    <tr><th>Volume ID</th><th>Size (GB)</th><th>Type</th><th>Region</th><th>State</th><th>Attached To</th></tr>")
        for vol in sorted(ebs_volume_list, key=lambda x: (x['Region'], x['VolumeId'])):
            if vol['AttachedTo']:
                attached = vol['AttachedTo']
                attached_cell = f"<td>{attached}</td>"
            else:
                attached = 'unattached'
                tooltip = 'This volume is not attached to any instance and may be a candidate for deletion if no longer needed.'
                attached_cell = f"<td class='warning tooltip' data-tip='{tooltip}'>{attached} <span class='info-icon'>ⓘ</span></td>"

            html.append(
                f"    <tr>"
                f"<td>{vol['VolumeId']}</td>"
                f"<td>{vol['Size']}</td>"
                f"<td>{vol['Type']}</td>"
                f"<td>{vol['Region']}</td>"
                f"<td>{vol['State']}</td>"
                f"{attached_cell}"
                "</tr>"
            )
        html.append("  </table>")
    else:
        html.append("  <p>No EBS volumes found.</p>")

    # EBS snapshot report
    html.append("  <h2>EBS Snapshots</h2>")
    html.append(
        f"  <p>Snapshots: {len(ebs_snapshot_list)}, Total size: {sum(s.get('FullSnapshotSize', 0) for s in ebs_snapshot_list):.2f} GB</p>"
    )
    if ebs_snapshot_list:
        html.append("  <table>")
        html.append(
            "    <tr><th>Name</th><th>Snapshot ID</th><th>Full Size (GB)</th><th>Volume Size (GB)</th><th>Description</th><th>Storage Tier</th><th>Started</th></tr>"
        )
        for snap in sorted(ebs_snapshot_list, key=lambda x: (x['Region'], x['SnapshotId'])):
            html.append(
                f"    <tr><td>{snap['Name']}</td><td>{snap['SnapshotId']}</td><td>{snap['FullSnapshotSize']:.2f}</td>"
                f"<td>{snap['VolumeSize']}</td><td>{snap['Description']}</td><td>{snap['StorageTier']}</td><td>{snap['StartTime']}</td></tr>"
            )
        html.append("  </table>")
    else:
        html.append("  <p>No EBS snapshots found.</p>")
    html.append("<h2>Budgets</h2>")
    sts = boto3.client("sts", region_name=DEFAULT_REGION)
    try:
        budgets_client = boto3.client("budgets", region_name=DEFAULT_REGION)
        budgets = budgets_client.describe_budgets(
            AccountId=sts.get_caller_identity()["Account"]
        ).get("Budgets", [])
    except Exception as e:
        html.append(f"  <p>Error retrieving budgets: {str(e)}</p>")
        budgets = None

    if budgets:
        html.append("""
        <table>
            <thead>
                <tr>
                    <th>Budget Name</th>
                    <th>Limit</th>
                    <th>Time Unit</th>
                    <th>Notifications</th>
                </tr>
            </thead>
            <tbody>
        """)
        
        for budget in budgets:
            # Get unique notification addresses with their subscription types
            notification_addresses = set()
            try:
                for notification in budgets_client.describe_notifications_for_budget(
                    AccountId=sts.get_caller_identity()['Account'], 
                    BudgetName=budget['BudgetName']
                ).get("Notifications", []):
                    subscribers = budgets_client.describe_subscribers_for_notification(
                        AccountId=sts.get_caller_identity()["Account"], 
                        BudgetName=budget["BudgetName"],
                        Notification=notification
                    )
                    threshold = notification['Threshold']
                    comparison = notification['ComparisonOperator']
                    if comparison == 'GREATER_THAN':
                        threshold_text = f"Over {threshold}%"
                    elif comparison == 'LESS_THAN':
                        threshold_text = f"Under {threshold}%"
                    else:
                        threshold_text = f"At {threshold}%"
                    for subscriber in subscribers.get("Subscribers", []):
                        address = subscriber.get('Address', '')
                        sub_type = subscriber.get('SubscriptionType', 'UNKNOWN')
                        notification_addresses.add(f"{threshold_text} {sub_type}: {address}")
            except Exception as e:
                notification_addresses = ["Error retrieving notifications"]
            
            time_unit = budget.get("TimeUnit", "").capitalize()
            
            html.append(f"""
                <tr>
                    <td>{budget["BudgetName"]}</td>
                    <td>{budget["BudgetLimit"]["Amount"]}{budget["BudgetLimit"]["Unit"]}</td>
                    <td>{time_unit}</td>
                    <td>{'<br>'.join(sorted(notification_addresses)) if notification_addresses else 'No notifications'}</td>
                </tr>
            """)
        html.append("""
            </tbody>
        </table>
        """)
    elif budgets is not None:
        html.append("  <p>No budgets found.</p>")

    # Add AWS Cost Optimization Status section
    html.append("<h2>AWS Optimization Status</h2>")
    html.append("<table>")
    html.append("<tr><th>Service</th><th>Status</th></tr>")
    
    try:
        compute_optimizer = boto3.client(
            "compute-optimizer", region_name=DEFAULT_REGION
        )
        status = compute_optimizer.get_enrollment_status().get("status", "Inactive")
        if status == "Active":
            html.append("""
        <tr>
            <td>AWS Compute Optimizer (Rightsizing)</td>
            <td class="success">✅ Enabled</td>
        </tr>
        """)
        else:
            html.append("""
        <tr>
            <td>AWS Compute Optimizer (Rightsizing)</td>
            <td class="warning">⚠️ Not enabled</td>
        </tr>
        """)
    except Exception:
        html.append("""
        <tr>
            <td>AWS Compute Optimizer (Rightsizing)</td>
            <td class="warning">⚠️ Not available</td>
        </tr>
        """)
    html.append("</table>")

    html.append("<h2>CloudWatch Alarms</h2>")
    if cloudwatch_alarm_list:
        html.append("<table>")
        html.append("<tr><th>Name</th><th>State</th><th>Region</th><th>Condition</th></tr>")
        for alarm in sorted(cloudwatch_alarm_list, key=lambda x: (x['Region'], x['Name'])):
            html.append(
                f"<tr><td>{alarm['Name']}</td><td>{alarm['State']}</td><td>{alarm['Region']}</td><td>{alarm['Condition']}</td></tr>"
            )
        html.append("</table>")
    else:
        html.append("<p>No CloudWatch alarms found.</p>")

    html.append("<h2>IAM Users</h2>")
    try:
        html.append("<div class='table-container'><table>")
        html.append("""
            <tr>
                <th>Username</th>
                <th>Created</th>
                <th>Console Access</th>
                <th>Last Console Login</th>
                <th>Access Key</th>
                <th>Last Key Used</th>
            </tr>
        """)
        
        users = get_iam_user_details()
        for user in sorted(users, key=lambda x: x['UserName'].lower()):
            # Format console access info
            console_text, console_warning = format_days_since(
                user['PasswordLastUsed'], 
                'console' if user['HasConsoleAccess'] == 'Yes' else None
            )
            
            # Format access key info
            key_text, key_warning = format_days_since(
                user['AccessKeyLastUsed'],
                'access key' if user['HasAccessKey'] == 'Yes' else None
            )
            
            html.append(f"""
                <tr>
                    <td>{user['UserName']}</td>
                    <td>{datetime.strptime(str(user['CreateDate']), '%Y-%m-%d %H:%M:%S%z').strftime('%Y-%m-%d')}</td>
                    <td style="text-align: center;">{user['HasConsoleAccess']}</td>
                    <td{' class="warning tooltip" data-tip="' + console_warning + '"' if console_warning else ''}>
                        {console_text}{' <span class="info-icon">ⓘ</span>' if console_warning else ''}
                    </td>
                    <td style="text-align: center;">{user['HasAccessKey']}</td>
                    <td{' class="warning tooltip" data-tip="' + key_warning + '"' if key_warning else ''}>
                        {key_text}{' <span class="info-icon">ⓘ</span>' if key_warning else ''}
                    </td>
                </tr>
            """)
        
        html.append("</table></div>")
    except Exception as e:
        html.append(f"<p>Error retrieving IAM user details: {str(e)}</p>")

    # Add IAM Groups section
    html.append("<h2>IAM Groups</h2>")
    try:
        groups = get_iam_group_details()
        if groups:
            html.append("""
                <div class='table-container'><table>
                    <tr>
                        <th>Group Name</th>
                        <th>Created</th>
                        <th style="text-align: center;">Users</th>
                    </tr>
            """)
            
            for group in sorted(groups, key=lambda x: x['GroupName'].lower()):
                html.append(f"""
                    <tr>
                        <td>{group['GroupName']}</td>
                        <td>{datetime.strptime(str(group['CreateDate']), '%Y-%m-%d %H:%M:%S%z').strftime('%Y-%m-%d')}</td>
                        <td style="text-align: center;">{group['UserCount']}</td>
                    </tr>
                """)
            
            html.append("</table></div>")
        else:
            html.append("<p>No IAM groups found.</p>")
    except Exception as e:
        html.append(f"<p>Error retrieving IAM group details: {str(e)}</p>")

    html.append("<h2>IAM Roles</h2>")
    try:
        roles, service_role_count = get_iam_role_details()
        if service_role_count:
            html.append(f"<p>{service_role_count} AWS Service Roles not shown.</p>")
        if roles:
            html.append("""
                <div class='table-container'><table>
                    <tr>
                        <th>Role Name</th>
                        <th>Created</th>
                    </tr>
            """)

            for role in sorted(roles, key=lambda x: x['RoleName'].lower()):
                html.append(f"""
                    <tr>
                        <td>{role['RoleName']}</td>
                        <td>{datetime.strptime(str(role['CreateDate']), '%Y-%m-%d %H:%M:%S%z').strftime('%Y-%m-%d')}</td>
                    </tr>
                """)

            html.append("</table></div>")
        else:
            html.append("<p>No IAM roles found.</p>")
    except Exception as e:
        html.append(f"<p>Error retrieving IAM role details: {str(e)}</p>")

    html.append("<h2>CloudTrail Trails</h2>")
    try:
        cloudtrail = boto3.client("cloudtrail", region_name=DEFAULT_REGION)
        trails = cloudtrail.describe_trails().get('trailList', [])
        
        if trails:
            html.append("""
            <table>
                <thead>
                    <tr>
                        <th>Name</th>
                        <th>Home Region</th>
                        <th>Multi-Region</th>
                        <th>Log File Validation</th>
                        <th>Logging</th>
                        <th>Organization Trail</th>
                        <th>S3 Bucket</th>
                    </tr>
                </thead>
                <tbody>
            """)
            
            for trail in trails:

                try:
                    status = cloudtrail.get_trail_status(Name=trail['Name'])
                    is_logging = status.get('IsLogging', False)
                    logging_status = '✅ Enabled' if is_logging else '❌ Disabled'
                except Exception:
                    logging_status = '❓ Unknown'
                
                is_org_trail = trail.get('IsOrganizationTrail', False)
                org_trail_status = '✅ Yes' if is_org_trail else '❌ No'
                
                html.append(f"""
                    <tr>
                        <td>{trail.get('Name', 'N/A')}</td>
                        <td>{trail.get('HomeRegion', 'N/A')}</td>
                        <td>{'✅ Yes' if trail.get('IsMultiRegionTrail', False) else '❌ No'}</td>
                        <td>{'✅ Enabled' if trail.get('LogFileValidationEnabled', False) else '❌ Disabled'}</td>
                        <td>{logging_status}</td>
                        <td>{org_trail_status}</td>
                        <td>{trail.get('S3BucketName', 'N/A')}</td>
                    </tr>
                """)
            
            html.append("""
                </tbody>
            </table>
            """)
        else:
            html.append("  <p>No CloudTrail trails found. It's recommended to enable CloudTrail for security and compliance.</p>")
    except Exception as e:
        html.append(f"  <p>Error retrieving CloudTrail information: {str(e)}</p>")

    html.extend([
        "  </main>",
        "</body>",
        "</html>"
    ])

    # Add generation timestamp and duration information
    generated_at = datetime.now(timezone.utc)
    elapsed = generated_at - start_time
    header_close_index = html.index("  </header>")
    html.insert(
        header_close_index,
        f"    <p>Generated: {generated_at.strftime('%Y-%m-%d %H:%M:%S %Z')} (in {elapsed.total_seconds():.2f} seconds)</p>"
    )

    report_file = Path("report.html")
    if report_file.exists():
        archive_dir = Path("archived_reports")
        archive_dir.mkdir(exist_ok=True)
        backup_file = archive_dir / f"report_{generated_at.strftime('%Y%m%d_%H%M%S')}.html"
        report_file.rename(backup_file)
        print(f"🔁 Existing report moved to {backup_file}")

    report_file.write_text("\n".join(html))
    print(
        f"✅ HTML report (report.html) generated at {generated_at.strftime('%Y-%m-%d %H:%M:%S %Z')} in {elapsed.total_seconds():.2f} seconds"
    )
