from typing import Union, Sequence, Generator, List

import lm_node.unitrim
from tokenizers import Tokenizer


def split_chunks(text_seq: Union[Sequence[str],
                                 Generator[str, any, None]],
                 size=2048, yield_tokens=True, boundary: str = "\n",
                 boundary_id: int = None,
                 tokenizer: Tokenizer = None,
                 unitrim: lm_node.unitrim.Unitrimmer = None,
                 preamble: [str, None] = None) -> \
        Generator[Union[List[int], str], any, None]:
    if boundary_id is None:
        boundary_id = tokenizer.encode(boundary)[0]
    pad_token = tokenizer.encode("<|endoftext|>")
    end_of_text = tokenizer.encode("<|endoftext|>")
    prior = []
    preamble_tokens = []
    if preamble is not None:
        size = size - 1
        preamble_tokens = tokenizer.encode(preamble)

    for text in text_seq:
        tokens = tokenizer.encode(text)
        num_tokens = len(tokens)
        idx = 0
        begin = 0
        boundary_idx = 0
        # Loop while the tokenization output is larger than the target size.
        while idx < num_tokens:
            token = tokens[idx]
            if token == boundary_id:
                boundary_idx = idx
            if idx - begin + len(prior) >= size:
                chunk = preamble_tokens[:]
                # Ensure that our chunk has complete unicode runes.
                chunk.extend(unitrim.trim(tokens[begin:idx]))
                trimmed_length = len(chunk)
                idx = begin + trimmed_length
                if trimmed_length != len(chunk):
                    print("Trimmed.")
                if prior:
                    prior.extend(chunk)
                    chunk = prior
                    prior.clear()
                # We do a decode and encode roundtrip.
                chunk = tokenizer.encode(tokenizer.decode(chunk))
                # After doing a re-encode and decode, we see how many tokens
                # are left to fill this out to `size`
                roundtrip_remainder = size - len(chunk) + len(preamble_tokens)
                if roundtrip_remainder:
                    addl = unitrim.trim(
                        tokens[idx:idx + roundtrip_remainder])
                    chunk.extend(addl)
                    idx = idx + len(addl)
                    redecode = tokenizer.encode(
                        tokenizer.decode(chunk))
                    # After we add new tokens, we double check and make sure
                    # that the tokens that are encoded/decoded don't exceed
                    # the size, and always produce valid unicode.
                    while len(redecode) > size + len(preamble_tokens) or \
                            not unitrim.send_ready(redecode):
                        chunk.pop()
                        idx -= 1
                        redecode = tokenizer.encode(
                            tokenizer.decode(chunk))
                    chunk = redecode
                pad_size = size - len(chunk) + len(preamble_tokens)
                if pad_size > 0:
                    chunk.extend(pad_token * pad_size)
                if yield_tokens:
                    yield chunk
                else:
                    yield tokenizer.decode(chunk)
                if boundary_idx:
                    if idx - boundary_idx + 1 <= size:
                        idx = boundary_idx + 1
                    boundary_idx = 0
                begin = idx
            idx += 1
        if begin < num_tokens:
            prior.extend(tokens[begin:])
            if len(prior) < size:
                prior.extend(end_of_text)
            if len(prior) == size:
                preambled = preamble_tokens[:]
                preambled.extend(prior)
                if yield_tokens:
                    yield preambled
                else:
                    yield tokenizer.decode(preambled)
                prior.clear()
                prior.extend(preamble_tokens[:])
    # Finish up and clean up if we have no more chunks to do.
    if prior:
        pad_size = size - len(prior)
        prior.extend(pad_token * pad_size)
        preambled = preamble_tokens[:]
        preambled.extend(prior)
        if yield_tokens:
            yield preambled
        else:
            yield tokenizer.decode(preambled)
