import logging
from datetime import datetime, timedelta
from typing import Any, Optional

import pytz
from typeguard import typechecked

from psi_elt import get_logger, TIME_UNITS
from psi_elt.abstract_elt_job import AbstractELTJob

NoneType = type(None)


class UnrecognizedELTJobException(Exception):
    pass


@typechecked
class AbstractRunner:
    def __init__(self, jobs: dict[str, Any], logger: Optional[logging.Logger] = None):
        """
        Initializes an AbstractRunner.

        This class should be treated as abstract and should be extended.
        Child classes must override the _run function. Optionally, child
        classes may override the parse_event and validate_parsed_event
        functions as well.

        This class includes utilties for getting a capture time and
        initializing an ELT job.
        """
        assert jobs, "A jobs dict is required to initialize an AbstractRunner"
        self.__jobs = jobs
        self.__logger = logger or get_logger()

    @property
    def logger(self) -> logging.Logger:
        return self.__logger

    def run(self, event: dict[str, Any]) -> dict[str, Any]:
        t0 = datetime.now()
        parsed_event = self.parse_event(event)
        self.validate_parsed_event(parsed_event)
        self.logger.info(f"Running with parsed event: {parsed_event}")
        job = self._run(**parsed_event)
        self.logger.info(f"ELT job {job.name} completed in {str(datetime.now() - t0)}")
        return {"event": event, "parsed_event": parsed_event, "job": job.to_dict()}

    def _run(self, job: AbstractELTJob, **parsed_event):
        raise Exception("AbstractRunner._run should be overridden by the child class")

    def parse_event(self, event: dict[str, Any]) -> dict[str, Any]:
        raise Exception("AbstractRunner.parse_event should be overridden by the child class")

    def validate_parsed_event(self, parsed_event: dict[str, Any]):
        raise Exception(
            "AbstractRunner.validate_parsed_event should be overridden by the child class"
        )

    def _validate_dict(self, dct: dict[str, Any], schema: dict[str, Any]):
        required_keys = [
            k
            for k, v in schema.items()
            if not ((isinstance(v, list) and NoneType in v) or isinstance(v, dict))
        ]
        optional_keys = [
            k
            for k, v in schema.items()
            if (isinstance(v, list) and NoneType in v) or isinstance(v, dict)
        ]

        missing_keys = [k for k in required_keys if k not in dct]
        assert not missing_keys, f"Missing keys: {missing_keys}"

        extra_keys = [k for k in dct if k not in required_keys + optional_keys]
        assert not extra_keys, f"Extra keys: {extra_keys}"

        for k, v in dct.items():
            required_type = schema[k]
            if isinstance(required_type, list):
                assert (
                    type(v) in required_type
                ), f"Invalid {k}, expected {required_type} but got: {v}"
            elif isinstance(required_type, dict):
                self._validate_dict(v, required_type)
            else:
                assert (
                    type(v) == required_type
                ), f"Invalid {k}, expected {required_type} but got: {v}"

    def get_capture_time(self, event: dict) -> datetime:
        input_options: dict = event.get("input_options") or {}
        if "capture_time" in event and event["capture_time"]:
            # fromisoformat returns a naive datetime object that IS NOT adjusted for the local timezone
            try:
                capture_time = datetime.fromisoformat(event["capture_time"]).replace(
                    tzinfo=pytz.utc
                )
            except Exception:
                capture_time = self._parse_ago_time(event["capture_time"])
        elif input_options.get("s3_uri"):
            capture_time_ms = event["s3_uri"].split("/")[-1]
            # fromtimestamp returns a naive datetime object that IS adjusted for the local timezone
            capture_time = (
                datetime.fromtimestamp(int(capture_time_ms) / 1000)
                .astimezone()
                .astimezone(pytz.utc)
            )
        else:
            # now returns a naive datetime object that IS adjusted for the local timezone
            capture_time = datetime.now().astimezone().astimezone(pytz.utc)

        assert isinstance(
            capture_time, datetime
        ), f"Invalid capture time. Expected a datetime object, but got: {capture_time}"
        assert (
            capture_time.tzinfo == pytz.UTC
        ), f"Invalid capture time. Expected capture time to be UTC, but got: {capture_time}"
        return capture_time

    def _parse_ago_time(self, capture_time: str):
        try:
            if capture_time[-2:].isalpha():
                time_interval = TIME_UNITS[capture_time[-2:]]
                number = capture_time[0:-2]
            elif capture_time[-1:].isalpha():
                time_interval = TIME_UNITS[capture_time[-1:]]
                number = capture_time[0:-1]
            else:
                raise Exception(f"Unable to parse {capture_time}")
            timedelta_params = {time_interval: int(number)}
            return datetime.now().astimezone().astimezone(pytz.UTC) - timedelta(**timedelta_params)
        except Exception as e:
            self.__logger.error(f'Failed to parse ago time "{capture_time}" with exception: {e}')
            raise e

    def get_job(
        self, job_name: str, capture_time: datetime, job_init_kwargs: dict
    ) -> AbstractELTJob:
        cls = self.__jobs.get(job_name)
        if cls:
            return cls(job_name, capture_time, **job_init_kwargs)
        raise UnrecognizedELTJobException(f"Unrecognized ELT job with name: {job_name}")
