import sys
import torch
from transformers import GPTNeoForCausalLM, GPTNeoConfig, GPT2TokenizerFast
from transformers.models.gpt_neo.eight_bit_utils import bnbfy_

def no_init(loading_code):
    def dummy(self):
        return

    modules = [torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm]
    original = {}
    for mod in modules:
        original[mod] = mod.reset_parameters
        mod.reset_parameters = dummy

    result = loading_code()
    for mod in modules:
        mod.reset_parameters = original[mod]

    return result

eight_bit = list(map(int, sys.argv[2].split(",")))

config = GPTNeoConfig.from_pretrained(sys.argv[1])
config.eight_bit = eight_bit
config.save_pretrained(sys.argv[3])

model = no_init(lambda: GPTNeoForCausalLM.from_pretrained(sys.argv[1], local_files_only=True))
tokenizer = GPT2TokenizerFast.from_pretrained("/home/xuser/diffusionstorage/datasets/tokenizers/gpt2")

torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
print(f"unquantized: {torch.cuda.max_memory_allocated() / 1024. / 1024. / 1024.:.4f}GB")

for i in eight_bit:
    model.transformer.h[i] = bnbfy_(model.transformer.h[i])

torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
print(f"quantized: {torch.cuda.max_memory_allocated() / 1024. / 1024. / 1024.:.4f}GB")

checkpoint = {}
for i, x in enumerate(model.state_dict().items()):
    checkpoint[x[0]] = f"b{i}.pt"
    torch.save(x[1], f"{sys.argv[3]}/b{i}.pt")
    #print(f"{x[0]} -> {x[1].shape} ({x[1].dtype})")
torch.save(checkpoint, f"{sys.argv[3]}/m.pt")

print(tokenizer.decode(model.generate(torch.tensor([[198]]).long().cuda(), do_sample=True, max_length=100, tfs=0.9, temperature=0.72, repetition_penalty=1.1, pad_token_id=50256, use_cache=True)[0]))
