import json
from typing import List

import torch
import numpy as np
import lzma
import base64
from nacl import pwhash, secret, utils
from nacl.hash import blake2b

hidden_dim = 4096
password = b'novelai_16YQDi0u8DDQLDCvTZJPYuVTcJNLP7MG'
salt = b'__novelai_salt__'
kdf = pwhash.argon2i.kdf

# no need for super secure hashing in this case
ops = pwhash.argon2i.OPSLIMIT_MIN
mem = pwhash.argon2i.MEMLIMIT_MIN

key = kdf(secret.SecretBox.KEY_SIZE, password, salt, opslimit=ops, memlimit=mem)
box = secret.SecretBox(key)


def bytes_to_prefix(dims: List[int], bin: bytes, typ: np.ScalarType) -> \
        torch.Tensor:
    embs = np.frombuffer(bin[:], dtype=typ).reshape((-1, dims))
    return torch.tensor(embs)


# decodes pytorch embedding tensor and model version from base64 bundle
def decode_prefix(encoded):
    decoded = base64.b64decode(encoded)
    compressed = box.decrypt(decoded)
    data = lzma.decompress(compressed)
    model_version = np.frombuffer(data[0:8], dtype=np.int64)
    if model_version == 5:
        dims = int(np.frombuffer(data[8:16], dtype=np.int64).item())
        embs = np.frombuffer(data[16:], dtype=np.float32).reshape((-1, dims))
    else:
        embs = np.frombuffer(data[8:], dtype=np.float32).reshape(
            (-1, hidden_dim))
    return torch.tensor(embs), model_version.item()


# encodes pytorch embedding tensor and model version (int) to base64 bundle
# input e.g. {"embs": torch.zeros(20,4096), "model_version": 0}
def encode_prefix(prefix):
    embs = prefix["embs"]
    model_version = prefix["model_version"]
    embs = np.array(embs.detach().cpu().float())
    if prefix["model_version"] >= 5:
        data = np.array([model_version, embs.shape[-1]],
                        dtype=np.int64).tobytes() + embs.astype(
            np.float32).tobytes()
    else:
        assert embs.shape[-1] == hidden_dim
        data = np.array([model_version],
                        dtype=np.int64).tobytes() + embs.astype(
            np.float32).tobytes()
    compressed = lzma.compress(data)
    nonce = utils.random(secret.SecretBox.NONCE_SIZE)
    encrypted = box.encrypt(compressed, nonce)
    encoded = base64.b64encode(encrypted)
    return encoded.decode()


# pls read this for details https://discord.com/channels/839050953118056488/839050953118056492/862299144488091698
# self encrypts encoded prefix
def self_encrypt_prefix(encoded):
    encoded = encoded.encode()
    clear_hash = blake2b(encoded,
                         digest_size=secret.SecretBox.KEY_SIZE,
                         person=b'__novelai_self__')[
                 0:secret.SecretBox.KEY_SIZE]
    nonce = utils.random(secret.SecretBox.NONCE_SIZE)
    selfbox = secret.SecretBox(clear_hash)
    encrypted = selfbox.encrypt(encoded, nonce)
    prefix_id = base64.b64encode(
        blake2b(encrypted,
                digest_size=secret.SecretBox.KEY_SIZE,
                person=b'__novelai_self__')).decode()
    return encrypted, prefix_id


# returns encoded prefix from self encrypted prefix and prefix_key
# input e.g. {"encrypted_prefix": encrypted, "prefix_key": base64_prefix_key_string}
def self_decrypt_prefix(encrypted, prefix_key):
    encrypted = encrypted
    prefix_key = prefix_key
    # key = base64.b64decode(prefix_key)
    selfbox = secret.SecretBox(prefix_key)
    decrypted = selfbox.decrypt(encrypted)
    return decrypted


def write_json_embs(embs, path):
    with open(path, "w") as fh:
        fh.write(json.dumps(np.array(embs.detach().cpu().float()).tolist()))


def read_json_embs(path):
    with open(path, "r") as fh:
        return torch.tensor(json.loads(fh.read())).float()


if __name__ == "__main__":
    binary = read_json_embs("test.json")
    encrypted, prefix_id = self_encrypt_prefix(binary)
    print(prefix_id)
