import os
import pandas as pd

from datetime import datetime

from nsys_recipe.lib import helpers
from nsys_recipe.lib import loader
from nsys_recipe.lib import nvtx
from nsys_recipe.lib import recipe
from nsys_recipe.lib.args import Option
from nsys_recipe.lib.recipe import Context

class NcclSum(recipe.Recipe):
    display_name = 'NCCL Summary'
    description = """
    This recipe provides a summary of NCCL functions and their execution
    times.
"""

    @staticmethod
    def mapper_func(nsysrep, parsed_args):
        sqlite_file = helpers.nsysrep_to_sqlite(nsysrep)
        if sqlite_file is None:
            return None

        results = dict()

        tableNames = {
            loader.k_table_str : None, 
            loader.k_table_nvtx : None,
            }
        
        if parsed_args.gpu:
            tableNames[loader.k_table_cuda_api] = [loader.k_column_correlationId, 'globalTid', 'start']
            tableNames[loader.k_table_cuda_kernel] = [loader.k_column_correlationId, 'start', 'end']
            tableNames[loader.k_table_cuda_memcpy] = [loader.k_column_correlationId, 'start', 'end']
            tableNames[loader.k_table_cuda_memset] = [loader.k_column_correlationId, 'start', 'end']
        
        file=loader.load_file( sqlite_file, tableNames)    
        #print(file.keys())
        
        str_df = file[loader.k_table_str]
        ranges_write_df = file[loader.k_table_nvtx]
        ranges_df = ranges_write_df

        #75 - NvtxDomainCreate
        domains_df = ranges_df.loc[ranges_df['eventType'] == 75, ['domainId', 'text']]
        domains_df = domains_df.loc[domains_df['text'] == 'NCCL'].reset_index()
        target_domainId = domains_df.at[0,'domainId']


        #combine text & textId together
        ranges_df = nvtx.combine_text_fields(str_df, ranges_df, ranges_write_df)

        #filter to 
        #target domain        
        #eventType 59 - NvtxPushPopRange
        #eventType 60 - NvtxStartEndRange
        #remove invalid ranges (start or end are missing)
        ranges_df = ranges_df.loc[(ranges_df['domainId'] == target_domainId) 
                                  & (ranges_df['eventType'].isin([59,60]) == True)
                                  & (ranges_df['start'].isnull() == False) 
                                  & (ranges_df['end'].isnull() == False)]
    
        #do GPU projection if requested
        if parsed_args.gpu:
            cuda_api_df = file[loader.k_table_cuda_api]
            cuda_graph_api_df =file.get(loader.k_table_cuda_graph_api, None)
            cuda_kernel_df = file[loader.k_table_cuda_kernel]
            cuda_memcopies_df = file[loader.k_table_cuda_memcpy]
            cuda_memsets_df = file[loader.k_table_cuda_memset]

            #make projection cheaper by reducing NVTX to just ranges of interest
            #ranges_df = ranges_df.loc[ ranges_df['text'].isin(paceCfg.pace_name_list) == True ]
            #do projection to get new ranges
            ranges_df = nvtx.cuda_projection( str_df, ranges_df, cuda_api_df, cuda_graph_api_df, cuda_kernel_df, cuda_memcopies_df, cuda_memsets_df )

            #swap new columns in for old columns
            ranges_df = ranges_df.drop(columns=['start', 'end'])
            ranges_df = ranges_df.rename({'gpu_start':'start','gpu_end':'end','gpu_duration':'duration'},axis='columns')
        else:
            #calculate duration (NOTE: done in gpu projection)
            ranges_df['duration'] = ranges_df['end'] - ranges_df['start']

        #remove ranges where duration is null, zero, or negative
        ranges_df = ranges_df.loc[(ranges_df['duration'].isnull() == False) & (ranges_df['duration'] > 0)]

        nvtx_gdf = ranges_df[['text', 'duration']].groupby('text', sort=False)

        nvtx_target_gds = nvtx_gdf['duration']

        nvtx_stats_df = pd.DataFrame( nvtx_target_gds.describe() )
        #nvtx_stats_df['rank'] = rank

        results = dict() 
        results['filePath'] = os.path.splitext(os.path.basename(nsysrep))[0]
        results['stats'] = nvtx_stats_df

        return results

    def save_data(self, data):
        files_df = data['files']
        dstPath = self.add_output_file('files.parquet', 'Table')
        files_df.to_parquet(dstPath, index=False)

        all_stats_df = data['all_stats']
        dstPath = self.add_output_file('all_stats.parquet', 'Table')
        all_stats_df.to_parquet(dstPath, index=True)

        rank_stats_df = data['rank_stats']
        dstPath = self.add_output_file('rank_stats.parquet', 'Table')
        rank_stats_df.to_parquet(dstPath, index=True)

    def reducer_func(self, mapper_results_list):
        #filter files that didn't export to map properly
        mapper_results_list = helpers.filter_none(mapper_results_list)
        #sort files by filename so reasonably ordered
        mapper_results_list.sort(key=lambda x: x.get('filePath'))

        #assign a rank
        rank = 0
        stats_list = list()
        for mapper_result in mapper_results_list:
            mapper_result['rank'] = rank
            rank += 1
                    
        results = dict() 

        rank = 0
        file_list = dict()
        for mapper_result in mapper_results_list:
            file_list[mapper_result['rank']] = str(mapper_result['filePath'])
            rank += 1
        results['files'] = pd.DataFrame({"File":pd.Series(file_list)})

        stats_list = list()
        for mapper_result in mapper_results_list:
            stats_df = mapper_result['stats']
            stats_df['sum'] = stats_df['mean'] * stats_df['count']
            stats_df['rank'] = mapper_result['rank']
            stats_df['file'] = mapper_result['filePath']
            stats_list.append( stats_df )
            #print(mapper_result['stats'])
        
        stats_concat_df=None
        if len(stats_list) <= 1:
            stats_concat_df = stats_list[0]
        else:
            stats_concat_df = pd.concat(stats_list)
            #print(stats_concat_df)

            stats_gdf = stats_concat_df.groupby('text')

            sum_total_ds = stats_gdf['sum'].sum()
            count_total_ds = stats_gdf['count'].sum()
            mean_ds = sum_total_ds / count_total_ds

            stats_df = pd.DataFrame({
                "Min":stats_gdf['min'].min(), 
                "Max":stats_gdf['max'].max(),
                "Mean": mean_ds, 
                "Q1 (approx)":stats_gdf['25%'].min(), 
                "Median (approx)":stats_gdf['50%'].median(), 
                "Q3 (approx)":stats_gdf['75%'].max(), 
                "Sum Total":sum_total_ds,
                "Sum Min":stats_gdf['sum'].min(), 
                "Sum Median":stats_gdf['sum'].median(), 
                "Sum Max":stats_gdf['sum'].max(), 
                "Count Total":count_total_ds,
                "Count Min":stats_gdf['count'].min(), 
                "Count Median":stats_gdf['count'].median(), 
                "Count Max":stats_gdf['count'].max(), 
                }, 
                index=stats_gdf.groups.keys())

        stats_concat_df = stats_concat_df.rename({
            "min":"Min",
            "max":"Max",
            "mean":"Mean",
            "25%":"Q1",
            "50%":"Median",
            "75%":"Q3",
            "std":"StdDev",
            "sum":"Sum",
            'count':"Count"
            }, axis='columns')
        
        results['all_stats']=stats_df
        results['rank_stats']=stats_concat_df
    
        #print('Reducer End')

        self.save_data(results)

    def save_metadata(self):
        self._analysis_dict.update({
            'EndTime': str(datetime.now()),
            'InputReports': self._parsed_args.dir,
            'Outputs': self._output_files
        })
        self.create_analysis_file()

    def run(self, context):
        super().run(context)
        mapper_res = context.wait(context.map(
            self.mapper_func,
            self._parsed_args.dir,
            parsed_args=self._parsed_args
        ))
        self.reducer_func(mapper_res)

        self.create_notebook('stats.ipynb', 'nsys_pres.py')
        self.save_metadata()

    @classmethod
    def get_argument_parser(cls):
        parser = super().get_argument_parser()

        parser.add_recipe_argument(Option.OUTPUT)
        parser.add_recipe_argument(Option.FORCE_OVERWRITE)
        parser.add_recipe_argument(
            '--gpu',
            action='store_true',
            help="GPU projection")
        parser.add_recipe_argument(Option.REPORT_DIR, required=True)

        return parser
