#!/usr/bin/env python3

#
# Copyright (C) 2018 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""A command line utility to pull multiple change lists from Gerrit."""

from __future__ import print_function

import argparse
import collections
import itertools
import json
import multiprocessing
import os
import os.path
import re
import sys
import xml.dom.minidom

from gerrit import (
    create_url_opener_from_args, find_gerrit_name, query_change_lists, run
)
from subprocess import PIPE

try:
    # pylint: disable=redefined-builtin
    from __builtin__ import raw_input as input  # PY2
except ImportError:
    pass

try:
    from shlex import quote as _sh_quote  # PY3.3
except ImportError:
    # Shell language simple string pattern.  If a string matches this pattern,
    # it doesn't have to be quoted.
    _SHELL_SIMPLE_PATTERN = re.compile('^[a-zA-Z90-9_./-]+$')

    def _sh_quote(txt):
        """Quote a string if it contains special characters."""
        return txt if _SHELL_SIMPLE_PATTERN.match(txt) else json.dumps(txt)


if bytes is str:
    def write_bytes(data, file):  # PY2
        """Write bytes to a file."""
        # pylint: disable=redefined-builtin
        file.write(data)
else:
    def write_bytes(data, file):  # PY3
        """Write bytes to a file."""
        # pylint: disable=redefined-builtin
        file.buffer.write(data)


def _confirm(question, default, file=sys.stderr):
    """Prompt a yes/no question and convert the answer to a boolean value."""
    # pylint: disable=redefined-builtin
    answers = {'': default, 'y': True, 'yes': True, 'n': False, 'no': False}
    suffix = '[Y/n] ' if default else ' [y/N] '
    while True:
        file.write(question + suffix)
        file.flush()
        ans = answers.get(input().lower())
        if ans is not None:
            return ans


class ChangeList(object):
    """A ChangeList to be checked out."""
    # pylint: disable=too-few-public-methods,too-many-instance-attributes

    def __init__(self, project, fetch, commit_sha1, commit, change_list):
        """Initialize a ChangeList instance."""
        # pylint: disable=too-many-arguments

        self.project = project
        self.number = change_list['_number']

        self.fetch = fetch

        fetch_git = None
        for protocol in ('http', 'sso', 'rpc'):
            fetch_git = fetch.get(protocol)
            if fetch_git:
                break

        if not fetch_git:
            raise ValueError(
                'unknown fetch protocols: ' + str(list(fetch.keys())))

        self.fetch_url = fetch_git['url']
        self.fetch_ref = fetch_git['ref']

        self.commit_sha1 = commit_sha1
        self.commit = commit
        self.parents = commit['parents']

        self.change_list = change_list


    def is_merge(self):
        """Check whether this change list a merge commit."""
        return len(self.parents) > 1


def find_repo_top(curdir):
    """Find the top directory for this git-repo source tree."""
    olddir = None
    while curdir != olddir:
        if os.path.exists(os.path.join(curdir, '.repo')):
            return curdir
        olddir = curdir
        curdir = os.path.dirname(curdir)
    raise ValueError('.repo dir not found')


def build_project_name_dir_dict(manifest_name):
    """Build the mapping from Gerrit project name to source tree project
    directory path."""
    manifest_cmd = ['repo', 'manifest']
    if manifest_name:
        manifest_cmd.extend(['-m', manifest_name])
    raw_manifest_xml = run(manifest_cmd, stdout=PIPE, check=True).stdout

    manifest_xml = xml.dom.minidom.parseString(raw_manifest_xml)
    project_dirs = {}
    for project in manifest_xml.getElementsByTagName('project'):
        name = project.getAttribute('name')
        path = project.getAttribute('path')
        if path:
            project_dirs[name] = path
        else:
            project_dirs[name] = name

    return project_dirs


def group_and_sort_change_lists(change_lists):
    """Build a dict that maps projects to a list of topologically sorted change
    lists."""

    # Build a dict that map projects to dicts that map commits to changes.
    projects = collections.defaultdict(dict)
    for change_list in change_lists:
        commit_sha1 = None
        for commit_sha1, value in change_list['revisions'].items():
            fetch = value['fetch']
            commit = value['commit']

        if not commit_sha1:
            raise ValueError('bad revision')

        project = change_list['project']

        project_changes = projects[project]
        if commit_sha1 in project_changes:
            raise KeyError('repeated commit sha1 "{}" in project "{}"'.format(
                commit_sha1, project))

        project_changes[commit_sha1] = ChangeList(
            project, fetch, commit_sha1, commit, change_list)

    # Sort all change lists in a project in post ordering.
    def _sort_project_change_lists(changes):
        visited_changes = set()
        sorted_changes = []

        def _post_order_traverse(change):
            visited_changes.add(change)
            for parent in change.parents:
                parent_change = changes.get(parent['commit'])
                if parent_change and parent_change not in visited_changes:
                    _post_order_traverse(parent_change)
            sorted_changes.append(change)

        for change in sorted(changes.values(), key=lambda x: x.number):
            if change not in visited_changes:
                _post_order_traverse(change)

        return sorted_changes

    # Sort changes in each projects
    sorted_changes = []
    for project in sorted(projects.keys()):
        sorted_changes.append(_sort_project_change_lists(projects[project]))

    return sorted_changes


def _main_json(args):
    """Print the change lists in JSON format."""
    change_lists = _get_change_lists_from_args(args)
    json.dump(change_lists, sys.stdout, indent=4, separators=(', ', ': '))
    print()  # Print the end-of-line


# Git commands for merge commits
_MERGE_COMMANDS = {
    'merge': ['git', 'merge', '--no-edit'],
    'merge-ff-only': ['git', 'merge', '--no-edit', '--ff-only'],
    'merge-no-ff': ['git', 'merge', '--no-edit', '--no-ff'],
    'reset': ['git', 'reset', '--hard'],
    'checkout': ['git', 'checkout'],
}


# Git commands for non-merge commits
_PICK_COMMANDS = {
    'pick': ['git', 'cherry-pick', '--allow-empty'],
    'merge': ['git', 'merge', '--no-edit'],
    'merge-ff-only': ['git', 'merge', '--no-edit', '--ff-only'],
    'merge-no-ff': ['git', 'merge', '--no-edit', '--no-ff'],
    'reset': ['git', 'reset', '--hard'],
    'checkout': ['git', 'checkout'],
}


def build_pull_commands(change, branch_name, merge_opt, pick_opt):
    """Build command lines for each change.  The command lines will be passed
    to subprocess.run()."""

    cmds = []
    if branch_name is not None:
        cmds.append(['repo', 'start', branch_name])
    cmds.append(['git', 'fetch', change.fetch_url, change.fetch_ref])
    if change.is_merge():
        cmds.append(_MERGE_COMMANDS[merge_opt] + ['FETCH_HEAD'])
    else:
        cmds.append(_PICK_COMMANDS[pick_opt] + ['FETCH_HEAD'])
    return cmds


def _sh_quote_command(cmd):
    """Convert a command (an argument to subprocess.run()) to a shell command
    string."""
    return ' '.join(_sh_quote(x) for x in cmd)


def _sh_quote_commands(cmds):
    """Convert multiple commands (arguments to subprocess.run()) to shell
    command strings."""
    return ' && '.join(_sh_quote_command(cmd) for cmd in cmds)


def _main_bash(args):
    """Print the bash command to pull the change lists."""
    repo_top = find_repo_top(os.getcwd())
    project_dirs = build_project_name_dir_dict(args.manifest)
    branch_name = _get_local_branch_name_from_args(args)

    change_lists = _get_change_lists_from_args(args)
    change_list_groups = group_and_sort_change_lists(change_lists)

    print(_sh_quote_command(['pushd', repo_top]))
    for changes in change_list_groups:
        for change in changes:
            project_dir = project_dirs.get(change.project, change.project)
            cmds = []
            cmds.append(['pushd', project_dir])
            cmds.extend(build_pull_commands(
                change, branch_name, args.merge, args.pick))
            cmds.append(['popd'])
            print(_sh_quote_commands(cmds))
    print(_sh_quote_command(['popd']))


def _do_pull_change_lists_for_project(task):
    """Pick a list of changes (usually under a project directory)."""
    changes, task_opts = task

    branch_name = task_opts['branch_name']
    merge_opt = task_opts['merge_opt']
    pick_opt = task_opts['pick_opt']
    project_dirs = task_opts['project_dirs']
    repo_top = task_opts['repo_top']

    for i, change in enumerate(changes):
        try:
            cwd = project_dirs[change.project]
        except KeyError:
            err_msg = 'error: project "{}" cannot be found in manifest.xml\n'
            err_msg = err_msg.format(change.project).encode('utf-8')
            return (change, changes[i + 1:], [], err_msg)

        print(change.commit_sha1[0:10], i + 1, cwd)
        cmds = build_pull_commands(change, branch_name, merge_opt, pick_opt)
        for cmd in cmds:
            proc = run(cmd, cwd=os.path.join(repo_top, cwd), stderr=PIPE)
            if proc.returncode != 0:
                return (change, changes[i + 1:], cmd, proc.stderr)
    return None


def _print_pull_failures(failures, file=sys.stderr):
    """Print pull failures and tracebacks."""
    # pylint: disable=redefined-builtin

    separator = '=' * 78
    separator_sub = '-' * 78

    print(separator, file=file)
    for failed_change, skipped_changes, cmd, errors in failures:
        print('PROJECT:', failed_change.project, file=file)
        print('FAILED COMMIT:', failed_change.commit_sha1, file=file)
        for change in skipped_changes:
            print('PENDING COMMIT:', change.commit_sha1, file=file)
        print(separator_sub, file=sys.stderr)
        print('FAILED COMMAND:', _sh_quote_command(cmd), file=file)
        write_bytes(errors, file=sys.stderr)
        print(separator, file=sys.stderr)


def _main_pull(args):
    """Pull the change lists."""
    repo_top = find_repo_top(os.getcwd())
    project_dirs = build_project_name_dir_dict(args.manifest)
    branch_name = _get_local_branch_name_from_args(args)

    # Collect change lists
    change_lists = _get_change_lists_from_args(args)
    change_list_groups = group_and_sort_change_lists(change_lists)

    # Build the options list for tasks
    task_opts = {
        'branch_name': branch_name,
        'merge_opt': args.merge,
        'pick_opt': args.pick,
        'project_dirs': project_dirs,
        'repo_top': repo_top,
    }

    # Run the commands to pull the change lists
    if args.parallel <= 1:
        results = [_do_pull_change_lists_for_project((changes, task_opts))
                   for changes in change_list_groups]
    else:
        pool = multiprocessing.Pool(processes=args.parallel)
        results = pool.map(_do_pull_change_lists_for_project,
                           zip(change_list_groups, itertools.repeat(task_opts)))

    # Print failures and tracebacks
    failures = [result for result in results if result]
    if failures:
        _print_pull_failures(failures)
        sys.exit(1)


def _parse_args():
    """Parse command line options."""
    parser = argparse.ArgumentParser()

    parser.add_argument('command', choices=['pull', 'bash', 'json'],
                        help='Commands')

    parser.add_argument('query', help='Change list query string')
    parser.add_argument('-g', '--gerrit', help='Gerrit review URL')

    parser.add_argument('--gitcookies',
                        default=os.path.expanduser('~/.gitcookies'),
                        help='Gerrit cookie file')
    parser.add_argument('--manifest', help='Manifest')
    parser.add_argument('--limits', default=1000,
                        help='Max number of change lists')

    parser.add_argument('-m', '--merge',
                        choices=sorted(_MERGE_COMMANDS.keys()),
                        default='merge-ff-only',
                        help='Method to pull merge commits')

    parser.add_argument('-p', '--pick',
                        choices=sorted(_PICK_COMMANDS.keys()),
                        default='pick',
                        help='Method to pull merge commits')

    parser.add_argument('-b', '--branch',
                        help='Local branch name for `repo start`')

    parser.add_argument('-j', '--parallel', default=1, type=int,
                        help='Number of parallel running commands')

    return parser.parse_args()


def _get_change_lists_from_args(args):
    """Query the change lists by args."""
    url_opener = create_url_opener_from_args(args)
    return query_change_lists(url_opener, args.gerrit, args.query, args.limits)


def _get_local_branch_name_from_args(args):
    """Get the local branch name from args."""
    if not args.branch and not _confirm(
            'Do you want to continue without local branch name?', False):
        print('error: `-b` or `--branch` must be specified', file=sys.stderr)
        sys.exit(1)
    return args.branch


def main():
    """Main function"""
    args = _parse_args()

    if not args.gerrit:
        try:
            args.gerrit = find_gerrit_name()
        # pylint: disable=bare-except
        except:
            print('gerrit instance not found, use [-g GERRIT]')
            sys.exit(1)

    if args.command == 'json':
        _main_json(args)
    elif args.command == 'bash':
        _main_bash(args)
    elif args.command == 'pull':
        _main_pull(args)
    else:
        raise KeyError('unknown command')

if __name__ == '__main__':
    main()