github.com/tickoalcantara12/micro/v3@v3.0.0-20221007104245-9d75b9bcbab9/util/wrapper/wrapper.go (about) 1 package wrapper 2 3 import ( 4 "context" 5 "encoding/base64" 6 "reflect" 7 "strings" 8 "time" 9 10 "github.com/tickoalcantara12/micro/v3/service/auth" 11 "github.com/tickoalcantara12/micro/v3/service/client" 12 "github.com/tickoalcantara12/micro/v3/service/context/metadata" 13 "github.com/tickoalcantara12/micro/v3/service/debug" 14 "github.com/tickoalcantara12/micro/v3/service/debug/trace" 15 "github.com/tickoalcantara12/micro/v3/service/errors" 16 "github.com/tickoalcantara12/micro/v3/service/logger" 17 "github.com/tickoalcantara12/micro/v3/service/metrics" 18 "github.com/tickoalcantara12/micro/v3/service/server" 19 inauth "github.com/tickoalcantara12/micro/v3/util/auth" 20 "github.com/tickoalcantara12/micro/v3/util/cache" 21 ) 22 23 type authWrapper struct { 24 client.Client 25 } 26 27 func (a *authWrapper) Call(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error { 28 ctx = a.wrapContext(ctx, opts...) 29 return a.Client.Call(ctx, req, rsp, opts...) 30 } 31 32 func (a *authWrapper) Stream(ctx context.Context, req client.Request, opts ...client.CallOption) (client.Stream, error) { 33 ctx = a.wrapContext(ctx, opts...) 34 return a.Client.Stream(ctx, req, opts...) 35 } 36 37 func (a *authWrapper) wrapContext(ctx context.Context, opts ...client.CallOption) context.Context { 38 // parse the options 39 var options client.CallOptions 40 for _, o := range opts { 41 o(&options) 42 } 43 44 // set the namespace header if it has not been set (e.g. on a service to service request) 45 authOpts := auth.DefaultAuth.Options() 46 if _, ok := metadata.Get(ctx, "Micro-Namespace"); !ok { 47 ctx = metadata.Set(ctx, "Micro-Namespace", authOpts.Issuer) 48 } 49 50 // We dont't override the header unless the AuthToken option has been specified 51 if !options.AuthToken { 52 return ctx 53 } 54 55 // check to see if we have a valid access token 56 if authOpts.Token != nil && !authOpts.Token.Expired() { 57 ctx = metadata.Set(ctx, "Authorization", inauth.BearerScheme+authOpts.Token.AccessToken) 58 return ctx 59 } 60 61 // call without an auth token 62 return ctx 63 } 64 65 // AuthClient wraps requests with the auth header 66 func AuthClient(c client.Client) client.Client { 67 return &authWrapper{c} 68 } 69 70 // AuthHandler wraps a server handler to perform auth 71 func AuthHandler() server.HandlerWrapper { 72 return func(h server.HandlerFunc) server.HandlerFunc { 73 return func(ctx context.Context, req server.Request, rsp interface{}) error { 74 // Extract the token if the header is present. We will inspect the token regardless of if it's 75 // present or not since noop auth will return a blank account upon Inspecting a blank token. 76 var token string 77 if header, ok := metadata.Get(ctx, "Authorization"); ok { 78 // Ensure the correct scheme is being used 79 switch { 80 case strings.HasPrefix(header, inauth.BearerScheme): 81 // Strip the bearer scheme prefix 82 token = strings.TrimPrefix(header, inauth.BearerScheme) 83 case strings.HasPrefix(header, "Basic "): 84 b, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(header, "Basic ")) 85 if err != nil { 86 return errors.Unauthorized(req.Service(), "invalid authorization header. Incorrect format") 87 } 88 parts := strings.SplitN(string(b), ":", 2) 89 if len(parts) != 2 { 90 return errors.Unauthorized(req.Service(), "invalid authorization header. Incorrect format") 91 } 92 93 token = parts[1] 94 default: 95 return errors.Unauthorized(req.Service(), "invalid authorization header. Expected Bearer or Basic schema") 96 } 97 } 98 99 // Determine the namespace 100 ns := auth.DefaultAuth.Options().Issuer 101 102 var acc *auth.Account 103 if a, err := auth.Inspect(token); err == nil { 104 ctx = auth.ContextWithAccount(ctx, a) 105 acc = a 106 } 107 108 // construct the resource 109 res := &auth.Resource{ 110 Type: "service", 111 Name: req.Service(), 112 Endpoint: req.Endpoint(), 113 } 114 115 // Verify the caller has access to the resource. 116 err := auth.Verify(acc, res, auth.VerifyNamespace(ns)) 117 if err == auth.ErrForbidden && acc != nil { 118 return errors.Forbidden(req.Service(), "Forbidden call made to %v:%v by %v", req.Service(), req.Endpoint(), acc.ID) 119 } else if err == auth.ErrForbidden { 120 return errors.Unauthorized(req.Service(), "Unauthorized call made to %v:%v", req.Service(), req.Endpoint()) 121 } else if err != nil { 122 return errors.InternalServerError(req.Service(), "Error authorizing request: %v", err) 123 } 124 125 // The user is authorised, allow the call 126 return h(ctx, req, rsp) 127 } 128 } 129 } 130 131 type logWrapper struct { 132 client.Client 133 } 134 135 func (l *logWrapper) Call(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error { 136 logger.Debugf("Calling service %s endpoint %s", req.Service(), req.Endpoint()) 137 return l.Client.Call(ctx, req, rsp, opts...) 138 } 139 140 func (l *logWrapper) Stream(ctx context.Context, req client.Request, opts ...client.CallOption) (client.Stream, error) { 141 logger.Debugf("Streaming service %s endpoint %s", req.Service(), req.Endpoint()) 142 return l.Client.Stream(ctx, req, opts...) 143 } 144 145 func LogClient(c client.Client) client.Client { 146 return &logWrapper{c} 147 } 148 149 func LogHandler() server.HandlerWrapper { 150 // return a handler wrapper 151 return func(h server.HandlerFunc) server.HandlerFunc { 152 // return a function that returns a function 153 return func(ctx context.Context, req server.Request, rsp interface{}) error { 154 logger.Debugf("Serving request for service %s endpoint %s", req.Service(), req.Endpoint()) 155 return h(ctx, req, rsp) 156 } 157 } 158 } 159 160 // HandlerStats wraps a server handler to generate request/error stats 161 func HandlerStats() server.HandlerWrapper { 162 // return a handler wrapper 163 return func(h server.HandlerFunc) server.HandlerFunc { 164 // return a function that returns a function 165 return func(ctx context.Context, req server.Request, rsp interface{}) error { 166 // execute the handler 167 err := h(ctx, req, rsp) 168 // record the stats 169 debug.DefaultStats.Record(err) 170 // return the error 171 return err 172 } 173 } 174 } 175 176 type traceWrapper struct { 177 client.Client 178 } 179 180 func (c *traceWrapper) Call(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error { 181 newCtx, s := debug.DefaultTracer.Start(ctx, req.Service()+"."+req.Endpoint()) 182 183 s.Type = trace.SpanTypeRequestOutbound 184 err := c.Client.Call(newCtx, req, rsp, opts...) 185 if err != nil { 186 s.Metadata["error"] = err.Error() 187 } 188 189 // finish the trace 190 debug.DefaultTracer.Finish(s) 191 192 return err 193 } 194 195 // TraceCall is a call tracing wrapper 196 func TraceCall(c client.Client) client.Client { 197 return &traceWrapper{ 198 Client: c, 199 } 200 } 201 202 // TraceHandler wraps a server handler to perform tracing 203 func TraceHandler() server.HandlerWrapper { 204 // return a handler wrapper 205 return func(h server.HandlerFunc) server.HandlerFunc { 206 // return a function that returns a function 207 return func(ctx context.Context, req server.Request, rsp interface{}) error { 208 // don't store traces for debug 209 if strings.HasPrefix(req.Endpoint(), "Debug.") { 210 return h(ctx, req, rsp) 211 } 212 213 // get the span 214 newCtx, s := debug.DefaultTracer.Start(ctx, req.Service()+"."+req.Endpoint()) 215 s.Type = trace.SpanTypeRequestInbound 216 217 err := h(newCtx, req, rsp) 218 if err != nil { 219 s.Metadata["error"] = err.Error() 220 } 221 222 // finish 223 debug.DefaultTracer.Finish(s) 224 225 return err 226 } 227 } 228 } 229 230 type cacheWrapper struct { 231 Cache *cache.Cache 232 client.Client 233 } 234 235 // Call executes the request. If the CacheExpiry option was set, the response will be cached using 236 // a hash of the metadata and request as the key. 237 func (c *cacheWrapper) Call(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error { 238 // parse the options 239 var options client.CallOptions 240 for _, o := range opts { 241 o(&options) 242 } 243 244 // if the client doesn't have a cacbe setup don't continue 245 if c.Cache == nil { 246 return c.Client.Call(ctx, req, rsp, opts...) 247 } 248 249 cacheOpts, ok := cache.GetOptions(options.Context) 250 if !ok { 251 return c.Client.Call(ctx, req, rsp, opts...) 252 } 253 254 // if the cache expiry is not set, execute the call without the cache 255 if cacheOpts.Expiry == 0 || rsp == nil { 256 return c.Client.Call(ctx, req, rsp, opts...) 257 } 258 259 // check to see if there is a response cached, if there is assign it 260 if r, ok := c.Cache.Get(ctx, req); ok { 261 val := reflect.ValueOf(rsp).Elem() 262 val.Set(reflect.ValueOf(r).Elem()) 263 return nil 264 } 265 266 // don't cache the result if there was an error 267 if err := c.Client.Call(ctx, req, rsp, opts...); err != nil { 268 return err 269 } 270 271 // set the result in the cache 272 c.Cache.Set(ctx, req, rsp, cacheOpts.Expiry) 273 return nil 274 } 275 276 // CacheClient wraps requests with the cache wrapper 277 func CacheClient(c client.Client) client.Client { 278 return &cacheWrapper{ 279 Cache: cache.New(), 280 Client: c, 281 } 282 } 283 284 // MetricsHandler wraps a server handler to instrument calls 285 func MetricsHandler() server.HandlerWrapper { 286 // return a handler wrapper 287 return func(h server.HandlerFunc) server.HandlerFunc { 288 // return a function that returns a function 289 return func(ctx context.Context, req server.Request, rsp interface{}) error { 290 291 // Don't instrument debug calls: 292 if strings.HasPrefix(req.Endpoint(), "Debug.") { 293 return h(ctx, req, rsp) 294 } 295 296 // Build some tags to describe the call: 297 tags := metrics.Tags{ 298 "method": req.Method(), 299 } 300 301 // Start the clock: 302 callTime := time.Now() 303 304 // Run the handlerFunction: 305 err := h(ctx, req, rsp) 306 307 // Add a result tag: 308 if err != nil { 309 tags["result"] = "failure" 310 } else { 311 tags["result"] = "success" 312 } 313 314 // Instrument the result (if the DefaultClient has been configured): 315 metrics.Timing("service.handler", time.Since(callTime), tags) 316 317 return err 318 } 319 } 320 }