import json
import networkx as nx
import matplotlib.pyplot as plt

from loguru import logger

OUTPUT_FILE = "variables.txt"

def min_max_scores(content):
    logger.info("Computing scores for all scenes. It may take a bit of time...")
    scores = dict()
    #Ierate over chapters
    for idx_chapter, chapter in enumerate(content["Chapters"]):
        #Iterate over scenes
        for idx_scene, scene in enumerate(chapter["Scenes"]):
            scene_name = scene["Title"]
            
            if idx_chapter == 0 and idx_scene == 2:
                scores[scene_name] = {}
            elif idx_chapter == 1 and idx_scene == 1:
                scores[scene_name] = {}
            else:
                scores[scene_name] = find_scene_scores(scene)
            print(scene_name+" ("+str(idx_scene+1)+") done")
    return scores
            
def find_scene_scores(scene):
    graph, end_nodes = create_graph(scene)
    
    #scene_min_score = 10000
    #scene_max_score = -1000
    
    max_score=-100
    min_score=100
    
    max_authenticiy=-100
    min_authenticiy=100
    
    max_respect=-100
    min_respect=100
    
    max_compassion=-100
    min_compassion=100
    
    max_hope=-100
    min_hope=100
    
    max_empathy=-100
    min_empathy=100
   
    # Iterate over each end node
    for end_nodes in end_nodes:
        #paths =nx.all_simple_paths(graph,source=1,target=end_nodes)
        paths2 = nx.all_simple_edge_paths(graph,source=1,target=end_nodes)

        #Iterate over all paths
        for path in paths2:
            
            path_score_weight = 0
            path_authenticity_weight = 0
            path_respect_weight = 0
            path_compassion_weight = 0
            path_hope_weight = 0
            path_empathy_weight = 0
            
            for edge in path:
                scores = graph.get_edge_data(edge[0], edge[1])[edge[2]]
                
                path_score_weight += scores["score"]
                path_authenticity_weight += scores["authenticiy"]
                path_respect_weight += scores["respect"]
                path_compassion_weight += scores["compassion"]
                path_hope_weight += scores["hope"]
                path_empathy_weight += scores["empathy"]
                
            #path_weight = nx.path_weight(graph,path,weight=weight_name)

            # if the score of the current path is lower than the scene_min_score, we replace scene_min_score by the weight of the current path
            if path_score_weight < min_score:
                min_score = path_score_weight
            
            if path_authenticity_weight < min_authenticiy:
                min_authenticiy = path_authenticity_weight
                
            if path_respect_weight < min_respect:
                min_respect = path_respect_weight
                
            if path_compassion_weight < min_compassion:
                min_compassion = path_compassion_weight
                
            if path_hope_weight < min_hope:
                min_hope = path_hope_weight
                
            if path_empathy_weight < min_empathy:
                min_empathy = path_empathy_weight
                
                
            # if the score of the current path is bigger than the scene_max_score, we replace scene_max_score by the weight of the current path
            if path_score_weight > max_score:
                max_score = path_score_weight
                
            if path_authenticity_weight > max_authenticiy:
                max_authenticiy = path_authenticity_weight
                
            if path_respect_weight > max_respect:
                max_respect = path_respect_weight
                
            if path_compassion_weight > max_compassion:
                max_compassion = path_compassion_weight
                
            if path_hope_weight > max_hope:
                max_hope = path_hope_weight
                
            if path_empathy_weight > max_empathy:
                max_empathy = path_empathy_weight
                
                
    return {
                "max_score": max_score,
                "min_Score": min_score,
                "max_authenticiy": max_authenticiy,
                "min_authenticiy": min_authenticiy,
                "max_respect": max_respect,
                "min_respect": min_respect,
                "max_compassion": max_compassion,
                "min_compassion": min_compassion,
                "max_hope": max_hope,
                "min_hope": min_hope,
                "max_empathy": max_empathy,
                "min_empathy": min_empathy
            }

def create_graph(scene):
    #graph = nx.DiGraph()
    graph = nx.MultiDiGraph()
    end_nodes = set()
    
    for interaction in scene["Interactions"]:
        current_node = interaction["Id"]
          
        for choice in interaction['Responses']:
            child_node = choice["NextInteractionID"]

            score = choice["Authenticity"] + choice["Respect"] + choice["Compassion"] + choice["Hope"] + choice["Empathy"]
            if child_node != -1:
                graph.add_edge(current_node,child_node, score=score, authenticiy=choice["Authenticity"], respect=choice["Respect"], compassion=choice["Compassion"], hope=choice["Hope"], empathy=choice["Empathy"])
            else:
                end_nodes.add(current_node)

    return graph, end_nodes

def normalize_string(text):
    """Normalize an input string"""
    return text.lower().replace(" ", "_").replace("é", "e") \
        .replace("ê", "e").replace("è", "e").replace("à", "a")

def show_graph(graph, attribut):
    pos = nx.spring_layout(graph)
    nx.draw_networkx_nodes(graph, pos, node_size=300)
    nx.draw(graph, pos, with_labels=True, connectionstyle='arc3, rad = 0.1')
    edge_labels=dict([((u,v,),d[attribut])
                for u,v,d in graph.edges(data=True)])
    plt.show()
           
def write_vars_to_file(scores):
    logger.info("Writing scores int ouput file "+OUTPUT_FILE)
    with open(OUTPUT_FILE, 'w', encoding='utf-8') as f:
        for key in scores:
            for emotion in scores[key]: 
                #f.write("@set "+normalize_string(key.upper())+"_"+normalize_string(emotion.upper())+"="+str(scores[key][emotion])+"\n")
                f.write("  - name: "+normalize_string(key.upper())+"_"+normalize_string(emotion.upper())+"\n")  
                f.write("    value: "+str(scores[key][emotion])+"\n")   
        f.write("@stop") 

if __name__ == "__main__":
    
    logger.info("Testing find_scene_scores for first scene...")
    
    with open("Chapters.json", 'r', encoding="utf-8") as j:
        contents = json.loads(j.read())
        scene = contents["Chapters"][0]["Scenes"][3]
        
        scores = find_scene_scores(scene)
        
        print(scores)
                
        
  
        