github.com/hellofresh/janus@v0.0.0-20230925145208-ce8de8183c67/pkg/jwt/handler.go (about)

     1  package jwt
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"net/http"
     7  	"strings"
     8  	"time"
     9  
    10  	"github.com/dgrijalva/jwt-go"
    11  	log "github.com/sirupsen/logrus"
    12  	"golang.org/x/oauth2"
    13  
    14  	"github.com/hellofresh/janus/pkg/config"
    15  	"github.com/hellofresh/janus/pkg/jwt/provider"
    16  	"github.com/hellofresh/janus/pkg/render"
    17  )
    18  
    19  const (
    20  	bearer = "bearer"
    21  )
    22  
    23  // Handler struct
    24  type Handler struct {
    25  	Guard Guard
    26  }
    27  
    28  // Login can be used by clients to get a jwt token.
    29  // Payload needs to be json in the form of {"username": "<USERNAME>", "password": "<PASSWORD>"}.
    30  // Reply will be of the form {"token": "<TOKEN>"}.
    31  func (j *Handler) Login(config config.Credentials) http.HandlerFunc {
    32  	return func(w http.ResponseWriter, r *http.Request) {
    33  		accessToken, err := extractAccessToken(r)
    34  
    35  		if err != nil {
    36  			log.WithError(err).Debug("failed to extract access token")
    37  		}
    38  
    39  		httpClient := getClient(accessToken)
    40  		factory := provider.Factory{}
    41  		p := factory.Build(r.URL.Query().Get("provider"), config)
    42  
    43  		verified, err := p.Verify(r, httpClient)
    44  
    45  		if err != nil || !verified {
    46  			log.WithError(err).Debug(err.Error())
    47  			render.JSON(w, http.StatusUnauthorized, err.Error())
    48  			return
    49  		}
    50  
    51  		if 0 == j.Guard.Timeout {
    52  			j.Guard.Timeout = time.Hour
    53  		}
    54  
    55  		claims, err := p.GetClaims(httpClient)
    56  		if err != nil {
    57  			render.JSON(w, http.StatusBadRequest, err.Error())
    58  			return
    59  		}
    60  
    61  		token, err := IssueAdminToken(j.Guard.SigningMethod, claims, j.Guard.Timeout)
    62  
    63  		if err != nil {
    64  			render.JSON(w, http.StatusUnauthorized, "problem issuing JWT")
    65  			return
    66  		}
    67  
    68  		render.JSON(w, http.StatusOK, token)
    69  	}
    70  }
    71  
    72  // Refresh can be used to refresh existing and valid jwt token.
    73  // Reply will be of the form {"token": "<TOKEN>", "expire": "<DateTime in RFC-3339 format>"}.
    74  func (j *Handler) Refresh() http.HandlerFunc {
    75  	return func(w http.ResponseWriter, r *http.Request) {
    76  		parser := Parser{j.Guard.ParserConfig}
    77  		token, _ := parser.ParseFromRequest(r)
    78  		claims := token.Claims.(jwt.MapClaims)
    79  
    80  		origIat := int64(claims["iat"].(float64))
    81  
    82  		if origIat < time.Now().Add(-j.Guard.MaxRefresh).Unix() {
    83  			render.JSON(w, http.StatusUnauthorized, "token is expired")
    84  			return
    85  		}
    86  
    87  		// Create the token
    88  		newToken := jwt.New(jwt.GetSigningMethod(j.Guard.SigningMethod.Alg))
    89  		newClaims := newToken.Claims.(jwt.MapClaims)
    90  
    91  		for key := range claims {
    92  			newClaims[key] = claims[key]
    93  		}
    94  
    95  		expire := time.Now().Add(j.Guard.Timeout)
    96  		newClaims["sub"] = claims["sub"]
    97  		newClaims["exp"] = expire.Unix()
    98  		newClaims["iat"] = origIat
    99  
   100  		// currently only HSXXX algorithms are supported for issuing admin token, so we cast key to bytes array
   101  		tokenString, err := newToken.SignedString([]byte(j.Guard.SigningMethod.Key))
   102  		if err != nil {
   103  			render.JSON(w, http.StatusUnauthorized, "create JWT Token failed")
   104  			return
   105  		}
   106  
   107  		render.JSON(w, http.StatusOK, render.M{
   108  			"token":  tokenString,
   109  			"type":   "Bearer",
   110  			"expire": expire.Format(time.RFC3339),
   111  		})
   112  	}
   113  }
   114  
   115  func extractAccessToken(r *http.Request) (string, error) {
   116  	// We're using OAuth, start checking for access keys
   117  	authHeaderValue := r.Header.Get("Authorization")
   118  	parts := strings.Split(authHeaderValue, " ")
   119  	if len(parts) < 2 {
   120  		return "", errors.New("attempted access with malformed header, no auth header found")
   121  	}
   122  
   123  	if strings.ToLower(parts[0]) != bearer {
   124  		return "", errors.New("bearer token malformed")
   125  	}
   126  
   127  	return parts[1], nil
   128  }
   129  
   130  func getClient(token string) *http.Client {
   131  	ctx := context.Background()
   132  	ts := oauth2.StaticTokenSource(
   133  		&oauth2.Token{AccessToken: token},
   134  	)
   135  	return oauth2.NewClient(ctx, ts)
   136  }