#!/usr/bin/env python3
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# Utility for creating well-formed pull request merges and pushing them to
# Apache.
#   usage: ./merge_arrow_pr.py  <pr-number>  (see config env vars below)
#
# This utility assumes:
#   - you already have a local Arrow git clone
#   - you have added remotes corresponding to both:
#       (i) the GitHub Apache Arrow mirror
#       (ii) the Apache git repo
#
# There are several pieces of authorization possibly needed via environment
# variables.
#
# Configuration environment variables:
#   - ARROW_GITHUB_API_TOKEN: a GitHub API token to use for API requests
#   - ARROW_GITHUB_ORG: the GitHub organisation ('apache' by default)
#   - DEBUG: use for testing to avoid pushing to apache (0 by default)

import configparser
import os
import pprint
import re
import subprocess
import sys
import requests
import getpass

# Remote name which points to the GitHub site
ORG_NAME = (
    os.environ.get("ARROW_GITHUB_ORG") or
    os.environ.get("PR_REMOTE_NAME") or  # backward compatibility
    "apache"
)
PROJECT_NAME = os.environ.get('ARROW_PROJECT_NAME') or "arrow"

# For testing to avoid accidentally pushing to apache
DEBUG = bool(int(os.environ.get("DEBUG", 0)))

if DEBUG:
    print("**************** DEBUGGING ****************")


def get_json(url, headers=None):
    response = requests.get(url, headers=headers)
    if response.status_code != 200:
        raise ValueError(response.json())
    # GitHub returns a link header with the next, previous, last
    # page if there is pagination on the response. See:
    # https://docs.github.com/en/rest/guides/using-pagination-in-the-rest-api#using-link-headers
    next_responses = None
    if "link" in response.headers:
        links = response.headers['link'].split(', ')
        for link in links:
            if 'rel="next"' in link:
                # Format: '<url>; rel="next"'
                next_url = link.split(";")[0][1:-1]
                next_responses = get_json(next_url, headers)
    responses = response.json()
    if next_responses:
        if isinstance(responses, list):
            responses.extend(next_responses)
        else:
            raise ValueError('GitHub response was paginated and is not a list')
    return responses


def run_cmd(cmd):
    if isinstance(cmd, str):
        cmd = cmd.split(' ')

    try:
        output = subprocess.check_output(cmd)
    except subprocess.CalledProcessError as e:
        # this avoids hiding the stdout / stderr of failed processes
        print('Command failed: %s' % cmd)
        print('With output:')
        print('--------------')
        print(e.output)
        print('--------------')
        raise e

    if isinstance(output, bytes):
        output = output.decode('utf-8')
    return output


_REGEX_CI_DIRECTIVE = re.compile(r'\[[^\]]*\]')


def strip_ci_directives(commit_message):
    # Remove things like '[force ci]', '[skip appveyor]' from the assembled
    # commit message
    return _REGEX_CI_DIRECTIVE.sub('', commit_message)


def fix_version_from_branch(versions):
    # Note: Assumes this is a sorted (newest->oldest) list of un-released
    # versions
    return versions[-1]


MIGRATION_COMMENT_REGEX = re.compile(
    r"This issue has been migrated to \[issue #(?P<issue_id>(\d+))"
)


class GitHubIssue(object):

    def __init__(self, github_api, github_id, cmd):
        self.github_api = github_api
        self.github_id = github_id
        self.cmd = cmd

        try:
            self.issue = self.github_api.get_issue_data(github_id)
        except Exception as e:
            self.cmd.fail("GitHub could not find %s\n%s" % (github_id, e))

    def get_label(self, prefix):
        prefix = f"{prefix}:"
        return [
            lbl["name"][len(prefix):].strip()
            for lbl in self.issue["labels"] if lbl["name"].startswith(prefix)
        ]

    @property
    def components(self):
        return self.get_label("Component")

    @property
    def assignees(self):
        return [a["login"] for a in self.issue["assignees"]]

    @property
    def current_fix_versions(self):
        try:
            return self.issue.get("milestone", {}).get("title")
        except AttributeError:
            pass

    @property
    def current_versions(self):
        all_versions = self.github_api.get_milestones()

        unreleased_versions = [x for x in all_versions if x["state"] == "open"]
        unreleased_versions = [x["title"] for x in unreleased_versions]

        return unreleased_versions

    def resolve(self, fix_version, comment, pr_body):
        cur_status = self.issue["state"]

        if cur_status == "closed":
            self.cmd.fail("GitHub issue %s already has status '%s'"
                          % (self.github_id, cur_status))

        if DEBUG:
            print("GitHub issue %s untouched -> %s" %
                  (self.github_id, fix_version))
        else:
            self.github_api.assign_milestone(self.github_id, fix_version)
            if f"Closes: #{self.github_id}" not in pr_body:
                self.github_api.close_issue(self.github_id, comment)
            print("Successfully resolved %s!" % (self.github_id))

        self.issue = self.github_api.get_issue_data(self.github_id)
        self.show()

    def show(self):
        issue = self.issue
        print(format_issue_output("github", self.github_id, issue["state"],
                                  issue["title"], ', '.join(self.assignees),
                                  self.components))


def get_candidate_fix_version(mainline_versions,
                              maintenance_branches=()):

    all_versions = [getattr(v, "name", v) for v in mainline_versions]

    def version_tuple(x):
        # Parquet versions are something like cpp-1.2.0
        numeric_version = getattr(x, "name", x).split("-", 1)[-1]
        return tuple(int(_) for _ in numeric_version.split("."))
    all_versions = sorted(all_versions, key=version_tuple, reverse=True)

    # Only suggest versions starting with a number, like 0.x but not JS-0.x
    mainline_versions = all_versions
    major_versions = [v for v in mainline_versions if v.endswith('.0.0')]

    if len(mainline_versions) > len(major_versions):
        # If there is a future major release, suggest that
        mainline_versions = major_versions

    mainline_versions = [v for v in mainline_versions
                         if f"maint-{v}" not in maintenance_branches]
    default_fix_versions = fix_version_from_branch(mainline_versions)

    return default_fix_versions


def format_issue_output(issue_type, issue_id, status,
                        summary, assignee, components):
    if not assignee:
        assignee = "NOT ASSIGNED!!!"
    else:
        assignee = getattr(assignee, "displayName", assignee)

    if len(components) == 0:
        components = 'NO COMPONENTS!!!'
    else:
        components = ', '.join((getattr(x, "name", x) for x in components))

    url_id = issue_id
    if "GH" in issue_id:
        url_id = issue_id.replace("GH-", "")

    url = f'https://github.com/{ORG_NAME}/{PROJECT_NAME}/issues/{url_id}'

    return """=== {} {} ===
Summary\t\t{}
Assignee\t{}
Components\t{}
Status\t\t{}
URL\t\t{}""".format(issue_type.upper(), issue_id, summary, assignee,
                    components, status, url)


class GitHubAPI(object):

    def __init__(self, project_name, cmd):
        self.github_api = (
            f"https://api.github.com/repos/{ORG_NAME}/{project_name}"
        )

        token = None
        config = load_configuration()
        if "github" in config.sections():
            token = config["github"]["api_token"]
        if not token:
            token = os.environ.get('ARROW_GITHUB_API_TOKEN')
        if not token:
            token = cmd.prompt('Env ARROW_GITHUB_API_TOKEN not set, '
                               'please enter your GitHub API token '
                               '(GitHub personal access token):')
        headers = {
            'Accept': 'application/vnd.github.v3+json',
            'Authorization': 'token {0}'.format(token),
        }
        self.headers = headers

    def get_milestones(self):
        return get_json("%s/milestones" % (self.github_api, ),
                        headers=self.headers)

    def get_milestone_number(self, version):
        return next((
            m["number"] for m in self.get_milestones() if m["title"] == version
        ), None)

    def get_issue_data(self, number):
        return get_json("%s/issues/%s" % (self.github_api, number),
                        headers=self.headers)

    def get_pr_data(self, number):
        return get_json("%s/pulls/%s" % (self.github_api, number),
                        headers=self.headers)

    def get_pr_commits(self, number):
        return get_json("%s/pulls/%s/commits" % (self.github_api, number),
                        headers=self.headers)

    def get_branches(self):
        return get_json("%s/branches" % (self.github_api),
                        headers=self.headers)

    def close_issue(self, number, comment):
        issue_url = f'{self.github_api}/issues/{number}'
        comment_url = f'{self.github_api}/issues/{number}/comments'

        r = requests.post(comment_url, json={
                          "body": comment}, headers=self.headers)
        if not r.ok:
            raise ValueError(
                f"Failed request: {comment_url}:{r.status_code} -> {r.json()}")

        r = requests.patch(
            issue_url, json={"state": "closed"}, headers=self.headers)
        if not r.ok:
            raise ValueError(
                f"Failed request: {issue_url}:{r.status_code} -> {r.json()}")

    def assign_milestone(self, number, version):
        url = f'{self.github_api}/issues/{number}'
        milestone_number = self.get_milestone_number(version)
        if not milestone_number:
            raise ValueError(f"Invalid version {version}, milestone not found")
        payload = {
            'milestone': milestone_number
        }
        r = requests.patch(url, headers=self.headers, json=payload)
        if not r.ok:
            raise ValueError(
                f"Failed request: {url}:{r.status_code} -> {r.json()}")
        return r.json()

    def merge_pr(self, number, commit_title, commit_message):
        url = f'{self.github_api}/pulls/{number}/merge'
        payload = {
            'commit_title': commit_title,
            'commit_message': commit_message,
            'merge_method': 'squash',
        }
        response = requests.put(url, headers=self.headers, json=payload)
        result = response.json()
        if response.status_code == 200 and 'merged' in result:
            self.clear_pr_state_labels(number)
        else:
            result['merged'] = False
            result['message'] += f': {url}'
        return result

    def clear_pr_state_labels(self, number):
        url = f'{self.github_api}/issues/{number}/labels'
        response = requests.get(url, headers=self.headers)
        labels = response.json()
        for label in labels:
            # All PR workflow state labels starts with "awaiting"
            if label['name'].startswith('awaiting'):
                label_url = f"{url}/{label['name']}"
                requests.delete(label_url, headers=self.headers)


class CommandInput(object):
    """
    Interface to input(...) to enable unit test mocks to be created
    """

    def fail(self, msg):
        raise Exception(msg)

    def prompt(self, prompt):
        return input(prompt)

    def getpass(self, prompt):
        return getpass.getpass(prompt)

    def continue_maybe(self, prompt):
        while True:
            result = input("\n%s (y/n): " % prompt)
            if result.lower() == "y":
                return
            elif result.lower() == "n":
                self.fail("Okay, exiting")
            else:
                prompt = "Please input 'y' or 'n'"


class PullRequest(object):
    GITHUB_PR_TITLE_PATTERN = re.compile(r'^GH-([0-9]+)\b.*$')

    def __init__(self, cmd, github_api, git_remote, number):
        self.cmd = cmd
        self._github_api = github_api
        self.git_remote = git_remote
        self.number = number
        self._pr_data = github_api.get_pr_data(number)
        try:
            self.url = self._pr_data["url"]
            self.title = self._pr_data["title"]
            self.body = self._pr_data["body"]
            self.target_ref = self._pr_data["base"]["ref"]
            self.user_login = self._pr_data["user"]["login"]
            self.base_ref = self._pr_data["head"]["ref"]
        except KeyError:
            pprint.pprint(self._pr_data)
            raise
        self.description = "%s/%s" % (self.user_login, self.base_ref)

        self.issue = self._get_issue()

    def show(self):
        print("\n=== Pull Request #%s ===" % self.number)
        print("title\t%s\nsource\t%s\ntarget\t%s\nurl\t%s"
              % (self.title, self.description, self.target_ref, self.url))
        if self.issue is not None:
            self.issue.show()
        else:
            print("Minor PR.  Please ensure it meets guidelines for minor.\n")

    @property
    def is_merged(self):
        return bool(self._pr_data["merged"])

    @property
    def is_mergeable(self):
        return bool(self._pr_data["mergeable"])

    @property
    def maintenance_branches(self):
        return [x["name"] for x in self._github_api.get_branches()
                if x["name"].startswith("maint-")]

    def _get_issue(self):
        if self.title.startswith("MINOR:"):
            return None

        m = self.GITHUB_PR_TITLE_PATTERN.search(self.title)
        if m:
            github_id = m.group(1)
            return GitHubIssue(self._github_api, github_id, self.cmd)

        self.cmd.fail("PR title should be prefixed by a GitHub ID, like: "
                      "GH-XXX, but found {0}".format(self.title))

    def merge(self):
        """
        merge the requested PR and return the merge hash
        """
        commits = self._github_api.get_pr_commits(self.number)

        def format_commit_author(commit):
            author = commit['commit']['author']
            name = author['name']
            email = author['email']
            return f'{name} <{email}>'
        commit_authors = [format_commit_author(commit) for commit in commits]
        co_authored_by_re = re.compile(
            r'^Co-authored-by:\s*(.*)', re.MULTILINE)

        def extract_co_authors(commit):
            message = commit['commit']['message']
            return co_authored_by_re.findall(message)
        commit_co_authors = []
        for commit in commits:
            commit_co_authors.extend(extract_co_authors(commit))

        all_commit_authors = commit_authors + commit_co_authors
        distinct_authors = sorted(set(all_commit_authors),
                                  key=lambda x: commit_authors.count(x),
                                  reverse=True)

        for i, author in enumerate(distinct_authors):
            print("Author {}: {}".format(i + 1, author))

        if len(distinct_authors) > 1:
            primary_author, distinct_other_authors = get_primary_author(
                self.cmd, distinct_authors)
        else:
            # If there is only one author, do not prompt for a lead author
            primary_author = distinct_authors.pop()
            distinct_other_authors = []

        commit_title = f'{self.title} (#{self.number})'
        commit_message_chunks = []
        if self.body is not None:
            # Remove comments (i.e. <-- comment -->) from the PR description.
            body = re.sub(r"<!--.*?-->", "", self.body, flags=re.DOTALL)
            # avoid github user name references by inserting a space after @
            body = re.sub(r"@(\w+)", "@ \\1", body)
            commit_message_chunks.append(body)

        committer_name = run_cmd("git config --get user.name").strip()
        committer_email = run_cmd("git config --get user.email").strip()

        authors = ("Authored-by:" if len(distinct_other_authors) == 0
                   else "Lead-authored-by:")
        authors += " %s" % primary_author
        if len(distinct_authors) > 0:
            authors += "\n" + "\n".join(["Co-authored-by: %s" % a
                                         for a in distinct_other_authors])
        authors += "\n" + "Signed-off-by: %s <%s>" % (committer_name,
                                                      committer_email)
        commit_message_chunks.append(authors)

        commit_message = "\n\n".join(commit_message_chunks)

        # Normalize line ends and collapse extraneous newlines. We allow two
        # consecutive newlines for paragraph breaks but not more.
        commit_message = "\n".join(commit_message.splitlines())
        commit_message = re.sub("\n{2,}", "\n\n", commit_message)

        if DEBUG:
            print("*** Commit title ***")
            print(commit_title)
            print()
            print("*** Commit message ***")
            print(commit_message)

        if DEBUG:
            merge_hash = None
        else:
            result = self._github_api.merge_pr(self.number,
                                               commit_title,
                                               commit_message)
            if not result['merged']:
                message = result['message']
                self.cmd.fail(f'Failed to merge pull request: {message}')
            merge_hash = result['sha']

        print("Pull request #%s merged!" % self.number)
        print("Merge hash: %s" % merge_hash)


def get_primary_author(cmd, distinct_authors):
    author_pat = re.compile(r'(.*) <(.*)>')

    while True:
        primary_author = cmd.prompt(
            "Enter primary author in the format of "
            "\"name <email>\" [%s]: " % distinct_authors[0])

        if primary_author == "":
            return distinct_authors[0], distinct_authors[1:]

        if author_pat.match(primary_author):
            break
        print('Bad author "{}", please try again'.format(primary_author))

    # When primary author is specified manually, de-dup it from
    # author list and put it at the head of author list.
    distinct_other_authors = [x for x in distinct_authors
                              if x != primary_author]
    return primary_author, distinct_other_authors


def prompt_for_fix_version(cmd, issue, maintenance_branches=()):
    default_fix_version = get_candidate_fix_version(
        mainline_versions=issue.current_versions,
        maintenance_branches=maintenance_branches
    )

    current_fix_versions = issue.current_fix_versions
    if (current_fix_versions and
            current_fix_versions != default_fix_version):
        print("\n=== The assigned milestone is not the default ===")
        print(f"Assigned milestone: {current_fix_versions}")
        print(f"Current milestone: {default_fix_version}")
        if issue.issue["milestone"].get("state") == 'closed':
            print("The assigned milestone state is closed. Contact the ")
            print("Release Manager if it has to be added to a closed Release")
        print("Please ensure to assign the correct milestone.")
        # Default to existing assigned milestone
        default_fix_version = current_fix_versions

    issue_fix_version = cmd.prompt("Enter fix version [%s]: "
                                   % default_fix_version)
    if issue_fix_version == "":
        issue_fix_version = default_fix_version
    issue_fix_version = issue_fix_version.strip()
    return issue_fix_version


CONFIG_FILE = "~/.config/arrow/merge.conf"


def load_configuration():
    config = configparser.ConfigParser()
    config.read(os.path.expanduser(CONFIG_FILE))
    return config


def get_pr_num():
    if len(sys.argv) == 2:
        return sys.argv[1]

    return input("Which pull request would you like to merge? (e.g. 34): ")


def cli():
    # Location of your Arrow git clone
    ARROW_HOME = os.path.abspath(os.path.dirname(__file__))
    print(f"ARROW_HOME = {ARROW_HOME}")
    print(f"ORG_NAME = {ORG_NAME}")
    print(f"PROJECT_NAME = {PROJECT_NAME}")

    cmd = CommandInput()

    pr_num = get_pr_num()

    os.chdir(ARROW_HOME)

    github_api = GitHubAPI(PROJECT_NAME, cmd)
    pr = PullRequest(cmd, github_api, ORG_NAME, pr_num)

    if pr.is_merged:
        print("Pull request %s has already been merged" % pr_num)
        sys.exit(0)

    if not pr.is_mergeable:
        print("Pull request %s is not mergeable in its current form" % pr_num)
        sys.exit(1)

    pr.show()

    cmd.continue_maybe("Proceed with merging pull request #%s?" % pr_num)

    pr.merge()

    if pr.issue is None:
        print("Minor PR.  No issue to update.\n")
        return

    cmd.continue_maybe("Would you like to update the associated issue?")
    issue_comment = (
        "Issue resolved by pull request %s\n%s"
        % (pr_num,
           f"https://github.com/{ORG_NAME}/{PROJECT_NAME}/pull/{pr_num}")
    )
    fix_version = prompt_for_fix_version(cmd, pr.issue,
                                         pr.maintenance_branches)
    pr.issue.resolve(fix_version, issue_comment, pr.body)


if __name__ == '__main__':
    try:
        cli()
    except Exception:
        raise
