import importlib
import inspect
import json
import logging
import multiprocessing
import os
import shutil
import sys

from datetime import datetime
from enum import Enum
from functools import partial

from nsys_recipe.lib import exceptions
from nsys_recipe.lib import helpers
from nsys_recipe.lib import args
from nsys_recipe.lib import nsys_path

class Mode(Enum):
    """Data processing modes"""
    NONE = 0
    CONCURRENT = 1
    DASK_FUTURES = 2

class Context:
    """Base class for data processing engines, each associated to a Mode"""
    mode = None
    context_map = None

    def __enter__(self):
        return self

    def __exit__(*args):
        pass

    def launch(self, function, *args, **kwargs):
        """Wrapper for task execution.

        Parameters
        ----------
        function : callable
            Function to execute.
        *args : tuple
            Positional arguments.
        **kwargs : dict
            Dict of keyword arguments.

        Returns
        -------
        result : return type of function
        """
        raise NotImplementedError

    def map(self, function, *iterables, **kwargs):
        """Execute task for each iterable.

        The function is applied to each element in iterables, which should
        have the same length. The kwargs remains packed and is passed as
        argument to the function. Results are guaranteed to be in the same
        order as the input.

        Parameters
        ----------
        function : callable
            Function to execute.
        *iterables : iterables
            Objects to map over.
        **kwargs : dict
            Dict of keyword arguments.

        Returns
        -------
        result : list
        """
        raise NotImplementedError

    def wait(self, waitable):
        """Wrapper for task completion.

        If waitable is a remote result, wait until computation completes and
        return the value of the variable. This is meant to be used on results
        of launch and map functions.

        Returns
        -------
        result : depends on input
            Returns a list of results if the waitable is a list.
            Returns single element otherwise.
        """
        raise NotImplementedError

    @staticmethod
    def import_module(name):
        try:
            module = importlib.import_module(name)
        except ModuleNotFoundError as e:
            raise exceptions.ModeModuleNotFoundError(e)
        return module

    @classmethod
    def create_context(cls, mode=Mode.CONCURRENT):
        """Create an instance of Context corresponding to mode.

        The first time this is called, create context_map that maps each
        context to its mode.

        Parameters
        ----------
        mode : Mode
            Mode of the context to create.

        Returns
        -------
        context : Context
        """
        if cls.context_map is None:
            keys = Mode
            values = [ContextNone, ContextConcurrent, ContextDaskFutures]
            cls.context_map = dict(zip(keys, values))

        if mode not in cls.context_map:
            raise NotImplementedError

        return cls.context_map[mode]()

class ContextNone(Context):
    """Standard single-threaded mode"""
    mode = Mode.NONE

    def launch(self, function, *args, **kwargs):
        return function(*args, **kwargs)

    def map(self, function, *iterables, **kwargs):
        partial_func = partial(function, **kwargs)
        return [*map(partial(self.launch, partial_func), *iterables)]

    def wait(self, waitable):
        return waitable

class ContextConcurrent(Context):
    """Concurrent mode using concurrent.futures"""
    mode = Mode.CONCURRENT

    def __init__(self, pool=None):
        if pool is not None:
            self._custom = True
            self._pool = pool
        else:
            self._custom = False
            pkg_concurrent_futures = Context.import_module('concurrent.futures')
            self._pool = pkg_concurrent_futures.ProcessPoolExecutor()

    def __enter__(self):
        return self

    def __exit__(self, *args):
        if self._pool and not self._custom:
            self._pool.__exit__(*args)

    def launch(self, function, *args, **kwargs):
        return self._pool.submit(function, *args, **kwargs).result()

    def map(self, function, *iterables, **kwargs):
        partial_func = partial(function, **kwargs)
        return [*self._pool.map(partial_func, *iterables)]

    def wait(self, waitable):
        return waitable

class ContextDaskFutures(Context):
    """Concurrent mode using dask.distributed"""
    mode = Mode.DASK_FUTURES

    def __init__(self, cluster=None):
        pkg_dask_distributed = Context.import_module('distributed')

        if cluster is not None:
            self._custom = True
            self._dask_client = pkg_dask_distributed.Client(cluster)
        else:
            self._custom = False
            scheduler_file = os.getenv('NSYS_DASK_SCHEDULER_FILE')
            if scheduler_file and not os.path.exists(scheduler_file):
                raise exceptions.ValueError("File '{}' does not exist."
                    .format(scheduler_file))
            self._dask_client = pkg_dask_distributed.Client(scheduler_file=scheduler_file)

        def _dask_worker_callback(recipe_pkg_path):
            sys.path.insert(0, recipe_pkg_path)

        # We assume that all worker nodes have the same recipe path.
        recipe_pkg_path = nsys_path.find_installed_file('python/packages')
        partial_callback = partial(_dask_worker_callback, recipe_pkg_path)
        self._dask_client.register_worker_callbacks(setup=partial_callback)

    def __enter__(self):
        return self

    def __exit__(self, *args):
        if self._dask_client and not self._custom:
            self._dask_client.__exit__(*args)

    def launch(self, function, *args, **kwargs):
        return self._dask_client.submit(function, *args, **kwargs)

    def map(self, function, *iterables, **kwargs):
        partial_func = partial(function, **kwargs)
        return self._dask_client.map(partial_func, *iterables)

    def wait(self, waitable):
        if isinstance(waitable, list):
            return self._dask_client.gather(waitable)
        return waitable.result()

class Recipe:
    """Base class for all recipes"""
    script_name = None
    display_name = 'NO NAME'
    description = None

    def __init__(self, parsed_args):
        """Initialize.

        Parameters
        ----------
        parsed_args : argparse.Namespace
            Parsed arguments.
        """
        self._parsed_args = parsed_args
        self._output_dir = None
        self._output_files = {}
        self._analysis_dict = {}

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if exc_type is not None and self._output_dir is not None:
            shutil.rmtree(self._output_dir)

    @property
    def output_dir(self):
        """ Get output directory name without creating it. """
        return self._output_dir

    def _get_unique_output_dir(self):
        i = 1
        if 'script_name' in self._parsed_args and self._parsed_args.script_name is not None:
            script_name = self._parsed_args.script_name
        else:
            script_name = self.get_script_name()
        while os.path.exists('{}-{}'.format(script_name, i)):
            i += 1
        return '{}-{}'.format(script_name, i)

    def _process_output_arg(self):
        if self._parsed_args.output is None:
            return self._get_unique_output_dir()

        if os.path.exists(self._parsed_args.output):
            if not self._parsed_args.force_overwrite:
                print("Failed to create '{}': Directory exists."
                    " Use '--force-overwrite' to overwrite existing directories."
                    .format(self._parsed_args.output))
                return self._get_unique_output_dir()
            else:
                shutil.rmtree(self._parsed_args.output)

        return self._parsed_args.output

    def get_output_dir(self):
        """Return a unique output directory name.

        The first time this is called, create a unique directory where the
        output files are stored.
        Unless the '--output' option is specified with a directory name that
        does not exist or used along the '--force-overwrite' option, a unique
        directory name will be generated with an incrementing id.
        """
        if self._output_dir is None:
            self._output_dir = self._process_output_arg()
            os.makedirs(self._output_dir)

        return self._output_dir

    def add_output_file(self, filename, filetype=None):
        """Get path of the output file.

        Prepend the output directory name to filename.
        If filetype is not None, add it to _output_files so it can later be
        recorded in the nsys-analysis json file.

        Parameters
        ----------
        filename : str
            Output file name.
        filetype : str
            File type to be recorded.

        Returns
        -------
        filepath : str
            Output file path.
        """
        if filetype:
            self._output_files[filename] = filetype
        return os.path.join(self.get_output_dir(), filename)

    def create_analysis_file(self):
        """Create the nsys-analysis json file containing metadata."""
        analysis_filename = self.get_output_dir() + '.nsys-analysis'
        with open(self.add_output_file(analysis_filename), 'w') as f:
            json.dump(self._analysis_dict, f, indent=4)

    def create_notebook(self, notebook_name, helper_filename=None, replace_dict=None):
        """Create output jupyter notebook from an existing template notebook.

        The output notebook is created under the same name as the template.
        Any key strings contained in replace_dict will be replaced by its value.

        Parameters
        ----------
        notebook_name : str
            Name of the template notebook file located in the same directory as
            the recipe script.
        helper_filename : str
            Name of the helper file located in the lib directory to deploy with
            the notebook.
        replace_dict : dict
            Dictionary that contains the string to be replaced and the new
            value.
        """
        nb_output_file = self.add_output_file(notebook_name, 'Notebook')
        nb_template = os.path.join(self.get_script_dir(), notebook_name)

        if replace_dict:
            with open(nb_template, 'r') as f:
                file_content = f.read()

            for key, value in replace_dict.items():
                file_content = file_content.replace(str(key), str(value))

            with open(nb_output_file, 'w') as f:
                f.write(file_content)

        else:
            shutil.copy(nb_template, nb_output_file)

        if helper_filename:
            lib_dir = os.path.dirname(__file__)
            helper_output_file = self.add_output_file(helper_filename)
            shutil.copy(os.path.join(lib_dir, helper_filename), helper_output_file)

    def get_script_dir(self):
        if 'script_dir' in self._parsed_args and self._parsed_args.script_dir is not None:
            return self._parsed_args.script_dir
        return os.path.dirname(inspect.getmodule(self).__file__)

    def run(self, context):
        self._analysis_dict = {
            'RecipeJsonVersion': '1',
            'RecipeRuntimeVersion': '1.0.0',
            'Name': self.display_name,
            'RecipeScript': self.get_script_name(True),
            'Mode': context.mode.name.replace('_', '-').lower(),
            'StartTime': str(datetime.now())
        }

    @classmethod
    def get_script_name(cls, extension=False):
        """Get script name of the (derived) Recipe class.

        Parameters
        ----------
        extension : bool
            If set to True, give the script name including the extension.
        """
        if cls.script_name is None:
            cls.script_name = os.path.basename(inspect.getmodule(cls).__file__)
        if not extension:
            return os.path.splitext(cls.script_name)[0]
        return cls.script_name

    @classmethod
    def get_argument_parser(cls):
        """Get default argument parser."""
        parser = args.ArgumentParser(
            prog=cls.get_script_name(True),
            description=cls.description)

        return parser
