import boto3
import sys
import os
import logging
from botocore.config import Config
from botocore.exceptions import ClientError
from datetime import datetime
from typing import List


DEFAULT_CLIENT_CONFIG = Config(
    read_timeout=20,
    max_pool_connections=5000,
    retries={'max_attempts': 10}
)

class TimestreamLoader:
    def __init__(
        self,
        database_name: str = None,
        table_name: str = None,
        session: boto3.Session = None,
        logger: logging.Logger = None
    ):
        assert table_name, 'Must specify table_name'
        self.__database_name = database_name if database_name else 'scorecards-db'
        self.__table_name = table_name
        self.__logger = logger if logger else self._get_logger()
        self.__session = session if session else boto3.Session()
        self.__timestream_write_client = self.__session.client('timestream-write', config=DEFAULT_CLIENT_CONFIG)

    def load(self, collector_name: str, data: list, capture_time: datetime):
        assert collector_name
        assert capture_time
        for value in data:
            records, dimensions = self.__record_and_dimensions_builder(value)
            dimensions.append({
                'Name': 'source',
                'Value': collector_name,
                'DimensionValueType': 'VARCHAR'
            })
            if records and collector_name == 'pagerduty-analytics':
                self.__write_upsert_records(dimensions, records)
            elif records:
                self.__write_common_records(dimensions, records, capture_time)
            

    def __write_common_records(self, dimensions: List[dict], records: List[dict], capture_time: datetime):
        assert capture_time, 'A Capture time needs to be provided'
        capture_time_ms = self.__dt_to_ms(capture_time)
        try:
            _ = self.__timestream_write_client.write_records(
                DatabaseName=self.__database_name,
                TableName=self.__table_name,
                CommonAttributes={
                    'Dimensions': dimensions,
                    'MeasureValueType': 'DOUBLE',
                    'Time': str(capture_time_ms),
                    'Version': 1
                },
                Records=records
            )
        except ClientError as e:
            self.__logger.error(f'Failed to write common records: {e}')
            for rej in e.response.get('RejectedRecords', []):
                self.__logger.error(f'rejected record: {rej} || Dimensions {dimensions} || Common Records {records[rej["RecordIndex"]]}')
            raise e
    
    def __write_upsert_records(self, dimensions: List[dict], records: List[dict]):
        
        try:
            _ = self.__timestream_write_client.write_records(
                DatabaseName=self.__database_name,
                TableName=self.__table_name,
                CommonAttributes={
                    'Dimensions': dimensions,
                    'MeasureValueType': 'DOUBLE'
                },
                Records=records
            )
        except ClientError as e:
            self.__logger.error(f'Failed to write upsert records: {e}')
            for rej in e.response.get('RejectedRecords', []):
                self.__logger.error(f'rejected record: {rej} || Dimensions {dimensions} || Upsert Records {records[rej["RecordIndex"]]}')
            raise e

    def __record_and_dimensions_builder(self, dict_to_iterate: dict) -> (List[dict], List[dict]):
        records, dimensions = [], []
        # time and version are require to be their own to handle upserts
        
        for k, v in dict_to_iterate.items():
            if dict_to_iterate.get('Time'):
                record = {
                    'Time': dict_to_iterate.get('Time'),
                    'Version': dict_to_iterate.get('Version')
                }
            else:
                record = {}
            if v is None or k in ['Time', 'Version']:
                pass
            elif type(v) in [float, int]:
                record['MeasureName'] = k.strip()
                record['MeasureValue'] = str(v).strip()
                records.append(record)
            elif type(v) == list:
                record['MeasureName'] = k.strip()
                record['MeasureValue'] = str(len(v)).strip()
                records.append(record)
            elif type(v) == str:
                dimensions.append({
                    'Name': k.strip(),
                    'Value': v.strip(),
                    'DimensionValueType': 'VARCHAR'
                })
            else:
                self.__logger.warning(f'Could not create a record or dimension from type: {type(v)}')
        return records, dimensions
    
    def __dt_to_ms(self, dt: datetime) -> int:
        if dt:
            return int(round(dt.timestamp() * 1000))
        return -1
    
    def _get_logger(self, log_level: str = 'info') -> logging.Logger:
        """Return a logger with the given logging level set"""
        handler = logging.StreamHandler(sys.stdout)
        log = logging.getLogger('prometheus_exporter')
        log_level = log_level.lower()
        if log_level == 'debug':
            log.setLevel(logging.DEBUG)
            handler.setLevel(logging.DEBUG)
        elif log_level == 'warning':
            log.setLevel(logging.WARNING)
            handler.setLevel(logging.WARNING)
        elif log_level == 'error':
            log.setLevel(logging.ERROR)
            handler.setLevel(logging.ERROR)
        elif log_level == 'critical':
            log.setLevel(logging.CRITICAL)
            handler.setLevel(logging.CRITICAL)
        else:
            log.setLevel(logging.INFO)
            handler.setLevel(logging.INFO)

        formatter = logging.Formatter(
            '%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d - %(message)s')
        handler.setFormatter(formatter)
        log.addHandler(handler)
        log.propagate = False
        return log
