go.mondoo.com/cnquery@v0.0.0-20231005093811-59568235f6ea/logger/context.go (about)

     1  // Copyright (c) Mondoo, Inc.
     2  // SPDX-License-Identifier: BUSL-1.1
     3  
     4  package logger
     5  
     6  import (
     7  	"context"
     8  
     9  	"github.com/google/uuid"
    10  	"github.com/rs/zerolog"
    11  	"github.com/rs/zerolog/log"
    12  )
    13  
    14  const RequestIDFieldKey = "req-id"
    15  
    16  type tagsCtxKey struct{}
    17  
    18  type tags struct {
    19  	m map[string]string
    20  }
    21  
    22  func AddTag(ctx context.Context, tagName string, value string) {
    23  	if v, ok := ctx.Value(tagsCtxKey{}).(*tags); ok {
    24  		v.m[tagName] = value
    25  	}
    26  }
    27  
    28  func GetTags(ctx context.Context) map[string]string {
    29  	if v, ok := ctx.Value(tagsCtxKey{}).(*tags); ok {
    30  		return v.m
    31  	}
    32  	return nil
    33  }
    34  
    35  func WithTagsContext(ctx context.Context) context.Context {
    36  	return context.WithValue(ctx, tagsCtxKey{}, &tags{m: map[string]string{}})
    37  }
    38  
    39  // RequestScopedContext returns a context that contains a logger which logs the request ID
    40  // Given a context, a logger can be retrieved as follows
    41  //
    42  //	ctx := RequestScopedContext(context.Background(), "req-id")
    43  //	log := FromContext(ctx)
    44  //	log.Debug().Msg("hello")
    45  func RequestScopedContext(ctx context.Context, reqID string) context.Context {
    46  	if reqID == "" {
    47  		// The leading underscore indicates the request id was generated on the
    48  		// server instead of the client. This could be temporary and be useful
    49  		// for debugging which client calls are not passing the request id
    50  		reqID = "_" + uuid.New().String()
    51  	}
    52  	l := log.With().Str(RequestIDFieldKey, reqID).Logger()
    53  	return WithTagsContext(l.WithContext(ctx))
    54  }
    55  
    56  // FromContext returns the logger in the context if present, otherwise the it
    57  // returns the default logger
    58  func FromContext(ctx context.Context) *zerolog.Logger {
    59  	l := log.Ctx(ctx)
    60  	if l.GetLevel() == zerolog.Disabled {
    61  		// If a context logger was not set, we'll return a global
    62  		// logger instead of the default noop logger
    63  		l := log.With().Str(RequestIDFieldKey, "global").Logger()
    64  		return &l
    65  	}
    66  	tags := GetTags(ctx)
    67  	if len(tags) > 0 {
    68  		dict := zerolog.Dict()
    69  		for k, v := range tags {
    70  			dict.Str(k, v)
    71  		}
    72  		lv := l.With().Dict("ctags", dict).Logger()
    73  		return &lv
    74  	}
    75  	return l
    76  }