import logging
import urllib.parse as urllib_parse
from datetime import datetime
from typing import Any, Optional

import requests
from requests.auth import HTTPBasicAuth
from typeguard import typechecked

from psi_elt.abstract_elt_job import AbstractELTJob, EXTRACTOR_ELT_JOB_TYPE


@typechecked
class CustomExtractor(AbstractELTJob):
    """A CustomExtractor is the first of the ELT jobs that is responsible for
    extracting data from an external API and returning the raw data without any
    transformation.

    This class should be treated as abstract and should be extended.
    Child classes must override the _extract function.
    This class includes utilties for running HTTP requests.
    """

    def __init__(
        self,
        name: str,
        capture_time: datetime,
        s3_loader: Optional[Any] = None,
        aws_session: Optional[Any] = None,
        logger: Optional[logging.Logger] = None,
        api: Optional[Any] = None,
        http_session: Optional[Any] = None,
    ):
        super().__init__(
            name=name,
            elt_job_type=EXTRACTOR_ELT_JOB_TYPE,
            capture_time=capture_time,
            aws_session=aws_session,
            logger=logger,
        )
        self.s3_loader = s3_loader
        self.api = api
        self.__http_session = http_session or requests.Session()

    @property
    def http_session(self) -> Any:
        return self.__http_session

    def extract(
        self,
        is_test: bool,
        extract_kwargs: Optional[dict] = None,
        s3_load_kwargs: Optional[dict] = None,
    ):
        extract_kwargs = self.remove_none_from_dict(extract_kwargs)
        s3_load_kwargs = self.remove_none_from_dict(s3_load_kwargs)
        self._extract(is_test, s3_load_kwargs=s3_load_kwargs, **extract_kwargs)

    def _extract(self, is_test: bool, s3_load_kwargs: Optional[dict] = None, **extract_kwargs):
        raise Exception("CustomExtractor._extract must be overridden by the child class")

    def run_http_request(
        self,
        url: str,
        method: str = "GET",
        path_params: Optional[list[str]] = None,
        query_params: Optional[dict[str, Any]] = None,
        headers: dict[str, str] = None,
        body: Optional[Any] = None,
        json_body: Optional[dict] = None,
        basic_auth_username: Optional[str] = None,
        basic_auth_password: Optional[str] = None,
        bearer_token: Optional[str] = None,
        url_encode_query_params: bool = False,
    ):
        assert url, "A URL is required to run a HTTP request"
        assert method in [
            "CONNECT",
            "DELETE",
            "GET",
            "HEAD",
            "OPTIONS",
            "PATCH",
            "POST",
            "PUT",
            "TRACE",
        ]
        assert not (body and json_body), f"Must specify either body or json_body, not both"
        http_resp = None
        try:
            # Format path parameters
            if path_params:
                path_params_str = "/".join(path_params)
                url += f"/{path_params_str}"

            # Format query parameters
            if query_params:
                query_params = self.__url_encode_query_params(query_params, url_encode_query_params)
                query_params_str = "&".join([f"{k}={v}" for k, v in query_params.items()])
                url += f"?{query_params_str}"

            # Format basic auth
            basic_auth = None
            if basic_auth_username and basic_auth_password:
                basic_auth = HTTPBasicAuth(basic_auth_username, basic_auth_password)

            # Format bearer token
            if bearer_token:
                headers = headers or {}
                headers["Authorization"] = f"Bearer {bearer_token}"

            self.http_session: requests.Session = self.http_session
            http_resp = self.http_session.request(
                method=method,
                url=url,
                headers=headers or None,
                data=body or None,
                json=json_body or None,
                auth=basic_auth or None,
            )
            http_resp.raise_for_status()
        except Exception as e:
            self.logger.error(f"Exception occurred when running HTTP request: {e}")
            if http_resp:
                self.logger.error(f"Status code: {http_resp.status_code}")
                self.logger.error(f"Response text: {http_resp.text}")
            raise e

    def __url_encode_query_params(self, query_params: dict, url_encode: bool = True) -> dict:
        if url_encode:
            return {
                urllib_parse.quote(k): urllib_parse.quote(str(v)) for k, v in query_params.items()
            }
        return query_params
