import logging
import re
import requests
from datetime import datetime, timedelta
from typing import Any, Optional

import boto3
import json
import pytz

from psi_collector import get_logger


class CustomCollector:
    def __init__(
        self,
        name: str,
        capture_time: datetime,
        api: Any,
        aws_session: boto3.session.Session,
        http_session: requests.Session,
        logger: Optional[logging.Logger] = None,
    ):
        self._name = name
        self._api = api
        self._aws_session = aws_session if aws_session else boto3.Session()
        self._http_session = http_session
        self._logger = logger or get_logger()
        self._start_date = self.start_date(capture_time)
        self._end_date = self.end_date(capture_time)
        assert self._start_date <= self._end_date, "Start date must be before the end date"

        # Default capture time of 6am UTC
        default_capture_time = (
            datetime.combine(capture_time.date(), datetime.min.time()) + timedelta(hours=6)
        ).replace(tzinfo=pytz.UTC)
        self._default_capture_time_ms = self.dt_to_ms(default_capture_time)
        self._logger.info(f"Initialized custom collector: {self}")

    def collect(self, s3_path: Optional[str] = None, **kwargs) -> list[dict]:
        params = {key: val for key, val in kwargs.items() if val}
        params["s3_path"] = s3_path
        data = self._collect(**params)
        for i in range(len(data)):
            for k, v in data[i].items():
                # Replace all empty strings with "null"
                data[i][k] = "null" if isinstance(v, str) and not v else v
        return data

    def start_date(self, capture_time: datetime) -> datetime:
        raise Exception(
            'Function "start_date()" not implemented on CustomCollector. Must be overridden by the child class.'
        )

    def end_date(self, capture_time: datetime) -> datetime:
        raise Exception(
            'Function "end_date()" not implemented on CustomCollector. Must be overridden by the child class.'
        )

    def _get_s3_object(self, s3_path: str) -> Any:
        assert s3_path

        s3_path = s3_path.split("s3://")[-1]
        bucket = s3_path.split("/", 1)[0]
        key = s3_path.split("/", 1)[-1]

        try:
            client = self._aws_session.client("s3")
            response = client.get_object(Bucket=bucket, Key=key)
            return response["Body"].read().decode("utf-8")
        except Exception as e:
            self._logger.error(f"Failed to get s3 object {s3_path} with exception: {e}")
            raise e

    def _get_parameter(self, parameter: str) -> str:
        assert parameter
        try:
            client = self._aws_session.client("ssm")
            response = client.get_parameter(Name=parameter, WithDecryption=True)
            value = response.get("Parameter", {}).get("Value")
            if value is None:
                raise Exception("Parameter value not found in response")
            return value
        except Exception as e:
            self._logger.error(f"Failed to get parameter {parameter} with exception: {e}")
            raise e

    def __repr__(self) -> str:
        return json.dumps(
            {
                "name": self._name,
                "start_date": self.format_date(self._start_date),
                "end_date": self.format_date(self._end_date),
            }
        )

    def format_date(self, dt: datetime) -> str:
        """Formats the date as a string. If the provided datetime is naive (no timezone), this
        assumes UTC and returns a string like "2021-05-26T01:02:03Z". If the provided datetime has
        a non-UTC timezone, returns a string like "2021-05-26T01:02:03-05:00".
        """
        if dt.tzinfo and dt.tzinfo != pytz.UTC:
            return dt.isoformat(timespec="seconds")
        return dt.strftime("%Y-%m-%dT%H:%M:%SZ")

    def parse_date(self, dt: str) -> datetime:
        """Parses the given string into a datetime. By default, expects UTC offset information in
        the given string (e.g. "2021-05-26T01:02:03-05:00") and assigns the appropriate time zone.
        If dt is a string like "2021-05-26T01:02:03Z", assumes the timezone is UTC.
        """
        try:
            if dt.endswith("Z"):
                parsed = datetime.strptime(dt, "%Y-%m-%dT%H:%M:%SZ")
                return pytz.timezone("UTC").localize(parsed)
            return datetime.strptime("".join(dt.rsplit(":", 1)), "%Y-%m-%dT%H:%M:%S%z")
        except Exception as e:
            raise ValueError(
                'Datetime must be a string like "2021-05-26T01:02:03Z" for UTC times or "2021-05-26T01:02:03-05:00" for non-UTC times'
            )

    def _is_tappid_compliant(self, t_appid: str) -> bool:
        return bool(re.match("^SVC[0-9]{5}$", t_appid))

    def dt_to_ms(self, dt: datetime) -> int:
        if dt:
            return int(round(dt.timestamp() * 1000))
        return -1
