import os
import pandas as pd

from datetime import datetime

from nsys_recipe.lib import helpers
from nsys_recipe.lib import recipe
from nsys_recipe.lib.args import Option
from nsys_recipe.lib.recipe import Context

class GpuGaps(recipe.Recipe):
    display_name = 'GPU Gaps'
    description = """
    This recipe identifies time regions where a GPU is idle for longer than a
    set threshold. For each process, each GPU device is examined, and gaps are
    found within the time range that starts with the beginning of the first GPU
    operation on that device and ends with the end of the last GPU operation on
    that device. Profiling overheads are taken into account to exclude GPU gaps
    that cannot be addressed by the user.
"""

    @staticmethod
    def mapper_func(nsysrep, parsed_args):
        sqlite_file = helpers.nsysrep_to_sqlite(nsysrep)
        if sqlite_file is None:
            return None

        stats_cls = helpers.get_stats_cls('gpu_gaps', 'GpuGaps')
        return helpers.stats_cls_to_df(sqlite_file, parsed_args, stats_cls)

    def reducer_func(self, dfs):
        dfs = helpers.filter_none(dfs)
        df = pd.concat(dfs)

        # Remove any tags or hidden columns that are for internal use.
        df.columns = df.columns.str.replace('(:).*', '', regex=True)
        df = df.loc[:, ~df.columns.str.startswith('_')]

        df.to_parquet(self.add_output_file('analysis.parquet'))

    def save_metadata(self):
        self._analysis_dict.update({
            'EndTime': str(datetime.now()),
            'InputReports': self._parsed_args.dir,
            'Outputs': self._output_files
        })
        self.create_analysis_file()

    def run(self, context):
        super().run(context)
        mapper_res = context.wait(context.map(
            self.mapper_func,
            self._parsed_args.dir,
            parsed_args=self._parsed_args
        ))
        self.reducer_func(mapper_res)

        self.create_notebook('analysis.ipynb', 'nsys_display.py')
        self.save_metadata()

    @classmethod
    def get_argument_parser(cls):
        parser = super().get_argument_parser()

        parser.add_recipe_argument(Option.OUTPUT)
        parser.add_recipe_argument(Option.FORCE_OVERWRITE)
        parser.add_recipe_argument(Option.START)
        parser.add_recipe_argument(Option.END)
        parser.add_recipe_argument(Option.NVTX)
        parser.add_recipe_argument(Option.ROWS)
        parser.add_recipe_argument(
            '--gap',
            metavar='threshold',
            type=int,
            default=500,
            help="Minimum duration of GPU gaps in milliseconds")
        parser.add_recipe_argument(Option.REPORT_DIR, required=True)

        return parser
