#from transformers import generatation_utils 
from collections import defaultdict
from functools import reduce
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from transformers import (
    GPT2Tokenizer, 
    AutoTokenizer, 
    AutoModelWithLMHead,
    BeamSearchScorer,
    LogitsProcessorList,
    StoppingCriteriaList,
    MinLengthLogitsProcessor,
    MaxLengthCriteria,
)
import torch
import time

PADDING_MODE="longest"

#Utility for flattening a list
#https://knowyourmeme.com/memes/flat-fuck-friday
flatten = lambda t: [item for sublist in t for item in sublist]

#Super simple gneeration function. Just make sure to pass all decode parameters through
#logits processor and stopping criteria. 
#TODO: Top k/top p
def generate(model,
    input_ids,
    model_kwargs: Optional[Dict],
    decode_kwargs: Optional[Dict] = {},
    num_beams: Optional[int] = 1,
    device: Optional[str] = "cuda"
    ):
    
    if not('logits_processor' in decode_kwargs.keys()):
        decode_kwargs['logits_processor'] = LogitsProcessorList([
            MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
        ])
    if not('stopping_criteria' in decode_kwargs.keys()):
        decode_kwargs['stopping_criteria'] = StoppingCriteriaList([
            MaxLengthCriteria(max_length=10)
        ])

    if num_beams == 1:
        with torch.no_grad():
            outputs = model.greedy_search(input_ids, **decode_kwargs, **model_kwargs)
    else:
        beam_scorer = BeamSearchScorer(
            batch_size=input_ids.shape[0]//num_beams,
            num_beams=num_beams,
            device=device,
        )
        with torch.no_grad():
            outputs = model.beam_search(input_ids, beam_scorer, **decode_kwargs, **model_kwargs)

    return outputs

#Groups KG entries based off if their root node is identical. Used to remove redundent groups. Helpful for caching.
#KG entries must be in the format (e1, r) or more precisely: event, relationship
#Entries will be hashed on e1
def group_decode(
    model,
    tokenizer,
    tuples: List[List[str]], 
    batch_size: Optional[int] = -1,
    num_beams: Optional[int] = 1,
    device: Optional[str] = "cuda",
    **kwargs):
    #Remove redundent elements, easier for encoding
    simplified = defaultdict(list)
    for t in tuples:
        simplified[t[0]] += [t[1]]
    
    #Fetch keys, tokenize text
    keys = list(simplified.keys())
    suffix = list(simplified.values())
    tokenized_root = tokenizer(keys, return_tensors="pt", padding=PADDING_MODE).to(device)

    #Encode for lazy use later
    with torch.no_grad():
        encoded = model(**tokenized_root, use_cache=True, return_dict=True)
    #Fetch the past values we just encoded
    past_values = encoded["past_key_values"]

    #Convert tuples to tensors
    past_values = torch.stack(list(map(lambda t: torch.stack(t, dim=0), past_values)), dim=0)
    
    #Interleave along suffix dimension
    counts = torch.tensor(list(map(lambda x: len(x)*num_beams, suffix))).to(device)
    past_values = torch.repeat_interleave(past_values, counts, dim=2)

    #Convert back to tuple
    past_values = tuple(list(map(lambda x: tuple(map(lambda y: y.squeeze(0), torch.split(x.squeeze(0), 1))), \
        #Split over the outer dimension, to retain size for the number of layers dimension
        torch.split(past_values, 1))))

    #Call generate function
    model_kwargs = {
        "past":past_values,
        "return_dict":True,
        "use_cache":True,
    }
    suffix = list(map(lambda x: x*num_beams, suffix))
    flattened_suffix = flatten(suffix)
    suffix_tokenized=tokenizer(flattened_suffix, return_tensors="pt", padding=PADDING_MODE).to(device)
    inp_ids = suffix_tokenized['input_ids']

    return generate(model, inp_ids, model_kwargs, num_beams=num_beams)

attr_filter_list = [" none", "  "]
#Fetches a dictionary KG from sentence relations and results tuples
#Prune being set to -1 means keep all not 'none'
def get_KG(sentences, relationslist, results, prune=-1):
    graph = dict()
    #Serial read off of results
    idx= 0
    #Per sentence relations
    for i, sentence_attrs in enumerate(relationslist):
        #Construct children
        child_vert = dict()
        for prop in sentence_attrs:
            results[idx] = list(filter(lambda x: not(x in attr_filter_list), results[idx]))
            results[idx] = list(filter(lambda x: len(list(set(x))) != 1, results[idx]))

            if prune != -1:
                results[idx] = results[idx][:min(len(results[idx]), prune)]
            child_vert[prop] = results[idx]
            idx += 1
        #i refers to sentence number
        graph[sentences[i]] = child_vert
    return graph

literal = """
#Testing
test_case1 = [["This is the first sentence. ", "This is a "], ["This is the first sentence. ", "This is a "], ["This is the first sentence. ", "This is a "], ["This is the first sentence. ", "oIntent"],["This is the second sentence", "oIntent"],["This is the third sentence", "oIntent"]]
test_case2 = [["I am a cat. ", " The"],["This is the second sentence", " The"],["This is the third sentence", " The"]]
test_case3 = [["AAAAAAA.", " This"]]

curr_test_cast = test_case2

model = AutoModelWithLMHead.from_pretrained("gpt2").half().to("cuda")

#Set up tokenizer. TODO: Make fast tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    model.config.pad_token_id = model.config.eos_token_id


output = group_decode(model=model, tokenizer=tokenizer, tuples=curr_test_cast, num_beams=2)
print("Generated: " + curr_test_cast[0][0] + tokenizer.batch_decode(output)[0])
"""