import pytest
import pytz

from botocore.exceptions import ClientError
from datetime import date, datetime
from timestream.timestream_reader import DEFAULT_WRITE_CLIENT_CONFIG, TimestreamReader
from test_utils.mock import MockApi


class TestTimestreamLoader_UnitTests:

    @pytest.mark.unit
    def test_does_database_exist(self):
        # Set expectations
        expected_error = ClientError({'response': {'Error': {'Message': 'The database default does not exist.', 'Code': 'ResourceNotFoundException'}, 'ResponseMetadata': {'RequestId': 'V6E3FBONKV7C6YK3WLSVISIJGI', 'HTTPStatusCode': 404, 'HTTPHeaders': {'x-amzn-requestid': 'V6E3FBONKV7C6YK3WLSVISIJGI', 'content-type': 'application/x-amz-json-1.0', 'content-length': '129', 'date': 'Wed, 20 Oct 2021 17:04:42 GMT'}, 'RetryAttempts': 0}, 'Message': 'The database default does not exist.'}, 'operation_name': 'DescribeDatabase'}, 'DescribeDatabase')
        mock_response = {'Database': {'Arn': 'arn:aws:timestream:us-east-1:000000000000:database/test', 'DatabaseName': 'test', 'TableCount': 1, 'KmsKeyId': 'arn:aws:kms:us-east-1:000000000000:key/12345678-1234-1234-1234-1234567890ab', 'CreationTime': datetime(2021, 7, 19, 16, 25, 26, 419000), 'LastUpdatedTime': datetime(2021, 10, 8, 16, 25, 37, 893000)}, 'ResponseMetadata': {'RequestId': '1234567890abcdefghijklmnop', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amzn-requestid': '1234567890abcdefghijklmnop', 'content-type': 'application/x-amz-json-1.0', 'content-length': '289', 'date': 'Wed, 20 Oct 2021 16:54:48 GMT'}, 'RetryAttempts': 0}}

        # Set up mocks
        mock_write_client = MockApi()
        mock_write_client.on(
            'describe_database',
            DatabaseName='default',
        ).raise_exception(expected_error)
        mock_write_client.on(
            'describe_database',
            DatabaseName='test',
        ).return_val(mock_response)
        mock_session = self.get_mock_session(write_client=mock_write_client)

        # Test False when using TimestreamReader's database attribute
        reader = TimestreamReader(mock_session)
        result = reader.does_database_exist('default')
        assert not result

        # Test exception raised if not found
        with pytest.raises(ClientError):
            reader.does_database_exist('default', raise_err=True)

        # Test True using test name
        result = reader.does_database_exist('test')
        assert result

        mock_session.assert_expectations()
        mock_write_client.assert_expectations()

    @pytest.mark.unit
    def test_does_table_exist(self):
        # Set expectations
        expected_error = ClientError({'response': {'Error': {'Message': 'The table default does not exist.', 'Code': 'ResourceNotFoundException'}, 'ResponseMetadata': {'RequestId': 'V6E3FBONKV7C6YK3WLSVISIJGI', 'HTTPStatusCode': 404, 'HTTPHeaders': {'x-amzn-requestid': 'V6E3FBONKV7C6YK3WLSVISIJGI', 'content-type': 'application/x-amz-json-1.0', 'content-length': '129', 'date': 'Wed, 20 Oct 2021 17:04:42 GMT'}, 'RetryAttempts': 0}, 'Message': 'The table default does not exist.'}, 'operation_name': 'DescribeTable'}, 'DescribeTable')
        mock_response = {'Table': {'Arn': 'arn:aws:timestream:us-east-1:000000000000:database/test/table/test', 'TableName': 'test', 'DatabaseName': 'test', 'TableStatus': 'ACTIVE', 'RetentionProperties': {'MemoryStoreRetentionPeriodInHours': 720, 'MagneticStoreRetentionPeriodInDays': 1}, 'CreationTime': datetime(2021, 8, 19, 18, 21, 1, 519000), 'LastUpdatedTime': datetime(2021, 8, 31, 12, 36, 4, 503000)}, 'ResponseMetadata': {'RequestId': 'SDKOJJS7STGCQXOFVLXPBYDRKI', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amzn-requestid': 'SDKOJJS7STGCQXOFVLXPBYDRKI', 'content-type': 'application/x-amz-json-1.0', 'content-length': '366', 'date': 'Wed, 20 Oct 2021 17:41:09 GMT'}, 'RetryAttempts': 1}}

        # Set up mocks
        mock_write_client = MockApi()
        mock_write_client.on(
            'describe_table',
            DatabaseName='default',
            TableName='default'
        ).raise_exception(expected_error)
        mock_write_client.on(
            'describe_table',
            DatabaseName='test',
            TableName='test'
        ).return_val(mock_response)
        mock_session = self.get_mock_session(write_client=mock_write_client)

        # Test False when using TimestreamReader's database and table attribute
        reader = TimestreamReader(mock_session)
        result = reader.does_table_exist('default', 'default')
        assert not result

        # Test exception raised if not found
        with pytest.raises(ClientError) as e_info:
            reader.does_table_exist('default', 'default', raise_err=True)

        # Test True using test name
        result = reader.does_table_exist('test', 'test')
        assert result

        mock_session.assert_expectations()
        mock_write_client.assert_expectations()

    @pytest.mark.unit
    def test_read(self):
        # Set expectations
        expected_time = datetime(2022, 1, 1, 11, 12, 13, 14, pytz.utc)
        expected_time_str = expected_time.strftime('%Y-%m-%d %H:%M:%S.%f000')
        expected_sql = 'SELECT * FROM "db"."table" LIMIT 5'
        expected_rows = [{
            'Data': [{'ScalarValue': 1.0}, {'NullValue': True}, {'ScalarValue': expected_time_str}]
        }]
        expected_columns = [{
            'Name': 'measure_value::double'
        }, {
            'Name': 'dimension1'
        }, {
            'Name': 'time'
        }]
        expected_next_token = 'token'
        expected_resp = {
            'Rows': expected_rows,
            'ColumnInfo': expected_columns,
            'NextToken': expected_next_token
        }
        expected_records = [{
            'measure_value::double': 1.0,
            'dimension1': None,
            'time': expected_time
        }]

        # Set up mocks
        mock_read_client = MockApi()
        mock_read_client.on(
            'query', QueryString=expected_sql, MaxRows=1000
        ).return_val(expected_resp)
        mock_read_client.on(
            'query', QueryString=expected_sql, NextToken=expected_next_token, MaxRows=1000
        ).return_val({})
        
        # Test read
        reader = TimestreamReader(self.get_mock_session(read_client=mock_read_client))
        records = reader.read('db', 'table', limit=5)
        assert records == expected_records

    @pytest.mark.unit
    def test_generate_sql(self):
        reader = TimestreamReader(self.get_mock_session())
        time = datetime(2022, 1, 1, 11, 12, 13, 14, pytz.utc)
        min_time = datetime.combine(time.date(), datetime.min.time()).replace(tzinfo=pytz.utc)
        max_time = datetime.combine(time.date(), datetime.max.time()).replace(tzinfo=pytz.utc)

        # Test projection expression
        sql = reader._generate_sql('db', 'table', projection_expression='COUNT(*)')
        assert sql == 'SELECT COUNT(*) FROM "db"."table"'

        # Test measure value
        sql = reader._generate_sql('db', 'table', measure_value=1.0, measure_value_op='<>')
        assert sql == 'SELECT * FROM "db"."table" WHERE measure_value::double <> 1.0'


        # Test time
        sql = reader._generate_sql('db', 'table', time=time)
        assert sql == "SELECT * FROM \"db\".\"table\" WHERE time = FROM_ISO8601_TIMESTAMP('2022-01-01T11:12:13.000014+00:00')"

        # Test min time
        sql = reader._generate_sql('db', 'table', min_time=min_time)
        assert sql == "SELECT * FROM \"db\".\"table\" WHERE time >= FROM_ISO8601_TIMESTAMP('2022-01-01T00:00:00+00:00')"

        # Test max time
        sql = reader._generate_sql('db', 'table', max_time=max_time)
        assert sql == "SELECT * FROM \"db\".\"table\" WHERE time <= FROM_ISO8601_TIMESTAMP('2022-01-01T23:59:59.999999+00:00')"

        # Test min + max time
        sql = reader._generate_sql('db', 'table', min_time=min_time, max_time=max_time)
        assert sql == "SELECT * FROM \"db\".\"table\" WHERE time BETWEEN FROM_ISO8601_TIMESTAMP('2022-01-01T00:00:00+00:00') AND FROM_ISO8601_TIMESTAMP('2022-01-01T23:59:59.999999+00:00')"

        # Test kwargs + limit
        sql = reader._generate_sql('db', 'table', limit=5, dim1='val1', t_AppID=['SVC98765', 'SVC01234'])
        assert sql == "SELECT * FROM \"db\".\"table\" WHERE dim1 = 'val1' AND t_AppID IN ('SVC98765','SVC01234') LIMIT 5"

    @pytest.mark.unit
    def test_generate_sql_assertions(self):
        reader = TimestreamReader(self.get_mock_session())
        time = datetime(2022, 1, 1, 11, 12, 13, 14, pytz.utc)
        min_time = datetime.combine(time.date(), datetime.min.time()).replace(tzinfo=pytz.utc)
        max_time = datetime.combine(time.date(), datetime.max.time()).replace(tzinfo=pytz.utc)

        with pytest.raises(AssertionError) as e_info:
            reader._generate_sql(None, 'table')
        assert str(e_info.value).startswith('Cannot generate SQL without a database')

        with pytest.raises(AssertionError) as e_info:
            reader._generate_sql('db', None)
        assert str(e_info.value).startswith('Cannot generate SQL without a table')

        with pytest.raises(AssertionError) as e_info:
            reader._generate_sql('db', 'table', projection_expression=1)
        assert str(e_info.value).startswith('Invalid projection expression')

        with pytest.raises(AssertionError) as e_info:
            reader._generate_sql('db', 'table', measure_value='a')
        assert str(e_info.value).startswith('Invalid measure value')

        with pytest.raises(AssertionError) as e_info:
            reader._generate_sql('db', 'table', measure_value_op=10)
        assert str(e_info.value).startswith('Invalid measure value op')

        with pytest.raises(AssertionError) as e_info:
            reader._generate_sql('db', 'table', time=10)
        assert str(e_info.value).startswith('Invalid time')

        with pytest.raises(AssertionError) as e_info:
            reader._generate_sql('db', 'table', min_time=10)
        assert str(e_info.value).startswith('Invalid min_time')

        with pytest.raises(AssertionError) as e_info:
            reader._generate_sql('db', 'table', max_time=10)
        assert str(e_info.value).startswith('Invalid max_time')

        with pytest.raises(AssertionError) as e_info:
            reader._generate_sql('db', 'table', time=time, min_time=min_time)
        assert str(e_info.value).startswith('Cannot specify min_time when time is specified')

        with pytest.raises(AssertionError) as e_info:
            reader._generate_sql('db', 'table', time=time, max_time=max_time)
        assert str(e_info.value).startswith('Cannot specify max_time when time is specified')

        with pytest.raises(AssertionError) as e_info:
            reader._generate_sql('db', 'table', limit='10')
        assert str(e_info.value).startswith('Invalid limit')

        with pytest.raises(AssertionError) as e_info:
            reader._generate_sql('db', 'table', limit=-1)
        assert str(e_info.value).startswith('Invalid limit')

        with pytest.raises(AssertionError) as e_info:
            reader._generate_sql('db', 'table', dim1=10)
        assert str(e_info.value).startswith('Invalid kwarg value')

    @pytest.mark.unit
    def test_parse_rows(self):
        # Set expectations
        expected_time = datetime(2022, 1, 1, 11, 12, 13, 14, pytz.utc)
        expected_time_str = expected_time.strftime('%Y-%m-%d %H:%M:%S.%f000')
        expected_rows = [{
            'Data': [{'ScalarValue': 1.0}, {'NullValue': True}, {'ScalarValue': expected_time_str}]
        }, {
            'Data': [{'ScalarValue': 2.0}, {'Unknown': 'bad'}, {'ScalarValue': expected_time_str}]
        }]
        expected_columns = [{
            'Name': 'measure_value::double'
        }, {
            'Name': 'dimension1'
        }, {
            'Name': 'time'
        }]
        expected_records = [{
            'measure_value::double': 1.0,
            'dimension1': None,
            'time': expected_time
        }]
        reader = TimestreamReader(self.get_mock_session())

        # Test with bad row
        with pytest.raises(Exception) as e_info:
            reader._parse_rows(expected_rows, expected_columns)
        assert str(e_info.value).startswith('Unrecognized type for data')

        # Test without bad row
        expected_rows = expected_rows[:1]
        records = reader._parse_rows(expected_rows, expected_columns)
        assert records == expected_records

    @pytest.mark.unit
    def test_parse_rows_assertions(self):
        reader = TimestreamReader(self.get_mock_session())

        with pytest.raises(AssertionError) as e_info:
            reader._parse_rows('rows', [])
        assert str(e_info.value).startswith('Invalid rows to parse')
        
        with pytest.raises(AssertionError) as e_info:
            reader._parse_rows([], 'columns')
        assert str(e_info.value).startswith('Invalid columns to parse')

        with pytest.raises(AssertionError) as e_info:
            reader._parse_rows(['row'], [])
        assert str(e_info.value).startswith('Invalid row')

        with pytest.raises(AssertionError) as e_info:
            reader._parse_rows([{'Data': []}], [{'Name': 'col1'}])
        assert str(e_info.value).startswith('Invalid row data')

        with pytest.raises(AssertionError) as e_info:
            reader._parse_rows([{'Data': ['val1']}], ['column'])
        assert str(e_info.value).startswith('Invalid column')

        with pytest.raises(AssertionError) as e_info:
            reader._parse_rows([{'Data': ['val1']}], [{'bad': 'column'}])
        assert str(e_info.value).startswith('Invalid column')

    @pytest.mark.unit
    def test_dt_to_ms(self):
        mock_session = self.get_mock_session()
        reader = TimestreamReader(mock_session)
        mock_session.assert_expectations()

        dt = datetime.combine(
            date(2021, 1, 1),
            datetime.min.time()
        ).replace(tzinfo=pytz.UTC)
        assert reader._dt_to_ms(dt) == 1609459200000
        assert reader._dt_to_ms(None) == -1

    def get_mock_session(self, write_client=None, read_client=None) -> MockApi:
        mock_session = MockApi()
        mock_session.on(
            'client',
            'timestream-write',
            config=DEFAULT_WRITE_CLIENT_CONFIG
        ).return_val(write_client or MockApi())
        mock_session.on(
            'client',
            'timestream-query'
        ).return_val(read_client or MockApi())
        return mock_session
