
# pylint: disable=broad-exception-caught
"""
NOTICE OF LICENSE.

Copyright 2025 @AnabolicsAnonymous

Licensed under the Affero General Public License v3.0 (AGPL-3.0)
"""
import json
import os
import re
import sys
import time
from datetime import datetime, timezone

import requests

from core import config
from core import design

class DiscordNotifier:
    """
    Send notifications to Discord with safer defaults and retries.
    """
    def __init__(self):
        self.webhook_url = config.CONFIG["notification"]["Embed_Webhook_URL"]
        self.output = design.Output()
        self.export_dir = "./beacon_data/detected_ips/"
        self.payload_template = self._load_payload_template()

    def _load_payload_template(self):
        template_path = 'payload.json'
        example_path = 'payload.json.example'
        if not os.path.exists(template_path):
            if os.path.exists(example_path):
                print(f"{self.output.get_output()} Error: payload.json not found.")
                print(f"{self.output.get_output()} Edit payload.json.example to payload.json")
            else:
                print(f"{self.output.get_output()} Error: Neither payload files found")
            sys.exit(1)
        try:
            with open(template_path, 'r', encoding='utf-8') as f:
                return json.load(f)
        except json.JSONDecodeError as e:
            print(f"{self.output.get_output()} Error: payload.json invalid JSON: {str(e)}")
            sys.exit(1)
        except Exception as e:
            print(f"{self.output.get_output()} Error loading payload template: {str(e)}")
            sys.exit(1)

    @staticmethod
    def _mask_ip(ip: str) -> str:
        if not isinstance(ip, str):
            return str(ip)
        if ':' in ip:
            segments = ip.split(':')
            if len(segments) > 0:
                masked = [segments[0]] + [('x'*len(s) if s else '') for s in segments[1:]]
                return ':'.join(masked)
        else:
            octets = ip.split('.')
            if len(octets) == 4:
                return '.'.join([octets[0]] + ['x'*len(o) for o in octets[1:]])
        return ip

    def _clean_attack_vector(self, attack_vector: str) -> str:
        if not attack_vector:
            return "Undetected"
        cleaned = re.sub(r'\s+', ' ', attack_vector)
        cleaned = re.sub(r'\s*\]\s*', ']', cleaned)
        cleaned = re.sub(r'\s*\[\s*', '[', cleaned)
        return cleaned

    def _attack_count(self) -> int:
        try:
            os.makedirs(self.export_dir, exist_ok=True)
            return len([f for f in os.listdir(self.export_dir) if f.endswith('.txt') or f.endswith('.json')])
        except Exception:
            return 0

    def _read_export_json(self, json_file: str):
        try:
            with open(json_file, 'r', encoding='utf-8') as f:
                return json.load(f)
        except Exception as e:
            print(f"{self.output.get_output()} Error reading export JSON: {str(e)}")
            return None

    def _post_with_retries(self, payload: dict, max_attempts=3, base_sleep=1.5) -> bool:
        for attempt in range(1, max_attempts + 1):
            try:
                resp = requests.post(
                    self.webhook_url,
                    json=payload,
                    headers={"Content-Type": "application/json"},
                    timeout=10
                )
                if resp.status_code == 204:
                    return True
                if resp.status_code in (429, 500, 502, 503, 504):
                    time.sleep(base_sleep * (2 ** (attempt - 1)))
                    continue
                print(f"{self.output.get_output()} Webhook failed: {resp.status_code} -> {resp.text}")
                return False
            except Exception as e:
                time.sleep(base_sleep * (2 ** (attempt - 1)))
                print(f"{self.output.get_output()} Webhook exception (attempt {attempt}): {e}")
        return False

    # ---------- NEW: formatting helpers ----------
    @staticmethod
    def _fmt_int(n):
        try:
            return f"{int(n):,}"
        except Exception:
            return str(n)

    @staticmethod
    def _fmt_mbps(n):
        try:
            n_float = float(n)
            if n_float >= 1000:
                return f"{n_float:,.0f} Mbps ({n_float/1000:.2f} Gbps)"
            return f"{n_float:,.0f} Mbps"
        except Exception:
            return str(n)

    def send_notification(self, attack_data: dict, export_json: str) -> bool:
        try:
            if not self.webhook_url:
                print(f"{self.output.get_output()} Error: Discord webhook URL not configured")
                return False
            if not self.payload_template:
                print(f"{self.output.get_output()} Error: Payload template not loaded")
                return False

            payload = json.loads(json.dumps(self.payload_template))
            attack_id = self._attack_count()
            export_data = {}
            embed = payload["embeds"][0]

            if export_json and os.path.exists(export_json):
                export_data = self._read_export_json(export_json) or {}

            # Title from template should already be "AjaxVPN TEST" if you set it so.
            embed["description"] = "AjaxVPN has detected and analyzed a potential DDoS attack."
            embed["timestamp"] = datetime.now(timezone.utc).isoformat()

            # Pre-format values for readability
            pps_val = self._fmt_int(attack_data.get("pps", "N/A"))
            mbps_val = self._fmt_mbps(attack_data.get("mbps", "N/A"))
            cpu_val = attack_data.get("cpu", "N/A")

            for field in embed.get("fields", []):
                val = field.get("value", "")
                val = val.replace("{{pps}}", pps_val)
                val = val.replace("{{mbps}}", mbps_val)
                val = val.replace("{{cpu}}", cpu_val)
                val = val.replace("{{status}}", "Detected")
                val = val.replace("{{pcap}}", str(attack_data.get("pcap","N/A")))
                val = val.replace("{{attack_vector}}", self._clean_attack_vector(attack_data.get("attack_vector","Undetected")))

                if export_data:
                    if "ipv4_addresses" in export_data or "ipv6_addresses" in export_data:
                        ipv4_count = len(export_data.get("ipv4_addresses", {}))
                        ipv6_count = len(export_data.get("ipv6_addresses", {}))
                        total_ips  = ipv4_count + ipv6_count
                        val = val.replace("{{total_ips}}", f"{total_ips:,}")
                        val = val.replace("{{ipv4_count}}", f"{ipv4_count:,}")
                        val = val.replace("{{ipv6_count}}", f"{ipv6_count:,}")
                    if "most_common_source_ip" in export_data:
                        val = val.replace("{{most_common_source_ip}}", self._mask_ip(export_data["most_common_source_ip"]))
                    if "most_common_dest_ip" in export_data:
                        val = val.replace("{{most_common_dest_ip}}", self._mask_ip(export_data["most_common_dest_ip"]))
                    if "pcap_packets_captured" in export_data and "pcap_duration_seconds" in export_data:
                        packets = export_data["pcap_packets_captured"]
                        duration = export_data["pcap_duration_seconds"]
                        avg_pps = round(packets / duration) if duration > 0 else 0
                        val = val.replace("{{packets}}", f"{packets:,}")
                        val = val.replace("{{avg_pps}}", f"{avg_pps:,}")
                field["value"] = val

            footer = embed.get("footer", {})
            footer["text"] = (footer.get("text", "PCAP: {{pcap}} • Powered by AjaxVPN AjaxSEC")
                              .replace("{{pcap}}", str(attack_data.get("pcap", "N/A"))))
            embed["footer"] = footer

            json.dumps(payload)
            return self._post_with_retries(payload)
        except Exception as e:
            print(f"{self.output.get_output()} Error sending Discord notification: {str(e)}")
            return False
