import boto3
import logging
import pytz
import sys

from botocore.config import Config
from botocore.exceptions import ClientError
from datetime import datetime
from timestream.dimensions import TimestreamDimensions
from timestream.load_result import TimestreamLoadResult
from timestream.record import RejectedTimestreamRecord, TimestreamRecord
from typing import List


DEFAULT_CLIENT_CONFIG = Config(
    read_timeout=20,
    max_pool_connections=5000,
    retries={'max_attempts': 10}
)


class TimestreamLoader:
    def __init__(
        self,
        database_name: str = None,
        table_name: str = None,
        session: boto3.Session = None,
        current_time_ms: int = None,
        logger: logging.Logger = None
    ):
        assert database_name, 'Must specify a database_name'
        assert table_name, 'Must specify table_name'
        self.__database_name = database_name
        self.__table_name = table_name
        self.__logger = logger if logger else self._get_logger()
        self.__session = session if session else boto3.Session()
        self.__timestream_write_client = self.__session.client('timestream-write', config=DEFAULT_CLIENT_CONFIG)

        self.__current_time_ms = self.__dt_to_ms(datetime.now().astimezone().astimezone(pytz.utc))
        if current_time_ms is not None:
            assert type(current_time_ms) is int and current_time_ms > 0, f'Invalid current_time_ms: {current_time_ms}. Expecting an integer greater than 0'
            self.__logger.warning(f'Overriding TimestreamLoader.__current_time_ms with value {current_time_ms}')
            self.__current_time_ms = current_time_ms

    def load(
        self,
        data: list,
        capture_time: datetime,
        source: str = None,
        upsert: bool = False,
        ignore_parse_error: bool = True
    ) -> dict:
        '''
        Performs a Timestream load. Validates the input and parses the data into Timestream records.
        Returns the result of the load after writing to Timestream.
        '''
        return self.__load(data, capture_time, source, upsert, ignore_parse_error, False).to_dict()

    def load_preview(
        self,
        data: list,
        capture_time: datetime,
        source: str = None,
        upsert: bool = False,
        ignore_parse_error: bool = True
    ) -> dict:
        '''
        Performs a dry run of a Timestream load, including validating the input values and parsing data
        into Timestream records.
        Returns the result of the load before writing to Timestream.
        '''
        return self.__load(data, capture_time, source, upsert, ignore_parse_error, True).to_dict()

    def __load(
        self,
        data: list,
        capture_time: datetime,
        source: str = None,
        upsert: bool = False,
        ignore_parse_error: bool = True,
        preview: bool = False
    ) -> TimestreamLoadResult:
        # Parse incoming data into TimestreamRecords
        result = self.__parse_data(data, upsert, ignore_parse_error)
        if preview or result.error:
            return result

        # Group records by common dimensions using a hash string
        grouped_records, common_dimensions = {}, {}
        for r in result.accepted_records:
            hashstr = r.dimensions.hashstr()
            common_dimensions[hashstr] = r.dimensions
            grouped_records[hashstr] = grouped_records.get(hashstr) or []
            grouped_records[hashstr].append(r)

        # Assert all groups have records
        assert all(len(records) > 0 for records in grouped_records.values()), "Not all groups have records (shouldn't be possible)"

        # Write records in batches to Timestream
        result = TimestreamLoadResult()
        for hashstr, records in grouped_records.items():
            dimensions = common_dimensions[hashstr]
            if source:
                dimensions.add_dimension('source', source, overwrite=True)
            if records and upsert:
                accepted_records, rejected_records = self.__write_upsert_records(dimensions, records)
            elif records:
                accepted_records, rejected_records = self.__write_insert_records(dimensions, records, capture_time)
            
            if accepted_records:
                result.add_accepted_records(accepted_records)
            if rejected_records:
                result.add_rejected_records(rejected_records)
                result.error = Exception('Some records were rejected by Timestream')
        return result

    def __write_insert_records(
        self,
        dimensions: TimestreamDimensions,
        records: List[TimestreamRecord],
        capture_time: datetime
    ) -> tuple:
        assert capture_time, 'A capture time needs to be provided'
        capture_time_ms = self.__dt_to_ms(capture_time)
        accepted_records, rejected_records = records, []
        try:
            _ = self.__timestream_write_client.write_records(
                DatabaseName=self.__database_name,
                TableName=self.__table_name,
                CommonAttributes={
                    'Dimensions': dimensions.loadable(),
                    'MeasureValueType': 'DOUBLE',
                    'Time': str(capture_time_ms),
                    'Version': 1
                },
                Records=[r.loadable() for r in records]
            )
        except ClientError as e:
            self.__logger.error(f'Failed to write insert records: {e}')
            accepted_records, rejected_records = self.__process_client_error(e, records)
        return accepted_records, rejected_records

    def __write_upsert_records(
        self,
        dimensions: TimestreamDimensions,
        records: List[TimestreamRecord]
    ) -> tuple:
        accepted_records, rejected_records = records, []
        try:
            _ = self.__timestream_write_client.write_records(
                DatabaseName=self.__database_name,
                TableName=self.__table_name,
                CommonAttributes={
                    'Dimensions': dimensions.loadable(),
                    'MeasureValueType': 'DOUBLE',
                    'Version': self.__current_time_ms
                },
                Records=[r.loadable() for r in records]
            )
        except ClientError as e:
            self.__logger.error(f'Failed to write upsert records: {e}')
            accepted_records, rejected_records = self.__process_client_error(e, records)
        return accepted_records, rejected_records

    def __process_client_error(self, e: ClientError, records: List[TimestreamRecord]) -> tuple:
        '''
        Processes a boto ClientError from a Timestream write operation. Logs error messages
        for each rejected Timestream record and returns 2 lists:
          1. accepted_records: the records that were successfully written to Timestream
          2. rejected_records: the records that Timestream rejected, and the reasons why
        '''
        rejected_records, rejected_indices = [], []
        for rejection_data in e.response.get('RejectedRecords', []):
            rejected_record = RejectedTimestreamRecord(
                record=records[rejection_data['RecordIndex']],
                error=Exception(rejection_data['Reason'])
            )
            self.__logger.error(f'Rejected record:\n{rejected_record}')
            rejected_records.append(rejected_record) 
            rejected_indices.append(rejection_data['RecordIndex'])
        
        accepted_records = [
            records[i] for i in range(len(records))
            if i not in rejected_indices
        ]
        return accepted_records, rejected_records

    def __parse_data(self, data: List[dict], upsert: bool, ignore_parse_error: bool) -> TimestreamLoadResult:
        '''
        Parses the given data and returns a TimestreamLoadResult that contains:
          1. Successfully parsed TimestreamRecords
          2. Invalid items that failed to parse into TimestreamRecords (and the errors for why they failed to parse)
        '''
        result = TimestreamLoadResult()
        if not data:
            return result
        assert type(data) is list, f'Invalid data, expecting a list but got type "{type(data)}"'

        for item in data:
            try:
                records_from_item = self.__parse_item(item, upsert, ignore_parse_error)
            except AssertionError as e:
                result.error = Exception('Some items could not be parsed into Timestream records')
                result.add_invalid_item(item, e)
                continue
            result.add_accepted_records(records_from_item)
        return result

    def __parse_item(self, item: dict, upsert: bool, ignore_parse_error: bool) -> List[TimestreamRecord]:
        '''
        Parses a single dictionary into Timestream records. A single dictionary may break down into
        multiple Timestream records if there are multiple numerical values in the dictionary.

        Raises an AssertionError if the item (dict) cannot be parsed into Timestream records.
        If "ignore_parse_error" is true, this will attempt to ignore all values of the item
        that cannot be converted into either a Timestream record or dimension (rather than
        raising an AssertionError).
        '''
        if not item:
            return []
        assert type(item) is dict, f'Invalid item with type "{type(item)}", expecting a dictionary'

        records, dimensions = [], TimestreamDimensions()

        # Get "time" from the item. Case insensitive and ignores leading/trailing whitespace
        mstime = {'time': v for k, v in item.items() if str(k).lower().strip() == 'time'}.get('time')

        # Make sure user is not attempting to override the Version
        assert 'Version' not in item, 'Field "Version" is a protected field and cannot exist on an item'

        if upsert:
            assert mstime is not None, f'To perform an upsert, item must have a "time" field'
            assert len([k for k in item if str(k).lower().strip() == 'time']) < 2, f'Item has multiple time fields'
        else:
            assert mstime is None, f'To perform an insert, item cannot specify a "time" field'

        # Build records and dimensions
        for k, v in item.items():
            if v is None or k in ['time', 'version']:
                continue
            elif type(v) in [float, int, list]:
                # Ints, floats, and lists get converted into records
                records.append(TimestreamRecord(
                    measure_name=k,
                    measure_value=v if type(v) in [float, int] else len(v),
                    mstime=mstime
                ))
            elif type(v) is str:
                # Strings get converted into dimensions
                dimensions.add_dimension(k, v)
            else:
                msg = f'Could not create a record or dimension from type "{type(v)}" at key {k}'
                if not ignore_parse_error:
                    raise AssertionError(msg)
                self.__logger.warning(msg)

        # Recreate all records with dimensions attached and return
        return [
            TimestreamRecord(
                measure_name=r.measure_name,
                measure_value=r.measure_value,
                dimensions=dimensions,
                mstime=r.mstime
            )
            for r in records
        ]

    def __dt_to_ms(self, dt: datetime) -> int:
        '''Converts a datetime into a POSIX timestamp in milliseconds'''
        if dt:
            return int(round(dt.timestamp() * 1000))
        return -1

    def _get_logger(self, log_level: str = 'info') -> logging.Logger:
        """Return a logger with the given logging level set"""
        handler = logging.StreamHandler(sys.stdout)
        log = logging.getLogger('prometheus_exporter')
        log_level = log_level.lower()
        if log_level == 'debug':
            log.setLevel(logging.DEBUG)
            handler.setLevel(logging.DEBUG)
        elif log_level == 'warning':
            log.setLevel(logging.WARNING)
            handler.setLevel(logging.WARNING)
        elif log_level == 'error':
            log.setLevel(logging.ERROR)
            handler.setLevel(logging.ERROR)
        elif log_level == 'critical':
            log.setLevel(logging.CRITICAL)
            handler.setLevel(logging.CRITICAL)
        else:
            log.setLevel(logging.INFO)
            handler.setLevel(logging.INFO)

        formatter = logging.Formatter(
            '%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d - %(message)s')
        handler.setFormatter(formatter)
        log.addHandler(handler)
        log.propagate = False
        return log
