github.com/bazelbuild/remote-apis-sdks@v0.0.0-20240425170053-8a36686a6350/go/pkg/contextmd/contextmd.go (about)

     1  // Package contextmd allows attaching metadata to the context of RPC calls.
     2  package contextmd
     3  
     4  import (
     5  	"context"
     6  	"fmt"
     7  	"sort"
     8  	"strings"
     9  
    10  	log "github.com/golang/glog"
    11  	"github.com/google/uuid"
    12  	"google.golang.org/grpc/metadata"
    13  	"google.golang.org/protobuf/proto"
    14  
    15  	repb "github.com/bazelbuild/remote-apis/build/bazel/remote/execution/v2"
    16  )
    17  
    18  const (
    19  	// The headers key of our RequestMetadata.
    20  	remoteHeadersKey     = "build.bazel.remote.execution.v2.requestmetadata-bin"
    21  	defaultMaxHeaderSize = 8 * 1024
    22  )
    23  
    24  // Metadata is optionally attached to RPC requests.
    25  type Metadata struct {
    26  	// ActionID is an optional id to use to identify an action.
    27  	ActionID string
    28  	// InvocationID is an optional id to use to identify an invocation spanning multiple commands.
    29  	InvocationID string
    30  	// CorrelatedInvocationID is an optional id to use to identify a build spanning multiple invocations.
    31  	CorrelatedInvocationID string
    32  	// ToolName is an optional tool name to pass to the remote server for logging.
    33  	ToolName string
    34  	// ToolVersion is an optional tool version to pass to the remote server for logging.
    35  	ToolVersion string
    36  }
    37  
    38  // Infof is equivalent to log.V(x).Infof(...) except it
    39  // also logs context metadata, if available.
    40  func Infof(ctx context.Context, v log.Level, format string, args ...any) {
    41  	if log.V(v) {
    42  		if m, err := ExtractMetadata(ctx); err == nil && m.ActionID != "" {
    43  			format = "%s: " + format
    44  			args = append([]any{m.ActionID}, args...)
    45  		}
    46  		log.InfoDepth(1, fmt.Sprintf(format, args...))
    47  	}
    48  }
    49  
    50  // ExtractMetadata parses the metadata from the given context, if it exists.
    51  // If metadata does not exist, empty values are returned.
    52  func ExtractMetadata(ctx context.Context) (m *Metadata, err error) {
    53  	md, ok := metadata.FromOutgoingContext(ctx)
    54  	if !ok {
    55  		return &Metadata{}, nil
    56  	}
    57  	vs := md.Get(remoteHeadersKey)
    58  	if len(vs) == 0 {
    59  		return &Metadata{}, nil
    60  	}
    61  	buf := []byte(vs[0])
    62  	meta := &repb.RequestMetadata{}
    63  	if err := proto.Unmarshal(buf, meta); err != nil {
    64  		return nil, err
    65  	}
    66  	return &Metadata{
    67  		ToolName:               meta.ToolDetails.GetToolName(),
    68  		ToolVersion:            meta.ToolDetails.GetToolVersion(),
    69  		ActionID:               meta.ActionId,
    70  		InvocationID:           meta.ToolInvocationId,
    71  		CorrelatedInvocationID: meta.CorrelatedInvocationsId,
    72  	}, nil
    73  }
    74  
    75  // WithMetadata attaches metadata to the passed-in context, returning a new
    76  // context. This function should be called in every test method after a context is created. It uses
    77  // the already created context to generate a new one containing the metadata header.
    78  func WithMetadata(ctx context.Context, ms ...*Metadata) (context.Context, error) {
    79  	m := MergeMetadata(ms...)
    80  	actionID := m.ActionID
    81  	if actionID == "" {
    82  		if id, err := uuid.NewRandom(); err == nil {
    83  			actionID = id.String()
    84  			log.V(2).Infof("Generated action_id %s for %s", actionID, m.ToolName)
    85  		} else {
    86  			log.Warningf("Failed to generate action_id: %s", err)
    87  		}
    88  	}
    89  	invocationID := m.InvocationID
    90  	if invocationID == "" {
    91  		if id, err := uuid.NewRandom(); err == nil {
    92  			invocationID = id.String()
    93  			log.V(2).Infof("Generated invocation_id %s for %s %s", invocationID, m.ToolName, actionID)
    94  		} else {
    95  			log.Warningf("Failed to generate invocation_id: %s", err)
    96  		}
    97  	}
    98  
    99  	meta := &repb.RequestMetadata{
   100  		ActionId:         actionID,
   101  		ToolInvocationId: invocationID,
   102  		ToolDetails: &repb.ToolDetails{
   103  			ToolName:    m.ToolName,
   104  			ToolVersion: m.ToolVersion,
   105  		},
   106  	}
   107  
   108  	// Marshal the proto to a binary buffer
   109  	buf, err := proto.Marshal(meta)
   110  	if err != nil {
   111  		return nil, err
   112  	}
   113  
   114  	// metadata package converts the binary buffer to a base64 string, so no need to encode before
   115  	// sending.
   116  	mdPair := metadata.Pairs(remoteHeadersKey, string(buf))
   117  	return metadata.NewOutgoingContext(ctx, mdPair), nil
   118  }
   119  
   120  // MergeMetadata returns a new instance that has the tool name, tool version and correlated action id from
   121  // the first argument, and joins a sorted and unique set of action IDs and invocation IDs from all arguments.
   122  // Nil is never returned.
   123  func MergeMetadata(metas ...*Metadata) *Metadata {
   124  	if len(metas) == 0 {
   125  		return &Metadata{}
   126  	}
   127  
   128  	md := metas[0]
   129  	actionIds := make(map[string]struct{}, len(metas))
   130  	invocationIds := make(map[string]struct{}, len(metas))
   131  	for _, m := range metas {
   132  		actionIds[m.ActionID] = struct{}{}
   133  		invocationIds[m.InvocationID] = struct{}{}
   134  	}
   135  	md.ActionID = mergeSet(actionIds)
   136  	md.InvocationID = mergeSet(invocationIds)
   137  	return md
   138  }
   139  
   140  // FromContexts returns a context derived from the first one with metadata merged from all of ctxs.
   141  //
   142  // If len(ctxs) == 0, ctx is returned as is.
   143  // Returns the first error or nil.
   144  func FromContexts(ctx context.Context, ctxs ...context.Context) (context.Context, error) {
   145  	if len(ctxs) == 0 {
   146  		return ctx, nil
   147  	}
   148  
   149  	metas := make([]*Metadata, len(ctxs)+1)
   150  	md, err := ExtractMetadata(ctx)
   151  	if err != nil {
   152  		return ctx, err
   153  	}
   154  	metas[0] = md
   155  	for i, c := range ctxs {
   156  		md, err := ExtractMetadata(c)
   157  		if err != nil {
   158  			return ctx, err
   159  		}
   160  		metas[i+1] = md
   161  	}
   162  
   163  	// We cap to a bit less than the maximum header size in order to allow
   164  	// for some proto fields serialization overhead.
   165  	m := capToLimit(MergeMetadata(metas...), defaultMaxHeaderSize-100)
   166  	return WithMetadata(ctx, m)
   167  }
   168  
   169  func mergeSet(set map[string]struct{}) string {
   170  	vals := make([]string, 0, len(set))
   171  	for v := range set {
   172  		vals = append(vals, v)
   173  	}
   174  	sort.Strings(vals)
   175  	return strings.Join(vals, ",")
   176  }
   177  
   178  // capToLimit ensures total length does not exceed max header size.
   179  func capToLimit(m *Metadata, limit int) *Metadata {
   180  	total := len(m.ToolName) + len(m.ToolVersion) + len(m.ActionID) + len(m.InvocationID) + len(m.CorrelatedInvocationID)
   181  	excess := total - limit
   182  	if excess <= 0 {
   183  		return m
   184  	}
   185  	// We ignore the tool name, because in practice this is a
   186  	// very short constant which makes no sense to truncate.
   187  	diff := len(m.ActionID) - len(m.InvocationID)
   188  	if diff > 0 {
   189  		if diff > excess {
   190  			m.ActionID = m.ActionID[:len(m.ActionID)-excess]
   191  		} else {
   192  			m.ActionID = m.ActionID[:len(m.ActionID)-diff]
   193  			rem := (excess - diff + 1) / 2
   194  			m.ActionID = m.ActionID[:len(m.ActionID)-rem]
   195  			m.InvocationID = m.InvocationID[:len(m.InvocationID)-rem]
   196  		}
   197  	} else {
   198  		diff = -diff
   199  		if diff > excess {
   200  			m.InvocationID = m.InvocationID[:len(m.InvocationID)-excess]
   201  		} else {
   202  			m.InvocationID = m.InvocationID[:len(m.InvocationID)-diff]
   203  			rem := (excess - diff + 1) / 2
   204  			m.InvocationID = m.InvocationID[:len(m.InvocationID)-rem]
   205  			m.ActionID = m.ActionID[:len(m.ActionID)-rem]
   206  		}
   207  	}
   208  	return m
   209  }