github.com/tickoalcantara12/micro/v3@v3.0.0-20221007104245-9d75b9bcbab9/util/wrapper/wrapper.go (about)

     1  package wrapper
     2  
     3  import (
     4  	"context"
     5  	"encoding/base64"
     6  	"reflect"
     7  	"strings"
     8  	"time"
     9  
    10  	"github.com/tickoalcantara12/micro/v3/service/auth"
    11  	"github.com/tickoalcantara12/micro/v3/service/client"
    12  	"github.com/tickoalcantara12/micro/v3/service/context/metadata"
    13  	"github.com/tickoalcantara12/micro/v3/service/debug"
    14  	"github.com/tickoalcantara12/micro/v3/service/debug/trace"
    15  	"github.com/tickoalcantara12/micro/v3/service/errors"
    16  	"github.com/tickoalcantara12/micro/v3/service/logger"
    17  	"github.com/tickoalcantara12/micro/v3/service/metrics"
    18  	"github.com/tickoalcantara12/micro/v3/service/server"
    19  	inauth "github.com/tickoalcantara12/micro/v3/util/auth"
    20  	"github.com/tickoalcantara12/micro/v3/util/cache"
    21  )
    22  
    23  type authWrapper struct {
    24  	client.Client
    25  }
    26  
    27  func (a *authWrapper) Call(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error {
    28  	ctx = a.wrapContext(ctx, opts...)
    29  	return a.Client.Call(ctx, req, rsp, opts...)
    30  }
    31  
    32  func (a *authWrapper) Stream(ctx context.Context, req client.Request, opts ...client.CallOption) (client.Stream, error) {
    33  	ctx = a.wrapContext(ctx, opts...)
    34  	return a.Client.Stream(ctx, req, opts...)
    35  }
    36  
    37  func (a *authWrapper) wrapContext(ctx context.Context, opts ...client.CallOption) context.Context {
    38  	// parse the options
    39  	var options client.CallOptions
    40  	for _, o := range opts {
    41  		o(&options)
    42  	}
    43  
    44  	// set the namespace header if it has not been set (e.g. on a service to service request)
    45  	authOpts := auth.DefaultAuth.Options()
    46  	if _, ok := metadata.Get(ctx, "Micro-Namespace"); !ok {
    47  		ctx = metadata.Set(ctx, "Micro-Namespace", authOpts.Issuer)
    48  	}
    49  
    50  	// We dont't override the header unless the AuthToken option has been specified
    51  	if !options.AuthToken {
    52  		return ctx
    53  	}
    54  
    55  	// check to see if we have a valid access token
    56  	if authOpts.Token != nil && !authOpts.Token.Expired() {
    57  		ctx = metadata.Set(ctx, "Authorization", inauth.BearerScheme+authOpts.Token.AccessToken)
    58  		return ctx
    59  	}
    60  
    61  	// call without an auth token
    62  	return ctx
    63  }
    64  
    65  // AuthClient wraps requests with the auth header
    66  func AuthClient(c client.Client) client.Client {
    67  	return &authWrapper{c}
    68  }
    69  
    70  // AuthHandler wraps a server handler to perform auth
    71  func AuthHandler() server.HandlerWrapper {
    72  	return func(h server.HandlerFunc) server.HandlerFunc {
    73  		return func(ctx context.Context, req server.Request, rsp interface{}) error {
    74  			// Extract the token if the header is present. We will inspect the token regardless of if it's
    75  			// present or not since noop auth will return a blank account upon Inspecting a blank token.
    76  			var token string
    77  			if header, ok := metadata.Get(ctx, "Authorization"); ok {
    78  				// Ensure the correct scheme is being used
    79  				switch {
    80  				case strings.HasPrefix(header, inauth.BearerScheme):
    81  					// Strip the bearer scheme prefix
    82  					token = strings.TrimPrefix(header, inauth.BearerScheme)
    83  				case strings.HasPrefix(header, "Basic "):
    84  					b, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(header, "Basic "))
    85  					if err != nil {
    86  						return errors.Unauthorized(req.Service(), "invalid authorization header. Incorrect format")
    87  					}
    88  					parts := strings.SplitN(string(b), ":", 2)
    89  					if len(parts) != 2 {
    90  						return errors.Unauthorized(req.Service(), "invalid authorization header. Incorrect format")
    91  					}
    92  
    93  					token = parts[1]
    94  				default:
    95  					return errors.Unauthorized(req.Service(), "invalid authorization header. Expected Bearer or Basic schema")
    96  				}
    97  			}
    98  
    99  			// Determine the namespace
   100  			ns := auth.DefaultAuth.Options().Issuer
   101  
   102  			var acc *auth.Account
   103  			if a, err := auth.Inspect(token); err == nil {
   104  				ctx = auth.ContextWithAccount(ctx, a)
   105  				acc = a
   106  			}
   107  
   108  			// construct the resource
   109  			res := &auth.Resource{
   110  				Type:     "service",
   111  				Name:     req.Service(),
   112  				Endpoint: req.Endpoint(),
   113  			}
   114  
   115  			// Verify the caller has access to the resource.
   116  			err := auth.Verify(acc, res, auth.VerifyNamespace(ns))
   117  			if err == auth.ErrForbidden && acc != nil {
   118  				return errors.Forbidden(req.Service(), "Forbidden call made to %v:%v by %v", req.Service(), req.Endpoint(), acc.ID)
   119  			} else if err == auth.ErrForbidden {
   120  				return errors.Unauthorized(req.Service(), "Unauthorized call made to %v:%v", req.Service(), req.Endpoint())
   121  			} else if err != nil {
   122  				return errors.InternalServerError(req.Service(), "Error authorizing request: %v", err)
   123  			}
   124  
   125  			// The user is authorised, allow the call
   126  			return h(ctx, req, rsp)
   127  		}
   128  	}
   129  }
   130  
   131  type logWrapper struct {
   132  	client.Client
   133  }
   134  
   135  func (l *logWrapper) Call(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error {
   136  	logger.Debugf("Calling service %s endpoint %s", req.Service(), req.Endpoint())
   137  	return l.Client.Call(ctx, req, rsp, opts...)
   138  }
   139  
   140  func (l *logWrapper) Stream(ctx context.Context, req client.Request, opts ...client.CallOption) (client.Stream, error) {
   141  	logger.Debugf("Streaming service %s endpoint %s", req.Service(), req.Endpoint())
   142  	return l.Client.Stream(ctx, req, opts...)
   143  }
   144  
   145  func LogClient(c client.Client) client.Client {
   146  	return &logWrapper{c}
   147  }
   148  
   149  func LogHandler() server.HandlerWrapper {
   150  	// return a handler wrapper
   151  	return func(h server.HandlerFunc) server.HandlerFunc {
   152  		// return a function that returns a function
   153  		return func(ctx context.Context, req server.Request, rsp interface{}) error {
   154  			logger.Debugf("Serving request for service %s endpoint %s", req.Service(), req.Endpoint())
   155  			return h(ctx, req, rsp)
   156  		}
   157  	}
   158  }
   159  
   160  // HandlerStats wraps a server handler to generate request/error stats
   161  func HandlerStats() server.HandlerWrapper {
   162  	// return a handler wrapper
   163  	return func(h server.HandlerFunc) server.HandlerFunc {
   164  		// return a function that returns a function
   165  		return func(ctx context.Context, req server.Request, rsp interface{}) error {
   166  			// execute the handler
   167  			err := h(ctx, req, rsp)
   168  			// record the stats
   169  			debug.DefaultStats.Record(err)
   170  			// return the error
   171  			return err
   172  		}
   173  	}
   174  }
   175  
   176  type traceWrapper struct {
   177  	client.Client
   178  }
   179  
   180  func (c *traceWrapper) Call(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error {
   181  	newCtx, s := debug.DefaultTracer.Start(ctx, req.Service()+"."+req.Endpoint())
   182  
   183  	s.Type = trace.SpanTypeRequestOutbound
   184  	err := c.Client.Call(newCtx, req, rsp, opts...)
   185  	if err != nil {
   186  		s.Metadata["error"] = err.Error()
   187  	}
   188  
   189  	// finish the trace
   190  	debug.DefaultTracer.Finish(s)
   191  
   192  	return err
   193  }
   194  
   195  // TraceCall is a call tracing wrapper
   196  func TraceCall(c client.Client) client.Client {
   197  	return &traceWrapper{
   198  		Client: c,
   199  	}
   200  }
   201  
   202  // TraceHandler wraps a server handler to perform tracing
   203  func TraceHandler() server.HandlerWrapper {
   204  	// return a handler wrapper
   205  	return func(h server.HandlerFunc) server.HandlerFunc {
   206  		// return a function that returns a function
   207  		return func(ctx context.Context, req server.Request, rsp interface{}) error {
   208  			// don't store traces for debug
   209  			if strings.HasPrefix(req.Endpoint(), "Debug.") {
   210  				return h(ctx, req, rsp)
   211  			}
   212  
   213  			// get the span
   214  			newCtx, s := debug.DefaultTracer.Start(ctx, req.Service()+"."+req.Endpoint())
   215  			s.Type = trace.SpanTypeRequestInbound
   216  
   217  			err := h(newCtx, req, rsp)
   218  			if err != nil {
   219  				s.Metadata["error"] = err.Error()
   220  			}
   221  
   222  			// finish
   223  			debug.DefaultTracer.Finish(s)
   224  
   225  			return err
   226  		}
   227  	}
   228  }
   229  
   230  type cacheWrapper struct {
   231  	Cache *cache.Cache
   232  	client.Client
   233  }
   234  
   235  // Call executes the request. If the CacheExpiry option was set, the response will be cached using
   236  // a hash of the metadata and request as the key.
   237  func (c *cacheWrapper) Call(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error {
   238  	// parse the options
   239  	var options client.CallOptions
   240  	for _, o := range opts {
   241  		o(&options)
   242  	}
   243  
   244  	// if the client doesn't have a cacbe setup don't continue
   245  	if c.Cache == nil {
   246  		return c.Client.Call(ctx, req, rsp, opts...)
   247  	}
   248  
   249  	cacheOpts, ok := cache.GetOptions(options.Context)
   250  	if !ok {
   251  		return c.Client.Call(ctx, req, rsp, opts...)
   252  	}
   253  
   254  	// if the cache expiry is not set, execute the call without the cache
   255  	if cacheOpts.Expiry == 0 || rsp == nil {
   256  		return c.Client.Call(ctx, req, rsp, opts...)
   257  	}
   258  
   259  	// check to see if there is a response cached, if there is assign it
   260  	if r, ok := c.Cache.Get(ctx, req); ok {
   261  		val := reflect.ValueOf(rsp).Elem()
   262  		val.Set(reflect.ValueOf(r).Elem())
   263  		return nil
   264  	}
   265  
   266  	// don't cache the result if there was an error
   267  	if err := c.Client.Call(ctx, req, rsp, opts...); err != nil {
   268  		return err
   269  	}
   270  
   271  	// set the result in the cache
   272  	c.Cache.Set(ctx, req, rsp, cacheOpts.Expiry)
   273  	return nil
   274  }
   275  
   276  // CacheClient wraps requests with the cache wrapper
   277  func CacheClient(c client.Client) client.Client {
   278  	return &cacheWrapper{
   279  		Cache:  cache.New(),
   280  		Client: c,
   281  	}
   282  }
   283  
   284  // MetricsHandler wraps a server handler to instrument calls
   285  func MetricsHandler() server.HandlerWrapper {
   286  	// return a handler wrapper
   287  	return func(h server.HandlerFunc) server.HandlerFunc {
   288  		// return a function that returns a function
   289  		return func(ctx context.Context, req server.Request, rsp interface{}) error {
   290  
   291  			// Don't instrument debug calls:
   292  			if strings.HasPrefix(req.Endpoint(), "Debug.") {
   293  				return h(ctx, req, rsp)
   294  			}
   295  
   296  			// Build some tags to describe the call:
   297  			tags := metrics.Tags{
   298  				"method": req.Method(),
   299  			}
   300  
   301  			// Start the clock:
   302  			callTime := time.Now()
   303  
   304  			// Run the handlerFunction:
   305  			err := h(ctx, req, rsp)
   306  
   307  			// Add a result tag:
   308  			if err != nil {
   309  				tags["result"] = "failure"
   310  			} else {
   311  				tags["result"] = "success"
   312  			}
   313  
   314  			// Instrument the result (if the DefaultClient has been configured):
   315  			metrics.Timing("service.handler", time.Since(callTime), tags)
   316  
   317  			return err
   318  		}
   319  	}
   320  }