import logging
from datetime import datetime
from typing import Any, Optional

from typeguard import typechecked

from psi_elt.abstract_runner import AbstractRunner, NoneType
from psi_elt.custom_transformer import CustomTransformer
from psi_elt.utils import filter_dict


@typechecked
class TransformerRunner(AbstractRunner):
    def __init__(
        self,
        transformers: dict[str, Any],
        aws_wrangler: Optional[Any] = None,
        timestream_loader: Optional[Any] = None,
        logger: Optional[logging.Logger] = None,
    ):
        """
        Initializes a TransformerRunner.

        Accepts a transformers dict, mapping transformer name to transformer class.
        For example:
        {
            "tag-compliance": TagComplianceTransformer,
            "qualys": QualysTransformer
        }
        """
        super().__init__(jobs=transformers, logger=logger)
        self.__aws_wrangler = aws_wrangler
        self.__timestream_loader = timestream_loader

    def parse_event(self, event: dict[str, Any]) -> dict[str, Any]:
        """Parses a transformer event. Child classes may override this function to
        implement custom event parsing rules.

        #### Required event keys
           1. name : the name of the transformer to run

        #### Optional event keys
          1. capture_time                         : target time for the transformer to run against
          2. is_test                              : boolean for whethor or not this is a test run
          3. input_options.athena_workgroup       : Athena workgroup to use when reading data
          4. input_options.athena_data_catalog    : Athena data catalog to use when reading data
          5. input_options.athena_database        : Athena database to read data from
          6. output_options.database_name         : Timestream database to write data to
          7. output_options.table_name            : Timestream table to write data to
          8. output_options.write_to_timestream   : boolean for whether to write to Timestream
          9. output_options.upsert                : boolean for whether to upsert or insert data
         10. output_options.max_batch_size        : int max batch size used when batching
         11. output_options.min_common_dimensions : int min common dimensions used when batching
         12. output_options.is_multi              : boolean for whether to use multi-measure records
        """
        parsed_event = {
            "name": event.get("name"),
            "capture_time": self.get_capture_time(event),
            "job_init_kwargs": {},
            "is_test": bool(event.get("is_test")),
            "transform_kwargs": {},
            "timestream_load_kwargs": {},
        }

        # Parse input options
        parsed_event["transform_kwargs"] = filter_dict(
            dct=event.get("input_options"),
            keys=["athena_workgroup", "athena_data_catalog", "athena_database"],
        )

        # Parse output options
        parsed_event["timestream_load_kwargs"] = filter_dict(
            dct=event.get("output_options"),
            keys=[
                "database_name",
                "table_name",
                "upsert",
                "max_batch_size",
                "min_common_dimensions",
                "is_multi",
            ],
        )
        parsed_event["transform_kwargs"].update(
            filter_dict(dct=event.get("output_options"), keys=["write_to_timestream"])
        )
        return parsed_event

    def validate_parsed_event(self, parsed_event: dict[str, Any]):
        """Validates that a parsed event has the required format"""
        schema = {
            "name": str,
            "capture_time": datetime,
            "job_init_kwargs": {},
            "is_test": bool,
            "transform_kwargs": {
                "athena_workgroup": [str, NoneType],
                "athena_data_catalog": [str, NoneType],
                "athena_database": [str, NoneType],
                "write_to_timestream": [bool, NoneType],
            },
            "timestream_load_kwargs": {
                "database_name": [str, NoneType],
                "table_name": [str, NoneType],
                "upsert": [bool, NoneType],
                "ignore_parse_error": [bool, NoneType],
                "max_batch_size": [int, NoneType],
                "min_common_dimensions": [int, NoneType],
                "is_multi": [bool, NoneType],
            },
        }
        try:
            self._validate_dict(parsed_event, schema)
        except AssertionError as e:
            self.logger.error("Invalid event")
            raise e

    def _run(
        self,
        name: str,
        capture_time: datetime,
        job_init_kwargs: dict,
        is_test: bool,
        transform_kwargs: dict,
        timestream_load_kwargs: dict,
    ) -> CustomTransformer:
        job_init_kwargs["aws_wrangler"] = self.__aws_wrangler
        job_init_kwargs["timestream_loader"] = self.__timestream_loader
        transformer: CustomTransformer = self.get_job(name, capture_time, job_init_kwargs)
        transformer.transform(is_test, transform_kwargs, timestream_load_kwargs)
        return transformer
