import os
import pandas as pd

from datetime import datetime

import nsysstats

from nsys_recipe.lib import helpers
from nsys_recipe.lib import recipe
from nsys_recipe.lib.analysis_details import AnalysisDetails
from nsys_recipe.lib.args import Option
from nsys_recipe.lib.recipe import Context

class GpuMetricUtilReport(nsysstats.Report):
    create_bin_table = """
    CREATE TEMP TABLE BIN (
        rangeId   INTEGER PRIMARY KEY   NOT NULL
    )
"""

    insert_bin_table = """
    INSERT INTO temp.BIN
    WITH RECURSIVE
        range AS (
            SELECT
                0 AS rangeId
            UNION ALL
            SELECT
                rangeId + 1 AS rangeId
            FROM
                range
            LIMIT {NUM_BINS}
        )
    SELECT rangeId FROM range
"""

    query_metrics = """
WITH
    analysis AS (
        SELECT
            duration
        FROM
            ANALYSIS_DETAILS
    ),
    metrics AS (
        SELECT
            rawTimestamp AS start,
            typeId & 0xFF AS gpu,
            LEAD (rawTimestamp) OVER (PARTITION BY typeId) end,
            CAST(JSON_EXTRACT(data, '$.SM Active') as INT) AS smActive,
            CAST(JSON_EXTRACT(data, '$.SM Issue') as INT) AS smIssue,
            CAST(JSON_EXTRACT(data, '$.Tensor Active') as INT) AS tensorActive
        FROM
            GENERIC_EVENTS
    ),
    range AS (
        SELECT
            0 AS start,
            duration AS end,
            {BIN_SIZE} AS binSize,
            gpu
        FROM
            metrics
        JOIN
            analysis
        GROUP BY gpu
    ),
    bin AS (
        SELECT
            bin.rangeId,
            bin.rangeId * range.binSize + range.start AS cstart,
            min(bin.rangeId * range.binSize + range.start + range.binSize, range.end) AS cend,
            binSize,
            range.gpu
        FROM
            temp.BIN AS bin
        JOIN
            range
            ON cstart < cend
    ),
    utilization AS (
        SELECT
            bin.rangeId,
            bin.cstart AS start,
            bin.cend AS end,
            sum(CAST(min(metrics.end, bin.cend) - max(metrics.start, bin.cstart) AS FLOAT) * smActive) / (bin.cend - bin.cstart) AS smActiveAverage,
            sum(CAST(min(metrics.end, bin.cend) - max(metrics.start, bin.cstart) AS FLOAT) * smIssue) / (bin.cend - bin.cstart) AS smIssueAverage,
            sum(CAST(min(metrics.end, bin.cend) - max(metrics.start, bin.cstart) AS FLOAT) * tensorActive) / (bin.cend - bin.cstart) AS tensorActiveAverage,
            bin.gpu
        FROM
            bin
        LEFT JOIN
            metrics
            ON      metrics.gpu == bin.gpu
                AND metrics.start < bin.cend
                AND metrics.end > bin.cstart
        GROUP BY
            bin.rangeId, bin.gpu
    )
SELECT
    rangeId * {BIN_SIZE} * 1e-9 AS "Duration",
    round(smActiveAverage, 1) AS "SmActive",
    round(smIssueAverage, 1) AS "SmIssue",
    round(tensorActiveAverage, 1) AS "TensorActive",
    gpu AS "GPU"
FROM
    utilization
LIMIT {ROW_LIMIT}
"""

    table_checks = {
        'ANALYSIS_DETAILS':
            '{DBFILE} does not contain analysis details.',
        'GENERIC_EVENTS':
            '{DBFILE} does not contain GPU metric data.'
    }

    def setup(self):
        err = super().setup()
        if err != None:
            return err

        self.statements = [
            self.create_bin_table,
            self.insert_bin_table.format(NUM_BINS = self.parsed_args.bins)]

        self.query = self.query_metrics.format(
            ROW_LIMIT = self.parsed_args.rows,
            BIN_SIZE = self.parsed_args.binSize)

class GpuMetricUtilMap(recipe.Recipe):
    display_name = 'GPU Metric Utilization Heatmap'
    description = """
    This recipe calculates the percentage of SM Active, SM Issue, and
    Tensor Active metrics.
"""

    @staticmethod
    def mapper_func(nsysrep, parsed_args):
        sqlite_file = helpers.nsysrep_to_sqlite(nsysrep)
        if sqlite_file is None:
            return None

        return helpers.stats_cls_to_df(
            sqlite_file, parsed_args, GpuMetricUtilReport)

    def reducer_func(self, dfs):
        dfs = helpers.filter_none(dfs)
        df = pd.concat(dfs)
        df = df[['Duration', 'SmActive', 'SmIssue', 'TensorActive', 'GPU', 'Report']]
        df.to_parquet(self.add_output_file('analysis.parquet'))

    def save_metadata(self):
        self._analysis_dict.update({
            'EndTime': str(datetime.now()),
            'InputReports': self._parsed_args.dir,
            'Outputs': self._output_files
        })
        self.create_analysis_file()

    def run(self, context):
        super().run(context)
        analysis_details_df, filtered_dir = AnalysisDetails.get_details(context, self._parsed_args)
        self._parsed_args.binSize = analysis_details_df['duration'].max() // self._parsed_args.bins

        mapper_res = context.wait(context.map(
            self.mapper_func,
            filtered_dir,
            parsed_args=self._parsed_args
        ))
        self.reducer_func(mapper_res)

        self.create_notebook('heatmap.ipynb', replace_dict={'REPLACE_BIN': self._parsed_args.bins})
        self.save_metadata()

    @classmethod
    def get_argument_parser(cls):
        parser = super().get_argument_parser()

        parser.add_recipe_argument(Option.OUTPUT)
        parser.add_recipe_argument(Option.FORCE_OVERWRITE)
        parser.add_recipe_argument(Option.START)
        parser.add_recipe_argument(Option.END)
        parser.add_recipe_argument(Option.NVTX)
        parser.add_recipe_argument(Option.ROWS)
        parser.add_recipe_argument(Option.BINS)
        parser.add_recipe_argument(Option.REPORT_DIR, required=True)

        return parser
