import logging
import re
import uuid
from typing import Any, Iterator, Optional, Union

import awswrangler
import boto3
import numpy
import pandas
from typeguard import typechecked

from psi_awswrangler import get_logger


ATHENA_READ_SQL_QUERY_DEFAULT_ENCRYPTION = "SSE_S3"
S3_READ_JSON_DEFAULT_ORIENT = "columns"
S3_TO_PARQUET_DEFAULT_COMPRESSION = "snappy"
S3_TO_PARQUET_DEFAULT_MODE = "overwrite_partitions"


@typechecked
class AwsWrangler:
    def __init__(
        self,
        athena_output_path: Optional[str] = None,
        aws_session: Optional[Any] = None,
        logger: Optional[logging.Logger] = None,
    ):
        self.__athena_output_path = None
        if athena_output_path:
            self.athena_output_path = athena_output_path
        self.aws_session = aws_session or boto3.Session()
        self.logger = logger or get_logger()

    # Properties -----------------------------------------------------------------------------------

    @property
    def athena_output_path(self) -> Union[str, None]:
        return self.__athena_output_path

    @athena_output_path.setter
    def athena_output_path(self, s3_path: str):
        assert self.__is_valid_s3_path(s3_path, True), f"Invalid S3 path: {s3_path}"
        self.__athena_output_path = s3_path

    # Public Functions -----------------------------------------------------------------------------

    def athena_read_sql_query(  # Transform
        self,
        sql: str,
        database: str,
        ctas_approach: bool = False,
        unload_approach: bool = False,
        s3_output_path: Optional[str] = None,
        workgroup: Optional[str] = None,
        data_catalog: Optional[str] = None,
        encryption: str = ATHENA_READ_SQL_QUERY_DEFAULT_ENCRYPTION,
        use_threads: Union[bool, int] = False,
        aws_session: Optional[boto3.Session] = None,
        remove_whitespace: bool = True,
        remove_nan: bool = True,
        **awswrangler_additional_kwargs,
    ) -> Union[pandas.DataFrame, Iterator[pandas.DataFrame]]:
        """Execute any SQL query on AWS Athena and return the results as a Pandas DataFrame:
        https://aws-sdk-pandas.readthedocs.io/en/stable/stubs/awswrangler.athena.read_sql_query.html
        """
        if not s3_output_path and not workgroup:
            assert (
                self.athena_output_path
            ), f"An Athena output path is required to interact with Athena"
            hex = uuid.uuid4().hex
            s3_output_path = f"{self.athena_output_path}{hex}/"

        self.logger.info(f"Running Athena SQL query:\n{sql}")
        try:
            if s3_output_path is not None:
                assert self.__is_valid_s3_path(s3_output_path), f"Invalid S3 path: {s3_output_path}"
            print(s3_output_path)
            df = awswrangler.athena.read_sql_query(
                sql=sql,
                database=database,
                ctas_approach=ctas_approach,
                unload_approach=unload_approach,
                s3_output=s3_output_path,
                workgroup=workgroup,
                data_source=data_catalog,
                encryption=encryption or ATHENA_READ_SQL_QUERY_DEFAULT_ENCRYPTION,
                use_threads=use_threads,
                boto3_session=aws_session or self.aws_session,
                **awswrangler_additional_kwargs,
            )
            if remove_nan:
                df = self.__nan_remover(df)
            if remove_whitespace:
                df = self.__whitespace_remover(df)
            return df
        except Exception as e:
            self.logger.error(f"Failed to read Athena SQL query with exception: {e}")
            raise e

    def s3_read_csv(  # Extract
        self,
        path: Union[str, list[str]],
        use_threads: Union[bool, int] = False,
        aws_session: Optional[boto3.Session] = None,
        remove_whitespace: bool = True,
        remove_nans: bool = True,
        **awswrangler_additional_kwargs,
    ) -> Union[pandas.DataFrame, Iterator[pandas.DataFrame]]:
        """Read CSV file(s) from a received S3 prefix or list of S3 objects paths.
        https://aws-sdk-pandas.readthedocs.io/en/stable/stubs/awswrangler.s3.read_csv.html
        """
        self.logger.info(f"Reading CSV from S3 path(s): {path}")
        try:
            paths = path if isinstance(path, list) else [path]
            for p in paths:
                assert self.__is_valid_s3_path(p), f"Invalid S3 path: {p}"
            df = awswrangler.s3.read_csv(
                path=path,
                use_threads=use_threads,
                boto3_session=aws_session or self.aws_session,
                **awswrangler_additional_kwargs,
                # pandas_kwargs=pandas_kwargs
            )
            if remove_nans:
                df = self.__nan_remover(df)
            if remove_whitespace:
                df = self.__whitespace_remover(df)
            return df
        except Exception as e:
            self.logger.error(f"Failed to read CSV file(s) from S3 with exception: {e}")
            raise e

    def s3_read_json(  # Load
        self,
        s3_uri: Union[str, list[str]],
        orient: str = S3_READ_JSON_DEFAULT_ORIENT,
        use_threads: Union[bool, int] = False,
        boto3_session: Optional[boto3.Session] = None,
        remove_whitespace: bool = True,
        remove_nans: bool = True,
        **awswrangler_additional_kwargs,
    ) -> Union[pandas.DataFrame, Iterator[pandas.DataFrame]]:
        """Read JSON file(s) from a received S3 prefix or list of S3 objects paths/uris.
        https://aws-sdk-pandas.readthedocs.io/en/stable/stubs/awswrangler.s3.read_json.html
        """
        self.logger.info(f"Reading JSON from S3 path(s): {s3_uri}")
        try:
            paths = s3_uri if isinstance(s3_uri, list) else [s3_uri]
            for p in paths:
                assert self.__is_valid_s3_path(p), f"Invalid S3 path: {p}"
            awswrangler_additional_kwargs = awswrangler_additional_kwargs or {}
            awswrangler_additional_kwargs["dtype"] = awswrangler_additional_kwargs.get(
                "dtype", True
            )
            df = awswrangler.s3.read_json(
                path=s3_uri,
                orient=orient or S3_READ_JSON_DEFAULT_ORIENT,
                use_threads=use_threads,
                boto3_session=boto3_session or self.aws_session,
                **awswrangler_additional_kwargs,
            )
            if remove_nans:
                df = self.__nan_remover(df)
            if remove_whitespace:
                df = self.__whitespace_remover(df)
            return df
        except Exception as e:
            self.logger.error(f"Failed to read JSON file(s) from S3 with exception: {e}")
            raise e

    def s3_read_parquet(  # Load
        self,
        path: Union[str, list[str]],
        use_threads: Union[bool, int] = False,
        remove_whitespace: bool = True,
        remove_nans: bool = True,
        **awswrangler_additional_kwargs,
    ) -> Union[pandas.DataFrame, Iterator[pandas.DataFrame]]:
        """Read Apache Parquet file(s) from a received S3 prefix or list of S3 objects paths.
        https://aws-sdk-pandas.readthedocs.io/en/stable/stubs/awswrangler.s3.read_parquet.html
        """
        self.logger.info(f"Reading Parquet from S3 path(s): {path}")
        try:
            paths = path if isinstance(path, list) else [path]
            for p in paths:
                assert self.__is_valid_s3_path(p), f"Invalid S3 path: {p}"
            df = awswrangler.s3.read_parquet(
                path=path,
                boto3_session=self.aws_session,
                use_threads=use_threads,
                **awswrangler_additional_kwargs,
            )
            if remove_nans:
                df = self.__nan_remover(df)
            if remove_whitespace:
                df = self.__whitespace_remover(df)
            return df
        except awswrangler.exceptions.NoFilesFound:
            self.logger.warning(f"No parquet files found at {path}, returning empty data frame")
            return pandas.DataFrame()
        except Exception as e:
            self.logger.error(f"Failed to read Parquet file(s) from S3 with exception: {e}")
            raise e

    def s3_list_directories(  # Load
        self,
        path: str,
        chunked: bool = False,
        boto3_session: Optional[boto3.Session] = None,
        **awswrangler_additional_kwargs,
    ) -> Union[list[str], Iterator[list[str]]]:
        self.logger.info(f"Listing S3 directories at path: {path}")
        try:
            assert self.__is_valid_s3_path(path, True), f"Invalid S3 path: {path}"
            return awswrangler.s3.list_directories(
                path=path,
                chunked=chunked,
                boto3_session=boto3_session or self.aws_session,
                **awswrangler_additional_kwargs,
            )
        except Exception as e:
            self.logger.error(f"Failed to list S3 directories with exception: {e}")
            raise e

    def s3_to_parquet(  # Load
        self,
        raw_data: pandas.DataFrame,
        s3_path: str,
        partition_cols: list[str],
        compression: str = S3_TO_PARQUET_DEFAULT_COMPRESSION,
        use_threads: Union[bool, int] = False,
        boto3_session: Optional[boto3.Session] = None,
        dataset: bool = True,
        mode: str = S3_TO_PARQUET_DEFAULT_MODE,
        athena_database: Optional[str] = None,
        athena_table: Optional[str] = None,
        **awswrangler_additional_kwargs,
    ) -> Any:
        self.logger.info(f"Writing DataFrame to S3 as parquet with path: {s3_path}")
        try:
            assert self.__is_valid_s3_path(s3_path), f"Invalid S3 path: {s3_path}"
            return awswrangler.s3.to_parquet(
                df=pandas.DataFrame(raw_data),
                path=s3_path,
                compression=compression or S3_TO_PARQUET_DEFAULT_COMPRESSION,
                use_threads=use_threads,
                boto3_session=boto3_session or self.aws_session,
                dataset=dataset,
                table=athena_table,
                database=athena_database,
                mode=mode or S3_TO_PARQUET_DEFAULT_MODE,
                partition_cols=partition_cols,
                **awswrangler_additional_kwargs,
            )
        except Exception as e:
            self.logger.error(f"Failed to write DataFrame to S3 as parquet with exception: {e}")
            raise e

    # Private Functions ----------------------------------------------------------------------------

    def __is_valid_s3_path(self, s3_path: str, reject_uri: bool = False) -> bool:
        if reject_uri:
            # The S3 path must not be a full S3 URI (must end in '/')
            bool(re.match("^s3://([^/]+)/(.*?([^/]+))/$", s3_path))
        return bool(re.match("^s3://([^/]+)/(.*?)$", s3_path))

    def __nan_remover(self, df: pandas.DataFrame) -> pandas.DataFrame:
        return df.fillna(numpy.nan).replace([numpy.nan], [None])

    def __whitespace_remover(self, df: pandas.DataFrame):
        cp = df.copy(deep=True)
        for i in cp.columns:
            if cp[i].dtype == "string":
                cp[i] = cp[i].str.strip()
            if cp[i].dtype == "object":
                cp[i] = cp[i].apply(lambda x: x.strip() if type(x) == str else x)
        return cp


# 0.2.5 7/7
