import logging
from datetime import datetime
from typing import Any, Optional

from typeguard import typechecked

from psi_elt.abstract_runner import AbstractRunner, NoneType
from psi_elt.custom_loader import CustomLoader
from psi_elt.utils import filter_dict


@typechecked
class LoaderRunner(AbstractRunner):
    def __init__(
        self,
        loaders: dict[str, Any],
        aws_wrangler: Optional[Any] = None,
        logger: Optional[logging.Logger] = None,
    ):
        """
        Initializes a LoaderRunner.

        Accepts a loaders dict, mapping loader name to loader class.
        For example:
        {
            "slater": SlaterLoader,
            "tep": TepLoader
        }
        """
        super().__init__(jobs=loaders, logger=logger)
        self.__aws_wrangler = aws_wrangler

    def parse_event(self, event: dict[str, Any]) -> dict[str, Any]:
        """Parses a loader event. Child classes may override this function to
        implement custom event parsing rules.

        #### Required event keys
           1. name : the name of the loader to run

        #### Optional event keys
           1. capture_time                   : target time for the job to run against
           2. is_test                        : boolean for whethor or not this is a test run
           3. input_options.s3_uri           : S3 location to read data from
           4. input_options.s3_path          : S3 path to read data from
           5. output_options.s3_uri          : S3 location to write data to
           6. output_options.s3_path         : S3 path to write data to
           7. output_options.write_to_athena : boolean for whether or not to write to Athena
           8. output_options.athena_database : Athena database to write data to
           9. output_options.athena_table    : Athena table to write data to
        """
        parsed_event = {
            "name": event.get("name"),
            "capture_time": self.get_capture_time(event),
            "job_init_kwargs": {},
            "is_test": bool(event.get("is_test")),
            "load_kwargs": {"input_options": {}, "output_options": {}},
        }

        # Parse input options
        parsed_event["load_kwargs"]["input_options"] = filter_dict(
            dct=event.get("input_options"), keys=["s3_uri", "s3_path"]
        )

        # Parse output options
        parsed_event["load_kwargs"]["output_options"] = filter_dict(
            dct=event.get("output_options"),
            keys=["s3_uri", "s3_path", "write_to_athena", "athena_database", "athena_table"],
        )
        return parsed_event

    def validate_parsed_event(self, parsed_event: dict[str, Any]):
        """Validates that a parsed event has the required format"""
        schema = {
            "name": str,
            "capture_time": datetime,
            "job_init_kwargs": {},
            "is_test": bool,
            "load_kwargs": {
                "input_options": {"s3_uri": [str, NoneType], "s3_path": [str, NoneType]},
                "output_options": {
                    "s3_uri": [str, NoneType],
                    "s3_path": [str, NoneType],
                    "write_to_athena": [bool, NoneType],
                    "athena_database": [str, NoneType],
                    "athena_table": [str, NoneType],
                },
            },
        }
        try:
            self._validate_dict(parsed_event, schema)
        except AssertionError as e:
            self.logger.error("Invalid event")
            raise e

    def _run(
        self,
        name: str,
        capture_time: datetime,
        job_init_kwargs: dict,
        is_test: bool,
        load_kwargs: dict,
    ) -> CustomLoader:
        job_init_kwargs["aws_wrangler"] = self.__aws_wrangler
        loader: CustomLoader = self.get_job(name, capture_time, job_init_kwargs)
        loader.load(is_test, load_kwargs)
        return loader
