go-micro.dev/v5@v5.12.0/util/wrapper/wrapper.go (about)

     1  package wrapper
     2  
     3  import (
     4  	"context"
     5  	"strings"
     6  
     7  	"go-micro.dev/v5/auth"
     8  	"go-micro.dev/v5/client"
     9  	"go-micro.dev/v5/debug/stats"
    10  	"go-micro.dev/v5/debug/trace"
    11  	"go-micro.dev/v5/metadata"
    12  	"go-micro.dev/v5/server"
    13  	"go-micro.dev/v5/transport/headers"
    14  )
    15  
    16  type fromServiceWrapper struct {
    17  	client.Client
    18  
    19  	// headers to inject
    20  	headers metadata.Metadata
    21  }
    22  
    23  func (f *fromServiceWrapper) setHeaders(ctx context.Context) context.Context {
    24  	// don't overwrite keys
    25  	return metadata.MergeContext(ctx, f.headers, false)
    26  }
    27  
    28  func (f *fromServiceWrapper) Call(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error {
    29  	ctx = f.setHeaders(ctx)
    30  	return f.Client.Call(ctx, req, rsp, opts...)
    31  }
    32  
    33  func (f *fromServiceWrapper) Stream(ctx context.Context, req client.Request, opts ...client.CallOption) (client.Stream, error) {
    34  	ctx = f.setHeaders(ctx)
    35  	return f.Client.Stream(ctx, req, opts...)
    36  }
    37  
    38  func (f *fromServiceWrapper) Publish(ctx context.Context, p client.Message, opts ...client.PublishOption) error {
    39  	ctx = f.setHeaders(ctx)
    40  	return f.Client.Publish(ctx, p, opts...)
    41  }
    42  
    43  // FromService wraps a client to inject service and auth metadata.
    44  func FromService(name string, c client.Client) client.Client {
    45  	return &fromServiceWrapper{
    46  		c,
    47  		metadata.Metadata{
    48  			headers.Prefix + "From-Service": name,
    49  		},
    50  	}
    51  }
    52  
    53  // HandlerStats wraps a server handler to generate request/error stats.
    54  func HandlerStats(stats stats.Stats) server.HandlerWrapper {
    55  	// return a handler wrapper
    56  	return func(h server.HandlerFunc) server.HandlerFunc {
    57  		// return a function that returns a function
    58  		return func(ctx context.Context, req server.Request, rsp interface{}) error {
    59  			// execute the handler
    60  			err := h(ctx, req, rsp)
    61  			// record the stats
    62  			stats.Record(err)
    63  			// return the error
    64  			return err
    65  		}
    66  	}
    67  }
    68  
    69  type traceWrapper struct {
    70  	client.Client
    71  
    72  	trace trace.Tracer
    73  
    74  	name string
    75  }
    76  
    77  func (c *traceWrapper) Call(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error {
    78  	newCtx, s := c.trace.Start(ctx, req.Service()+"."+req.Endpoint())
    79  
    80  	s.Type = trace.SpanTypeRequestOutbound
    81  	err := c.Client.Call(newCtx, req, rsp, opts...)
    82  	if err != nil {
    83  		s.Metadata["error"] = err.Error()
    84  	}
    85  
    86  	// finish the trace
    87  	c.trace.Finish(s)
    88  
    89  	return err
    90  }
    91  
    92  // TraceCall is a call tracing wrapper.
    93  func TraceCall(name string, t trace.Tracer, c client.Client) client.Client {
    94  	return &traceWrapper{
    95  		name:   name,
    96  		trace:  t,
    97  		Client: c,
    98  	}
    99  }
   100  
   101  // TraceHandler wraps a server handler to perform tracing.
   102  func TraceHandler(t trace.Tracer) server.HandlerWrapper {
   103  	// return a handler wrapper
   104  	return func(h server.HandlerFunc) server.HandlerFunc {
   105  		// return a function that returns a function
   106  		return func(ctx context.Context, req server.Request, rsp interface{}) error {
   107  			// don't store traces for debug
   108  			if strings.HasPrefix(req.Endpoint(), "Debug.") {
   109  				return h(ctx, req, rsp)
   110  			}
   111  
   112  			// get the span
   113  			newCtx, s := t.Start(ctx, req.Service()+"."+req.Endpoint())
   114  			s.Type = trace.SpanTypeRequestInbound
   115  
   116  			err := h(newCtx, req, rsp)
   117  			if err != nil {
   118  				s.Metadata["error"] = err.Error()
   119  			}
   120  
   121  			// finish
   122  			t.Finish(s)
   123  
   124  			return err
   125  		}
   126  	}
   127  }
   128  
   129  func AuthCall(a func() auth.Auth, c client.Client) client.Client {
   130  	return &authWrapper{Client: c, auth: a}
   131  }
   132  
   133  type authWrapper struct {
   134  	client.Client
   135  	auth func() auth.Auth
   136  }
   137  
   138  func (a *authWrapper) Call(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error {
   139  	// parse the options
   140  	var options client.CallOptions
   141  	for _, o := range opts {
   142  		o(&options)
   143  	}
   144  
   145  	// check to see if the authorization header has already been set.
   146  	// We dont't override the header unless the ServiceToken option has
   147  	// been specified or the header wasn't provided
   148  	if _, ok := metadata.Get(ctx, "Authorization"); ok && !options.ServiceToken {
   149  		return a.Client.Call(ctx, req, rsp, opts...)
   150  	}
   151  
   152  	// if auth is nil we won't be able to get an access token, so we execute
   153  	// the request without one.
   154  	aa := a.auth()
   155  	if aa == nil {
   156  		return a.Client.Call(ctx, req, rsp, opts...)
   157  	}
   158  
   159  	// set the namespace header if it has not been set (e.g. on a service to service request)
   160  	if _, ok := metadata.Get(ctx, headers.Namespace); !ok {
   161  		ctx = metadata.Set(ctx, headers.Namespace, aa.Options().Namespace)
   162  	}
   163  
   164  	// check to see if we have a valid access token
   165  	aaOpts := aa.Options()
   166  	if aaOpts.Token != nil && !aaOpts.Token.Expired() {
   167  		ctx = metadata.Set(ctx, "Authorization", auth.BearerScheme+aaOpts.Token.AccessToken)
   168  		return a.Client.Call(ctx, req, rsp, opts...)
   169  	}
   170  
   171  	// call without an auth token
   172  	return a.Client.Call(ctx, req, rsp, opts...)
   173  }