github.com/abayer/test-infra@v0.0.5/prow/cmd/branchprotector/protect.go (about)

     1  /*
     2  Copyright 2018 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package main
    18  
    19  import (
    20  	"errors"
    21  	"flag"
    22  	"fmt"
    23  	"net/url"
    24  	"strings"
    25  	"sync"
    26  
    27  	"k8s.io/test-infra/prow/config"
    28  	"k8s.io/test-infra/prow/flagutil"
    29  	"k8s.io/test-infra/prow/github"
    30  	"k8s.io/test-infra/prow/logrusutil"
    31  
    32  	"github.com/sirupsen/logrus"
    33  )
    34  
    35  type options struct {
    36  	config    string
    37  	jobConfig string
    38  	token     string
    39  	confirm   bool
    40  	endpoint  flagutil.Strings
    41  }
    42  
    43  func (o *options) Validate() error {
    44  	if o.config == "" {
    45  		return errors.New("empty --config-path")
    46  	}
    47  
    48  	if o.token == "" {
    49  		return errors.New("empty --github-token-path")
    50  	}
    51  
    52  	for _, ep := range o.endpoint.Strings() {
    53  		_, err := url.Parse(ep)
    54  		if err != nil {
    55  			return fmt.Errorf("invalid --endpoint URL %q: %v", ep, err)
    56  		}
    57  	}
    58  
    59  	return nil
    60  }
    61  
    62  func gatherOptions() options {
    63  	o := options{
    64  		endpoint: flagutil.NewStrings("https://api.github.com"),
    65  	}
    66  	flag.StringVar(&o.config, "config-path", "", "Path to prow config.yaml")
    67  	flag.StringVar(&o.jobConfig, "job-config-path", "", "Path to prow job configs.")
    68  	flag.BoolVar(&o.confirm, "confirm", false, "Mutate github if set")
    69  	flag.Var(&o.endpoint, "github-endpoint", "Github api endpoint, may differ for enterprise")
    70  	flag.StringVar(&o.token, "github-token-path", "", "Path to github token")
    71  	flag.Parse()
    72  	return o
    73  }
    74  
    75  type requirements struct {
    76  	Org     string
    77  	Repo    string
    78  	Branch  string
    79  	Request *github.BranchProtectionRequest
    80  }
    81  
    82  // Errors holds a list of errors, including a method to concurrently append.
    83  type Errors struct {
    84  	lock sync.Mutex
    85  	errs []error
    86  }
    87  
    88  func (e *Errors) add(err error) {
    89  	e.lock.Lock()
    90  	logrus.Info(err)
    91  	defer e.lock.Unlock()
    92  	e.errs = append(e.errs, err)
    93  }
    94  
    95  func main() {
    96  	logrus.SetFormatter(
    97  		logrusutil.NewDefaultFieldsFormatter(nil, logrus.Fields{"component": "branchprotector"}),
    98  	)
    99  
   100  	o := gatherOptions()
   101  	if err := o.Validate(); err != nil {
   102  		logrus.Fatal(err)
   103  	}
   104  
   105  	cfg, err := config.Load(o.config, o.jobConfig)
   106  	if err != nil {
   107  		logrus.WithError(err).Fatalf("Failed to load --config-path=%s", o.config)
   108  	}
   109  
   110  	secretAgent := &config.SecretAgent{}
   111  	if err := secretAgent.Start([]string{o.token}); err != nil {
   112  		logrus.WithError(err).Fatal("Error starting secrets agent.")
   113  	}
   114  
   115  	var c *github.Client
   116  
   117  	if o.confirm {
   118  		c = github.NewClient(secretAgent.GetTokenGenerator(o.token), o.endpoint.Strings()...)
   119  	} else {
   120  		c = github.NewDryRunClient(secretAgent.GetTokenGenerator(o.token), o.endpoint.Strings()...)
   121  	}
   122  	c.Throttle(300, 100) // 300 hourly tokens, bursts of 100
   123  
   124  	p := protector{
   125  		client:         c,
   126  		cfg:            cfg,
   127  		updates:        make(chan requirements),
   128  		errors:         Errors{},
   129  		completedRepos: make(map[string]bool),
   130  		done:           make(chan []error),
   131  	}
   132  
   133  	go p.configureBranches()
   134  	p.protect()
   135  	close(p.updates)
   136  	errors := <-p.done
   137  	if n := len(errors); n > 0 {
   138  		for i, err := range errors {
   139  			logrus.WithError(err).Error(i)
   140  		}
   141  		logrus.Fatalf("Encountered %d errors protecting branches", n)
   142  	}
   143  }
   144  
   145  type client interface {
   146  	RemoveBranchProtection(org, repo, branch string) error
   147  	UpdateBranchProtection(org, repo, branch string, config github.BranchProtectionRequest) error
   148  	GetBranches(org, repo string, onlyProtected bool) ([]github.Branch, error)
   149  	GetRepos(org string, user bool) ([]github.Repo, error)
   150  }
   151  
   152  type protector struct {
   153  	client         client
   154  	cfg            *config.Config
   155  	updates        chan requirements
   156  	errors         Errors
   157  	completedRepos map[string]bool
   158  	done           chan []error
   159  }
   160  
   161  func (p *protector) configureBranches() {
   162  	for u := range p.updates {
   163  		if u.Request == nil {
   164  			if err := p.client.RemoveBranchProtection(u.Org, u.Repo, u.Branch); err != nil {
   165  				p.errors.add(fmt.Errorf("remove %s/%s=%s protection failed: %v", u.Org, u.Repo, u.Branch, err))
   166  			}
   167  			continue
   168  		}
   169  
   170  		if err := p.client.UpdateBranchProtection(u.Org, u.Repo, u.Branch, *u.Request); err != nil {
   171  			p.errors.add(fmt.Errorf("update %s/%s=%s protection to %v failed: %v", u.Org, u.Repo, u.Branch, *u.Request, err))
   172  		}
   173  	}
   174  	p.done <- p.errors.errs
   175  }
   176  
   177  // protect protects branches specified in the presubmit and branch-protection config sections.
   178  func (p *protector) protect() {
   179  	bp := p.cfg.BranchProtection
   180  
   181  	// Scan the branch-protection configuration
   182  	for orgName, org := range bp.Orgs {
   183  		if err := p.UpdateOrg(orgName, org, bp.HasProtect()); err != nil {
   184  			p.errors.add(err)
   185  		}
   186  	}
   187  
   188  	// Do not automatically protect tested repositories
   189  	if !bp.ProtectTested {
   190  		return
   191  	}
   192  
   193  	// Some repos with presubmits might not be listed in the branch-protection
   194  	for repo := range p.cfg.Presubmits {
   195  		if p.completedRepos[repo] == true {
   196  			continue
   197  		}
   198  		parts := strings.Split(repo, "/")
   199  		if len(parts) != 2 { // TODO(fejta): use a strong type here instead
   200  			logrus.Fatalf("Bad repo: %s", repo)
   201  		}
   202  		orgName := parts[0]
   203  		repoName := parts[1]
   204  		if err := p.UpdateRepo(orgName, repoName, config.Repo{}); err != nil {
   205  			p.errors.add(err)
   206  		}
   207  	}
   208  }
   209  
   210  // UpdateOrg updates all repos in the org with the specified defaults
   211  func (p *protector) UpdateOrg(orgName string, org config.Org, allRepos bool) error {
   212  	var repos []string
   213  	allRepos = allRepos || org.HasProtect()
   214  	if allRepos {
   215  		// Strongly opinionated org, configure every repo in the org.
   216  		rs, err := p.client.GetRepos(orgName, false)
   217  		if err != nil {
   218  			return fmt.Errorf("GetRepos(%s) failed: %v", orgName, err)
   219  		}
   220  		for _, r := range rs {
   221  			repos = append(repos, r.Name)
   222  		}
   223  	} else {
   224  		// Unopinionated org, just set explicitly defined repos
   225  		for r := range org.Repos {
   226  			repos = append(repos, r)
   227  		}
   228  	}
   229  
   230  	for _, repoName := range repos {
   231  		err := p.UpdateRepo(orgName, repoName, org.Repos[repoName])
   232  		if err != nil {
   233  			return err
   234  		}
   235  	}
   236  	return nil
   237  }
   238  
   239  // UpdateRepo updates all branches in the repo with the specified defaults
   240  func (p *protector) UpdateRepo(orgName string, repo string, repoDefaults config.Repo) error {
   241  	p.completedRepos[orgName+"/"+repo] = true
   242  
   243  	branches := map[string]github.Branch{}
   244  	for _, onlyProtected := range []bool{false, true} { // put true second so it becomes the value
   245  		bs, err := p.client.GetBranches(orgName, repo, onlyProtected)
   246  		if err != nil {
   247  			return fmt.Errorf("GetBranches(%s, %s, %t) failed: %v", orgName, repo, onlyProtected, err)
   248  		}
   249  		for _, b := range bs {
   250  			branches[b.Name] = b
   251  		}
   252  	}
   253  
   254  	for bn, branch := range branches {
   255  		if err := p.UpdateBranch(orgName, repo, bn, branch.Protected); err != nil {
   256  			return fmt.Errorf("UpdateBranch(%s, %s, %s, %t) failed: %v", orgName, repo, bn, branch.Protected, err)
   257  		}
   258  	}
   259  	return nil
   260  }
   261  
   262  // UpdateBranch updates the branch with the specified configuration
   263  func (p *protector) UpdateBranch(orgName, repo string, branchName string, protected bool) error {
   264  	bp, err := p.cfg.GetBranchProtection(orgName, repo, branchName)
   265  	if err != nil {
   266  		return err
   267  	}
   268  	if bp == nil || bp.Protect == nil {
   269  		return nil
   270  	}
   271  	if !protected && !*bp.Protect {
   272  		logrus.Infof("%s/%s=%s: already unprotected", orgName, repo, branchName)
   273  		return nil
   274  	}
   275  	var req *github.BranchProtectionRequest
   276  	if *bp.Protect {
   277  		r := makeRequest(*bp)
   278  		req = &r
   279  	}
   280  	p.updates <- requirements{
   281  		Org:     orgName,
   282  		Repo:    repo,
   283  		Branch:  branchName,
   284  		Request: req,
   285  	}
   286  
   287  	return nil
   288  }