import os
import pandas as pd

from datetime import datetime

from nsys_recipe.lib import helpers
from nsys_recipe.lib import recipe
from nsys_recipe.lib.args import Option
from nsys_recipe.lib.recipe import Context

class NvtxGpuProjTrace(recipe.Recipe):
    display_name = 'NVTX GPU Trace'
    description = """
    This recipe provides a trace of NVTX time ranges projected from the CPU
    onto the GPU. Each NVTX range contains one or more GPU operations. A GPU
    operation is considered to be "contained" by an NVTX range if the CUDA API
    call used to launch the operation is within the NVTX range. Only ranges
    that start and end on the same thread are taken into account.

    The projected range will have the start timestamp of the first enclosed GPU
    operation and the end timestamp of the last enclosed GPU operation, as well
    as the stack state and relationship to other NVTX ranges.
"""

    @staticmethod
    def mapper_func(nsysrep, parsed_args):
        sqlite_file = helpers.nsysrep_to_sqlite(nsysrep)
        if sqlite_file is None:
            return None

        stats_cls = helpers.get_stats_cls('nvtx_gpu_proj_trace', 'NvtxGpuProjTrace')
        return helpers.stats_cls_to_df(sqlite_file, parsed_args, stats_cls)

    def reducer_func(self, dfs):
        dfs = helpers.filter_none(dfs)
        df = pd.concat(dfs)

        # Remove any tags or hidden columns that are for internal use.
        df.columns = df.columns.str.replace('(:).*', '', regex=True)
        df = df.loc[:, ~df.columns.str.startswith('_')]

        df.to_parquet(self.add_output_file('trace.parquet'))

    def save_metadata(self):
        self._analysis_dict.update({
            'EndTime': str(datetime.now()),
            'InputReports': self._parsed_args.dir,
            'Outputs': self._output_files
        })
        self.create_analysis_file()

    def run(self, context):
        super().run(context)
        mapper_res = context.wait(context.map(
            self.mapper_func,
            self._parsed_args.dir,
            parsed_args=self._parsed_args
        ))
        reducer_res = self.reducer_func(mapper_res)

        self.create_notebook('trace.ipynb', 'nsys_display.py')
        self.save_metadata()

    @classmethod
    def get_argument_parser(cls):
        parser = super().get_argument_parser()

        parser.add_recipe_argument(Option.OUTPUT)
        parser.add_recipe_argument(Option.FORCE_OVERWRITE)
        parser.add_recipe_argument(Option.START)
        parser.add_recipe_argument(Option.END)
        parser.add_recipe_argument(Option.NVTX)
        parser.add_recipe_argument(Option.REPORT_DIR, required=True)

        return parser
