docbot/action.go (287 lines of code) (raw):
package main
import (
"context"
"errors"
"fmt"
"regexp"
"strings"
"github.com/apache/pulsar-test-infra/docbot/pkg/logger"
"github.com/google/go-github/v45/github"
"golang.org/x/oauth2"
)
const (
MessageLabelMultiple = "Please select only one documentation label in your PR description."
openedActionType = "opened"
editedActionType = "edited"
labeledActionType = "labeled"
unlabeledActionType = "unlabeled"
)
var builtInDescriptions = make(map[string]string)
func init() {
builtInDescriptions["doc-required"] = "Your PR changes impact docs and you will update later"
builtInDescriptions["doc-not-needed"] = "Your PR changes do not impact docs"
builtInDescriptions["doc"] = "Your PR contains doc changes"
builtInDescriptions["doc-complete"] = "Docs have been already added"
}
type Action struct {
config *ActionConfig
globalContext context.Context
client *github.Client
prNumber int
}
func NewAction(ac *ActionConfig) *Action {
ctx := context.Background()
ts := oauth2.StaticTokenSource(
&oauth2.Token{AccessToken: ac.GetToken()},
)
tc := oauth2.NewClient(ctx, ts)
return NewActionWithClient(ctx, ac, github.NewClient(tc))
}
func NewActionWithClient(ctx context.Context, ac *ActionConfig, client *github.Client) *Action {
return &Action{
config: ac,
globalContext: ctx,
client: client,
}
}
func (a *Action) Run(prNumber int, actionType string) error {
a.prNumber = prNumber
switch actionType {
case openedActionType, editedActionType, labeledActionType, unlabeledActionType:
return a.checkLabels()
}
return nil
}
func (a *Action) checkLabels() error {
pr, _, err := a.client.PullRequests.Get(a.globalContext, a.config.GetOwner(), a.config.GetRepo(), a.prNumber)
if err != nil {
return fmt.Errorf("get PR: %v", err)
}
var bodyLabels map[string]bool
if pr.Body != nil {
bodyLabels = a.extractLabels(*pr.Body)
}
logger.Infof("PR description: %v\n", *pr.Body)
logger.Infoln("@List repo labels")
repoLabels, err := a.getRepoLabels()
if err != nil {
return fmt.Errorf("list repo labels: %v", err)
}
logger.Infof("Repo labels: %v\n", repoLabels)
prLabels := a.labelsToMap(pr.Labels)
logger.Infof("PR labels: %v\n", prLabels)
// Get expected labels
// Only handle labels already exist in repo
expectedLabelsMap := make(map[string]bool)
checkedCount := 0
for label, checked := range bodyLabels {
if _, exist := repoLabels[label]; !exist {
logger.Infof("Found label %v not exist int repo\n", label)
continue
}
expectedLabelsMap[label] = checked
if checked {
checkedCount++
}
}
logger.Infof("Expected labels: %v\n", expectedLabelsMap)
labelsToRemove := make(map[string]struct{}, 0)
labelsToAdd := make(map[string]struct{}, 0)
if checkedCount == 0 {
logger.Infoln("Label missing")
for label := range a.config.labelWatchSet {
_, found := prLabels[label]
if found {
labelsToRemove[label] = struct{}{}
}
}
_, found := prLabels[a.config.GetLabelMissing()]
if !found {
labelsToAdd[a.config.GetLabelMissing()] = struct{}{}
} else {
logger.Infoln("Already added missing label.")
return errors.New(a.getLabelMissingMessage())
}
} else {
if !a.config.GetEnableLabelMultiple() && checkedCount > 1 {
logger.Infoln("Multiple labels not enabled")
err = a.addAndCleanupHelpComment(pr.User.GetLogin(), MessageLabelMultiple)
if err != nil {
return err
}
return errors.New(MessageLabelMultiple)
}
_, found := prLabels[a.config.GetLabelMissing()]
if found {
labelsToRemove[a.config.GetLabelMissing()] = struct{}{}
}
for label, checked := range expectedLabelsMap {
if _, found := prLabels[label]; found && !checked {
labelsToRemove[label] = struct{}{}
} else if !found && checked {
labelsToAdd[label] = struct{}{}
}
}
}
if len(labelsToAdd) == 0 {
logger.Infoln("No labels to add.")
} else {
labels := a.labelsSetToString(labelsToAdd)
logger.Infof("Labels to add: %v\n", labels)
err = a.addLabels(labels)
if err != nil {
logger.Errorf("Failed add labels %v: %v\n", labelsToAdd, err)
return err
}
}
if len(labelsToRemove) == 0 {
logger.Infoln("No labels to remove.")
} else {
labels := a.labelsSetToString(labelsToRemove)
logger.Infof("Labels to remove: %v\n", labels)
for _, label := range labels {
err = a.removeLabel(label)
if err != nil {
logger.Errorf("Failed remove labels %v: %v\n", labelsToRemove, err)
return err
}
}
}
if checkedCount == 0 {
err := a.addAndCleanupHelpComment(pr.User.GetLogin(), a.getLabelMissingMessage())
if err != nil {
return err
}
return errors.New(a.getLabelMissingMessage())
}
return nil
}
func (a *Action) extractLabels(prBody string) map[string]bool {
r := regexp.MustCompile(a.config.GetLabelPattern())
targets := r.FindAllStringSubmatch(prBody, -1)
labels := make(map[string]bool)
for _, v := range targets {
checked := strings.ToLower(strings.TrimSpace(v[1])) == "x"
name := strings.TrimSpace(v[2])
// Filter uninterested labels
if _, exist := a.config.labelWatchSet[name]; !exist {
continue
}
labels[name] = checked
}
return labels
}
func (a *Action) getRepoLabels() (map[string]struct{}, error) {
ctx := context.Background()
listOptions := &github.ListOptions{PerPage: 100}
repoLabels := make(map[string]struct{}, 0)
for {
rLabels, resp, err := a.client.Issues.ListLabels(ctx, a.config.GetOwner(), a.config.GetRepo(), listOptions)
if err != nil {
return nil, err
}
for _, label := range rLabels {
repoLabels[label.GetName()] = struct{}{}
}
if resp.NextPage == 0 {
break
}
listOptions.Page = resp.NextPage
}
return repoLabels, nil
}
func (a *Action) labelsToMap(labels []*github.Label) map[string]struct{} {
result := make(map[string]struct{}, 0)
for _, label := range labels {
result[label.GetName()] = struct{}{}
}
return result
}
func (a *Action) labelsSetToString(labels map[string]struct{}) []string {
result := []string{}
for label := range labels {
result = append(result, label)
}
return result
}
func (a *Action) getLabelInvalidCommentIDs(body string) ([]int64, error) {
ctx := context.Background()
listOptions := &github.IssueListCommentsOptions{}
listOptions.PerPage = 100
commentIDs := make([]int64, 0)
for {
comments, resp, err := a.client.Issues.ListComments(ctx, a.config.GetOwner(), a.config.GetRepo(),
a.prNumber, listOptions)
if err != nil {
return nil, err
}
for _, item := range comments {
if strings.Contains(*item.Body, body) {
commentIDs = append(commentIDs, *item.ID)
}
}
if resp.NextPage == 0 {
break
}
listOptions.Page = resp.NextPage
}
return commentIDs, nil
}
func (a *Action) createComment(body string) error {
_, _, err := a.client.Issues.CreateComment(a.globalContext, a.config.GetOwner(), a.config.GetRepo(),
a.prNumber, &github.IssueComment{Body: func(v string) *string { return &v }(body)})
return err
}
func (a *Action) deleteComment(commentID int64) error {
_, err := a.client.Issues.DeleteComment(a.globalContext, a.config.GetOwner(), a.config.GetRepo(),
commentID)
return err
}
func (a *Action) addLabels(labels []string) error {
_, _, err := a.client.Issues.AddLabelsToIssue(a.globalContext, a.config.GetOwner(), a.config.GetRepo(),
a.prNumber, labels)
return err
}
func (a *Action) removeLabel(label string) error {
_, err := a.client.Issues.RemoveLabelForIssue(a.globalContext, a.config.GetOwner(), a.config.GetRepo(),
a.prNumber, label)
return err
}
// addAndCleanupHelpComment adds a help comment when no help comment on the PR.
func (a *Action) addAndCleanupHelpComment(login, body string) error {
commentIDs, err := a.getLabelInvalidCommentIDs(body)
if err != nil {
logger.Errorf("Failed to get the comment list: %v", err)
return err
}
if len(commentIDs) == 0 {
err = a.createComment(fmt.Sprintf("@%s %s", login, body))
if err != nil {
logger.Errorf("Failed to create %s comment: %v", body, err)
return err
}
return nil
} else {
// cleanup
if len(commentIDs) > 1 {
for index, id := range commentIDs {
if index == 0 {
continue
}
err := a.deleteComment(id)
if err != nil {
logger.Errorf("Failed to delete %v comment: %v", id, err)
return err
}
}
}
}
return nil
}
func (a *Action) getLabelMissingMessage() string {
msg := "Please add the following content to your PR description and select a checkbox:\n```\n"
for _, label := range a.config.labelWatchList {
desc := ""
if value, found := builtInDescriptions[label]; found {
desc = fmt.Sprintf("<!-- %s -->", value)
}
msg += fmt.Sprintf("- [ ] `%s` %s\n", label, desc)
}
msg += "```"
return msg
}