import boto3
import logging
import pytz

from botocore.config import Config
from botocore.exceptions import ClientError
from datetime import datetime
from timestream import DEFAULT_LOGGER, log_it
from typing import List

DEFAULT_WRITE_CLIENT_CONFIG = Config(
    read_timeout=20,
    max_pool_connections=5000,
    retries={'max_attempts': 10}
)


class TimestreamReader:
    def __init__(
        self,
        session: boto3.Session = None,
        current_time_ms: int = None,
        logger: logging.Logger = None
    ):
        self._session = session if session else boto3.Session()
        self._timestream_read_client = self._session.client('timestream-query')
        self._timestream_write_client = self._session.client(
            'timestream-write',
            config=DEFAULT_WRITE_CLIENT_CONFIG
        )
        self._logger = logger or DEFAULT_LOGGER

        self._current_time_ms = self._dt_to_ms(datetime.now().astimezone(pytz.utc))
        if current_time_ms is not None:
            assert type(current_time_ms) is int and current_time_ms > 0, f'Invalid current_time_ms: {current_time_ms}. Expecting an integer greater than 0'
            self._logger.warning(f'Overriding current_time_ms with value {current_time_ms}')
            self._current_time_ms = current_time_ms

    @log_it
    def does_database_exist(
        self,
        database_name: str = None,
        raise_err: bool = False
    ) -> bool:
        '''
        Performs a boto3 describe_database API call and returns True if the
        specified database_name exists

        Parameters
        ----------
        database_name : str
            The name of the database you want to check

        raise_err : bool, Default : False
            Whether to raise the boto3 ResourceNotFoundException or just return
            False

        Returns
        -------
        bool

        Raises
        ------
        botocore.exceptions.ClientError/ResourceNotFoundException
        '''
        assert database_name, 'Cannot check if a database exists without a database name'
        try:
            response = self._timestream_write_client.describe_database(
                DatabaseName=database_name
            )
            result = True
        except ClientError as ex:
            if raise_err:
                raise ex
            else:
                result = False
        return result

    @log_it
    def does_table_exist(
        self,
        database_name: str = None,
        table_name: str = None,
        raise_err: bool = False
    ) -> bool:
        '''
        Performs a boto3 describe_table API call and returns True if the
        specified table_name exists in the database

        Parameters
        ----------
        database_name : str
            The name of the database you want to check

        table_name : str
            The name of the database you want to check

        raise_err : bool, Default : False
            Whether to raise the boto3 ResourceNotFoundException or just return
            False

        Returns
        -------
        bool

        Raises
        ------
        botocore.exceptions.ClientError/ResourceNotFoundException
        '''
        assert database_name, 'Cannot check if a table exists without a database name'
        assert table_name, 'Cannot check if a table exists without a table name'
        try:
            response = self._timestream_write_client.describe_table(
                DatabaseName=database_name,
                TableName=table_name
            )
            result = True
        except ClientError as ex:
            if raise_err:
                raise ex
            else:
                result = False
        return result

    @log_it
    def read(
        self,
        database_name: str,
        table_name: str,
        projection_expression: str = None,
        measure_value: float = None,
        measure_value_op: float = None,
        time: datetime = None,
        min_time: datetime = None,
        max_time: datetime = None,
        limit: int = None,
        **kwargs
    ) -> List[dict]:
        '''Generates a SQL query string based on the provided parameters executes it.
        Returns records as a list of dictionaries.

        Parameters
        ----------
        database_name : str
            (required) Name of the database to query

        table_name : str
            (required) Name of the database table to query

        projection_expression : str
            (optional) Expression to filter columns returned.
            For example:
             - t_AppID,measure_name,measure_value::double,time
             - COUNT(time)
            
            By default, uses '*'
        
        measure_value : float
            (optional) Measure value to filter records by.
            To set the operator, see "measure_value_op"

        measure_value_op : str
            (optional) Mathematical operator used when filtering by measure_value.
            One of [>, <, =, <>, >=, <=]
            If unspecified, uses '='

        time : datetime
            (optional) Timestamp to filter records by (searches for exact match).
            Forced into UTC time.
            If specified, "min_time" and "max_time" cannot be specified.

        min_time : datetime
            (optional) Timestamp to filter records by (searches for records later than this time).
            Forced into UTC time.
            If specified, "time" cannot be specified. May be used in conjunction with "max_time"

        max_time : datetime
            (optional) Timestamp to filter records by (searches for records earlier than this time).
            Forced into UTC time.
            If specified, "time" cannot be specified. May be used in conjunction with "min_time"

        limit : int
            (optional) Limit for the total number of records to return. Must be greater than 0.

        kwargs : dict
            (optional) All other keyword args are used for string comparisons. Accepts kwarg values
            of either strings or lists of strings.
            For example:
                reader._generate_sql(
                    dimensionA='first'
                    dimensionB=['second', 'third']
                )
            generates the where clause:
                dimensionA = 'first' AND dimensionB IN ('second','third')
        
        Returns
        -------
        records : List[dict]
        '''
        sql = self._generate_sql(
            database_name,
            table_name,
            projection_expression,
            measure_value,
            measure_value_op,
            time,
            min_time,
            max_time,
            limit,
            **kwargs
        )
        resp = self._timestream_read_client.query(
            QueryString=sql,
            MaxRows=1000
        )
        next_token = resp.get('NextToken')
        rows = resp.get('Rows') or []
        columns = resp.get('ColumnInfo') or []
        while next_token:
            resp = self._timestream_read_client.query(
                QueryString=sql,
                NextToken=next_token,
                MaxRows=1000
            )
            next_token = resp.get('NextToken')
            rows.extend(resp.get('Rows') or [])
        self._logger.info(f'Retrieved {len(rows)} records from Timestream')
        return self._parse_rows(rows, columns)

    @log_it
    def _generate_sql(
        self,
        database_name: str,
        table_name: str,
        projection_expression: str = None,
        measure_value: float = None,
        measure_value_op: float = None,
        time: datetime = None,
        min_time: datetime = None,
        max_time: datetime = None,
        limit: int = None,
        **kwargs
    ) -> str:
        assert database_name, f'Cannot read from Timestream without a database name'
        assert table_name, f'Cannot read from Timestream without a table name'
        assert projection_expression is None or isinstance(projection_expression, str), f'Invalid projection expression: {projection_expression}. Expected a string'
        assert measure_value is None or isinstance(measure_value, float), f'Invalid measure value: {measure_value}. Expected a float'
        assert measure_value_op in [None, '>', '<', '=', '<>', '>=', '<='], f'Invalid measure value op: {measure_value_op}'
        assert time is None or isinstance(time, datetime), f'Invalid time: {time}. Expected a datetime'
        assert min_time is None or isinstance(min_time, datetime), f'Invalid min_time: {min_time}. Expected a datetime'
        assert max_time is None or isinstance(max_time, datetime), f'Invalid max_time: {max_time}. Expected a datetime'
        assert not(time and min_time), f'Cannot specify min_time when time is specified'
        assert not(time and max_time), f'Cannot specify max_time when time is specified'
        assert limit is None or (isinstance(limit, int) and limit > 0), f'Invalid limit: {limit}. Expected an integer greater than 0'
        projection_expression = projection_expression or '*'
        measure_value_op = measure_value_op or '='
        time = None if not time else time.astimezone(pytz.utc)
        min_time = None if not min_time else min_time.astimezone(pytz.utc)
        max_time = None if not max_time else max_time.astimezone(pytz.utc)

        # Build SQL query
        sql = f'SELECT {projection_expression} FROM "{database_name}"."{table_name}"'
        clauses = []
        if measure_value:
            clauses.append(f'measure_value::double {measure_value_op} {measure_value}')
        if time:
            clauses.append(f"time = FROM_ISO8601_TIMESTAMP('{time.isoformat()}')")
        if max_time and min_time:
            clauses.append(f"time BETWEEN FROM_ISO8601_TIMESTAMP('{min_time.isoformat()}') AND FROM_ISO8601_TIMESTAMP('{max_time.isoformat()}')")
        if max_time and not min_time:
            clauses.append(f"time <= FROM_ISO8601_TIMESTAMP('{max_time.isoformat()}')")
        if min_time and not max_time:
            clauses.append(f"time >= FROM_ISO8601_TIMESTAMP('{min_time.isoformat()}')")
        for attr, val in kwargs.items():
            assert isinstance(attr, str), f'Invalid kwarg key: {attr}. Expected a string'
            if isinstance(val, str):
                clauses.append(f'{attr} = {val}')
            elif isinstance(val, list) and all(isinstance(v, str) for v in val):
                vals = "','".join(val)
                clauses.append(f"{attr} IN ('{vals}')")
            else:
                raise AssertionError(f'Invalid kwarg value: {val}. Expected either a string or list of strings')

        if clauses:
            clauses = ' AND '.join(clauses)
            sql += f' WHERE {clauses}'
        if limit:
            sql += f' LIMIT {limit}'
        self._logger.info(f'Generated SQL query: {sql}')
        return sql

    def _parse_rows(self, rows: List[dict], columns: List[dict]) -> List[dict]:
        '''Converts rows+columns returned by Timestream into a list of record dicts'''
        assert isinstance(rows, list)
        assert isinstance(columns, list)

        records = []
        for r in rows:
            assert isinstance(r, dict)
            data = r.get('Data') or []
            assert len(data) == len(columns)

            record = {}
            for i in range(len(columns)):
                assert isinstance(columns[i], dict) and 'Name' in columns[i]
                d, col_name = data[i], columns[i]['Name']
                if d.get('ScalarValue') is not None:
                    record[col_name] = d['ScalarValue']
                elif d.get('NullValue'):
                    record[col_name] = None
                else:
                    raise Exception(f'Unrecognized type for data: {d}')

            if 'measure_value::double' in record:
                record['measure_value::double'] = float(record.pop('measure_value::double'))
            if 'time' in record:
                record['time'] = datetime.strptime(record.pop('time'), '%Y-%m-%d %H:%M:%S.%f000').replace(tzinfo=pytz.utc)
            records.append(record)
        return records

    def _dt_to_ms(self, dt: datetime) -> int:
        '''Converts a datetime into a POSIX timestamp in milliseconds'''
        if dt:
            return int(round(dt.timestamp() * 1000))
        return -1
