from datetime import date, datetime, timedelta
from typing import Optional

import pytest
import pytz

from psi_collector.collector_runner import CollectorRunner
from psi_collector.custom_collector import CustomCollector
from test_utils.mock import MockApi

TEST_CAPTURE_TIME = datetime.fromisoformat("2021-06-01T00:00:00").replace(tzinfo=pytz.utc)
EXPECTED_DATA = [{"dummy": "data"}, {"empty": ""}, {"number": 0}]


class TestCollector(CustomCollector):
    def __init__(self, capture_time, api, aws_session, http_session):
        super().__init__("test", capture_time, api, aws_session, http_session)

    def _collect(self, s3_path: Optional[str] = None):
        return EXPECTED_DATA

    def start_date(self, capture_time):
        return capture_time

    def end_date(self, capture_time):
        return capture_time


TestCollector.__test__ = False


class TestCollectorRunner_UnitTests:
    @pytest.mark.unit
    def test_collect(self):
        # Set up mocks
        mock_s3_loader = MockApi()
        mock_s3_loader.on("load", "test", EXPECTED_DATA, TEST_CAPTURE_TIME).return_val(None)
        mock_timestream_loader = MockApi()
        mock_timestream_loader.on(
            "load",
            data=EXPECTED_DATA,
            capture_time=TEST_CAPTURE_TIME,
            source="test",
            upsert=True,
            ignore_parse_error=True,
            max_batch_size=None,
            min_common_dimensions=None,
            is_multi=False,
        ).return_val({"invalid_items": [], "rejected_records": []})

        # Test collect
        runner = CollectorRunner(
            collectors={"test": TestCollector},
            s3_loader=mock_s3_loader,
            timestream_loader=mock_timestream_loader,
        )
        runner.collect({"collector_name": "test", "capture_time": TEST_CAPTURE_TIME.isoformat()})
        mock_s3_loader.assert_expectations()
        mock_timestream_loader.assert_expectations()

    @pytest.mark.unit
    def test_collect_invalid_items(self):
        # Set up mocks
        mock_s3_loader = MockApi()
        mock_s3_loader.on("load", "test", EXPECTED_DATA, TEST_CAPTURE_TIME).return_val(None)
        mock_timestream_loader = MockApi()
        mock_timestream_loader.on(
            "load",
            data=EXPECTED_DATA,
            capture_time=TEST_CAPTURE_TIME,
            source="test",
            upsert=True,
            ignore_parse_error=True,
            max_batch_size=None,
            min_common_dimensions=None,
            is_multi=False,
        ).return_val({"invalid_items": [1], "rejected_records": []})

        # Test collect with invalid items
        runner = CollectorRunner(
            collectors={"test": TestCollector},
            s3_loader=mock_s3_loader,
            timestream_loader=mock_timestream_loader,
        )

        with pytest.raises(Exception) as e_info:
            runner.collect(
                {"collector_name": "test", "capture_time": TEST_CAPTURE_TIME.isoformat()}
            )
        assert str(e_info.value).startswith("Failed to write to Timestream with invalid items")
        mock_s3_loader.assert_expectations()
        mock_timestream_loader.assert_expectations()

    @pytest.mark.unit
    def test_collect_rejected_records(self):
        # Set up mocks
        mock_s3_loader = MockApi()
        mock_s3_loader.on("load", "test", EXPECTED_DATA, TEST_CAPTURE_TIME).return_val(None)
        mock_timestream_loader = MockApi()
        mock_timestream_loader.on(
            "load",
            data=EXPECTED_DATA,
            capture_time=TEST_CAPTURE_TIME,
            source="test",
            upsert=True,
            ignore_parse_error=True,
            max_batch_size=None,
            min_common_dimensions=None,
            is_multi=False,
        ).return_val({"invalid_items": [], "rejected_records": [1]})

        # Test collect with rejected records
        runner = CollectorRunner(
            collectors={"test": TestCollector},
            s3_loader=mock_s3_loader,
            timestream_loader=mock_timestream_loader,
        )

        with pytest.raises(Exception) as e_info:
            runner.collect(
                {"collector_name": "test", "capture_time": TEST_CAPTURE_TIME.isoformat()}
            )
        assert str(e_info.value).startswith("Failed to write to Timestream with rejected records")
        mock_s3_loader.assert_expectations()
        mock_timestream_loader.assert_expectations()

    @pytest.mark.unit
    def test_get_capture_time(self):
        event = {
            "capture_time": "2021-07-31T00:00:00",
            "s3_path": "s3://bucket/path/to/object/1627689600000",
        }
        expected_capture_time = datetime.fromisoformat("2021-07-31T00:00:00").replace(
            tzinfo=pytz.utc
        )
        runner = CollectorRunner({"test": TestCollector}, MockApi(), MockApi())

        # Test get_capture_time from capture_time
        dt = runner.get_capture_time(event)
        assert dt == expected_capture_time

        # Test get_capture_time from s3_path
        event.pop("capture_time")
        dt = runner.get_capture_time(event)
        assert dt == expected_capture_time

        # Test get_capture_time without capture_time or s3_path
        event.pop("s3_path")
        dt = runner.get_capture_time(event)
        assert dt.tzinfo == pytz.utc

    @pytest.mark.unit
    def test_get_collector(self):
        runner = CollectorRunner({"test": TestCollector}, MockApi(), MockApi())
        expected_capture_time = datetime.now().astimezone(pytz.utc)
        # Test exception raised with a bad collector name
        with pytest.raises(Exception) as e_info:
            runner.get_collector("bad collector name", capture_time=expected_capture_time)
        assert str(e_info.value).startswith("Unrecognized collector")

        collector = runner.get_collector(
            "test", aws_session=MockApi(), capture_time=expected_capture_time
        )
        assert type(collector) == TestCollector

    @pytest.mark.unit
    def test_parse_ago_time(self):

        # Test collect
        runner = CollectorRunner(
            collectors={"test": TestCollector}, s3_loader=MockApi(), timestream_loader=MockApi()
        )
        capture_time = runner.get_capture_time({"capture_time": "42069ms"})
        assert capture_time.date() == (datetime.now() - timedelta(milliseconds=42069)).date()
