github.com/mendersoftware/go-lib-micro@v0.0.0-20240304135804-e8e39c59b148/identity/middleware.go (about)

     1  // Copyright 2023 Northern.tech AS
     2  //
     3  //    Licensed under the Apache License, Version 2.0 (the "License");
     4  //    you may not use this file except in compliance with the License.
     5  //    You may obtain a copy of the License at
     6  //
     7  //        http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  //    Unless required by applicable law or agreed to in writing, software
    10  //    distributed under the License is distributed on an "AS IS" BASIS,
    11  //    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  //    See the License for the specific language governing permissions and
    13  //    limitations under the License.
    14  
    15  package identity
    16  
    17  import (
    18  	"net/http"
    19  	"regexp"
    20  
    21  	"github.com/ant0ine/go-json-rest/rest"
    22  	"github.com/gin-gonic/gin"
    23  
    24  	"github.com/mendersoftware/go-lib-micro/log"
    25  	urest "github.com/mendersoftware/go-lib-micro/rest.utils"
    26  )
    27  
    28  type MiddlewareOptions struct {
    29  	// PathRegex sets the regex for the path for which this middleware
    30  	// applies. Defaults to "^/api/management/v[0-9.]{1,6}/.+".
    31  	PathRegex *string
    32  
    33  	// UpdateLogger adds the decoded identity to the log context.
    34  	UpdateLogger *bool
    35  }
    36  
    37  func NewMiddlewareOptions() *MiddlewareOptions {
    38  	return new(MiddlewareOptions)
    39  }
    40  
    41  func (opts *MiddlewareOptions) SetPathRegex(regex string) *MiddlewareOptions {
    42  	opts.PathRegex = &regex
    43  	return opts
    44  }
    45  
    46  func (opts *MiddlewareOptions) SetUpdateLogger(updateLogger bool) *MiddlewareOptions {
    47  	opts.UpdateLogger = &updateLogger
    48  	return opts
    49  }
    50  
    51  func middlewareWithLogger(c *gin.Context) {
    52  	var (
    53  		err    error
    54  		jwt    string
    55  		idty   Identity
    56  		logCtx = log.Ctx{}
    57  		key    = "sub"
    58  		ctx    = c.Request.Context()
    59  		l      = log.FromContext(ctx)
    60  	)
    61  	jwt, err = ExtractJWTFromHeader(c.Request)
    62  	if err != nil {
    63  		goto exitUnauthorized
    64  	}
    65  	idty, err = ExtractIdentity(jwt)
    66  	if err != nil {
    67  		goto exitUnauthorized
    68  	}
    69  	ctx = WithContext(ctx, &idty)
    70  	if idty.IsDevice {
    71  		key = "device_id"
    72  	} else if idty.IsUser {
    73  		key = "user_id"
    74  	}
    75  	logCtx[key] = idty.Subject
    76  	if idty.Tenant != "" {
    77  		logCtx["tenant_id"] = idty.Tenant
    78  	}
    79  	if idty.Plan != "" {
    80  		logCtx["plan"] = idty.Plan
    81  	}
    82  	ctx = log.WithContext(ctx, l.F(logCtx))
    83  
    84  	c.Request = c.Request.WithContext(ctx)
    85  	return
    86  exitUnauthorized:
    87  	c.Header("WWW-Authenticate", `Bearer realm="ManagementJWT"`)
    88  	urest.RenderError(c, http.StatusUnauthorized, err)
    89  	c.Abort()
    90  }
    91  
    92  func middlewareBase(c *gin.Context) {
    93  	var (
    94  		err  error
    95  		jwt  string
    96  		idty Identity
    97  		ctx  = c.Request.Context()
    98  	)
    99  	jwt, err = ExtractJWTFromHeader(c.Request)
   100  	if err != nil {
   101  		goto exitUnauthorized
   102  	}
   103  	idty, err = ExtractIdentity(jwt)
   104  	if err != nil {
   105  		goto exitUnauthorized
   106  	}
   107  	ctx = WithContext(ctx, &idty)
   108  	c.Request = c.Request.WithContext(ctx)
   109  	return
   110  exitUnauthorized:
   111  	c.Header("WWW-Authenticate", `Bearer realm="ManagementJWT"`)
   112  	urest.RenderError(c, http.StatusUnauthorized, err)
   113  	c.Abort()
   114  }
   115  
   116  func Middleware(opts ...*MiddlewareOptions) gin.HandlerFunc {
   117  
   118  	var middleware gin.HandlerFunc
   119  
   120  	// Initialize default options
   121  	opt := NewMiddlewareOptions().
   122  		SetUpdateLogger(true)
   123  	for _, o := range opts {
   124  		if o == nil {
   125  			continue
   126  		}
   127  		if o.PathRegex != nil {
   128  			opt.PathRegex = o.PathRegex
   129  		}
   130  		if o.UpdateLogger != nil {
   131  			opt.UpdateLogger = o.UpdateLogger
   132  		}
   133  	}
   134  
   135  	if *opt.UpdateLogger {
   136  		middleware = middlewareWithLogger
   137  	} else {
   138  		middleware = middlewareBase
   139  	}
   140  
   141  	if opt.PathRegex != nil {
   142  		pathRegex := regexp.MustCompile(*opt.PathRegex)
   143  		return func(c *gin.Context) {
   144  			if !pathRegex.MatchString(c.FullPath()) {
   145  				return
   146  			}
   147  			middleware(c)
   148  		}
   149  	}
   150  	return middleware
   151  }
   152  
   153  // IdentityMiddleware adds the identity extracted from JWT token to the request's context.
   154  // IdentityMiddleware does not perform any form of token signature verification.
   155  // If it is not possible to extract identity from header error log will be generated.
   156  // IdentityMiddleware will not stop control propagating through the chain in any case.
   157  // It is recommended to use IdentityMiddleware with RequestLogMiddleware and
   158  // RequestLogMiddleware should be placed before IdentityMiddleware.
   159  // Otherwise, log generated by IdentityMiddleware will not contain "request_id" field.
   160  type IdentityMiddleware struct {
   161  	// If set to true, the middleware will update context logger setting
   162  	// 'user_id' or 'device_id' to the value of subject field, if the token
   163  	// is not a user or a device token, the middelware will add a 'sub'
   164  	// field to the logger
   165  	UpdateLogger bool
   166  }
   167  
   168  // MiddlewareFunc makes IdentityMiddleware implement the Middleware interface.
   169  func (mw *IdentityMiddleware) MiddlewareFunc(h rest.HandlerFunc) rest.HandlerFunc {
   170  	return func(w rest.ResponseWriter, r *rest.Request) {
   171  		jwt, err := ExtractJWTFromHeader(r.Request)
   172  		if err != nil {
   173  			h(w, r)
   174  			return
   175  		}
   176  
   177  		ctx := r.Context()
   178  		l := log.FromContext(ctx)
   179  
   180  		identity, err := ExtractIdentity(jwt)
   181  		if err != nil {
   182  			l.Warnf("Failed to parse extracted JWT: %s",
   183  				err.Error(),
   184  			)
   185  		} else {
   186  			if mw.UpdateLogger {
   187  				logCtx := log.Ctx{}
   188  
   189  				key := "sub"
   190  				if identity.IsDevice {
   191  					key = "device_id"
   192  				} else if identity.IsUser {
   193  					key = "user_id"
   194  				}
   195  
   196  				logCtx[key] = identity.Subject
   197  
   198  				if identity.Tenant != "" {
   199  					logCtx["tenant_id"] = identity.Tenant
   200  				}
   201  
   202  				if identity.Plan != "" {
   203  					logCtx["plan"] = identity.Plan
   204  				}
   205  
   206  				l = l.F(logCtx)
   207  				ctx = log.WithContext(ctx, l)
   208  			}
   209  			ctx = WithContext(ctx, &identity)
   210  			r.Request = r.WithContext(ctx)
   211  		}
   212  
   213  		h(w, r)
   214  	}
   215  }