import pytest
import pytz

from botocore.exceptions import ClientError
from datetime import date, datetime, tzinfo
from test_utils.mock import MockApi
from timestream.api import DEFAULT_CLIENT_CONFIG, TimestreamLoader
from timestream.dimensions import TimestreamDimensions
from timestream.load_result import TimestreamLoadResult
from timestream.record import RejectedTimestreamRecord, TimestreamRecord
class TestTimestreamLoader_UnitTests:

    @pytest.mark.unit
    def test_load(self):
        mock_session = MockApi()
        mock_session.on('client', 'timestream-write', config=DEFAULT_CLIENT_CONFIG).return_val(None)
        loader = TimestreamLoader('test', 'test', mock_session)
        mock_session.assert_expectations()

        result = loader.load(
            data=[],
            capture_time=datetime(2021, 1, 1, 0, 0, 0, 0, pytz.UTC)
        )
        assert type(result) == dict
        missing_keys = [
            k for k in ['error', 'invalid_items', 'accepted_records', 'rejected_records']
            if k not in result
        ]
        assert not missing_keys, f'Result {result} is missing keys {missing_keys}'
        assert type(result['invalid_items']) == list
        assert len(result['invalid_items']) == 0
        assert type(result['accepted_records']) == list
        assert len(result['accepted_records']) == 0
        assert type(result['rejected_records']) == list
        assert len(result['rejected_records']) == 0
        assert result['error'] == None

    @pytest.mark.unit
    def test_load_preview(self):
        mock_session = MockApi()
        mock_session.on('client', 'timestream-write', config=DEFAULT_CLIENT_CONFIG).return_val(None)
        loader = TimestreamLoader('test', 'test', mock_session)
        mock_session.assert_expectations()

        result = loader.load_preview(
            data=[],
            capture_time=datetime(2021, 1, 1, 0, 0, 0, 0, pytz.UTC)
        )
        assert type(result) == dict
        missing_keys = [
            k for k in ['error', 'invalid_items', 'accepted_records', 'rejected_records']
            if k not in result
        ]
        assert not missing_keys, f'Result {result} is missing keys {missing_keys}'
        assert type(result['invalid_items']) == list
        assert len(result['invalid_items']) == 0
        assert type(result['accepted_records']) == list
        assert len(result['accepted_records']) == 0
        assert type(result['rejected_records']) == list
        assert len(result['rejected_records']) == 0
        assert result['error'] == None

    @pytest.mark.unit
    def test_private_load(self):
        # Test preview
        mock_session = MockApi()
        mock_session.on('client', 'timestream-write', config=DEFAULT_CLIENT_CONFIG).return_val(None)
        loader = TimestreamLoader('test', 'test', mock_session)
        mock_session.assert_expectations()

        result = loader._TimestreamLoader__load(
            data=[],
            capture_time=datetime(2021, 1, 1, 0, 0, 0, 0, pytz.UTC),
            preview=True
        )
        assert type(result) == TimestreamLoadResult
        assert len(result.invalid_items) == 0
        assert len(result.accepted_records) == 0
        assert len(result.rejected_records) == 0
        assert result.error == None

        # Test insert
        expected_source = 'test soure'
        expected_reason = 'test reason'
        expected_error = ClientError({
            'RejectedRecords': [{
                'RecordIndex': 1,
                'Reason': expected_reason
            }]
        }, 'test operation')
        expected_records=[{
            'MeasureName': 'valid',
            'MeasureValue': '1.0'
        }, {
            'MeasureName': 'invalid',
            'MeasureValue': '0.0'
        }]
        expected_capture_time = '1609459200000'

        mock_client = MockApi()
        mock_client.on('write_records',
            DatabaseName='test',
            TableName='test',
            CommonAttributes={
                'Dimensions': [{'Name': 'source', 'Value': expected_source, 'DimensionValueType': 'VARCHAR'}],
                'MeasureValueType': 'DOUBLE',
                'Time': expected_capture_time,
                'Version': 1
            },
            Records=expected_records
        ).raise_exception(expected_error)
        mock_session = MockApi()
        mock_session.on('client', 'timestream-write', config=DEFAULT_CLIENT_CONFIG).return_val(mock_client)
        
        loader = TimestreamLoader('test', 'test', mock_session)  
        result = loader._TimestreamLoader__load(
            data=[{'valid': 1.0, 'invalid': 0.0}],
            capture_time=datetime(2021, 1, 1, 0, 0, 0, 0, pytz.UTC),
            source=expected_source
        )
        assert type(result) == TimestreamLoadResult
        assert len(result.accepted_records) == 1
        assert result.accepted_records[0].measure_name == 'valid'
        assert len(result.rejected_records) == 1
        rejected_record = result.rejected_records[0]
        assert type(rejected_record) == RejectedTimestreamRecord
        assert str(rejected_record.error) == expected_reason
        assert rejected_record.record.measure_name == 'invalid'
        mock_session.assert_expectations()
        mock_client.assert_expectations()

        # Test upsert
        expected_records=[{
            'MeasureName': 'valid',
            'MeasureValue': '1.0',
            'Time': expected_capture_time
        }, {
            'MeasureName': 'invalid',
            'MeasureValue': '0.0',
            'Time': expected_capture_time
        }]
        mock_client = MockApi()
        mock_client.on('write_records',
            DatabaseName='test',
            TableName='test',
            CommonAttributes={
                'Dimensions': [],
                'MeasureValueType': 'DOUBLE',
                'Version': int(expected_capture_time)
            },
            Records=expected_records
        ).raise_exception(expected_error)
        mock_session = MockApi()
        mock_session.on('client', 'timestream-write', config=DEFAULT_CLIENT_CONFIG).return_val(mock_client)
        
        loader = TimestreamLoader('test', 'test', mock_session, current_time_ms=int(expected_capture_time))
        result = loader._TimestreamLoader__load(
            data=[{'valid': 1.0, 'invalid': 0.0, 'time': int(expected_capture_time), 'version': 2}],
            capture_time=datetime(2021, 1, 1, 0, 0, 0, 0, pytz.UTC),
            upsert=True
        )
        assert type(result) == TimestreamLoadResult
        assert len(result.accepted_records) == 1
        assert result.accepted_records[0].measure_name == 'valid'
        assert len(result.rejected_records) == 1
        rejected_record = result.rejected_records[0]
        assert type(rejected_record) == RejectedTimestreamRecord
        assert str(rejected_record.error) == expected_reason
        assert rejected_record.record.measure_name == 'invalid'
        mock_session.assert_expectations()
        mock_client.assert_expectations()

    @pytest.mark.unit
    def test_write_insert_records(self):
        # Set expectations
        expected_reason = 'test reason'
        expected_error = ClientError({
            'RejectedRecords': [{
                'RecordIndex': 1,
                'Reason': expected_reason
            }]
        }, 'test operation')
        expected_valid_record = TimestreamRecord('valid', 1.0)
        expected_invalid_record = TimestreamRecord('invalid', 0.0)
        expected_records=[{
            'MeasureName': 'valid',
            'MeasureValue': '1.0'
        }, {
            'MeasureName': 'invalid',
            'MeasureValue': '0.0'
        }]
        expected_capture_time = '1609459200000'
        
        # Set up mocks
        mock_client = MockApi()
        mock_client.on('write_records',
            DatabaseName='test',
            TableName='test',
            CommonAttributes={
                'Dimensions': [],
                'MeasureValueType': 'DOUBLE',
                'Time': expected_capture_time,
                'Version': 1
            },
            Records=expected_records
        ).raise_exception(expected_error)
        mock_session = MockApi()
        mock_session.on('client', 'timestream-write', config=DEFAULT_CLIENT_CONFIG).return_val(mock_client)
        
        # Test ClientError occurs during Timestream write
        loader = TimestreamLoader('test', 'test', mock_session)
        accepted_records, rejected_records = loader._TimestreamLoader__write_insert_records(
            dimensions=TimestreamDimensions(),
            records=[expected_valid_record, expected_invalid_record],
            capture_time=datetime(2021, 1, 1, 0, 0, 0, 0, tzinfo=pytz.UTC)
        )
        assert len(accepted_records) == 1
        assert accepted_records[0] == expected_valid_record
        assert len(rejected_records) == 1

        rejected_record = rejected_records[0]
        assert type(rejected_record) == RejectedTimestreamRecord
        assert str(rejected_record.error) == expected_reason
        assert rejected_record.record == expected_invalid_record

        mock_session.assert_expectations()
        mock_client.assert_expectations()

    @pytest.mark.unit
    def test_write_upsert_records(self):
        # Set expectations
        expected_version=1609459200000
        expected_reason = 'test reason'
        expected_error = ClientError({
            'RejectedRecords': [{
                'RecordIndex': 1,
                'Reason': expected_reason
            }]
        }, 'test operation')
        expected_valid_record = TimestreamRecord('valid', 1.0)
        expected_invalid_record = TimestreamRecord('invalid', 0.0)
        expected_records=[{
            'MeasureName': 'valid',
            'MeasureValue': '1.0'
        }, {
            'MeasureName': 'invalid',
            'MeasureValue': '0.0'
        }]
        
        # Set up mocks
        mock_client = MockApi()
        mock_client.on('write_records',
            DatabaseName='test',
            TableName='test',
            CommonAttributes={
                'Dimensions': [],
                'MeasureValueType': 'DOUBLE',
                'Version': expected_version
            },
            Records=expected_records
        ).raise_exception(expected_error)
        mock_session = MockApi()
        mock_session.on('client', 'timestream-write', config=DEFAULT_CLIENT_CONFIG).return_val(mock_client)
        
        # Test ClientError occurs during Timestream write
        loader = TimestreamLoader('test', 'test', mock_session, current_time_ms=expected_version)
        accepted_records, rejected_records = loader._TimestreamLoader__write_upsert_records(
            dimensions=TimestreamDimensions(),
            records=[expected_valid_record, expected_invalid_record]
        )
        assert len(accepted_records) == 1
        assert accepted_records[0] == expected_valid_record
        assert len(rejected_records) == 1

        rejected_record = rejected_records[0]
        assert type(rejected_record) == RejectedTimestreamRecord
        assert str(rejected_record.error) == expected_reason
        assert rejected_record.record == expected_invalid_record

        mock_session.assert_expectations()
        mock_client.assert_expectations()

    @pytest.mark.unit
    def test_process_client_error(self):
        mock_session = MockApi()
        mock_session.on('client', 'timestream-write', config=DEFAULT_CLIENT_CONFIG).return_val(None)
        loader = TimestreamLoader('test', 'test', mock_session)
        mock_session.assert_expectations()

        expected_reason = 'test reason'
        expected_error = ClientError({
            'RejectedRecords': [{
                'RecordIndex': 1,
                'Reason': expected_reason
            }]
        }, 'test operation')
        expected_valid_record = TimestreamRecord('valid', 1.0)
        expected_invalid_record = TimestreamRecord('invalid', 0.0)

        accepted_records, rejected_records = loader._TimestreamLoader__process_client_error(
            e=expected_error,
            records=[expected_valid_record, expected_invalid_record]
        )
        assert len(accepted_records) == 1
        assert accepted_records[0] == expected_valid_record
        assert len(rejected_records) == 1

        rejected_record = rejected_records[0]
        assert type(rejected_record) == RejectedTimestreamRecord
        assert str(rejected_record.error) == expected_reason
        assert rejected_record.record == expected_invalid_record

    @pytest.mark.unit
    def test_parse_data(self):
        mock_session = MockApi()
        mock_session.on('client', 'timestream-write', config=DEFAULT_CLIENT_CONFIG).return_val(None)
        loader = TimestreamLoader('test', 'test', mock_session)

        # Test no data
        result = loader._TimestreamLoader__parse_data([], False, False)
        assert type(result) == TimestreamLoadResult
        assert str(result)
        assert len(result.invalid_items) == 0
        assert len(result.accepted_records) == 0
        assert len(result.rejected_records) == 0
        assert result.error == None

        # Test mix of valid and invalid items
        expected_invalid_item = {
            'unparseable': datetime(2021, 1, 1, 0, 0, 0, 0, tzinfo=pytz.UTC)
        }
        expected_error_msg = 'Some items could not be parsed into Timestream records'

        result = loader._TimestreamLoader__parse_data(
            data=[{'integer': 10}, expected_invalid_item],
            upsert=False,
            ignore_parse_error=False
        )
        assert type(result) == TimestreamLoadResult
        assert len(result.invalid_items) == 1
        assert len(result.accepted_records) == 1
        assert len(result.rejected_records) == 0
        assert str(result.error) == expected_error_msg
        assert result.invalid_items[0]['item'] == expected_invalid_item

        record = result.accepted_records[0]
        assert record and type(record) == TimestreamRecord
        assert record.measure_name == 'integer'
        assert record.measure_value == 10.0
        assert record.mstime == None
        assert record.time == None

    @pytest.mark.unit
    def test_parse_item(self):
        mock_session = MockApi()
        mock_session.on('client', 'timestream-write', config=DEFAULT_CLIENT_CONFIG).return_val(None)
        loader = TimestreamLoader('test', 'test', mock_session)
        mock_session.assert_expectations()

        # Test no item
        assert loader._TimestreamLoader__parse_item({}, False, False) == []

        # Test parse insert item
        expected_dimensions = TimestreamDimensions({
            'string': 'dimension',
            'remove this': 'dimension'
        })
        expected_dimensions.remove_dimension('remove this')
        records = loader._TimestreamLoader__parse_item({
            'integer': 10,
            'float': 20.0,
            'list': [10, 20.0],
            'string': 'dimension',
            'none': None
        }, False, False)
        assert len(records) == 3
        assert str(records[0])

        record = records[0]
        assert record and type(record) == TimestreamRecord
        assert record.measure_name == 'integer'
        assert record.measure_value == 10.0
        assert record.mstime == None
        assert record.time == None
        assert record.dimensions == expected_dimensions
        assert record.dimensions.get_dimension('string') == 'dimension'
        assert str(record.dimensions)

        record = records[1]
        assert record and type(record) == TimestreamRecord
        assert record.measure_name == 'float'
        assert record.measure_value == 20.0
        assert record.mstime == None
        assert record.time == None
        assert record.dimensions == expected_dimensions

        record = records[2]
        assert record and type(record) == TimestreamRecord
        assert record.measure_name == 'list'
        assert record.measure_value == 2.0
        assert record.mstime == None
        assert record.time == None
        assert record.dimensions == expected_dimensions

        # Test parse upsert item
        records = loader._TimestreamLoader__parse_item({
            'time': 1609459200000,
            'integer': 10
        }, True, False)
        assert len(records) == 1

        record = records[0]
        assert record and type(record) == TimestreamRecord
        assert record.measure_name == 'integer'
        assert record.measure_value == 10.0
        assert record.mstime == 1609459200000
        assert record.time == datetime.combine(date(2021, 1, 1), datetime.min.time()).replace(tzinfo=pytz.UTC)

        # Test parse error
        with pytest.raises(AssertionError) as e_info:
            loader._TimestreamLoader__parse_item({
                'unparseable': {'key': 'value'} 
            }, False, False)
        assert str(e_info.value) == 'Could not create a record or dimension from type "<class \'dict\'>" at key unparseable'

        # Test ignore parse error
        records = loader._TimestreamLoader__parse_item({
            'unparseable': {'key': 'value'} 
        }, False, True)
        assert len(records) == 0

    @pytest.mark.unit
    def test_dt_to_ms(self):
        mock_session = MockApi()
        mock_session.on('client', 'timestream-write', config=DEFAULT_CLIENT_CONFIG).return_val(None)
        loader = TimestreamLoader('test', 'test', mock_session)
        mock_session.assert_expectations()
        
        dt = datetime.combine(date(2021, 1, 1), datetime.min.time()).replace(tzinfo=pytz.UTC)
        assert loader._TimestreamLoader__dt_to_ms(dt) == 1609459200000
        assert loader._TimestreamLoader__dt_to_ms(None) == -1

    @pytest.mark.unit
    def test_get_logger(self):
        mock_session = MockApi()
        mock_session.on('client', 'timestream-write', config=DEFAULT_CLIENT_CONFIG).return_val(None)
        loader = TimestreamLoader('test', 'test', mock_session)
        mock_session.assert_expectations()

        loader._get_logger('debug')
        loader._get_logger('warning')
        loader._get_logger('error')
        loader._get_logger('critical')
        loader._get_logger('info')
