import logging
from datetime import datetime, timedelta
from typing import Any, Optional

import pytz

from psi_collector import TIME_UNITS, get_logger
from psi_collector.custom_collector import CustomCollector


class CollectorRunner:
    def __init__(
        self,
        collectors: dict[str, Any],
        s3_loader: Optional[Any] = None,
        timestream_loader: Optional[Any] = None,
        logger: Optional[logging.Logger] = None,
    ):
        """
        Initializes a CollectorRunner.

        Accepts a collectors dict, mapping collector name to collector class.
        For example:
        {
            "divvy": DivvyCollector,
            "pagerduty": PagerDutyCollector
        }

        """
        assert collectors and isinstance(
            collectors, dict
        ), "A collectors dict is required to initialize a CollectorRunner"
        assert s3_loader, "A S3 loader is required to initialize a CollectorRunner"
        assert timestream_loader, "A Timestream loader is required to initialize a CollectorRunner"
        assert type(logger) in [
            logging.Logger,
            type(None),
        ], f"A logger of type logging.Logger is required to initialize a CollectorRunner. Got {logger}"
        self.__collectors = collectors
        self.__s3_loader = s3_loader
        self.__timestream_loader = timestream_loader
        self.__logger = logger or get_logger()

    def collect(self, event: dict):
        """Initializes and runs a collector.

        Accepts an event dictionary with the following required attributes:
            1. collector_name : the name of the collector to run
            2. s3_path        : the full s3 path to the data to collect from S3

        If writing to S3, the following attributes are required:
            1. write_to_s3    : boolean that must be set to true if writing to S3
            2. bucket_name    : S3 bucket to write to
            3. output_s3_path : S3 path to write to

        If writing to Timestream, the following attributes are required:
            1. write_to_timestream     : boolean that must be set to true if writing to Timestream
            2. database_name           : Timestream database to write to
            3. output_timestream_table : Timestream table to write to
            (optional) is_upsert       : whether or not to upsert Timestream records (default True)

        Optionally, capture_time may be set on the event as a millisecond timestamp. If unset,
        capture_time is taken from either the s3_path or the current time.
        """
        t0 = datetime.now()
        parsed_event = self.parse_event(event)
        self.__validate_parsed_event(parsed_event)
        self.__collect(**parsed_event)
        self.__logger.info(f"Collection completed in {str(datetime.now() - t0)}")

    def parse_event(self, event: dict) -> dict:
        """Parses a collector event into the following parts:
        1. collector_name         : the name of the collector to run
        2. capture_time           : time of collection
        3. write_to_s3            : boolean for whether or not to write to S3
        4. write_to_timestream    : boolean for whether or not to write to Timestream
        5. collector_init_kwargs  : kwargs passed into the collector's __init__
        6. collect_kwargs         : kwargs passed into the collector's collect function
        7. s3_load_kwargs         : kwargs passed into the S3 loader's load function
        8. timestream_load_kwargs : kwargs passed into the Timestream loader's load function

        Child classes may override this function to implement custom event parsing rules.
        """
        assert event and isinstance(
            event, dict
        ), f"Event must be a non-empty dictionary, got: {event}"
        return {
            "collector_name": event.get("collector_name"),
            "capture_time": self.get_capture_time(event),
            "write_to_s3": event.get("write_to_s3", True),
            "write_to_timestream": event.get("write_to_timestream", True),
            "collector_init_kwargs": {},
            "collect_kwargs": {
                "s3_path": event.get("s3_path"),
                "account_id": event.get("account_id"),
            },
            "s3_load_kwargs": {},
            "timestream_load_kwargs": {
                "upsert": event.get("is_upsert", True),
                "ignore_parse_error": True,
                "max_batch_size": event.get("max_batch_size"),
                "min_common_dimensions": event.get("min_common_dimensions"),
                "is_multi": bool(event.get("is_multi")),
            },
        }

    def __validate_parsed_event(self, parsed_event: dict):
        assert parsed_event and isinstance(
            parsed_event, dict
        ), f"Parsed event must be a non-empty dictionary, got: {parsed_event}"
        required_keys = [
            "collector_name",
            "capture_time",
            "write_to_s3",
            "write_to_timestream",
            "collector_init_kwargs",
            "collect_kwargs",
            "s3_load_kwargs",
            "timestream_load_kwargs",
        ]
        missing_keys = [k for k in required_keys if k not in parsed_event]
        extra_keys = [k for k in parsed_event if k not in required_keys]
        assert not missing_keys, f"Parsed event is missing keys: {missing_keys}"
        assert not extra_keys, f"Parsed event has extra keys: {extra_keys}"
        assert isinstance(
            parsed_event["collector_name"], str
        ), f"Invalid collector_name, expected a string but got: {parsed_event['collector_name']}"
        assert isinstance(
            parsed_event["capture_time"], datetime
        ), f"Invalid capture_time, expected a datetime but got: {parsed_event['capture_time']}"
        assert isinstance(
            parsed_event["write_to_s3"], bool
        ), f"Invalid write_to_s3, expected a bool but got: {parsed_event['write_to_s3']}"
        assert isinstance(
            parsed_event["write_to_timestream"], bool
        ), f"Invalid write_to_timestream, expected a bool but got: {parsed_event['write_to_timestream']}"
        assert isinstance(
            parsed_event["collector_init_kwargs"], dict
        ), f"Invalid collector_init_kwargs, expected a dict but got: {parsed_event['collector_init_kwargs']}"
        assert isinstance(
            parsed_event["collect_kwargs"], dict
        ), f"Invalid collect_kwargs, expected a dict but got: {parsed_event['collect_kwargs']}"
        assert isinstance(
            parsed_event["s3_load_kwargs"], dict
        ), f"Invalid s3_load_kwargs, expected a dict but got: {parsed_event['s3_load_kwargs']}"
        assert isinstance(
            parsed_event["timestream_load_kwargs"], dict
        ), f"Invalid timestream_load_kwargs, expected a dict but got: {parsed_event['timestream_load_kwargs']}"

        collect_kwargs = parsed_event["collect_kwargs"]
        if "s3_path" in collect_kwargs:
            assert type(collect_kwargs["s3_path"]) in [
                type(None),
                str,
            ], f"Invalid collect_kwargs.s3_path, expected a string but got: {collect_kwargs['s3_path']}"
        if "account_id" in collect_kwargs:
            assert type(collect_kwargs["account_id"]) in [
                type(None),
                str,
            ], f"Invalid collect_kwargs.account_id, expected a string but got: {collect_kwargs['account_id']}"

        timestream_load_kwargs = parsed_event["timestream_load_kwargs"]
        if "upsert" in timestream_load_kwargs:
            assert type(timestream_load_kwargs["upsert"]) in [
                type(None),
                bool,
            ], f"Invalid timestream_load_kwargs.upsert, expected a bool but got: {timestream_load_kwargs['upsert']}"
        if "ignore_parse_error" in timestream_load_kwargs:
            assert type(timestream_load_kwargs["ignore_parse_error"]) in [
                type(None),
                bool,
            ], f"Invalid timestream_load_kwargs.ignore_parse_error, expected a bool but got: {timestream_load_kwargs['ignore_parse_error']}"
        if "max_batch_size" in timestream_load_kwargs:
            assert type(timestream_load_kwargs["max_batch_size"]) in [
                type(None),
                int,
            ], f"Invalid timestream_load_kwargs.max_batch_size, expected an int but got: {timestream_load_kwargs['max_batch_size']}"
        if "min_common_dimensions" in timestream_load_kwargs:
            assert type(timestream_load_kwargs["min_common_dimensions"]) in [
                type(None),
                int,
            ], f"Invalid timestream_load_kwargs.min_common_dimensions, expected an int but got: {timestream_load_kwargs['min_common_dimensions']}"
        if "is_multi" in timestream_load_kwargs:
            assert type(timestream_load_kwargs["is_multi"]) in [
                type(None),
                bool,
            ], f"Invalid timestream_load_kwargs.is_multi, expected a bool but got: {timestream_load_kwargs['is_multi']}"

    def __collect(
        self,
        collector_name: str,
        capture_time: datetime,
        write_to_s3: bool,
        write_to_timestream: bool,
        collector_init_kwargs: dict,
        collect_kwargs: dict,
        s3_load_kwargs: dict,
        timestream_load_kwargs: dict,
    ):
        collector = self.get_collector(collector_name, capture_time, **collector_init_kwargs)
        data = collector.collect(**collect_kwargs)
        if write_to_s3:
            self.__logger.info("Beginning upload to S3")
            self.__s3_loader.load(collector_name, data, capture_time, **s3_load_kwargs)
        if write_to_timestream:
            self.__logger.info("Beginning write to Timestream")
            result = self.__timestream_loader.load(
                data=data,
                capture_time=capture_time,
                source=collector_name,
                **timestream_load_kwargs,
            )
            if len(result["invalid_items"]) > 0:
                raise Exception(
                    f"Failed to write to Timestream with invalid items {result['invalid_items']}"
                )
            if len(result["rejected_records"]) > 0:
                raise Exception(
                    f"Failed to write to Timestream with rejected records {result['rejected_records']}"
                )

    def get_capture_time(self, event: dict) -> datetime:
        if "capture_time" in event and event["capture_time"]:
            # fromisoformat returns a naive datetime object that IS NOT adjusted for the local timezone
            try:
                return datetime.fromisoformat(event["capture_time"]).replace(tzinfo=pytz.utc)
            except Exception:
                return self._parse_ago_time(event["capture_time"])
        elif "s3_path" in event and event["s3_path"]:
            capture_time_ms = event["s3_path"].split("/")[-1]
            # fromtimestamp returns a naive datetime object that IS adjusted for the local timezone
            return (
                datetime.fromtimestamp(int(capture_time_ms) / 1000)
                .astimezone()
                .astimezone(pytz.utc)
            )
        else:
            # now returns a naive datetime object that IS adjusted for the local timezone
            return datetime.now().astimezone().astimezone(pytz.utc)

    def _parse_ago_time(self, capture_time: str):
        try:
            if capture_time[-2:].isalpha():
                time_interval = TIME_UNITS[capture_time[-2:]]
                number = capture_time[0:-2]
            elif capture_time[-1:].isalpha():
                time_interval = TIME_UNITS[capture_time[-1:]]
                number = capture_time[0:-1]
            else:
                raise Exception(f"Unable to parse {capture_time}")
            timedelta_params = {time_interval: int(number)}
            return datetime.now() - timedelta(**timedelta_params)
        except Exception as e:
            self.__logger.error(f'Failed to parse ago time "{capture_time}" with exception: {e}')
            raise e

    def get_collector(
        self, collector_name: str, capture_time: datetime, **collector_init_kwargs
    ) -> CustomCollector:
        assert capture_time, "A capture_time is required to initialize a collector"
        cls = self.__collectors.get(collector_name)
        if cls:
            return cls(
                capture_time,
                collector_init_kwargs.pop("api", None),
                collector_init_kwargs.pop("aws_session", None),
                collector_init_kwargs.pop("http_session", None),
                **collector_init_kwargs,
            )
        raise Exception(f"Unrecognized collector with name: {collector_name}")
