go.mondoo.com/cnquery@v0.0.0-20231005093811-59568235f6ea/logger/context.go (about) 1 // Copyright (c) Mondoo, Inc. 2 // SPDX-License-Identifier: BUSL-1.1 3 4 package logger 5 6 import ( 7 "context" 8 9 "github.com/google/uuid" 10 "github.com/rs/zerolog" 11 "github.com/rs/zerolog/log" 12 ) 13 14 const RequestIDFieldKey = "req-id" 15 16 type tagsCtxKey struct{} 17 18 type tags struct { 19 m map[string]string 20 } 21 22 func AddTag(ctx context.Context, tagName string, value string) { 23 if v, ok := ctx.Value(tagsCtxKey{}).(*tags); ok { 24 v.m[tagName] = value 25 } 26 } 27 28 func GetTags(ctx context.Context) map[string]string { 29 if v, ok := ctx.Value(tagsCtxKey{}).(*tags); ok { 30 return v.m 31 } 32 return nil 33 } 34 35 func WithTagsContext(ctx context.Context) context.Context { 36 return context.WithValue(ctx, tagsCtxKey{}, &tags{m: map[string]string{}}) 37 } 38 39 // RequestScopedContext returns a context that contains a logger which logs the request ID 40 // Given a context, a logger can be retrieved as follows 41 // 42 // ctx := RequestScopedContext(context.Background(), "req-id") 43 // log := FromContext(ctx) 44 // log.Debug().Msg("hello") 45 func RequestScopedContext(ctx context.Context, reqID string) context.Context { 46 if reqID == "" { 47 // The leading underscore indicates the request id was generated on the 48 // server instead of the client. This could be temporary and be useful 49 // for debugging which client calls are not passing the request id 50 reqID = "_" + uuid.New().String() 51 } 52 l := log.With().Str(RequestIDFieldKey, reqID).Logger() 53 return WithTagsContext(l.WithContext(ctx)) 54 } 55 56 // FromContext returns the logger in the context if present, otherwise the it 57 // returns the default logger 58 func FromContext(ctx context.Context) *zerolog.Logger { 59 l := log.Ctx(ctx) 60 if l.GetLevel() == zerolog.Disabled { 61 // If a context logger was not set, we'll return a global 62 // logger instead of the default noop logger 63 l := log.With().Str(RequestIDFieldKey, "global").Logger() 64 return &l 65 } 66 tags := GetTags(ctx) 67 if len(tags) > 0 { 68 dict := zerolog.Dict() 69 for k, v := range tags { 70 dict.Str(k, v) 71 } 72 lv := l.With().Dict("ctags", dict).Logger() 73 return &lv 74 } 75 return l 76 }