github.com/xmidt-org/webpa-common@v1.11.9/secure/tools/cmd/keyserver/issueHandler.go (about)

     1  package main
     2  
     3  import (
     4  	"crypto/rand"
     5  	"encoding/json"
     6  	"errors"
     7  	"fmt"
     8  	"github.com/SermoDigital/jose/crypto"
     9  	"github.com/SermoDigital/jose/jws"
    10  	"github.com/SermoDigital/jose/jwt"
    11  	"github.com/gorilla/schema"
    12  	"io/ioutil"
    13  	"net/http"
    14  	"strconv"
    15  	"time"
    16  )
    17  
    18  const (
    19  	KeyIDVariableName        = "kid"
    20  	DefaultExpireDuration    = time.Duration(24 * time.Hour)
    21  	DefaultNotBeforeDuration = time.Duration(1 * time.Hour)
    22  )
    23  
    24  var (
    25  	ErrorMissingKeyID = errors.New("A kid parameter is required")
    26  
    27  	zeroTime                                  = time.Time{}
    28  	defaultSigningMethod crypto.SigningMethod = crypto.SigningMethodRS256
    29  
    30  	supportedSigningMethods = map[string]crypto.SigningMethod{
    31  		defaultSigningMethod.Alg():      defaultSigningMethod,
    32  		crypto.SigningMethodRS384.Alg(): crypto.SigningMethodRS384,
    33  		crypto.SigningMethodRS512.Alg(): crypto.SigningMethodRS512,
    34  	}
    35  
    36  	supportedNumericDateLayouts = []string{
    37  		time.RFC3339,
    38  		time.RFC822,
    39  		time.RFC822Z,
    40  	}
    41  
    42  	// zeroNumericDate is a singleton value indicating a blank value
    43  	zeroNumericDate = NumericDate{}
    44  )
    45  
    46  // NumericDate represents a JWT NumericDate as specified in:
    47  // https://tools.ietf.org/html/rfc7519#section-2
    48  //
    49  // A number of formats for numeric dates are allowed, and each
    50  // is converted appropriately:
    51  //
    52  // (1) An int64 value, which is interpreted as the exact value to use
    53  // (2) A valid time.Duration, which is added to time.Now() to compute the value
    54  // (3) An absolute date specified in RFC33399 or RFC822 formates.  See the time package for details.
    55  type NumericDate struct {
    56  	duration time.Duration
    57  	absolute time.Time
    58  }
    59  
    60  func (nd *NumericDate) UnmarshalText(raw []byte) error {
    61  	if len(raw) == 0 {
    62  		*nd = zeroNumericDate
    63  		return nil
    64  	}
    65  
    66  	text := string(raw)
    67  
    68  	if value, err := strconv.ParseInt(text, 10, 64); err == nil {
    69  		*nd = NumericDate{duration: 0, absolute: time.Unix(value, 0)}
    70  		return nil
    71  	}
    72  
    73  	if duration, err := time.ParseDuration(text); err == nil {
    74  		*nd = NumericDate{duration: duration, absolute: zeroTime}
    75  		return nil
    76  	}
    77  
    78  	for _, layout := range supportedNumericDateLayouts {
    79  		if value, err := time.Parse(layout, text); err == nil {
    80  			*nd = NumericDate{duration: 0, absolute: value}
    81  			return nil
    82  		}
    83  	}
    84  
    85  	return fmt.Errorf("Unparseable datetime: %s", text)
    86  }
    87  
    88  // IsZero tests whether this NumericDate is blank, as would be the case when
    89  // the original request assigns a value to the empty string.  This is useful to
    90  // have the server generate a default value appropriate for the field.
    91  func (nd *NumericDate) IsZero() bool {
    92  	return nd.duration == 0 && nd.absolute.IsZero()
    93  }
    94  
    95  // Compute calculates the time.Time value given a point in time
    96  // assumed to be "now".  Use of this level of indirection allows a
    97  // single time value to be used in all calculations when issuing JWTs.
    98  func (nd *NumericDate) Compute(now time.Time) time.Time {
    99  	if nd.duration != 0 {
   100  		return now.Add(nd.duration)
   101  	}
   102  
   103  	return nd.absolute
   104  }
   105  
   106  // SigningMethod is a custom type which holds the alg value.
   107  type SigningMethod struct {
   108  	crypto.SigningMethod
   109  }
   110  
   111  func (s *SigningMethod) UnmarshalText(raw []byte) error {
   112  	if len(raw) == 0 {
   113  		*s = SigningMethod{defaultSigningMethod}
   114  		return nil
   115  	}
   116  
   117  	text := string(raw)
   118  	value, ok := supportedSigningMethods[text]
   119  	if ok {
   120  		*s = SigningMethod{value}
   121  		return nil
   122  	}
   123  
   124  	return fmt.Errorf("Unsupported algorithm: %s", text)
   125  }
   126  
   127  // IssueRequest contains the information necessary for issuing a JWS.
   128  // Any custom claims must be transmitted separately.
   129  type IssueRequest struct {
   130  	Now time.Time `schema:"-"`
   131  
   132  	KeyID     string         `schema:"kid"`
   133  	Algorithm *SigningMethod `schema:"alg"`
   134  
   135  	Expires   *NumericDate `schema:"exp"`
   136  	NotBefore *NumericDate `schema:"nbf"`
   137  
   138  	JWTID    *string   `schema:"jti"`
   139  	Subject  string    `schema:"sub"`
   140  	Audience *[]string `schema:"aud"`
   141  }
   142  
   143  func (ir *IssueRequest) SigningMethod() crypto.SigningMethod {
   144  	if ir.Algorithm != nil {
   145  		return ir.Algorithm.SigningMethod
   146  	}
   147  
   148  	return defaultSigningMethod
   149  }
   150  
   151  // AddToHeader adds the appropriate header information from this issue request
   152  func (ir *IssueRequest) AddToHeader(header map[string]interface{}) error {
   153  	// right now, we just add the kid
   154  	header[KeyIDVariableName] = ir.KeyID
   155  	return nil
   156  }
   157  
   158  // AddToClaims takes the various parts of this issue request and formats them
   159  // appropriately into a supplied jwt.Claims object.
   160  func (ir *IssueRequest) AddToClaims(claims jwt.Claims) error {
   161  	claims.SetIssuedAt(ir.Now)
   162  
   163  	if ir.Expires != nil {
   164  		if ir.Expires.IsZero() {
   165  			claims.SetExpiration(ir.Now.Add(DefaultExpireDuration))
   166  		} else {
   167  			claims.SetExpiration(ir.Expires.Compute(ir.Now))
   168  		}
   169  	}
   170  
   171  	if ir.NotBefore != nil {
   172  		if ir.NotBefore.IsZero() {
   173  			claims.SetNotBefore(ir.Now.Add(DefaultNotBeforeDuration))
   174  		} else {
   175  			claims.SetNotBefore(ir.NotBefore.Compute(ir.Now))
   176  		}
   177  	}
   178  
   179  	if ir.JWTID != nil {
   180  		jti := *ir.JWTID
   181  		if len(jti) == 0 {
   182  			// generate a type 4 UUID
   183  			buffer := make([]byte, 16)
   184  			if _, err := rand.Read(buffer); err != nil {
   185  				return err
   186  			}
   187  
   188  			buffer[6] = (buffer[6] | 0x40) & 0x4F
   189  			buffer[8] = (buffer[8] | 0x80) & 0x8F
   190  
   191  			// dashes are just noise!
   192  			jti = fmt.Sprintf("%X", buffer)
   193  		}
   194  
   195  		claims.SetJWTID(jti)
   196  	}
   197  
   198  	if len(ir.Subject) > 0 {
   199  		claims.SetSubject(ir.Subject)
   200  	}
   201  
   202  	if ir.Audience != nil {
   203  		claims.SetAudience((*ir.Audience)...)
   204  	}
   205  
   206  	return nil
   207  }
   208  
   209  func NewIssueRequest(decoder *schema.Decoder, source map[string][]string) (*IssueRequest, error) {
   210  	issueRequest := &IssueRequest{}
   211  	if err := decoder.Decode(issueRequest, source); err != nil {
   212  		return nil, err
   213  	}
   214  
   215  	if len(issueRequest.KeyID) == 0 {
   216  		return nil, ErrorMissingKeyID
   217  	}
   218  
   219  	issueRequest.Now = time.Now()
   220  	return issueRequest, nil
   221  }
   222  
   223  // IssueHandler issues JWS tokens
   224  type IssueHandler struct {
   225  	BasicHandler
   226  	issuer  string
   227  	decoder *schema.Decoder
   228  }
   229  
   230  // issue handles all the common logic for issuing a JWS token
   231  func (handler *IssueHandler) issue(response http.ResponseWriter, issueRequest *IssueRequest, claims jwt.Claims) {
   232  	issueKey, ok := handler.keyStore.PrivateKey(issueRequest.KeyID)
   233  	if !ok {
   234  		handler.httpError(response, http.StatusBadRequest, fmt.Sprintf("No such key: %s", issueRequest.KeyID))
   235  		return
   236  	}
   237  
   238  	if claims == nil {
   239  		claims = make(jwt.Claims)
   240  	}
   241  
   242  	issuedJWT := jws.NewJWT(jws.Claims(claims), issueRequest.SigningMethod())
   243  	if err := issueRequest.AddToClaims(issuedJWT.Claims()); err != nil {
   244  		handler.httpError(response, http.StatusInternalServerError, err.Error())
   245  		return
   246  	}
   247  
   248  	issuedJWT.Claims().SetIssuer(handler.issuer)
   249  	issuedJWS := issuedJWT.(jws.JWS)
   250  	if err := issueRequest.AddToHeader(issuedJWS.Protected()); err != nil {
   251  		handler.httpError(response, http.StatusInternalServerError, err.Error())
   252  		return
   253  	}
   254  
   255  	compact, err := issuedJWS.Compact(issueKey)
   256  	if err != nil {
   257  		handler.httpError(response, http.StatusInternalServerError, err.Error())
   258  		return
   259  	}
   260  
   261  	response.Header().Set("Content-Type", "application/jwt")
   262  	response.Write(compact)
   263  }
   264  
   265  // SimpleIssue handles requests with no body, appropriate for simple use cases.
   266  func (handler *IssueHandler) SimpleIssue(response http.ResponseWriter, request *http.Request) {
   267  	if err := request.ParseForm(); err != nil {
   268  		handler.httpError(response, http.StatusBadRequest, err.Error())
   269  		return
   270  	}
   271  
   272  	issueRequest, err := NewIssueRequest(handler.decoder, request.Form)
   273  	if err != nil {
   274  		handler.httpError(response, http.StatusBadRequest, err.Error())
   275  		return
   276  	}
   277  
   278  	handler.issue(response, issueRequest, nil)
   279  }
   280  
   281  // IssueUsingBody accepts a JSON claims document, to which it then adds all the standard
   282  // claims mentioned in request parameters, e.g. exp.  It then uses the merged claims
   283  // in an issued JWS.
   284  func (handler *IssueHandler) IssueUsingBody(response http.ResponseWriter, request *http.Request) {
   285  	if err := request.ParseForm(); err != nil {
   286  		handler.httpError(response, http.StatusBadRequest, err.Error())
   287  		return
   288  	}
   289  
   290  	issueRequest, err := NewIssueRequest(handler.decoder, request.Form)
   291  	if err != nil {
   292  		handler.httpError(response, http.StatusBadRequest, err.Error())
   293  		return
   294  	}
   295  
   296  	// this variant reads the claims directly from the request body
   297  	claims := make(jwt.Claims)
   298  	if request.Body != nil {
   299  		body, err := ioutil.ReadAll(request.Body)
   300  		if err != nil {
   301  			handler.httpError(response, http.StatusBadRequest, fmt.Sprintf("Unable to read request body: %s", err))
   302  			return
   303  		}
   304  
   305  		if len(body) > 0 {
   306  			// we don't want to uses the Claims unmarshalling logic, as that assumes base64
   307  			if err := json.Unmarshal(body, (*map[string]interface{})(&claims)); err != nil {
   308  				handler.httpError(response, http.StatusBadRequest, fmt.Sprintf("Unable to parse JSON in request body: %s", err))
   309  				return
   310  			}
   311  		}
   312  	}
   313  
   314  	handler.issue(response, issueRequest, claims)
   315  }