github.com/pf-qiu/concourse/v6@v6.7.3-0.20201207032516-1f455d73275f/atc/api/accessor/verifier.go (about)

     1  package accessor
     2  
     3  import (
     4  	"errors"
     5  	"net/http"
     6  	"strings"
     7  	"sync"
     8  	"time"
     9  
    10  	"github.com/pf-qiu/concourse/v6/atc/db"
    11  	"gopkg.in/square/go-jose.v2/jwt"
    12  )
    13  
    14  var (
    15  	ErrVerificationNoToken         = errors.New("token not provided")
    16  	ErrVerificationInvalidToken    = errors.New("token provided is invalid")
    17  	ErrVerificationTokenExpired    = errors.New("token is expired")
    18  	ErrVerificationInvalidAudience = errors.New("token has invalid audience")
    19  )
    20  
    21  //go:generate counterfeiter .  AccessTokenFetcher
    22  
    23  type AccessTokenFetcher interface {
    24  	GetAccessToken(rawToken string) (db.AccessToken, bool, error)
    25  }
    26  
    27  func NewVerifier(accessTokenFetcher AccessTokenFetcher, audience []string) *verifier {
    28  	return &verifier{
    29  		accessTokenFetcher: accessTokenFetcher,
    30  		audience:           audience,
    31  	}
    32  }
    33  
    34  type verifier struct {
    35  	sync.Mutex
    36  	accessTokenFetcher AccessTokenFetcher
    37  	audience           []string
    38  }
    39  
    40  func (v *verifier) Verify(r *http.Request) (map[string]interface{}, error) {
    41  
    42  	header := r.Header.Get("Authorization")
    43  	if header == "" {
    44  		return nil, ErrVerificationNoToken
    45  	}
    46  
    47  	parts := strings.Split(header, " ")
    48  	if len(parts) != 2 || !strings.EqualFold(parts[0], "bearer") {
    49  		return nil, ErrVerificationInvalidToken
    50  	}
    51  
    52  	return v.verify(parts[1])
    53  }
    54  
    55  func (v *verifier) verify(rawToken string) (map[string]interface{}, error) {
    56  	token, found, err := v.accessTokenFetcher.GetAccessToken(rawToken)
    57  	if err != nil {
    58  		return nil, err
    59  	}
    60  	if !found {
    61  		return nil, ErrVerificationInvalidToken
    62  	}
    63  
    64  	claims := token.Claims
    65  	err = claims.Validate(jwt.Expected{Time: time.Now()})
    66  	if err != nil {
    67  		return nil, ErrVerificationTokenExpired
    68  	}
    69  
    70  	for _, aud := range v.audience {
    71  		if claims.Audience.Contains(aud) {
    72  			return claims.RawClaims, nil
    73  		}
    74  	}
    75  
    76  	return nil, ErrVerificationInvalidAudience
    77  }