from .comet_class import *
import time
from Inference.gen_util import get_KG
from transformers import pipeline, BartTokenizer, BartForConditionalGeneration
#sample usage

print("model loading ...")
comet = Comet("./comet-atomic_2020_BART")
comet.model.zero_grad()
print("model loaded")

curr_time = time.time()

#querieslist = ["PersonX eats an apple", "PersonX eats cum"]
#relationslist = [["xNeed", "CausesDesire"], ["Desires", "xEffect"]]


prompt = """Bullets whizzed past my head and ricocheted off the stone wall behind me, but I didn't have time to look back. 
The front door slammed shut as the gunman pulled his pistol from his jacket. 
He was wearing a black leather jacket, with some form of facial covering over his nose and mouth. I heard a metallic click as he pulled on the magazine.

"Don't move, or I'll kill you," he shouted, before firing at least four bullets into the door lock. 
I felt a warm wetness seep across my forehead, and then more, as it ran down my cheek and neck. 
As I turned to see what was happening, he pulled a small bottle from his jacket pocket and splashed some liquid onto my face.
It burned like hell, and stung like he'd slapped me. 
I opened my eyes in time to see the gunman pull out another pistol from his jacket, aim it at me, and fire two shots. 
One of them whizzed past my ear. The other hit the wall behind me, sending chunks of rock tumbling down into the courtyard. 
My muscles were leaden as I dragged myself up onto my feet. My head swam, and all I could think was that I wanted to go home."""

prompt = """I stepped out of my lair at dusk, flaring my immense immense nostrils to feel the cool night air. 
Stepping with massive strides, I gazed upon the realm that was just outside a small, precarious perch, the exit of my lair.
I felt an intrinsic sense of authority over it. A sense of ownership. Stretching out my wings, I prepared to depart with a single goal in mind: to hunt.  
The wind caught beneath my wings, and with a mighty leap, I soared into the air. 

I took a short breath, drinking in the fresh wind as I began to flap my leathery airfoils, cutting through harsh airs like the sharpest knife through butter. 
With a whip of my sinewy neck, I surveyed the area beneath you, when something caught my eye.  
"""

prompt = """"The skinsaw mask is a magical face-covering which resembles both the stalker's mask and the reaper's mask in construction and design.
Crafted for use by members of the dreaded Skinsaw Cult, it is made using tanned human skin and resembles a deformed humanoid face.
It contains a single bulging eye, a grimacing mouth with jagged teeth and no nose. 
When worn, the mask whispers violent and murderous thoughts in the mind of wearer, encouraging the user towards greater debauchery. 
The mask heightens the senses in regards to detecting fear in others, and allows one to literally see the circulatory system of creatures. 
The latter ability allows the user to make more effective attacks using slashing weapons of any kind. 
The skinsaw mask eventually causes permanent damage to the user, as his or her thoughts become sullied with images of hate and mayhem."
"""

from nltk import sent_tokenize

prompt = prompt.replace("\n", " ")

prompt_tokenized = sent_tokenize(prompt)
querieslist = prompt_tokenized
relationslist = [["isFilledBy", "SymbolOf", "RelatedTo", "ObjectUse", "LocatedNear", "HasProperty", "CreatedBy", "Desires", "CausesDesire", "CapableOf", "AtLocation"]] * len(querieslist)
#relationslist = [most_relations]*len(querieslist)
#relationslist = 
num_attributes = len(relationslist[0])

#import sys
#print(relationslist)
#sys.exit()



batch = []
if len(querieslist) > 1 and len(relationslist) == 1:
    for x in range(len(querieslist)):
        for y in range(len(relationslist[0])):
            query = querieslist[x]
            relation = relationslist[0][y]
            query_str = f"{query} {relation} [GEN]"
            batch.append(query_str)

elif len(querieslist) >= 1 and len(relationslist) >= 1:
    for x in range(len(querieslist)):
        for y in range(len(relationslist[x])):
            query = querieslist[x]
            relation = relationslist[x][y]
            query_str = f"{query} {relation} [GEN]"
            batch.append(query_str)

results = comet.generate(batch, decode_method="beam", num_generate=10)
print("Took: " + str(time.time() - curr_time) + " seconds for COMET.")
print(results)



model_name = "../nar_qa"

tokenizer =  BartTokenizer.from_pretrained("yjernite/bart_eli5")
model_nqa = BartForConditionalGeneration.from_pretrained(model_name).half().to("cuda:0")

#nlp = pipeline('question-answering', model=model_name, tokenizer=model_name, device=0)

#Converts tuples to graph
graph = get_KG(querieslist, relationslist, results, prune=2)
    
    
#Construct the relevant questions
attr_indx = 0
questions = list()
inputs = list()
context = list()
for key in graph.keys():
    child = graph[key]
    #key is the sentence for context. child refers to attributes we with to query about
    for attr in child.keys():
        for value in child[attr]:
            #Clean first space
            if value[0] == ' ':
                value = value[1:]
            #This is how we feature engineer for different relationships
            question = question_relation_tuple[attr] + "\"" + value + "\"?"

            
            #context.append(" ".join(list(filter(lambda x: x != key,  prompt_tokenized))))
            #questions.append(question)

            #context.append(prompt)
            questions.append(question)
            inputs.append("question " + question + " context: " + prompt)


curr_time = time.time()
#Coreference resolution
inp = tokenizer(inputs, return_tensors="pt", padding=True, truncation=True).to("cuda:0")
context = [" ".join(prompt_tokenized)] * len(questions)


out = model_nqa.generate(**inp, no_repeat_ngram_size=3, num_beams=3, use_cache=True)

print("Took: " + str(time.time() - curr_time) + " seconds for QA on " + str(len(questions)) + " questions.")
out_list = list()
for i, q in enumerate(questions):
    out_list.append((0, q,  tokenizer.decode(out[i], skip_special_tokens=True)))


#Sort by QA score (how relevant the vertex is to the text)
#out.sort(key=lambda x: x[0], reverse=True)

questions = list()
for elem in out_list:
    print(elem)


literal = """
#Determine if the answer refers to an object
context = [" ".join(prompt_tokenized)] * len(questions)
QA_input = {
    'question': questions,
    'context': context
}
out = list()
res = nlp(**QA_input, handle_impossible_answer = True)
for i, q in enumerate(questions):
    out.append((res[i]['score'], q, res[i]['answer']))

#Sort by QA score (how relevant the vertex is to the text)
out.sort(key=lambda x: x[0], reverse=True)

questions = list()
for elem in out:
    if len(elem[-1].split()) >= 1:
        print(elem)
"""
