#!/usr/bin/python

"""
Given the output of doctest and a file, adjust the doctests so they won't fail.

Doctest failures due to exceptions are ignored.

AUTHORS::

- Nicolas M. Thiéry <nthiery at users dot sf dot net>  Initial version (2008?)

- Andrew Mathas <andrew dot mathas at sydney dot edu dot au> 2013-02-14
  Cleaned up the code and hacked it so that the script can now cope with the
  situations when either the expected output or computed output are empty.
  Added doctest to sage.tests.cmdline
"""

# ****************************************************************************
#       Copyright (C) 2006 William Stein
#                     2009 Nicolas M. Thiery
#                     2013 Andrew Mathas
#                     2014 Volker Braun
#                     2020 Jonathan Kliem
#                     2021 Frédéric Chapoton
#                     2023 Matthias Koeppe
#
#  Distributed under the terms of the GNU General Public License (GPL)
#  as published by the Free Software Foundation; either version 2 of
#  the License, or (at your option) any later version.
#                  https://www.gnu.org/licenses/
# ****************************************************************************

import itertools
import os
import re
import shlex
import subprocess
import sys

from argparse import ArgumentParser, FileType
from pathlib import Path

from sage.doctest.control import DocTestDefaults, DocTestController
from sage.doctest.parsing import parse_file_optional_tags, parse_optional_tags, unparse_optional_tags, update_optional_tags
from sage.env import SAGE_ROOT
from sage.features import PythonModule
from sage.features.all import all_features, module_feature, name_feature
from sage.misc.temporary_file import tmp_filename

parser = ArgumentParser(description="Given an input file with doctests, this creates a modified file that passes the doctests (modulo any raised exceptions). By default, the input file is modified. You can also name an output file.")
parser.add_argument('-l', '--long', dest='long', action="store_true", default=False,
                    help="include tests tagged '# long time'")
parser.add_argument("--distribution", type=str, default='',
                    help="distribution package to test, e.g., 'sagemath-graphs', 'sagemath-combinat[modules]'; sets defaults for --venv and --environment")
parser.add_argument("--venv", type=str, default='',
                    help="directory name of a venv where 'sage -t' is to be run")
parser.add_argument("--environment", type=str, default='',
                    help="name of a module that provides the global environment for tests, e.g., 'sage.all__sagemath_modules'; implies --keep-both and --full-tracebacks")
parser.add_argument("--no-test", default=False, action="store_true",
                    help="do not run the doctester, only rewrite '# optional/needs' tags; implies --only-tags")
parser.add_argument("--full-tracebacks", default=False, action="store_true",
                    help="include full tracebacks rather than '...'")
parser.add_argument("--only-tags", default=False, action="store_true",
                    help="only add '# optional/needs' tags where needed, ignore other failures")
parser.add_argument("--probe", metavar="FEATURES", type=str, default='',
                    help="check whether '# optional/needs' tags are still needed, remove these")
parser.add_argument("--keep-both", default=False, action="store_true",
                    help="do not replace test results; duplicate the test instead, showing both results, and mark both copies '# optional'")
parser.add_argument("--overwrite", default=False, action="store_true",
                    help="never interpret a second filename as OUTPUT; overwrite the source files")
parser.add_argument("--no-overwrite", default=False, action="store_true",
                    help="never interpret a second filename as OUTPUT; output goes to files named INPUT.fixed")
parser.add_argument("filename", nargs='*', help="input filenames; or (deprecated) INPUT_FILENAME OUTPUT_FILENAME if exactly two filenames are given and neither --overwrite nor --no-overwrite is present",
                    type=str)

args = parser.parse_args()


runtest_default_environment = "sage.repl.ipython_kernel.all_jupyter"

def default_venv_environment_from_distribution():
    if args.distribution:
        # shortcuts / variants
        args.distribution = args.distribution.replace('_', '-')
        if not (args.distribution.startswith('sagemath-')
                or args.distribution.startswith('sage-')):
            args.distribution = f'sagemath-{args.distribution}'
        # extras
        m = re.fullmatch(r'([^[]*)(\[([^]]*)\])?', args.distribution)
        plain_distribution, extras = m.group(1), m.group(3)
        tox_env_name = 'sagepython-sagewheels-nopypi-norequirements'
        if extras:
            tox_env_name += '-' + extras.replace(',', '-')
        default_venv = os.path.join(SAGE_ROOT, 'pkgs', plain_distribution, '.tox', tox_env_name)
        default_environment = 'sage.all__' + plain_distribution.replace('-', '_')
    else:
        default_venv = ''
        default_environment = runtest_default_environment
    return default_venv, default_environment

default_venv, default_environment = default_venv_environment_from_distribution()

if not args.venv:
    args.venv = default_venv
if not args.environment:
    args.environment = default_environment

if args.distribution or args.venv != default_venv or args.environment != default_environment:
    args.keep_both = args.full_tracebacks = True

venv_explainers = []

if args.venv:
    if m := re.search(f'pkgs/(sage[^/]*)/[.]tox/((sagepython|sagewheels|nopypi|norequirements)-*)*([^/]*)$',
                      args.venv):
        args.distribution, extras = m.group(1), m.group(4)
        if extras:
            args.distribution += '[' + extras.replace('-', ',') + ']'
        default_venv_given_distribution, default_environment_given_distribution = default_venv_environment_from_distribution()

        if (Path(args.venv).resolve() == Path(default_venv_given_distribution).resolve()
                or args.environment == default_environment_given_distribution):
            venv_explainers.append(f'--distribution {shlex.quote(args.distribution)}')
            default_venv, default_environment = default_venv_given_distribution, default_environment_given_distribution

if Path(args.venv).resolve() != Path(default_venv).resolve():
    venv_explainers.append(f'--venv {shlex.quote(args.venv)}')
if args.environment != default_environment:
    venv_explainers.append(f'--environment {args.environment}')

if venv_explainers:
    venv_explainer = ' (with ' + ' '.join(venv_explainers) + ')'
else:
    venv_explainer = ''


def process_block(block, src_in_lines, file_optional_tags):
    # Extract the line, what was expected, and was got.
    if not (m := re.match('File "([^"]*)", line ([0-9]+), in ', block)):
        return
    filename = m.group(1)
    first_line_num = line_num = int(m.group(2))  # 1-based line number of the first line of the example

    if m := re.search(r"using.*block-scoped tag.*'(sage: .*)'.*to avoid repeating the tag", block):
        indent = (len(src_in_lines[first_line_num - 1]) - len(src_in_lines[first_line_num - 1].lstrip()))
        src_in_lines[line_num - 2] += '\n' + ' ' * indent + m.group(1)

    if m := re.search(r"updating.*block-scoped tag.*'sage: (.*)'.*to avoid repeating the tag", block):
        src_in_lines[first_line_num - 1] = update_optional_tags(src_in_lines[first_line_num - 1],
                                                                tags=parse_optional_tags('# ' + m.group(1)))

    if m := re.search(r"referenced here was set only in doctest marked '# (optional|needs)[-: ]*([^;']*)", block):
        optional = m.group(2).split()
        if src_in_lines[first_line_num - 1].strip() in ['"""', "'''"]:
            # This happens due to a virtual doctest in src/sage/repl/user_globals.py
            return
        optional = set(optional) - set(file_optional_tags)
        src_in_lines[first_line_num - 1] = update_optional_tags(src_in_lines[first_line_num - 1],
                                                                add_tags=optional)

    if m := re.search(r"tag '# (optional|needs)[-: ]([^;']*)' may no longer be needed", block):
        optional = m.group(2).split()
        src_in_lines[first_line_num - 1] = update_optional_tags(src_in_lines[first_line_num - 1],
                                                                remove_tags=optional)

    if m2 := re.search('(Expected:|Expected nothing|Exception raised:)\n', block):
        m1 = re.search('Failed example:\n', block)
        line_num += block[m1.end() : m2.start()].count('\n') - 1
        # Now line_num is the 1-based line number of the last line of the example

        if m2.group(1) == 'Expected nothing':
            expected = ''
            block = '\n' + block[m2.end():]  # so that split('\nGot:\n') does not fail below
        elif m2.group(1) == 'Exception raised:':
            # In this case, the doctester does not show the expected output,
            # so we do not know how many lines it spans; so we check for the next prompt or
            # docstring end.
            expected = []
            indentation = ' ' * (len(src_in_lines[line_num - 1]) - len(src_in_lines[line_num - 1].lstrip()))
            i = line_num
            while ((not src_in_lines[i].rstrip() or src_in_lines[i].startswith(indentation))
                   and not re.match(' *(sage:|""")', src_in_lines[i])):
                expected.append(src_in_lines[i])
                i += 1
            block = '\n'.join(expected) + '\nGot:\n' + block[m2.end():]
        else:
            block = block[m2.end():]
    else:
        return

    # Error testing.
    if m := re.search(r"(?:ModuleNotFoundError: No module named|ImportError: cannot import name '([^']*)' from) '([^']*)'", block):
        if m.group(1):
            # "ImportError: cannot import name 'function_field_polymod' from 'sage.rings.function_field' (unknown location)"
            module = m.group(2) + '.' + m.group(1)
        else:
            module = m.group(2)
        asked_why = re.search('#.*(why|explain)', src_in_lines[first_line_num - 1])
        optional = module_feature(module)
        if optional and optional.name not in file_optional_tags:
            src_in_lines[first_line_num - 1] = update_optional_tags(src_in_lines[first_line_num - 1],
                                                                    add_tags=[optional.name])
            if not asked_why:
                # When no explanation has been demanded,
                # we just mark the doctest with the feature
                return
            # Otherwise, continue and show the backtrace as 'GOT'

    if 'Traceback (most recent call last):' in block:

        expected, got = block.split('\nGot:\n')
        if args.full_tracebacks:
            if re.fullmatch(' *\n', got):
                got = got[re.end(0):]
            # don't show doctester internals (anything before first "<doctest...>" frame
            if m := re.search('( *Traceback.*\n *)(?s:.*?)(^ *File "<doctest)( [^>]*)>', got, re.MULTILINE):
                got = m.group(1) + '...\n' + m.group(2) + '...' + got[m.end(3):]
            while m := re.search(' *File "<doctest( [^>]*)>', got):
                got = got[:m.start(1)] + '...' + got[m.end(1):]
            # simplify filenames shown in backtrace
            while m := re.search('"([-a-zA-Z0-9._/]*/site-packages)/sage/', got):
                got = got[:m.start(1)] + '...' + got[m.end(1):]

            last_frame = got.rfind('File "')
            if (last_frame >= 0
                    and (index_NameError := got.rfind("NameError:")) >= 0
                    and got[last_frame:].startswith('File "<doctest')):
                # NameError from top level, so keep it brief
                if m := re.match("NameError: name '(.*)'", got[index_NameError:]):
                    name = m.group(1)
                    if name == 'x':  # Don't mark it '# needs sage.symbolic'; that's almost always wrong
                        return
                    if feature := name_feature(name):
                        add_tags = [feature.name]
                    else:
                        if args.only_tags:
                            return
                        add_tags = [f"NameError: '{name}'{venv_explainer}"]
                else:
                    if args.only_tags:
                        return
                    add_tags = [f"NameError{venv_explainer}"]
                src_in_lines[first_line_num - 1] = update_optional_tags(src_in_lines[first_line_num - 1],
                                                                        add_tags=add_tags)
                return
            got = got.splitlines()
        else:
            got = got.splitlines()
            got = ['Traceback (most recent call last):', '...', got[-1].lstrip()]
    elif block[-21:] == 'Got:\n    <BLANKLINE>\n':
        expected = block[:-22]
        got = ['']
    else:
        expected, got = block.split('\nGot:\n')
        got = got.splitlines()      # got can't be the empty string

    if args.only_tags:
        return

    expected = expected.splitlines()

    if args.keep_both:
        test_lines = ([update_optional_tags(src_in_lines[first_line_num - 1],
                                            add_tags=[f'GOT{venv_explainer}'])]
                      + src_in_lines[first_line_num : line_num])
        src_in_lines[first_line_num - 1] = update_optional_tags(src_in_lines[first_line_num - 1],
                                                                add_tags=['EXPECTED'])
        indent = (len(src_in_lines[line_num - 1]) - len(src_in_lines[line_num - 1].lstrip()))
        line_num += len(expected)  # skip to the last line of the expected output
        src_in_lines[line_num - 1] += '\n'.join([''] + test_lines)  # 2nd copy of the test
        # now line_num is the last line of the 2nd copy of the test
        expected = []

    # If we expected nothing, and got something, then we need to insert the line before line_num
    # and match indentation with line number line_num-1
    if not expected:
        indent = (len(src_in_lines[first_line_num - 1]) - len(src_in_lines[first_line_num - 1].lstrip()))
        src_in_lines[line_num - 1] += '\n' + '\n'.join('%s%s' % (' ' * indent, line.lstrip()) for line in got)
        return

    # Guess how much extra indenting ``got`` needs to match with the indentation
    # of src_in_lines - we match the indentation with the line in ``got`` which
    # has the smallest indentation after lstrip(). Note that the amount of indentation
    # required could be negative if the ``got`` block is indented. In this case
    # ``indent`` is set to zero.
    indent = max(0, (len(src_in_lines[line_num]) - len(src_in_lines[line_num].lstrip())
                     - min(len(got[j]) - len(got[j].lstrip()) for j in range(len(got)))))

    # Double check that what was expected was indeed in the source file and if
    # it is not then then print a warning for the user which contains the
    # problematic lines.
    if any(expected[i].strip() != src_in_lines[line_num + i].strip()
           for i in range(len(expected))):
        import warnings
        txt = "Did not manage to replace\n%s\n%s\n%s\nwith\n%s\n%s\n%s"
        warnings.warn(txt % ('>' * 40, '\n'.join(expected), '>' * 40,
                             '<' * 40, '\n'.join(got), '<' * 40))
        return

    # If we got something when we expected nothing then we delete the line from the
    # output, otherwise, add all of what we `got` onto the end of src_in_lines[line_num]
    if got == ['']:
        src_in_lines[line_num] = None
    else:
        src_in_lines[line_num] = '\n'.join((' ' * indent + got[i])
                                           for i in range(len(got)))

    # Mark any remaining `expected` lines as ``None`` so as to preserve the line numbering
    for i in range(1, len(expected)):
        src_in_lines[line_num + i] = None


# set input and output files
def output_filename(filename):
    if len(args.filename) == 2 and not args.overwrite and not args.no_overwrite:
        if args.filename[0] == filename:
            return args.filename[1]
        else:
            return None
        return filename + ".fixed"
    if args.no_overwrite:
        return filename + ".fixed"
    return filename

# Test the doctester, putting the output of the test into sage's temporary directory
if args.no_test:
    doc_out = ''
else:
    executable = f'{os.path.relpath(args.venv)}/bin/sage' if args.venv else 'sage'
    environment_args = f'--environment {args.environment} ' if args.environment != runtest_default_environment else ''
    long_args = f'--long ' if args.long else ''
    probe_args = f'--probe {shlex.quote(args.probe)} ' if args.probe else ''
    lib_args = f'--if-installed ' if args.venv else ''
    doc_file = tmp_filename()
    if args.venv or environment_args:
       input = os.path.join(os.path.relpath(SAGE_ROOT), 'src', 'sage', 'version.py')
       cmdline = f'{shlex.quote(executable)} -t {environment_args}{long_args}{probe_args}{lib_args}{shlex.quote(input)}'
       print(f'Running "{cmdline}"')
       if status := os.waitstatus_to_exitcode(os.system(f'{cmdline} > {shlex.quote(doc_file)}')):
           print(f'Doctester exited with error status {status}')
           sys.exit(status)
    # Run the doctester, putting the output of the test into sage's temporary directory
    if len(args.filename) == 2 and not args.overwrite and not args.no_overwrite:
        print("sage-fixdoctests: When passing two filenames, the second one is taken as an output filename; "
              "this is deprecated. To pass two input filenames, use the option --overwrite.")

        input_filenames = [args.filename[0]]
    else:
        input_filenames = args.filename
    input_args = " ".join(shlex.quote(f) for f in input_filenames)
    cmdline = f'{shlex.quote(executable)} -t -p {environment_args}{long_args}{probe_args}{lib_args}{input_args}'
    print(f'Running "{cmdline}"')
    os.system(f'{cmdline} > {shlex.quote(doc_file)}')

    with open(doc_file, 'r') as doc:
        doc_out = doc.read()

    # echo control messages
    for m in re.finditer('^Skipping .*', doc_out, re.MULTILINE):
        print('sage-runtests: ' + m.group(0))

sep = "**********************************************************************\n"
doctests = doc_out.split(sep)

seen = set()

def block_filename(block):
    if not (m := re.match('File "([^"]*)", line ([0-9]+), in ', block)):
        return None
    return m.group(1)

def expanded_filename_args():
    DD = DocTestDefaults(optional='all', warn_long=10000)
    DC = DocTestController(DD, input_filenames)
    DC.add_files()
    DC.expand_files_into_sources()
    for source in DC.sources:
        yield source.path

def process_grouped_blocks(grouped_iterator):

    for input, blocks in grouped_iterator:

        if not input:  # Blocks of noise
            continue
        if input in seen:
            continue
        seen.add(input)

        with open(input, 'r') as test_file:
            src_in = test_file.read()
        src_in_lines = src_in.splitlines()
        shallow_copy_of_src_in_lines = list(src_in_lines)
        file_optional_tags = set(parse_file_optional_tags(enumerate(src_in_lines)))

        for block in blocks:
            process_block(block, src_in_lines, file_optional_tags)

        # Now source line numbers do not matter any more, and lines can be real lines again
        src_in_lines = list(itertools.chain.from_iterable(
            [] if line is None else [''] if not line else line.splitlines()
            for line in src_in_lines))

        # Remove duplicate optional tags and rewrite all '# optional' that should be '# needs'
        persistent_optional_tags = {}
        for i, line in enumerate(src_in_lines):
            if m := re.match(' *sage: *(.*)#', line):
                tags, line_sans_tags, is_persistent = parse_optional_tags(line, return_string_sans_tags=True)
                if is_persistent:
                    persistent_optional_tags = {tag: explanation
                                                for tag, explanation in tags.items()
                                                if explanation or tag not in file_optional_tags}
                    line = update_optional_tags(line, tags=persistent_optional_tags, force_rewrite='standard')
                    if not line.rstrip():
                        # persistent (block-scoped or file-scoped) tag was removed, so remove the whole line
                        line = None
                else:
                    tags = {tag: explanation
                            for tag, explanation in tags.items()
                            if explanation or (tag not in file_optional_tags
                                               and tag not in persistent_optional_tags)}
                    line = update_optional_tags(line, tags=tags, force_rewrite='standard')
                src_in_lines[i] = line
            elif line.strip() in ['', '"""', "'''"]:    # Blank line or end of docstring
                persistent_optional_tags = {}

        if src_in_lines != shallow_copy_of_src_in_lines:
            if (output := output_filename(input)) is None:
                print(f"Not saving modifications made in '{input}'")
            else:
                with open(output, 'w') as test_output:
                    for line in src_in_lines:
                        if line is None:
                            continue
                        test_output.write(line)
                        test_output.write('\n')
                # Show summary of changes
                if input != output :
                    print("The fixed doctests have been saved as '{0}'.".format(output))
                else:
                    relative = os.path.relpath(output, SAGE_ROOT)
                    print(f"The input file '{output}' has been overwritten.")
                    if not relative.startswith('..'):
                        subprocess.call(['git', '--no-pager', 'diff', relative], cwd=SAGE_ROOT)
        else:
            print(f"No fixes made in '{input}'")

process_grouped_blocks(
    itertools.chain(itertools.groupby(doctests, block_filename),
                    ((filename, []) for filename in expanded_filename_args())))
