import torch
import pytorch_lightning
import argparse
from pathlib import Path

args = argparse.ArgumentParser()
args.add_argument('--path', type=str, required=True, help='path to model')

args = args.parse_args()
path = Path(args.path)

model_dict = torch.load(path, map_location='cpu')
dict_keys = list(model_dict["state_dict"].keys())
for key in dict_keys:
    if key.startswith("model."):
        del model_dict["state_dict"][key]

for param in model_dict["state_dict"].keys():
    model_dict["state_dict"][param] = model_dict["state_dict"][param].half()

#print(model_dict["state_dict"].keys())
torch.save(model_dict, path.parent / "pruned.ckpt")