import importlib
import numpy as np
import os
import pandas as pd

from collections import OrderedDict
from datetime import datetime

from nsys_recipe.lib import helpers
from nsys_recipe.lib import recipe
from nsys_recipe.lib.args import Option
from nsys_recipe.lib.recipe import Context

class CudaGpuKernSum(recipe.Recipe):
    display_name = 'CUDA GPU Kernel Summary'
    index_col = 'Name'
    description = """
    This recipe provides a summary of CUDA kernels and their execution times.
"""

    @staticmethod
    def mapper_func(nsysrep, parsed_args):
        sqlite_file = helpers.nsysrep_to_sqlite(nsysrep)
        if sqlite_file is None:
            return None

        stats_cls = helpers.get_stats_cls('cuda_gpu_kern_sum', 'CudaGpuKernSum')
        return helpers.stats_cls_to_df(sqlite_file, parsed_args, stats_cls)

    def save_per_rank_df(self, df):
        df = df.set_index(self.index_col)
        df = df[['Time', 'Total Time', 'Instances', 'Avg', 'Q1', 'Med', 'Q3',
            'Min', 'Max', 'StdDev', 'Report']]

        df.to_parquet(self.add_output_file('per-rank.parquet'))

    def save_all_ranks_df(self, df):
        grouped = df.groupby(self.index_col)

        d = OrderedDict()
        total_time = grouped['Total Time'].sum()
        d['Time'] = (total_time / total_time.sum() * 100).round(1)
        d['Total Time'] = total_time
        d['Instances'] = grouped['Instances'].sum()
        d['Avg'] = grouped.apply(
            lambda x: np.average(x['Avg'], weights=x['Instances']))
        d['Q1'] = grouped['Q1'].min()
        d['Med'] = grouped['Med'].median()
        d['Q3'] = grouped['Q3'].max()
        d['Min'] = grouped['Min'].min()
        d['Max'] = grouped['Max'].max()
        d['StdDev'] = grouped.apply(
            lambda x: helpers.stddev(x, d))
        min_value = grouped['Min'].min()
        d['Min Report'] = grouped.apply(
            lambda x: ', '.join(x.loc[x['Min'] == min_value.loc[x.name], 'Report']))
        max_value = grouped['Max'].max()
        d['Max Report'] = grouped.apply(
            lambda x: ', '.join(x.loc[x['Max'] == max_value.loc[x.name], 'Report']))

        df = pd.concat(d.values(), axis=1, keys=d.keys())
        df = df.sort_values(by=['Total Time'], ascending=False)
        df.to_parquet(self.add_output_file('all-ranks.parquet'))

    def reducer_func(self, dfs):
        dfs = helpers.filter_none(dfs)
        df = pd.concat(dfs)

        # Remove any tags or hidden columns that are for internal use.
        df.columns = df.columns.str.replace('(:).*', '', regex=True)
        df.columns = df.columns.str.lstrip('_')

        self.save_per_rank_df(df)
        self.save_all_ranks_df(df)

    def save_metadata(self):
        self._analysis_dict.update({
            'EndTime': str(datetime.now()),
            'InputReports': self._parsed_args.dir,
            'Outputs': self._output_files
        })
        self.create_analysis_file()

    def run(self, context):
        super().run(context)
        if self._parsed_args.diff:
            importlib.import_module('nsys_recipe.diff_sum.diff_sum').run(self, context)
            return

        mapper_res = context.wait(context.map(
            self.mapper_func,
            self._parsed_args.dir,
            parsed_args=self._parsed_args
        ))
        self.reducer_func(mapper_res)

        self.create_notebook('stats.ipynb', 'nsys_display.py')
        self.save_metadata()

    @classmethod
    def get_argument_parser(cls):
        parser = super().get_argument_parser()

        mutually_exclusive_group = parser.recipe_group.add_mutually_exclusive_group(required=True)
        parser.add_argument_to_group(mutually_exclusive_group, Option.REPORT_DIR)
        parser.add_argument_to_group(mutually_exclusive_group, Option.DIFF)

        parser.add_recipe_argument(Option.OUTPUT)
        parser.add_recipe_argument(Option.FORCE_OVERWRITE)
        parser.add_recipe_argument(Option.START)
        parser.add_recipe_argument(Option.END)
        parser.add_recipe_argument(Option.NVTX)
        parser.add_recipe_argument(Option.BASE)
        parser.add_recipe_argument(Option.MANGLED)

        return parser
