import glob

import numpy
import torch
import threading
import functools
from torch.nn import functional as F
import numpy as np

def read_data_files(path):
    files = list(glob.iglob(path + '/**/*.txt', recursive=True))
    num_files = len(files)
    curr_file = 1
    for data in files:
        print(f"=== {curr_file}/{num_files}", data)
        curr_file += 1
        yield "⁂\n" + open(data, 'r').read()

def build_engram(forward, tokens, shift=10000, factor=20000,
                 rampdown=lambda x: x / 2):
    # get hidden states
    h = list(forward(input_ids=tokens[:, -512:].long().cuda(),
                     output_hidden_states=True).hidden_states[1:])

    h = [h.cuda() for h in h]
    # todo: use rampdown
    f = 0
    fa = 1.0 / float(len(h))

    # combine hidden states (token axis)
    # we use double() here to reduce accuracy loss from overflowing. it's safe
    # to go back to float() after the math is done. There is probably a more
    # efficient way to do this
    for layer in range(len(h)):
        f = f + fa
        h[layer] = torch.mean(h[layer].detach().double(), dim=(1,)) * f

    h = torch.sum(torch.stack(h, axis=1)[0], dim=(0,))

    # note: static values are used here to make sorting more consistent.
    # Previously I normalized per-engram but that reduced the overal accuracy
    # of the sorting
    return ((h + shift) / factor).float().to("cpu").numpy()


def process_scores(chosen: torch.Tensor,
                   before: torch.Tensor,
                   after: torch.Tensor,
                   num_logprobs: int,
                   filter_inf: bool = False):
    if num_logprobs == -1:
        return [([chosen.item()], (None, None))], None, None

    if not torch.is_tensor(chosen):
        chosen = torch.from_numpy(chosen).to(before).long()

    before = before.half()
    after = after.half()
    # If we've been asked to filter infinites.
    if filter_inf:
        ninf_none_fn = lambda t: round(t, 4) if t != numpy.NINF else None
        filter_fn = lambda t: (ninf_none_fn(t[0]), ninf_none_fn(t[1]))
    else:
        ninf_none_fn = lambda t: t
        filter_fn = lambda t: t

    chosen_scores = (ninf_none_fn(torch.take(before, chosen).item()),
                     ninf_none_fn(torch.take(after, chosen).item()))
    if num_logprobs == -1:
        return [([chosen.item()], (None, None))], None, None
    elif num_logprobs == 0:
        return [([chosen.item()], chosen_scores)], None, None
    # Get the top-N tokens for before and after processing, sorted
    # by logprobs.
    top_tokens = (torch.argsort(before,
                                descending=True)[0, :num_logprobs],
                  torch.argsort(after,
                                descending=True)[0, :num_logprobs])
    # Get the scores for each token.
    top_scores = (torch.take_along_dim(before, top_tokens[0]),
                  torch.take_along_dim(after, top_tokens[1]))
    # Remove neginf scores.
    neginf_masks = (torch.isneginf(top_scores[0]),
                    torch.isneginf(top_scores[1]))
    top_scores = (top_scores[0][neginf_masks[0] == False],
                  top_scores[1][neginf_masks[1] == False])
    # Reduce both sets of tokens to only those tokens that have scores.
    before_sz = top_scores[0].shape[0]
    after_sz = top_scores[1].shape[0]
    top_tokens = (top_tokens[0][:before_sz],
                  top_tokens[1][:after_sz])

    # For both sets, obtain the inverse relationship, before -> after,
    # and after -> before
    before_after = (torch.stack((top_scores[0],
                                 torch.take_along_dim(after,
                                                      top_tokens[0])),
                                dim=1),
                    torch.stack((torch.take_along_dim(before,
                                                      top_tokens[1]),
                                 top_scores[1]),
                                dim=1))

    # Transform into token to alternate probabilities data structures.
    xfm_fn = lambda top, tsr: list(zip(map(lambda t: [t.item()], top),
                                       map(filter_fn, tsr.tolist())))
    tk_s_before = xfm_fn(top_tokens[0], before_after[0])
    tk_s_after = xfm_fn(top_tokens[1], before_after[1])
    return [([chosen.item()], chosen_scores)], tk_s_before, tk_s_after


def get_token_probabilities(forward, tokens, n=0, idx=0):
    f = forward(input_ids=tokens.long().cuda()).logits
    prior_logit = F.log_softmax(f[:, idx, :][0], dim=0)
    if idx == 0:
        begin = 1
        yield {"chosen": [np.array([tokens[0][0], ]),
                                           [None, None]]}
    else:
        begin = idx - 1
    for logit_idx in range(begin, f.shape[1]):
        curr_logit = F.log_softmax(f[:, logit_idx, :][0], dim=0)
        token = tokens[0][logit_idx]
        entry = {"chosen": [np.array([token, ]),
                            [prior_logit[token].to("cpu"), None]]}
        if n:
            topk = torch.topk(prior_logit, n).indices
            choices = []
            for choice_idx in range(len(topk)):
                token = topk[choice_idx].to("cpu")
                choices.append([np.array([token, ]),
                                [prior_logit[token].to("cpu"), None]])
            entry['choices'] = choices
        yield entry
        prior_logit = curr_logit


def get_next_words(forward, tokenizer, tokens):
    f = forward(input_ids=tokens.long().cuda()).logits

    original = f[:, -1, :][0].to("cpu")
    f = torch.topk(f, 100).indices
    caret = -1
    f = f[:, caret, :][0].to("cpu")

    return_list = []

    for i in range(len(f)):
        token = f[i]
        return_list.append(
            [tokenizer.decode(np.array([token, ])), float(original[token])])

    return return_list
