import logging
from datetime import datetime
from typing import Any, Optional

from typeguard import typechecked

from psi_elt.abstract_elt_job import AbstractELTJob, TRANSFORMER_ELT_JOB_TYPE


@typechecked
class CustomTransformer(AbstractELTJob):
    """A CustomTransformer is the third of the ELT jobs that is responsible for
    reading data (from S3, Athena, or Timestream) and transforming the data for
    use in a data mart.

    This class should be treated as abstract and should be extended.
    Child classes must override the _read_data and _transform functions.
    This class includes utilties for reading from S3, Athena, and Timestream.
    """

    def __init__(
        self,
        name: str,
        capture_time: datetime,
        aws_wrangler: Optional[Any] = None,
        timestream_loader: Optional[Any] = None,
        aws_session: Optional[Any] = None,
        logger: Optional[logging.Logger] = None,
    ):
        super().__init__(
            name=name,
            elt_job_type=TRANSFORMER_ELT_JOB_TYPE,
            capture_time=capture_time,
            aws_session=aws_session,
            logger=logger,
        )
        self.aws_wrangler = aws_wrangler
        self.timestream_loader = timestream_loader

    def transform(
        self,
        is_test: bool,
        transform_kwargs: Optional[dict] = None,
        timestream_load_kwargs: Optional[dict] = None,
    ):
        transform_kwargs = self.remove_none_from_dict(transform_kwargs)
        timestream_load_kwargs = self.remove_none_from_dict(timestream_load_kwargs)
        self._transform(is_test, timestream_load_kwargs=timestream_load_kwargs, **transform_kwargs)

    def _transform(
        self, is_test: bool, timestream_load_kwargs: Optional[dict] = None, **transform_kwargs
    ):
        raise Exception("CustomTransformer._transform must be overridden by the child class")

    def write_to_timestream(
        self,
        data: list,
        capture_time: datetime,
        database_name: Optional[str] = None,
        table_name: Optional[str] = None,
        source: Optional[str] = None,
        upsert: bool = False,
        ignore_parse_error: bool = False,
        max_batch_size: Optional[int] = None,
        min_common_dimensions: Optional[int] = None,
        is_multi: bool = True,
        is_preview: bool = False,
    ):
        """Calls TimestreamLoader.load() or load_preview() and raises an
        exception if there are any invalid items or rejected records
        """
        timestream_load_input = {
            "data": data,
            "capture_time": capture_time,
            "database_name": database_name,
            "table_name": table_name,
            "source": source,
            "upsert": upsert,
            "ignore_parse_error": ignore_parse_error,
            "max_batch_size": max_batch_size,
            "min_common_dimensions": min_common_dimensions,
            "is_multi": is_multi,
        }
        timestream_load_input = self.remove_none_from_dict(timestream_load_input)
        if is_preview:
            result = self.timestream_loader.load_preview(**timestream_load_input)
        else:
            result = self.timestream_loader.load(**timestream_load_input)

        if len(result["invalid_items"]) > 0:
            raise Exception(
                f"Failed to write to Timestream with invalid items {result['invalid_items']}"
            )
        if len(result["rejected_records"]) > 0:
            raise Exception(
                f"Failed to write to Timestream with rejected records {result['rejected_records']}"
            )
