import logging

from timestream.dimensions import TimestreamDimensions
from timestream.record import TimestreamRecord
from timestream import get_logger
from typing import List, Tuple

DEFAULT_MAX_BATCH_SIZE = 100 # maximum set by AWS Timestream service limits
DEFAULT_MIN_COMMON_DIMENSIONS = 3
MAX_RECORDS_SIZE = 10000


class Batch:
    def __init__(
        self,
        records: List[TimestreamRecord] = None,
        common_dimensions: TimestreamDimensions = None
    ):
        self.records = records or []
        self.common_dimensions = common_dimensions or TimestreamDimensions()
        


def batch_records(
    records: List[TimestreamRecord],
    max_batch_size: int = DEFAULT_MAX_BATCH_SIZE,
    min_common_dimensions: int = DEFAULT_MIN_COMMON_DIMENSIONS,
    logger: logging.Logger = None
) -> List[Batch]:
    logger = logger or get_logger()
    if not min_common_dimensions:
        return dimensionless_batch(records, max_batch_size)

    batches, remaining = [], records
    logger.info(f"MAX BATCH || {max_batch_size}  MIN COMMON DIM || {min_common_dimensions}")
    while len(remaining) > 0:
        logger.info(f"Remaining records to batch || {len(remaining)}")
        batch, remaining = _batch_records_helper(
            remaining, max_batch_size, min_common_dimensions, logger
        )
        batches.append(batch)
    return batches


def _batch_records_helper(
    records: List[TimestreamRecord],
    max_batch_size: int = DEFAULT_MAX_BATCH_SIZE,
    min_common_dimensions: int = DEFAULT_MIN_COMMON_DIMENSIONS,
    logger: logging.Logger = None
) -> Tuple[Batch, list]:
    logger = logger or get_logger()
    max_batch_size = max_batch_size if not max_batch_size == None else DEFAULT_MAX_BATCH_SIZE
    min_common_dimensions = min_common_dimensions if not min_common_dimensions == None else DEFAULT_MIN_COMMON_DIMENSIONS
    
    assert max_batch_size > 0 and max_batch_size <= DEFAULT_MAX_BATCH_SIZE
    assert min_common_dimensions >= 0

    if len(records) < 2:
        return Batch(records, TimestreamDimensions()), []
    elif len(records) > MAX_RECORDS_SIZE:
        remaining = records[MAX_RECORDS_SIZE:]
        records = records[:MAX_RECORDS_SIZE]
    else:
        remaining = []

    indexes = [0] # indexes for all the records in the batch
    common_dimensions = records[0].dimensions
    min_common_dimensions = min(min_common_dimensions, len(common_dimensions))
    
    for i in range(1, len(records)):
        if len(indexes) >= max_batch_size:
            break

        shared = [
            k for k in common_dimensions.get_dimension_keys()
            if records[i].dimensions.get_dimension(k) == common_dimensions.get_dimension(k)
        ]

        # Add the record to the batch if it has the minimum number of shared dimensions
        if len(shared) >= min_common_dimensions:
            for k in common_dimensions.get_dimension_keys():
                if k not in shared:
                    common_dimensions.remove_dimension(k)
            indexes.append(i)

    batch = Batch([records[i] for i in indexes], common_dimensions)
    remaining.extend([records[i] for i in range(len(records)) if i not in indexes])
    return batch, remaining


def dimensionless_batch(
    records: List[TimestreamRecord],
    max_batch_size: int = DEFAULT_MAX_BATCH_SIZE
) -> List[Batch]:
    '''Batch records in groups of max_batch_size ignoring dimensions and dimension commonalities'''
    max_batch_size = max_batch_size or DEFAULT_MAX_BATCH_SIZE
    assert isinstance(max_batch_size, int), f'Invalid max batch size: {max_batch_size}'
    batches = []
    for i in range(0, len(records), max_batch_size):
        batches.append(Batch(
            records=records[i:i+max_batch_size])
        )
    return batches
