import numpy
import pandas
import pytest

from psi_awswrangler.aws_wrangler import AwsWrangler
from test_utils.mock import MockApi, ANY_NOT_NONE
from unittest.mock import MagicMock, patch


class TestAwsWrangler:
    @pytest.mark.unit
    @patch("awswrangler.athena.read_sql_query")
    def test_athena_read_sql_query(self, mock_read_sql_query: MagicMock):
        # Set expected values
        expected_sql = "SELECT * FROM test_table"
        expected_database = "test_database"
        expected_ctas_approach = True
        expected_unload_approach = True
        expected_encryption = "SSE_KMS"
        expected_use_threads = False
        expected_pyarrow_additional_kwargs = {"pyarrow_kwarg1": 1}
        expected_data_catalog = None
        raw_df = pandas.DataFrame.from_records(
            [
                {
                    "key1": "   str with whitespace   ",
                    "key2": numpy.nan,
                },
                {
                    "key2": {"subkey1": "subval1"},
                },
            ]
        )
        expected_records = [
            {
                "key1": "str with whitespace",
                "key2": None,
            },
            {
                "key1": None,
                "key2": {"subkey1": "subval1"},
            },
        ]

        # Set up mocks
        mock_aws_session = MockApi()
        mock_read_sql_query.return_value = raw_df

        # Test athena_read_sql_query()
        aw = AwsWrangler(
            athena_output_path="s3://bucket/athena/output/path/",
            aws_session=mock_aws_session,
        )
        actual_df = aw.athena_read_sql_query(
            sql=expected_sql,
            database=expected_database,
            ctas_approach=expected_ctas_approach,
            unload_approach=expected_unload_approach,
            encryption=expected_encryption,
            use_threads=expected_use_threads,
            awswrangler_additional_kwargs=expected_pyarrow_additional_kwargs,
        )
        actual_records = actual_df.to_dict(orient="records")
        assert actual_records == expected_records

        mock_read_sql_query.assert_called_with(
            sql=expected_sql,
            database=expected_database,
            ctas_approach=expected_ctas_approach,
            unload_approach=expected_unload_approach,
            s3_output=ANY_NOT_NONE,
            workgroup=None,
            data_source=expected_data_catalog,
            encryption=expected_encryption,
            use_threads=expected_use_threads,
            boto3_session=mock_aws_session,
            awswrangler_additional_kwargs=expected_pyarrow_additional_kwargs,
        )
        mock_aws_session.assert_expectations()

    @pytest.mark.unit
    def test_athena_read_sql_query_error(self):
        aw = AwsWrangler(
            athena_output_path="s3://bucket/athena/output/path/",
            aws_session=MockApi(),
        )
        with pytest.raises(AssertionError) as e_info:
            aw.athena_read_sql_query(
                sql="SELECT * FROM test_table",
                database="test_database",
                s3_output_path="bad s3 path",
            )
        assert str(e_info.value).startswith("Invalid S3 path")

    @pytest.mark.unit
    @patch("awswrangler.s3.read_csv")
    def test_s3_read_csv(self, mock_read_csv: MagicMock):
        # Set expected values
        expected_path = "s3://bucket/path/to/object.csv"
        expected_use_threads = False
        expected_s3_additional_kwargs = {"s3_kwarg1": 1}
        raw_df = pandas.DataFrame.from_records(
            [
                {
                    "key1": "   str with whitespace   ",
                    "key2": numpy.nan,
                },
                {
                    "key2": {"subkey1": "subval1"},
                },
            ]
        )
        expected_records = [
            {
                "key1": "str with whitespace",
                "key2": None,
            },
            {
                "key1": None,
                "key2": {"subkey1": "subval1"},
            },
        ]

        # Set up mocks
        mock_aws_session = MockApi()
        mock_read_csv.return_value = raw_df

        # Test s3_read_csv()
        aw = AwsWrangler(
            aws_session=mock_aws_session,
        )
        actual_df = aw.s3_read_csv(
            path=expected_path,
            use_threads=expected_use_threads,
            s3_additional_kwargs=expected_s3_additional_kwargs,
        )
        actual_records = actual_df.to_dict(orient="records")
        assert actual_records == expected_records

        mock_read_csv.assert_called_once_with(
            path=expected_path,
            boto3_session=ANY_NOT_NONE,
            use_threads=False,
            s3_additional_kwargs=expected_s3_additional_kwargs,
        )
        mock_aws_session.assert_expectations()

    @pytest.mark.unit
    def test_s3_read_csv_error(self):
        aw = AwsWrangler(
            aws_session=MockApi(),
        )
        with pytest.raises(AssertionError) as e_info:
            aw.s3_read_csv(path="bad s3 path")
        assert str(e_info.value).startswith("Invalid S3 path")

    @pytest.mark.unit
    @patch("awswrangler.s3.read_json")
    def test_s3_read_json(self, mock_read_json: MagicMock):
        # Set expected values
        expected_path = "s3://bucket/path/to/object.json"
        expected_orient = "records"
        expected_use_threads = True
        expected_s3_additional_kwargs = {"s3_kwarg1": 1}
        raw_df = pandas.DataFrame.from_records(
            [
                {
                    "key1": "   str with whitespace   ",
                    "key2": numpy.nan,
                },
                {
                    "key2": {"subkey1": "subval1"},
                },
            ]
        )
        expected_records = [
            {
                "key1": "str with whitespace",
                "key2": None,
            },
            {
                "key1": None,
                "key2": {"subkey1": "subval1"},
            },
        ]

        # Set up mocks
        mock_aws_session = MockApi()
        mock_read_json.return_value = raw_df

        # Test s3_read_json()
        aw = AwsWrangler(
            aws_session=mock_aws_session,
        )
        actual_df = aw.s3_read_json(
            s3_uri=expected_path,
            orient=expected_orient,
            use_threads=expected_use_threads,
            s3_additional_kwargs=expected_s3_additional_kwargs,
        )
        actual_records = actual_df.to_dict(orient="records")
        assert actual_records == expected_records

        mock_read_json.assert_called_with(
            path=expected_path,
            orient=expected_orient,
            use_threads=expected_use_threads,
            boto3_session=mock_aws_session,
            s3_additional_kwargs=expected_s3_additional_kwargs,
            dtype=True,
        )
        mock_aws_session.assert_expectations()

    @pytest.mark.unit
    def test_s3_read_json_error(self):
        aw = AwsWrangler(
            aws_session=MockApi(),
        )
        with pytest.raises(AssertionError) as e_info:
            aw.s3_read_json("bad s3 path")
        assert str(e_info.value).startswith("Invalid S3 path")

    @pytest.mark.unit
    @patch("awswrangler.s3.read_parquet")
    def test_s3_read_parquet(self, mock_read_parquet: MagicMock):
        # Set expected values
        expected_path = "s3://bucket/path/to/object.parquet"
        expected_use_threads = False
        expected_s3_additional_kwargs = {"s3_kwarg1": 1}
        raw_df = pandas.DataFrame.from_records(
            [
                {
                    "key1": "   str with whitespace   ",
                    "key2": numpy.nan,
                },
                {
                    "key2": {"subkey1": "subval1"},
                },
            ]
        )
        expected_records = [
            {
                "key1": "str with whitespace",
                "key2": None,
            },
            {
                "key1": None,
                "key2": {"subkey1": "subval1"},
            },
        ]

        # Set up mocks
        mock_aws_session = MockApi()
        mock_read_parquet.return_value = raw_df

        # Test s3_read_parquet()
        aw = AwsWrangler(
            aws_session=mock_aws_session,
        )
        actual_df = aw.s3_read_parquet(
            path=expected_path,
            use_threads=expected_use_threads,
            s3_additional_kwargs=expected_s3_additional_kwargs,
        )
        actual_records = actual_df.to_dict(orient="records")
        assert actual_records == expected_records

        mock_read_parquet.assert_called_with(
            path=expected_path,
            boto3_session=mock_aws_session,
            use_threads=False,
            s3_additional_kwargs=expected_s3_additional_kwargs,
        )
        mock_aws_session.assert_expectations()

    @pytest.mark.unit
    def test_s3_read_parquet_error(self):
        aw = AwsWrangler(
            aws_session=MockApi(),
        )
        with pytest.raises(AssertionError) as e_info:
            aw.s3_read_parquet(path="bad s3 path")
        assert str(e_info.value).startswith("Invalid S3 path")

    @pytest.mark.unit
    @patch("awswrangler.s3.list_directories")
    def test_s3_list_directories(self, mock_list_directories: MagicMock):
        # Set expected values
        expected_path = "s3://bucket/path/"
        expected_chunked = True
        expected_s3_additional_kwargs = {"s3_kwarg1": 1}
        expected_dirs = ["fst/", "snd/"]

        # Set up mocks
        mock_aws_session = MockApi()
        mock_list_directories.return_value = expected_dirs

        # Test s3_list_directories()
        aw = AwsWrangler(
            aws_session=mock_aws_session,
        )
        actual_dirs = aw.s3_list_directories(
            path=expected_path,
            chunked=expected_chunked,
            s3_additional_kwargs=expected_s3_additional_kwargs,
        )
        assert actual_dirs == expected_dirs

        mock_list_directories.assert_called_with(
            path=expected_path,
            chunked=expected_chunked,
            boto3_session=mock_aws_session,
            s3_additional_kwargs=expected_s3_additional_kwargs,
        )
        mock_aws_session.assert_expectations()

    @pytest.mark.unit
    def test_s3_list_directories_error(self):
        aw = AwsWrangler(
            aws_session=MockApi(),
        )
        with pytest.raises(AssertionError) as e_info:
            aw.s3_list_directories(path="bad s3 path")
        assert str(e_info.value).startswith("Invalid S3 path")

    @pytest.mark.unit
    @patch("awswrangler.s3.to_parquet")
    def test_s3_to_parquet(self, mock_to_parquet: MagicMock):
        # Set expected values
        expected_path = "s3://bucket/path/"
        expected_compression = "gzip"
        expected_use_threads = False
        expected_dataset = False
        expected_mode = "append"
        expected_catalog_id = "CatalogID"
        expected_return_val = "return val"
        expected_partition_cols = ["ListString"]
        expected_database = "Database"
        expected_table = "Table"
        expected_dtype = {}

        # Set up mocks
        mock_aws_session = MockApi()
        mock_to_parquet.return_value = expected_return_val

        # Test s3_to_parquet()
        aw = AwsWrangler(
            aws_session=mock_aws_session,
        )
        actual_return_val = aw.s3_to_parquet(
            raw_data=pandas.DataFrame.from_records([{"key": "val"}]),
            s3_path=expected_path,
            partition_cols=expected_partition_cols,
            compression=expected_compression,
            dataset=expected_dataset,
            mode=expected_mode,
            athena_database=expected_database,
            athena_table=expected_table,
        )
        assert actual_return_val == expected_return_val

        mock_to_parquet.assert_called_with(
            df=ANY_NOT_NONE,
            path=expected_path,
            partition_cols=expected_partition_cols,
            compression=expected_compression,
            use_threads=expected_use_threads,
            boto3_session=mock_aws_session,
            dataset=expected_dataset,
            mode=expected_mode,
            database=expected_database,
            table=expected_table,
        )
        mock_aws_session.assert_expectations()

    @pytest.mark.unit
    def test_s3_to_parquet_error(self):
        aw = AwsWrangler(
            aws_session=MockApi(),
        )
        with pytest.raises(AssertionError) as e_info:
            aw.s3_to_parquet(
                raw_data=pandas.DataFrame(),
                s3_path="bad s3 path",
                partition_cols=["bad partition cols"],
            )
        assert str(e_info.value).startswith("Invalid S3 path")
