import pika, sys
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    GPTNeoForCausalLM,
    AutoConfig
)

import time, logging, json, threading, argparse
import torch, base64, os, socket, subprocess, struct
import os
from prometheus_client import CollectorRegistry, Gauge, push_to_gateway
from datetime import datetime
from multiprocessing import Process, Queue
from .comet_class import *
from transformers import pipeline
model_path = os.environ['MODEL_PATH']
is_dev = ""

if os.environ['DEV'] == "True":
    is_dev = "_dev"

model_name = os.environ['MODEL']
version = "0.0.0.2smhs"

queue_name = "generation_jobs_" + model_name + is_dev

parser = argparse.ArgumentParser(description="Node arguments")
args = parser.parse_args()

logger = logging.getLogger(__name__)
logger.setLevel(level=logging.INFO)
fh = logging.StreamHandler()
fh_formatter = logging.Formatter(
    "%(asctime)s %(levelname)s %(filename)s(%(process)d) - %(message)s"
)
fh.setFormatter(fh_formatter)
logger.addHandler(fh)

logger.info("version: " + version)
logger.info("Node started")
logger.info("queue: " + queue_name)

q = Queue()

load_time = time.time()

if "BART-comet" in model_name:

    comet = Comet(model_path)
    comet.model.zero_grad()

if "roberta-squadv2" in model_name:
    nlp = pipeline('question-answering', model=model_path, tokenizer=model_name, device=0)

f = open("/tmp/health_startup", "w")
f.close()
logger.info("Models loaded in " + str(time.time()-load_time) + "seconds")


if args.show:
    os._exit(0)

channel = None
total_sent = 0

credentials = pika.credentials.PlainCredentials(
    "kurumuz", "IX0zuEY6mLqsqDN0xS90nI8cFDCrr47o"
)
connection = pika.BlockingConnection(
    pika.ConnectionParameters("104.248.82.249", credentials=credentials)
)
channel = connection.channel()

def main():
    global channel
    channel.queue_declare(queue=queue_name, durable=True)
    channel.basic_qos(prefetch_count=1)
    channel.basic_consume(
        queue=queue_name,
        on_message_callback=on_request,
        arguments={"x-max-priority": 11}
    )

    logger.info("Started consuming")
    channel.start_consuming()

processing_time = 0
def on_request(ch, method, props, body):
    global q
    global sent_first_message
    global total_sent
    global processing_time
    ctx = [ch, method, props, body]
    global args
    req_dict = json.loads(body.decode("utf-8"))
    req_dict = get_input_filtered(ctx, req_dict)
    req_params = req_dict["parameters"]

    querieslist = req_dict["queries"] if "queries" in req_dict else None
    relationslist = req_dict["relations"] if "relations" in req_dict else None
    questionslist = req_dict["questions"] if "questions" in req_dict else None
    contextlist = req_dict["context"] if "context" in req_dict else None

    curr_time = time.time()
    if querieslist and relationslist:
        returnx = run_comet(querieslist, relationslist)
        
    elif contextlist and questionslist:
        returnx = run_roberta(querieslist, relationslist)

    else:
        returnx = {"error": "wrong input"}

    q.put((time.time() - curr_time))

    logger.info(f"Request took {str(time.time() - curr_time)} seconds")
        
    ch.basic_publish(
        exchange="",
        routing_key=props.reply_to,
        properties=pika.BasicProperties(
            correlation_id=props.correlation_id, content_type="application/json"
        ),
        body=json.dumps(returnx).encode("utf-8"),
    )
    ch.basic_ack(delivery_tag=method.delivery_tag)
    sent_first_message = True
    total_sent += 1

    f = open("/tmp/health_readiness", "w")
    f.close()


def run_comet(querieslist, relationslist):

    batch = []
    if len(querieslist) > 1 and len(relationslist) == 1:
        for x in range(len(querieslist)):
            for y in range(len(relationslist[0])):
                query = querieslist[x]
                relation = relationslist[0][y]
                query_str = f"{query} {relation} [GEN]"
                batch.append(query_str)

    elif len(querieslist) >= 1 and len(relationslist) >= 1:
        for x in range(len(querieslist)):
            for y in range(len(relationslist[x])):
                query = querieslist[x]
                relation = relationslist[x][y]
                query_str = f"{query} {relation} [GEN]"
                batch.append(query_str)
    else:
        send_error(ctx, "undefined error")

    returnx = comet.generate(batch, **req_params)

    return returnx

def run_roberta(questionslist, contextlist):
    questions = []
    contexts = []

    if len(contextlist) > 0 and len(questionslist) > 0:
        for x in range(len(questionslist)):
            for y in range(len(contextlist)):
                if not x != y or (len(questionslist) == 1 or len(contextlist) == 1):
                    question = questionslist[x]
                    context = contextlist[y]
                    questions.append(question)
                    contexts.append(context)
    
    QA_input = {
    'question': questions,
    'context': contexts
    }
    returnx = nlp(**QA_input, handle_impossible_answer = False)
    return returnx

def get_input_filtered(ctx, req_dict):

    if req_dict["input"] is None or req_dict["parameters"] is None or req_dict:
        send_error(ctx, "empty dict")
            
    if "queries" not in req_dict or "relations" not in req_dict:
        send_error(ctx, "no queries or relations given")
    
    if len(relationslist) != 1 and len(querieslist) != len(relationslist):
        send_error(ctx, "your query length should match your relations or wth")

    req_params = req_dict["parameters"]
    querieslist = req_dict["queries"]
    relationslist = req_dict["relations"]

    whitelist = [
        "decode_method",
        "num_generate"
    ]

    default_values = {
        "decode_method": "beam",
        "num_generate": 1
    }

    for x in default_values:
        if x not in req_params or req_params[x] is None:
            req_params[x] = default_values[x]

    for key in list(req_params):
        if key not in whitelist:
            del req_params[key]

    for relations in relationslist:
        for relation in relations:
            if relation not in all_relations:
                relations.remove(relation)

    return req_dict


def send_error(ctx, message):
    ch, method, props, _ = ctx
    ch.basic_publish(
        exchange="",
        routing_key=props.reply_to,
        properties=pika.BasicProperties(
            correlation_id=props.correlation_id, content_type="application/json"
        ),
        body=json.dumps({"error": str(message)}).encode("utf-8"),
    )
    ch.basic_ack(delivery_tag=method.delivery_tag)


def node_ok():
    return

def spy():
    global total_sent
    curr_time = time.time()
    test = 0
    machine_id = socket.gethostname()
    if test == 0:
        processing_time = 0
        total_sent = 0
        test = 1

    while 1:
        total_sent += q.qsize()
        for x in range(0, q.qsize()):
            processing_time += q.get()
        
        if (time.time() - curr_time) >= 60:
            registry = CollectorRegistry() 
            ga = Gauge('compute_time', 'Compute time used in a minute', registry=registry)
            gb = Gauge('answered_per_min', 'Answered requests in a minute', registry=registry)
            ga.set(processing_time) 
            gb.set(total_sent)
            push_to_gateway('http://104.248.82.249:9091', grouping_key={"instance": machine_id}, job=model_name, registry=registry)
            logger.info(f"answered {str(total_sent)} requests in {str(processing_time)} seconds.")
            curr_time = 0
            total_sent = 0
            processing_time = 0
            curr_time = time.time()

        time.sleep(0.1)

if __name__ == "__main__":
    th2 = Process(target=spy)
    th2.start()
    
    main()