import pickle
import importlib
import os
import sys
import glob
import math
import re
from collections import deque
import pandas as pd
import numpy as np

# DTSP-14650: Code cleanup
class PaceConfig:
    pace_name_list=None 
    nameColumn=None #Ex CUDA='shortName'  NVTX='text'
    startColumn='start'
    endColumn='end'
    durationColumn='duration'    
    
def mapper(ranges_df, #ex file[nsys_rep.k_table_cuda_kernel]
           str_df, #file[nsys_rep.k_table_str]
           session_start_df, #file[nsys_rep.k_table_session_start]
           cfg, #PaceConfig
           results):

    #print(ranges_df.shape)
    
    if str_df is None:
        ranges_df = ranges_df.loc[ ranges_df[cfg.nameColumn].isin(cfg.pace_name_list), [cfg.nameColumn, cfg.startColumn, cfg.endColumn] ]    
    else:
        pace_name_str_df = str_df.loc[ str_df['value'].isin(cfg.pace_name_list) == True ]
        #print(pace_name_str_df)

        if(pace_name_str_df.shape[0] == 0):
            return False

        #filter to just the ranges we care about to reduce work.  Less work even for GroupBy
        ranges_df = ranges_df.loc[ ranges_df[cfg.nameColumn].isin(pace_name_str_df['id']), [cfg.nameColumn, cfg.startColumn, cfg.endColumn] ]    
        #print(ranges_df.shape)
        #print(ranges_df)


    #filter out bad ranges
    ranges_df = ranges_df.loc[ (ranges_df[cfg.endColumn].isnull() == False) & (ranges_df[cfg.startColumn].isnull() == False) ]

    if(ranges_df.shape[0] == 0):
        return False

    #calculate duration for later use
    ranges_df[cfg.durationColumn] = ranges_df[cfg.endColumn] - ranges_df[cfg.startColumn]
    #print(ranges_df)
    ranges_df = ranges_df.rename({cfg.nameColumn:"Name"},axis='columns')
    #print(ranges_df)
    
    ranges_gdf = ranges_df.groupby("Name", sort=False)
    #print("Groups: ",ranges_gdf.groups.keys())
    
    ranges_duration_gds = ranges_gdf[cfg.durationColumn]

   
    stats_df = ranges_duration_gds.describe()

    
    if not (str_df is None):
        names_df = pd.DataFrame({'id':ranges_gdf.groups.keys()}).merge( pace_name_str_df, on='id', how='left')
        names_ds = names_df.rename({'value':'Name'},axis='columns')['Name']
        #print("Names: ",names_ds)
        stats_df.index=names_ds
    
    
    #print(stats_df) 
    results['stats'] = stats_df

    results['session_start'] = session_start_df.at[0,'utcEpochNs']
    
    if not (str_df is None):
        pace_name_str_df = pace_name_str_df.set_index('value')
        
    pace_df_map = dict()
    for pace_name in cfg.pace_name_list:
        
        pace_df = None
        
        if str_df is None:
            pace_df = ranges_gdf.get_group(pace_name)
        else:
            pace_name_id = pace_name_str_df.at[ pace_name, 'id' ]
            pace_df = ranges_gdf.get_group(pace_name_id)
            
        pace_df = pace_df.reset_index(drop=True)
        
        #print(pace_df)

        start_ds = pace_df[cfg.startColumn]
        end_ds = pace_df[cfg.endColumn]
        end_shift_ds = end_ds.shift(periods=1, fill_value=0)
        #end_ds = end_ds.iloc[0:start_ds.shape[0]]
        delta_ds = start_ds - end_shift_ds
        pace_df['delta'] = delta_ds
        #print(pace_df)

        #pace_df.iloc[0,'delta']=0 #throw away the first iteration since they are often not started at same time and initializing
        pace_df['delta_accum'] = pace_df['delta'].cumsum()
        pace_df['duration_accum'] = pace_df[cfg.durationColumn].cumsum()
        
        
        pace_df_map[pace_name] = pace_df[[cfg.startColumn, cfg.endColumn, 'delta', 'delta_accum', cfg.durationColumn, 'duration_accum']]
    
    results['pace'] = pace_df_map
    
    return True


def reducer(mapper_results_list, mapper_instance_ctx_list, mapper_shared_ctx, reducer_ctx, cfg, results):

    stats_list = list()
    for mapper_result in mapper_results_list:
        stats_df = mapper_result['stats']
        stats_df['sum'] = stats_df['mean'] * stats_df['count']
        stats_df['rank'] = mapper_result['rank']
        stats_df['file'] = mapper_result['filePath']
        stats_list.append( stats_df )
        #print(mapper_result['stats'])
    stats_concat_df = pd.concat(stats_list)
    #print(stats_concat_df)
    
    stats_gdf = stats_concat_df.groupby('Name')
    
    sum_total_ds = stats_gdf['sum'].sum()
    count_total_ds = stats_gdf['count'].sum()
    mean_ds = sum_total_ds / count_total_ds

    stats_df = pd.DataFrame({
        "Min":stats_gdf['min'].min(), 
        "Max":stats_gdf['max'].max(),
        "Mean": mean_ds, 
        "Q1 (approx)":stats_gdf['25%'].min(), 
        "Median (approx)":stats_gdf['50%'].median(), 
        "Q3 (approx)":stats_gdf['75%'].max(), 
        "Sum Total":sum_total_ds,
        "Sum Min":stats_gdf['sum'].min(), 
        "Sum Median":stats_gdf['sum'].median(), 
        "Sum Max":stats_gdf['sum'].max(), 
        "Count Total":count_total_ds,
        "Count Min":stats_gdf['count'].min(), 
        "Count Median":stats_gdf['count'].median(), 
        "Count Max":stats_gdf['count'].max(), 
        }, 
        index=stats_gdf.groups.keys())
    
    results['stats']=stats_df

    
    #to correct for start times (ie zero on timeline) being slightly different on wall-clock between machines
    #compute the lowerest start time amount
    session_start_min = sys.maxsize
    for mapper_result in mapper_results_list:
        session_start = mapper_result['session_start']
        if( session_start < session_start_min ):
            session_start_min = session_start
    #and save that back into the mapper data encase we want it later.
    for mapper_result in mapper_results_list:
        mapper_result['session_offset'] = mapper_result['session_start'] - session_start_min
                

    #BUILD PACE DataFrames per desired name

    
    #Initialize builders.
    #2 levels of dictionaries
    #1) key = Name, value = Dict()
    #      2) key = rank, value = times
    pace_builders_by_name = dict()
    for pace_name in cfg.pace_name_list:
        pace_builders_by_name[pace_name] = dict() 
            
    #unpack mapper level pace info into concatination lists to make the visualization dataframes
    for mapper_result in mapper_results_list:
        mapper_rank = mapper_result['rank']
        mapper_file = mapper_result['filePath']
        mapper_pace_df_map = mapper_result['pace']
        for pace_name, mapper_pace_df in mapper_pace_df_map.items():
            #adjust for session offset
            mapper_pace_df = mapper_pace_df.reset_index(drop=True)
            mapper_pace_df['rank'] = mapper_rank
            mapper_pace_df['file'] = mapper_file

            session_offset = mapper_result['session_offset']
            if session_offset != 0:
                mapper_pace_df['session_offset'] = session_offset
                mapper_pace_df[cfg.startColumn] = mapper_pace_df[cfg.startColumn] + mapper_pace_df['session_offset']
                mapper_pace_df[cfg.endColumn] = mapper_pace_df[cfg.endColumn] + mapper_pace_df['session_offset']
            
            pace_builders_by_rank = pace_builders_by_name[pace_name]
            pace_builders_by_rank[str(mapper_rank)] = mapper_pace_df #parquet doesn't like numbers as column names
            #pace_builders_by_rank[mapper_file] = mapper_pace_df
            
    #perform the concat on the final list to make the visualization dataframes
    pace_map_by_name = dict()
    for pace_name, pace_builders_by_rank in pace_builders_by_name.items():
        
        pace_map_by_column = dict()
        build_pace_df(pace_map_by_column, pace_builders_by_rank, cfg.startColumn)
        build_pace_df(pace_map_by_column, pace_builders_by_rank, cfg.endColumn)
        build_pace_df(pace_map_by_column, pace_builders_by_rank, 'duration_accum')
        build_pace_df(pace_map_by_column, pace_builders_by_rank, 'delta_accum')
        build_pace_df(pace_map_by_column, pace_builders_by_rank, cfg.durationColumn)
        delta_df = build_pace_df(pace_map_by_column, pace_builders_by_rank, 'delta')
    
        #data loaded as:
        #  columns=rank
        #  rows=iteration
        #to do stats on iteration, transpose to have a series per iteration
        delta_stats_df = delta_df.T.describe().T
        delta_stats_df = delta_stats_df.rename_axis("Iteration")
        delta_stats_df = delta_stats_df.rename(
            {"min":"Min",
             "25%":"Q1", 
             "50%":"Median",
             "75%":"Q3", 
             "max":"Max",
             "count":"Count",
             "std":"Std",
             "mean":"Mean",
            }, 
            axis='columns')
        pace_map_by_column['delta_stats'] = delta_stats_df
    
        pace_map_by_name[pace_name] = pace_map_by_column
        
    #save results to return to recipe
    results['pace'] = pace_map_by_name

    #print('Reducer End')
    return results

def build_pace_df( pace_map_by_column, pace_builders_by_rank, columnName ):
    pace_ds_list = dict()
   
    for pace_rank, pace_df in pace_builders_by_rank.items():
        pace_ds_list[pace_rank] = pace_df[columnName].reset_index(drop=True) 

    result = pd.DataFrame(pace_ds_list)
    result = result.rename_axis("Rank", axis='columns')
    result = result.rename_axis("Iteration")
    #result = result.reset_index(drop=True)

    #print(columnName)
    #print(result)

    pace_map_by_column[columnName] = result
    return result
    
def display_pace_graph(figs, pace_name, pace_map_by_column, pace_column, start=1):
        
    pace_df = pace_map_by_column[pace_column]
    pace_df = pace_df.loc[start:]
    
    __display_pace_graph(figs, pace_df, pace_name, pace_column)
    
    
def display_pace_graph_delta_minus_median(figs, pace_name, pace_map_by_column, start=1):
    
    pace_column = 'delta'
    stats_df = pace_map_by_column['delta_stats']
    median_ds = stats_df['Median']
    
    pace_df = pace_map_by_column[pace_column].copy()
    for columnName, column_ds in list(pace_df.iteritems()):
        pace_df[columnName] = column_ds - median_ds
    
    pace_df = pace_df.loc[start:]
    __display_pace_graph(figs, pace_df, pace_name, "variance of "+pace_column)
    

def __display_pace_graph(figs, pace_df, pace_name, pace_column):
    
    #display(pace_df)

    import warnings
    warnings.filterwarnings('ignore')

    fig = pace_df.plot.line()
    fig.update_layout(
        yaxis_title = "Time",
        title = "Progress - Iterations defined by "+pace_column+" of " + pace_name, 
        )
    fig.show()
    figs.append(fig)
    
    
    fig = pace_df.T.plot.line()
    fig.update_layout(
        yaxis_title = "Time",
        title = "Consistency - Iterations defined by "+pace_column+" of " + pace_name, 
        )
    fig.show()
    figs.append(fig)

