
# pylint: disable=broad-exception-caught
"""
Notifier that supports "Detected" and "Ended" statuses.
- Keeps the title you set in payload.json (e.g., "AjaxVPN TEST").
- Replaces {{status}} with the provided status.
- Optionally shows duration if your template includes {{duration}}.
"""
import json
import os
import re
import sys
import time
from datetime import datetime, timezone
import requests
from core import config
from core import design

class DiscordNotifier:
    def __init__(self):
        self.webhook_url = config.CONFIG["notification"]["Embed_Webhook_URL"]
        self.output = design.Output()
        self.export_dir = "./beacon_data/detected_ips/"
        self.payload_template = self._load_payload_template()

    def _load_payload_template(self):
        path = 'payload.json'
        if not os.path.exists(path):
            print(f"{self.output.get_output()} Error: payload.json not found.")
            sys.exit(1)
        with open(path, 'r', encoding='utf-8') as f:
            return json.load(f)

    @staticmethod
    def _mask_ip(ip: str) -> str:
        if not isinstance(ip, str):
            return str(ip)
        if ':' in ip:
            seg = ip.split(':')
            return ':'.join([seg[0]] + [('x'*len(s) if s else '') for s in seg[1:]])
        parts = ip.split('.')
        return '.'.join([parts[0]] + ['x'*len(o) for o in parts[1:]]) if len(parts)==4 else ip

    def _read_export_json(self, json_file: str):
        try:
            with open(json_file, 'r', encoding='utf-8') as f:
                return json.load(f)
        except Exception:
            return None

    def _post(self, payload: dict) -> bool:
        try:
            resp = requests.post(self.webhook_url, json=payload, headers={"Content-Type": "application/json"}, timeout=10)
            return resp.status_code == 204
        except Exception:
            return False

    @staticmethod
    def _fmt_int(n):
        try: return f"{int(n):,}"
        except Exception: return str(n)

    @staticmethod
    def _fmt_mbps(n):
        try:
            n = float(n)
            return f"{n:,.0f} Mbps" + (f" ({n/1000:.2f} Gbps)" if n >= 1000 else "")
        except Exception:
            return str(n)

    def send_notification(self, attack_data: dict, export_json: str) -> bool:
        try:
            payload = json.loads(json.dumps(self.payload_template))
            embed = payload["embeds"][0]

            embed["timestamp"] = datetime.now(timezone.utc).isoformat()
            status_text = attack_data.get("status", "Detected")

            pps_val = self._fmt_int(attack_data.get("pps", "N/A"))
            mbps_val = self._fmt_mbps(attack_data.get("mbps", "N/A"))
            cpu_val = attack_data.get("cpu", "N/A")
            duration_val = self._fmt_int(attack_data.get("duration", 0))

            export_data = self._read_export_json(export_json) if export_json and os.path.exists(export_json) else {}

            for field in embed.get("fields", []):
                val = field.get("value", "")
                val = val.replace("{{pps}}", pps_val)
                val = val.replace("{{mbps}}", mbps_val)
                val = val.replace("{{cpu}}", cpu_val)
                val = val.replace("{{status}}", status_text)
                val = val.replace("{{duration}}", f"{duration_val}s")
                val = val.replace("{{pcap}}", str(attack_data.get("pcap","N/A")))
                val = val.replace("{{attack_vector}}", attack_data.get("attack_vector","Undetected"))
                if export_data:
                    if "ipv4_addresses" in export_data or "ipv6_addresses" in export_data:
                        ipv4 = len(export_data.get("ipv4_addresses", {}))
                        ipv6 = len(export_data.get("ipv6_addresses", {}))
                        total = ipv4 + ipv6
                        val = val.replace("{{total_ips}}", f"{total:,}")
                        val = val.replace("{{ipv4_count}}", f"{ipv4:,}")
                        val = val.replace("{{ipv6_count}}", f"{ipv6:,}")
                    if "most_common_source_ip" in export_data:
                        val = val.replace("{{most_common_source_ip}}", self._mask_ip(export_data["most_common_source_ip"]))
                    if "most_common_dest_ip" in export_data:
                        val = val.replace("{{most_common_dest_ip}}", self._mask_ip(export_data["most_common_dest_ip"]))
                    if "pcap_packets_captured" in export_data and "pcap_duration_seconds" in export_data:
                        packets = export_data["pcap_packets_captured"]
                        duration = export_data["pcap_duration_seconds"]
                        avg_pps = round(packets / duration) if duration > 0 else 0
                        val = val.replace("{{packets}}", f"{packets:,}")
                        val = val.replace("{{avg_pps}}", f"{avg_pps:,}")
                field["value"] = val

            return self._post(payload)
        except Exception as e:
            print(f"{self.output.get_output()} Error sending Discord notification: {e}")
            return False
