github.com/argoproj/argo-cd/v2@v2.10.9/applicationset/services/pull_request/github.go (about)

     1  package pull_request
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"os"
     7  
     8  	"github.com/google/go-github/v35/github"
     9  	"golang.org/x/oauth2"
    10  )
    11  
    12  type GithubService struct {
    13  	client *github.Client
    14  	owner  string
    15  	repo   string
    16  	labels []string
    17  }
    18  
    19  var _ PullRequestService = (*GithubService)(nil)
    20  
    21  func NewGithubService(ctx context.Context, token, url, owner, repo string, labels []string) (PullRequestService, error) {
    22  	var ts oauth2.TokenSource
    23  	// Undocumented environment variable to set a default token, to be used in testing to dodge anonymous rate limits.
    24  	if token == "" {
    25  		token = os.Getenv("GITHUB_TOKEN")
    26  	}
    27  	if token != "" {
    28  		ts = oauth2.StaticTokenSource(
    29  			&oauth2.Token{AccessToken: token},
    30  		)
    31  	}
    32  	httpClient := oauth2.NewClient(ctx, ts)
    33  	var client *github.Client
    34  	if url == "" {
    35  		client = github.NewClient(httpClient)
    36  	} else {
    37  		var err error
    38  		client, err = github.NewEnterpriseClient(url, url, httpClient)
    39  		if err != nil {
    40  			return nil, err
    41  		}
    42  	}
    43  	return &GithubService{
    44  		client: client,
    45  		owner:  owner,
    46  		repo:   repo,
    47  		labels: labels,
    48  	}, nil
    49  }
    50  
    51  func (g *GithubService) List(ctx context.Context) ([]*PullRequest, error) {
    52  	opts := &github.PullRequestListOptions{
    53  		ListOptions: github.ListOptions{
    54  			PerPage: 100,
    55  		},
    56  	}
    57  	pullRequests := []*PullRequest{}
    58  	for {
    59  		pulls, resp, err := g.client.PullRequests.List(ctx, g.owner, g.repo, opts)
    60  		if err != nil {
    61  			return nil, fmt.Errorf("error listing pull requests for %s/%s: %w", g.owner, g.repo, err)
    62  		}
    63  		for _, pull := range pulls {
    64  			if !containLabels(g.labels, pull.Labels) {
    65  				continue
    66  			}
    67  			pullRequests = append(pullRequests, &PullRequest{
    68  				Number:       *pull.Number,
    69  				Branch:       *pull.Head.Ref,
    70  				TargetBranch: *pull.Base.Ref,
    71  				HeadSHA:      *pull.Head.SHA,
    72  				Labels:       getGithubPRLabelNames(pull.Labels),
    73  			})
    74  		}
    75  		if resp.NextPage == 0 {
    76  			break
    77  		}
    78  		opts.Page = resp.NextPage
    79  	}
    80  	return pullRequests, nil
    81  }
    82  
    83  // containLabels returns true if gotLabels contains expectedLabels
    84  func containLabels(expectedLabels []string, gotLabels []*github.Label) bool {
    85  	for _, expected := range expectedLabels {
    86  		found := false
    87  		for _, got := range gotLabels {
    88  			if got.Name == nil {
    89  				continue
    90  			}
    91  			if expected == *got.Name {
    92  				found = true
    93  				break
    94  			}
    95  		}
    96  		if !found {
    97  			return false
    98  		}
    99  	}
   100  	return true
   101  }
   102  
   103  // Get the Github pull request label names.
   104  func getGithubPRLabelNames(gitHubLabels []*github.Label) []string {
   105  	var labelNames []string
   106  	for _, gitHubLabel := range gitHubLabels {
   107  		labelNames = append(labelNames, *gitHubLabel.Name)
   108  	}
   109  	return labelNames
   110  }