'''
This script generates scores for the TREC KBA 2012 Cumulative Citation
Recommendation Task, described here:

http://trec-kba.org/kba-ccr-2012.shtml

last updated September 26, 2012

Direction questions & comments to the TREC KBA forums:
http://groups.google.com/group/trec-kba

'''
## use float division instead of integer division
from __future__ import division

__usage__ = '''
python KBAscore.py --annotation trec-kba-ccr-2012-judgments-2012JUN22-final.filter-run.txt --run-dir submissions
'''

import os
import csv
import gzip
import argparse
try:
    import matplotlib.pyplot as plt
except ImportError:
    plt = None

def write_graph (path_to_write_graph, Scores):
    '''
    Writes a graph showing the 4 metrics computed
    
    path_to_write_graph: string with graph output destination
    Scores: dict containing the score metrics computed using performance_metrics()
    '''
    plt.figure()
    Precision = list()
    Recall = list()
    Fscore = list()
    Xaxis = list()
    ScaledUtil = list()
    for cutoff in sorted(Scores):
        Xaxis.append(cutoff)
        Recall.append(Scores[cutoff]['R'])
        Precision.append(Scores[cutoff]['P'])
        Fscore.append(Scores[cutoff]['F'])
        ScaledUtil.append(Scores[cutoff]['SU'])

    plt.plot(Xaxis, Precision, label='Precision')
    plt.plot(Xaxis, Recall, label='Recall')
    plt.plot(Xaxis, Fscore, label='F-Score')
    plt.plot(Xaxis, ScaledUtil, label='Scaled Utility')
    plt.xlabel('Cutoff')
    plt.ylim(-0.01, 1.3)
    plt.legend(loc='upper right')        
    plt.savefig(path_to_write_graph)
    plt.close()

def write_performance_metrics (path_to_write_csv, CM, Scores):
    '''
    Writes a CSV file with the performance metrics at each cutoff
    
    path_to_write_csv: string with CSV file destination
    CM: dict, Confusion matrix generated from score_confusion_matrix()
    Scores: dict containing the score metrics computed using performance_metrics()
    '''
    writer = csv.writer(open(path_to_write_csv, 'wb'), delimiter=',')
    ## Write a header
    writer.writerow(['cutoff', 'TP', 'FP', 'FN', 'TN', 'P', 'R', 'F', 'SU'])
    
    ## Write the metrics for each cutoff on a different line
    for cutoff in sorted(CM):
        writer.writerow([cutoff,
                         CM[cutoff]['TP'], CM[cutoff]['FP'], CM[cutoff]['FN'], CM[cutoff]['TN'],
                         Scores[cutoff]['P'], Scores[cutoff]['R'], Scores[cutoff]['F'],
                         Scores[cutoff]['SU']])
        
def performance_metrics (CM):
    '''
    Computes the performance metrics (precision, recall, F-score, scaled utility)
    
    CM: dict containing the confusion matrix calculated from score_confusion_matrix()
    '''
    ## Compute the performance statistics                
    Scores = dict()
    for cutoff in CM:
        Scores[cutoff] = dict()
        
        if CM[cutoff]['TP'] != 0:
            ## Precision
            Scores[cutoff]['P'] = float(CM[cutoff]['TP']) / (CM[cutoff]['TP'] + CM[cutoff]['FP'])
            
            ## Recall
            Scores[cutoff]['R'] = float(CM[cutoff]['TP']) / (CM[cutoff]['TP'] + CM[cutoff]['FN'])
            
            ## F-score
            Scores[cutoff]['F'] = float((2 * Scores[cutoff]['P'] * Scores[cutoff]['R'])) 
            if Scores[cutoff]['F'] != 0:
                Scores[cutoff]['F'] = Scores[cutoff]['F'] / (Scores[cutoff]['P'] + Scores[cutoff]['R'])
            
            ## Scaled Utility from http://trec.nist.gov/pubs/trec11/papers/OVER.FILTERING.pdf
            T11U = 2 * CM[cutoff]['TP'] - CM[cutoff]['FP']
            MaxU = 2 * CM[cutoff]['TP'] + CM[cutoff]['FN']
            T11NU = float(T11U) / MaxU 
            ## MinNU is a tunable parameter
            MinNU = -0.5 
            Scores[cutoff]['SU'] = (max(T11NU, MinNU) - MinNU) / (1 - MinNU)
        else: 
            Scores[cutoff]['P'] = 0
            Scores[cutoff]['R'] = 0
            Scores[cutoff]['F'] = 0
            Scores[cutoff]['SU'] = 0
            
    return Scores

def score_confusion_matrix (path_to_run_file, annotation, cutoff_step, unannotated_is_TN):
    '''
    This function generates the confusion matrix (number of true/false positives
    and true/false negatives.  
    
    path_to_run_file: str, a filesystem link to the run submission 
    annotation: dict, containing the annotation data
    cutoff_step: int, increment between cutoffs
    unannotated_is_TN: boolean, true to count unannotated as negatives
    '''
    
    ## Open the run file    
    if path_to_run_file.endswith('.gz'):
        run_file = gzip.open(path_to_run_file, 'r')
    else:
        run_file = open(path_to_run_file, 'r')
        
    ## Create a dictionary containing the confusion matrix (CM)
    cutoffs = range(0, 999, cutoff_step)
    CM = dict()
    for cutoff in cutoffs:
        CM[cutoff] = dict(TP=0, FP=0, FN=0, TN=0)
    ## Iterate through every row of the run
    for onerow in run_file:
        ## Skip Comments         
        if onerow.startswith('#'):
            continue
        
        row = onerow.split()
        stream_id = row[2]
        urlname = row[3]
        score = int(row[4])        
        
        in_annotation_set = (stream_id, urlname) in annotation
        
        ## In the annotation set and relevant
        if in_annotation_set and annotation[(stream_id, urlname)]:            
            for cutoff in cutoffs:                
                if score > cutoff:
                    ## If above the cutoff: true-positive
                    CM[cutoff]['TP'] += 1                    
                   
        ## In the annotation set and non-relevant                       
        elif in_annotation_set and not annotation[(stream_id, urlname)]:
            for cutoff in cutoffs:
                if score > cutoff:
                    ## Above the cutoff: false-positive
                    CM[cutoff]['FP'] += 1
                else:
                    ## Below the cutoff: true-negative
                    CM[cutoff]['TN'] += 1            
        ## Not in the annotation set so its a negative (if flag is true)
        elif unannotated_is_TN:
            for cutoff in cutoffs:
                if score > cutoff:
                    ## Above the cutoff: false-positive
                    CM[cutoff]['FP'] += 1
                else:
                    ## Below the cutoff: true-negative
                    CM[cutoff]['TN'] += 1    
    
    ## Correct FN for things in the annotation set that are NOT in the run
    ## First, calculate number of true things in the annotation set
    annotation_positives = sum(annotation.itervalues())
    for cutoff in CM:
        ## Then subtract the number of TP at each cutoffs 
        ## (since FN+TP==True things in annotation set)
        CM[cutoff]['FN'] = annotation_positives - CM[cutoff]['TP']
    return CM
    
def load_annotation (path_to_annotation_file):
    '''
    Loads the annotation file into a dict
    
    path_to_annotation_file: string filesystem path to the annotation file
    '''
    annotation_file = csv.reader(open(path_to_annotation_file, 'r'), delimiter='\t')

    annotation = dict()
    for row in annotation_file:
       ## Skip comments
       if row[0][0] == "#":
           continue 
       
       stream_id = row[2]
       urlname = row[3]
       
       ## Add the stream_id and urlname to a hashed dictionary
       ## 0 means that its not central 1 means that it is central
              
       if (stream_id, urlname) in annotation:
           ## 2 means the annotators gave it a yes for centrality
           if int(row[5]) != 2:
                annotation[(stream_id, urlname)] = False
       else:
           ## 2 means the annotators gave it a yes for centrality
           annotation[(stream_id, urlname)] = int(row[5]) == 2 
    
    return annotation
            
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description=__doc__, usage=__usage__)
    parser.add_argument(
        '--run-dir', required=True, dest='run_dir',
        help='path to the directory containing run files')
    parser.add_argument('--annotation', help='path to the annotation file', required=True)
    parser.add_argument(
        '--cutoff-step', type=int, default=50, dest = 'cutoff_step',
        help='step size used in computing scores tables and plots')
    parser.add_argument(
        '--unannotated-is-true-negative', default=False, action='store_true', dest='unan_is_true',
        help='compute scores using assumption that all unannotated documents are true negatives')
    args = parser.parse_args()
 
    ## Load in the annotation data
    annotation = load_annotation(args.annotation)

    print 'This assumes that all run file names end in .gz'

    for run_file in os.listdir(args.run_dir):
        if not run_file.endswith('.gz'):
            continue

        ## take the name without the .gz
        run_file_name = '.'.join(run_file.split('.')[:-1])
        print 'processing: %s.gz' % run_file_name
        
        ## Generate the confusion matrix for a run
        CM = score_confusion_matrix(
            os.path.join(args.run_dir, run_file), 
            annotation, args.cutoff_step, args.unan_is_true)
        
        ## Generate performance statistics for a run
        Scores = performance_metrics(CM)

        ## Print the top F-Score 
        print ' Best F-Score: %.3f' % max([Scores[cutoff]['F'] for cutoff in CM])
        
        ## Output the key peformance statistics
        output_filepath = os.path.join(args.run_dir, run_file_name + str(args.cutoff_step) + '.csv')        
        write_performance_metrics(output_filepath, CM, Scores)
        print ' wrote metrics table to %s' % output_filepath
        
        if not plt:
            print ' not generating plot, because could not import matplotlib'
        else:
            ## Output a graph of the key performance statistics
            graph_filepath = os.path.join(args.run_dir, run_file_name + str(args.cutoff_step) + '.png')
            write_graph(graph_filepath, Scores)
            print ' wrote plot image to %s' % graph_filepath
