import pytest

from timestream.batch import _batch_records_helper, Batch
from timestream.dimensions import TimestreamDimensions
from timestream.record import TimestreamRecord


class TestBatch_UnitTests:

    @pytest.mark.unit
    def test_batch_records(self):
        assert True

    @pytest.mark.unit
    def test_batch_records_helper(self):
        # Test max batch size assertions
        with pytest.raises(AssertionError):
            _batch_records_helper([], max_batch_size=0)
        with pytest.raises(AssertionError):
            _batch_records_helper([], max_batch_size=101)
        
        # Test min common dimensions assertions
        with pytest.raises(AssertionError):
            _batch_records_helper([], min_common_dimensions=-1)

        # Test no records
        batch, remaining = _batch_records_helper([])
        assert isinstance(remaining, list)
        assert len(remaining) == 0
        assert isinstance(batch, Batch)
        assert isinstance(batch.records, list)
        assert len(batch.records) == 0

        record_a = TimestreamRecord('metric', 2,
            TimestreamDimensions({
                'shared0': '0',
                'shared1': '1',
                'label': 'a',
                'notshared0': 'n/a' 
            })
        )
        record_b = TimestreamRecord('metric', 3,
            TimestreamDimensions({
                'shared0': '0',
                'shared1': '1',
                'label': 'b',
                'notshared1': 'n/a' 
            })
        )
        record_c = TimestreamRecord('metric', 4,
            TimestreamDimensions({
                'shared0': '0',
                'shared1': '1',
                'label': 'c'
            })
        )
        record_d = TimestreamRecord('metric', 5,
            TimestreamDimensions({
                'shared0': '0',
                'shared1': '1',
                'label': 'd'
            })
        )
        records = [record_a, record_b, record_c, record_d]

        # Test batching min common dimensions 0
        batch, remaining = _batch_records_helper(records, min_common_dimensions=0)
        assert len(batch.records) == 4
        assert len(batch.common_dimensions) == 2
        assert len(remaining) == 0

        # Test batching min common dimensions 0, max batch size 2
        batch, remaining = _batch_records_helper(records, max_batch_size=2, min_common_dimensions=0)
        assert len(batch.records) == 2
        assert batch.records[0] == record_a
        assert batch.records[1] == record_b
        assert len(batch.common_dimensions) == 2
        assert len(remaining) == 2

        # Test batching min common dimensions 5
        batch, remaining = _batch_records_helper(records, min_common_dimensions=5)
        assert len(batch.records) == 1
        assert batch.records[0] == record_a
        assert len(batch.common_dimensions) == 4
        assert len(remaining) == 3

        # Test batching min common dimensions 2
        batch, remaining = _batch_records_helper(records, min_common_dimensions=2)
        assert batch.records == records
        assert len(batch.common_dimensions) == 2
        assert batch.common_dimensions.get_dimension('shared0') == '0'
        assert batch.common_dimensions.get_dimension('shared1') == '1'
