# requirements: gpt-neo-localattention3-rp-b branch of transformers and pynacl==1.4.0
# map files are just a flat array of tokens with eot token between files, then
# with open("literature.map", "wb") as fh:
#     fh.write(np.array(tokens, dtype=np.uint16).tobytes())
# optional: use this with -r and prepend a single fake token that is actually the number of steps
import json
import numpy
from typing import Union

import numpy as np
import traceback
from dotmap import DotMap

from lm_node.base import GPTModel
from lm_node import prefix
import torch
import os

try:
    from collections.abc import MutableMapping
except ImportError:
    from collections import MutableMapping
from pathlib import Path
from tqdm import tqdm

model = None
tokenizer = None


def check_samples(embs,
                  test_model: GPTModel,
                  gen_samples: int = 5,
                  gen_len: int = 200,
                  text_adv: bool = False):
    if text_adv:
        prompts = ["> You look around.",
                   "> You attack.",
                   "> You look at the beautiful woman's body."]
    else:
        prompts = ["The", "He", "She"]

    for prompt in prompts:
        ids = test_model.tokenizer(prompt,
                                   return_tensors="pt").input_ids.to("cpu")

        n_ids = ids.shape[1]
        if n_ids < 1:
            ids = torch.tensor([[test_model.tokenizer.eos_token_id]])

        n_embs = embs.shape[0]
        if n_embs > 0:
            ids = torch.cat((torch.full((ids.shape[0], n_embs),
                                        test_model.eot_token), ids),
                            dim=1)
        n_ids = ids.shape[1]
        max_length = n_ids + gen_len

        ids = ids.long().cuda()
        for i in range(gen_samples):
            header = ("-" * 30) + f" SAMPLE {i + 1} " + (
                    "-" * 30) + f"\nPrompt: {prompt}\n\n"
            generated_tokens = []
            for tokens, is_finished, scores_before, scores_after in \
                    test_model.model.generate(
                        ids,
                        do_sample=True,
                        min_length=max_length,
                        max_length=max_length,
                        temperature=0.7,
                        tfs=None,
                        top_k=50,
                        top_p=0.9,
                        repetition_penalty=1.08,
                        repetition_penalty_range=2048,
                        repetition_penalty_slope=3.33,
                        repetition_penalty_frequency=0.0,
                        repetition_penalty_presence=0.0,
                        use_cache=True,
                        pad_token_id=test_model.tokenizer.eos_token_id,
                        embs=[(0, embs)]):
                generated_tokens.append(int(tokens[0]))
            yield header + test_model.tokenizer.decode(
                generated_tokens) + "\n\n"
        del ids


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("-gpu",
                        help="CUDA id to use",
                        type=int,
                        required=False)
    parser.add_argument("-m", "--model-folder",
                        help="folder with split checkpoint",
                        type=str,
                        required=True)
    parser.add_argument("-i", "--input-folder",
                        help="folder with input numpy np.uint16 memmaps",
                        type=str,
                        required=True)
    parser.add_argument("-t", "--tokens-per-sample",
                        help="context size per training sample",
                        type=int,
                        default=256)
    parser.add_argument("-v", "--model-version",
                        help="version id of the model (sigurdv3 = 3)",
                        type=int, required=True)
    parser.add_argument("-s", "--steps",
                        help="number of training steps per map",
                        type=int,
                        default=3000)
    parser.add_argument("-p", "--prefix-len",
                        help="prefix length",
                        type=int,
                        default=20)
    parser.add_argument("-g", "--gen-samples",
                        help="number of samples to generate for each prompt",
                        type=int,
                        default=5)
    parser.add_argument("-l", "--gen-len",
                        help="length of each sample",
                        type=int,
                        default=200)
    parser.add_argument("-r", "--read-steps",
                        help="read steps from first map element",
                        action='store_true')
    parser.add_argument("-b", "--batch-size",
                        help="training batch size",
                        type=int,
                        default=1)
    parser.add_argument("-j", "--json-embeddings",
                        help="output embeddings as json",
                        action='store_true')
    parser.add_argument("-a", "--adventure",
                        help="enable text adventure tune mode",
                        action='store_true')
    parser.add_argument("--seed",
                        help="set a random seed",
                        type=int,
                        default=None)
    args = parser.parse_args()

    config = DotMap()
    config.model_path = args.model_folder
    config.prefix_path = None
    model = GPTModel(config)
    print(model.getGPUram())

    maps = []
    for root, subdirs, files in os.walk(args.input_folder):
        for file in files:
            file = str(Path(root) / Path(file))
            if not file.lower().endswith('.map'):
                continue
            maps.append(file)

    for mmap in maps:
        print(f"processing {mmap} @ {args.tokens_per_sample}")
        map_path = Path(mmap)
        base_path = str(map_path.parent / map_path.stem)
        tokens = np.memmap(mmap, mode="r", dtype="uint16")
        est_steps = int(len(tokens) / args.tokens_per_sample)
        if est_steps < args.steps:
            steps = args.steps
        else:
            steps = est_steps

        if args.read_steps:
            steps = int(tokens[0])
            tokens = tokens[1:]

        print(f"Estimated steps: {est_steps:,}")
        print(f"Used steps: {steps:,}")
        print(f"Token vocabulary size: {len(model.tokenizer)}")
        print(f"End of token id: {model.eot_token}")
        print(f"Prefix Length: {args.prefix_len}")

        step_log = open(f"{base_path}.jsonl", 'w')


        def write_prefix(result, step: Union[int, None]):
            if step is not None:
                step_repr = f"-{step}"
            else:
                step_repr = ""

            if args.json_embeddings:
                decoded = prefix.decode_prefix(
                    result["encoded_embedding"])
                prefix.write_json_embs(decoded[0],
                                       f"{base_path}{step_repr}.json")
            else:
                with open(f"{base_path}{step_repr}.emb", "w") as fh:
                    fh.write(result["encoded_embedding"])


        def write_log(step_data):
            no_nan = list(filter(lambda x: not numpy.isnan(x),
                                 step_data['losses']))
            last_avg = sum(no_nan) / float(len(no_nan))
            step_log.write(
                json.dumps(
                    {"step": step_data['step'],
                     "avg_loss": step_data['loss'],
                     "last_avg": last_avg,
                     "last_losses": step_data['losses']}) + "\n")
            step_log.flush()


        result = {}
        try:

            results = model.train(tokens,
                                  args.model_version,
                                  "",
                                  steps=steps,
                                  bs=args.batch_size,
                                  prefix_len=args.prefix_len,
                                  tokens_per_chunk=args.tokens_per_sample,
                                  seed=args.seed)
            last = 0
            with tqdm(total=steps) as pbar:
                def report_pbar(data, last):
                    pbar.update(data['step'] - last)
                    gpu_status = model.getGPUram()
                    losses = [f"{round(x, 3):0.3f}" for x in data['losses']]
                    last_avg = sum(data['losses']) / float(len(data['losses']))
                    pbar.write(f"steps: {data['step']},"
                               f" avg_loss: {data['loss']:0.3f},"
                               f" last_avg: {last_avg:0.3f},"
                               f" losses: {', '.join(losses)},"
                               f" {gpu_status}")
                    last = data['step']
                    return last


                for result in results:
                    if result['event'] == "training_update":
                        data = json.loads(result['data']['data'])
                        last = report_pbar(data, last)
                        write_log(data)
                        if last is not None and last % 1000 == 0 and last > 0:
                            write_prefix(result, last)
                    elif result["ok"]:
                        if args.json_embeddings:
                            decoded = prefix.decode_prefix(
                                result["encoded_embedding"])
                            prefix.write_json_embs(decoded[0],
                                                   f"{base_path}.json")
                        else:
                            with open(f"{base_path}.emb", "w") as fh:
                                fh.write(result["encoded_embedding"])
                    else:
                        data = json.loads(result['data']['data'])
                        report_pbar(data, last)
                        raise Exception(f"Failed: {result['event']}")
            embs = prefix.decode_prefix(result["encoded_embedding"])[0].cuda()
            samples = check_samples(embs,
                                    model,
                                    gen_samples=args.gen_samples,
                                    gen_len=args.gen_len)
            with tqdm(total=args.gen_samples * 3) as pbar, \
                    open(f"{base_path}.txt", "wb") as fh:
                for sample in samples:
                    pbar.update(1)
                    pbar.write(sample)
                    fh.write(sample.encode('utf-8',
                                           'surrogateescape'))
            print(f"{base_path} success")
        except Exception as e:
            print(e)
            print(traceback.format_exc())
            result["exception"] = str(e)
            with open(f"{base_path}.fail", "w") as fh:
                fh.write(json.dumps(result))
