import boto3
import logging
import pytz

from botocore.config import Config
from botocore.exceptions import ClientError
from datetime import datetime
from timestream import DEFAULT_LOGGER, log_it
from timestream.batch import batch_records, DEFAULT_MAX_BATCH_SIZE, DEFAULT_MIN_COMMON_DIMENSIONS
from timestream.dimensions import TimestreamDimensions
from timestream.load_result import TimestreamLoadResult
from timestream.record import RejectedTimestreamRecord, TimestreamRecord
from typing import List


DEFAULT_CLIENT_CONFIG = Config(
    read_timeout=20,
    max_pool_connections=5000,
    retries={'max_attempts': 10}
)


class TimestreamLoader:
    def __init__(
        self,
        database_name: str = None,
        table_name: str = None,
        session: boto3.Session = None,
        current_time_ms: int = None,
        logger: logging.Logger = None
    ):
        assert database_name, 'Must specify a database_name'
        assert table_name, 'Must specify table_name'
        self.__database_name = database_name
        self.__table_name = table_name
        self.__session = session if session else boto3.Session()
        self.__timestream_write_client = self.__session.client('timestream-write', config=DEFAULT_CLIENT_CONFIG)
        self._logger = logger or DEFAULT_LOGGER

        self.__current_time_ms = self.__dt_to_ms(datetime.now().astimezone().astimezone(pytz.utc))
        if current_time_ms is not None:
            assert type(current_time_ms) is int and current_time_ms > 0, f'Invalid current_time_ms: {current_time_ms}. Expecting an integer greater than 0'
            self._logger.warning(f'Overriding TimestreamLoader.__current_time_ms with value {current_time_ms}')
            self.__current_time_ms = current_time_ms

    @log_it
    def does_database_exist(
        self,
        database_name: str = None,
        raise_err: bool = False
    ) -> bool:
        '''
        Performs a boto3 describe_database API call and returns True if the
        specified database_name exists

        Parameters
        ----------
        database_name : str
            The name of the database you want to check, or self.database_name

        raise_err : bool, Default : False
            Whether to raise the boto3 ResourceNotFoundException or just return
            False

        Returns
        -------
        bool

        Raises
        ------
        botocore.exceptions.ClientError/ResourceNotFoundException
        '''
        db_name = database_name if database_name else self.__database_name
        try:
            response = self.__timestream_write_client.describe_database(
                DatabaseName=db_name)
            result = True
        except ClientError as ex:
            if raise_err:
                raise ex
            else:
                result = False
        return result

    @log_it
    def does_table_exist(
        self,
        database_name: str = None,
        table_name: str = None,
        raise_err: bool = False
    ) -> bool:
        '''
        Performs a boto3 describe_table API call and returns True if the
        specified table_name exists in the database

        Parameters
        ----------
        database_name : str
            The name of the database you want to check, or self.__database_name

        table_name : str
            The name of the database you want to check, or self.__table_name

        raise_err : bool, Default : False
            Whether to raise the boto3 ResourceNotFoundException or just return
            False

        Returns
        -------
        bool

        Raises
        ------
        botocore.exceptions.ClientError/ResourceNotFoundException
        '''
        db_name = database_name if database_name else self.__database_name
        table = table_name if table_name else self.__table_name
        try:
            response = self.__timestream_write_client.describe_table(
                DatabaseName=db_name,
                TableName=table)
            result = True
        except ClientError as ex:
            if raise_err:
                raise ex
            else:
                result = False
        return result

    def load(
        self,
        data: list,
        capture_time: datetime,
        source: str = None,
        upsert: bool = False,
        ignore_parse_error: bool = True,
        max_batch_size: int = DEFAULT_MAX_BATCH_SIZE,
        min_common_dimensions: int = DEFAULT_MIN_COMMON_DIMENSIONS
    ) -> dict:
        '''
        Performs a Timestream load. Validates the input and parses the data into Timestream records.
        Returns the result of the load after writing to Timestream.
        '''
        return self.__load(
            data,
            capture_time,
            source, upsert,
            ignore_parse_error,
            max_batch_size,
            min_common_dimensions,
            False,
        ).to_dict()

    def load_preview(
        self,
        data: list,
        capture_time: datetime,
        source: str = None,
        upsert: bool = False,
        ignore_parse_error: bool = True,
        max_batch_size: int = DEFAULT_MAX_BATCH_SIZE,
        min_common_dimensions: int = DEFAULT_MIN_COMMON_DIMENSIONS
    ) -> dict:
        '''
        Performs a dry run of a Timestream load, including validating the input values and parsing data
        into Timestream records.
        Returns the result of the load before writing to Timestream.
        '''
        return self.__load(
            data,
            capture_time,
            source, upsert,
            ignore_parse_error,
            max_batch_size,
            min_common_dimensions,
            True,
        ).to_dict()

    def __load(
        self,
        data: list,
        capture_time: datetime,
        source: str = None,
        upsert: bool = False,
        ignore_parse_error: bool = True,
        max_batch_size: int = DEFAULT_MAX_BATCH_SIZE,
        min_common_dimensions: int = DEFAULT_MIN_COMMON_DIMENSIONS,
        preview: bool = False
    ) -> TimestreamLoadResult:
        self._logger.info(f'Loading data with {len(data)} elements')

        # Parse incoming data into TimestreamRecords
        result = self.__parse_data(data, upsert, ignore_parse_error)
        if preview or result.error or len(result.accepted_records) == 0:
            return result

        # Batch the accepted records by common dimensions
        batches = batch_records(
            result.accepted_records,
            max_batch_size=max_batch_size,
            min_common_dimensions=min_common_dimensions
        )
        assert isinstance(batches, list) and len(batches) > 0, 'Batches must be a non-empty list'
        assert all(len(b.records) > 0 for b in batches), "Not all batches have records (shouldn't be possible)"

        # Write records in batches to Timestream
        result = TimestreamLoadResult()
        for b in batches:
            dimensions = b.common_dimensions
            if source:
                dimensions.add_dimension('source', source, overwrite=True)
            if b.records and upsert:
                accepted_records, rejected_records = self.__write_upsert_records(dimensions, b.records)
            elif b.records:
                accepted_records, rejected_records = self.__write_insert_records(dimensions, b.records, capture_time)

            if accepted_records:
                result.add_accepted_records(accepted_records)
            if rejected_records:
                result.add_rejected_records(rejected_records)
                result.error = Exception('Some records were rejected by Timestream')
        return result

    def __write_insert_records(
        self,
        common_dimensions: TimestreamDimensions,
        records: List[TimestreamRecord],
        capture_time: datetime
    ) -> tuple:
        self._logger.info(f'Writing insert records with {len(records)} records')
        assert capture_time, 'A capture time needs to be provided'
        capture_time_ms = self.__dt_to_ms(capture_time)
        accepted_records, rejected_records = records, []
        try:
            _ = self.__timestream_write_client.write_records(
                DatabaseName=self.__database_name,
                TableName=self.__table_name,
                CommonAttributes={
                    'Dimensions': common_dimensions.loadable(),
                    'MeasureValueType': 'DOUBLE',
                    'Time': str(capture_time_ms),
                    'Version': 1
                },
                Records=[
                    r.loadable(omit_dimensions=common_dimensions)
                    for r in records
                ]
            )
        except ClientError as e:
            self._logger.error(f'Failed to write insert records: {e}')
            accepted_records, rejected_records = self.__process_client_error(e, records)
        return accepted_records, rejected_records

    def __write_upsert_records(
        self,
        common_dimensions: TimestreamDimensions,
        records: List[TimestreamRecord]
    ) -> tuple:
        self._logger.info(f'Writing upsert records with {len(records)} records')
        accepted_records, rejected_records = records, []
        try:
            _ = self.__timestream_write_client.write_records(
                DatabaseName=self.__database_name,
                TableName=self.__table_name,
                CommonAttributes={
                    'Dimensions': common_dimensions.loadable(),
                    'MeasureValueType': 'DOUBLE',
                    'Version': self.__current_time_ms
                },
                Records=[
                    r.loadable(omit_dimensions=common_dimensions)
                    for r in records
                ]
            )
        except ClientError as e:
            self._logger.error(f'Failed to write upsert records: {e}')
            accepted_records, rejected_records = self.__process_client_error(e, records)
        return accepted_records, rejected_records

    def __process_client_error(self, e: ClientError, records: List[TimestreamRecord]) -> tuple:
        '''
        Processes a boto ClientError from a Timestream write operation. Logs error messages
        for each rejected Timestream record and returns 2 lists:
          1. accepted_records: the records that were successfully written to Timestream
          2. rejected_records: the records that Timestream rejected, and the reasons why
        '''
        rejected_records, rejected_indices = [], []
        for rejection_data in e.response.get('RejectedRecords', []):
            rejected_record = RejectedTimestreamRecord(
                record=records[rejection_data['RecordIndex']],
                error=Exception(rejection_data['Reason'])
            )
            self._logger.error(f'Rejected record:\n{rejected_record}')
            rejected_records.append(rejected_record)
            rejected_indices.append(rejection_data['RecordIndex'])

        accepted_records = [
            records[i] for i in range(len(records))
            if i not in rejected_indices
        ]
        return accepted_records, rejected_records

    def __parse_data(self, data: List[dict], upsert: bool, ignore_parse_error: bool) -> TimestreamLoadResult:
        '''
        Parses the given data and returns a TimestreamLoadResult that contains:
          1. Successfully parsed TimestreamRecords
          2. Invalid items that failed to parse into TimestreamRecords (and the errors for why they failed to parse)
        '''
        # TODO: detect duplicates
        result = TimestreamLoadResult()
        if not data:
            return result
        assert type(data) is list, f'Invalid data, expecting a list but got type "{type(data)}"'

        for item in data:
            try:
                records_from_item = self.__parse_item(item, upsert, ignore_parse_error)
            except AssertionError as e:
                result.error = Exception('Some items could not be parsed into Timestream records')
                result.add_invalid_item(item, e)
                continue
            result.add_accepted_records(records_from_item)
        return result

    def __parse_item(self, item: dict, upsert: bool, ignore_parse_error: bool) -> List[TimestreamRecord]:
        '''
        Parses a single dictionary into Timestream records. A single dictionary may break down into
        multiple Timestream records if there are multiple numerical values in the dictionary.

        Raises an AssertionError if the item (dict) cannot be parsed into Timestream records.
        If "ignore_parse_error" is true, this will attempt to ignore all values of the item
        that cannot be converted into either a Timestream record or dimension (rather than
        raising an AssertionError).
        '''
        if not item:
            return []
        assert type(item) is dict, f'Invalid item with type "{type(item)}", expecting a dictionary'

        records, dimensions = [], TimestreamDimensions()

        # Get "time" from the item. Case insensitive and ignores leading/trailing whitespace
        mstime = {'time': v for k, v in item.items() if str(k).lower().strip() == 'time'}.get('time')

        # Make sure user is not attempting to override the Version
        assert 'Version' not in item, 'Field "Version" is a protected field and cannot exist on an item'

        if upsert:
            assert mstime is not None, f'To perform an upsert, item must have a "time" field'
            assert len([k for k in item if str(k).lower().strip() == 'time']) < 2, f'Item has multiple time fields'
        else:
            assert mstime is None, f'To perform an insert, item cannot specify a "time" field'

        # Build records and dimensions
        for k, v in item.items():
            if v is None or k in ['time', 'version']:
                continue
            elif type(v) in [float, int, list]:
                # Ints, floats, and lists get converted into records
                records.append(TimestreamRecord(
                    measure_name=k,
                    measure_value=v if type(v) in [float, int] else len(v),
                    mstime=mstime
                ))
            elif type(v) is str:
                # Strings get converted into dimensions
                dimensions.add_dimension(k, v)
            else:
                msg = f'Could not create a record or dimension from type "{type(v)}" at key {k}'
                if not ignore_parse_error:
                    raise AssertionError(msg)
                self._logger.warning(msg)

        # Recreate all records with dimensions attached and return
        return [
            TimestreamRecord(
                measure_name=r.measure_name,
                measure_value=r.measure_value,
                dimensions=dimensions,
                mstime=r.mstime
            )
            for r in records
        ]

    def __dt_to_ms(self, dt: datetime) -> int:
        '''Converts a datetime into a POSIX timestamp in milliseconds'''
        if dt:
            return int(round(dt.timestamp() * 1000))
        return -1
