github.com/newrelic/go-agent@v3.26.0+incompatible/internal/distributed_tracing.go (about)

     1  // Copyright 2020 New Relic Corporation. All rights reserved.
     2  // SPDX-License-Identifier: Apache-2.0
     3  
     4  package internal
     5  
     6  import (
     7  	"encoding/base64"
     8  	"encoding/json"
     9  	"fmt"
    10  	"time"
    11  )
    12  
    13  type distTraceVersion [2]int
    14  
    15  func (v distTraceVersion) major() int { return v[0] }
    16  func (v distTraceVersion) minor() int { return v[1] }
    17  
    18  const (
    19  	// CallerType is the Type field's value for outbound payloads.
    20  	CallerType = "App"
    21  )
    22  
    23  var (
    24  	currentDistTraceVersion = distTraceVersion([2]int{0 /* Major */, 1 /* Minor */})
    25  	callerUnknown           = payloadCaller{Type: "Unknown", App: "Unknown", Account: "Unknown", TransportType: "Unknown"}
    26  )
    27  
    28  // timestampMillis allows raw payloads to use exact times, and marshalled
    29  // payloads to use times in millis.
    30  type timestampMillis time.Time
    31  
    32  func (tm *timestampMillis) UnmarshalJSON(data []byte) error {
    33  	var millis uint64
    34  	if err := json.Unmarshal(data, &millis); nil != err {
    35  		return err
    36  	}
    37  	*tm = timestampMillis(timeFromUnixMilliseconds(millis))
    38  	return nil
    39  }
    40  
    41  func (tm timestampMillis) MarshalJSON() ([]byte, error) {
    42  	return json.Marshal(TimeToUnixMilliseconds(tm.Time()))
    43  }
    44  
    45  func (tm timestampMillis) Time() time.Time  { return time.Time(tm) }
    46  func (tm *timestampMillis) Set(t time.Time) { *tm = timestampMillis(t) }
    47  
    48  // Payload is the distributed tracing payload.
    49  type Payload struct {
    50  	payloadCaller
    51  	TransactionID     string          `json:"tx,omitempty"`
    52  	ID                string          `json:"id,omitempty"`
    53  	TracedID          string          `json:"tr"`
    54  	Priority          Priority        `json:"pr"`
    55  	Sampled           *bool           `json:"sa"`
    56  	Timestamp         timestampMillis `json:"ti"`
    57  	TransportDuration time.Duration   `json:"-"`
    58  }
    59  
    60  type payloadCaller struct {
    61  	TransportType     string `json:"-"`
    62  	Type              string `json:"ty"`
    63  	App               string `json:"ap"`
    64  	Account           string `json:"ac"`
    65  	TrustedAccountKey string `json:"tk,omitempty"`
    66  }
    67  
    68  // IsValid validates the payload data by looking for missing fields.
    69  // Returns an error if there's a problem, nil if everything's fine
    70  func (p Payload) IsValid() error {
    71  
    72  	// If a payload is missing both `guid` and `transactionId` is received,
    73  	// a ParseException supportability metric should be generated.
    74  	if "" == p.TransactionID && "" == p.ID {
    75  		return ErrPayloadMissingField{message: "missing both guid/id and TransactionId/tx"}
    76  	}
    77  
    78  	if "" == p.Type {
    79  		return ErrPayloadMissingField{message: "missing Type/ty"}
    80  	}
    81  
    82  	if "" == p.Account {
    83  		return ErrPayloadMissingField{message: "missing Account/ac"}
    84  	}
    85  
    86  	if "" == p.App {
    87  		return ErrPayloadMissingField{message: "missing App/ap"}
    88  	}
    89  
    90  	if "" == p.TracedID {
    91  		return ErrPayloadMissingField{message: "missing TracedID/tr"}
    92  	}
    93  
    94  	if p.Timestamp.Time().IsZero() || 0 == p.Timestamp.Time().Unix() {
    95  		return ErrPayloadMissingField{message: "missing Timestamp/ti"}
    96  	}
    97  
    98  	return nil
    99  }
   100  
   101  func (p Payload) text(v distTraceVersion) []byte {
   102  	js, _ := json.Marshal(struct {
   103  		Version distTraceVersion `json:"v"`
   104  		Data    Payload          `json:"d"`
   105  	}{
   106  		Version: v,
   107  		Data:    p,
   108  	})
   109  	return js
   110  }
   111  
   112  // Text implements newrelic.DistributedTracePayload.
   113  func (p Payload) Text() string {
   114  	t := p.text(currentDistTraceVersion)
   115  	return string(t)
   116  }
   117  
   118  // HTTPSafe implements newrelic.DistributedTracePayload.
   119  func (p Payload) HTTPSafe() string {
   120  	t := p.text(currentDistTraceVersion)
   121  	return base64.StdEncoding.EncodeToString(t)
   122  }
   123  
   124  // SetSampled lets us set a value for our *bool,
   125  // which we can't do directly since a pointer
   126  // needs something to point at.
   127  func (p *Payload) SetSampled(sampled bool) {
   128  	p.Sampled = &sampled
   129  }
   130  
   131  // ErrPayloadParse indicates that the payload was malformed.
   132  type ErrPayloadParse struct{ err error }
   133  
   134  func (e ErrPayloadParse) Error() string {
   135  	return fmt.Sprintf("unable to parse inbound payload: %s", e.err.Error())
   136  }
   137  
   138  // ErrPayloadMissingField indicates there's a required field that's missing
   139  type ErrPayloadMissingField struct{ message string }
   140  
   141  func (e ErrPayloadMissingField) Error() string {
   142  	return fmt.Sprintf("payload is missing required fields: %s", e.message)
   143  }
   144  
   145  // ErrUnsupportedPayloadVersion indicates that the major version number is
   146  // unknown.
   147  type ErrUnsupportedPayloadVersion struct{ version int }
   148  
   149  func (e ErrUnsupportedPayloadVersion) Error() string {
   150  	return fmt.Sprintf("unsupported major version number %d", e.version)
   151  }
   152  
   153  // AcceptPayload parses the inbound distributed tracing payload.
   154  func AcceptPayload(p interface{}) (*Payload, error) {
   155  	var payload Payload
   156  	if byteSlice, ok := p.([]byte); ok {
   157  		p = string(byteSlice)
   158  	}
   159  	switch v := p.(type) {
   160  	case string:
   161  		if "" == v {
   162  			return nil, nil
   163  		}
   164  		var decoded []byte
   165  		if '{' == v[0] {
   166  			decoded = []byte(v)
   167  		} else {
   168  			var err error
   169  			decoded, err = base64.StdEncoding.DecodeString(v)
   170  			if nil != err {
   171  				return nil, ErrPayloadParse{err: err}
   172  			}
   173  		}
   174  		envelope := struct {
   175  			Version distTraceVersion `json:"v"`
   176  			Data    json.RawMessage  `json:"d"`
   177  		}{}
   178  		if err := json.Unmarshal(decoded, &envelope); nil != err {
   179  			return nil, ErrPayloadParse{err: err}
   180  		}
   181  
   182  		if 0 == envelope.Version.major() && 0 == envelope.Version.minor() {
   183  			return nil, ErrPayloadMissingField{message: "missing v"}
   184  		}
   185  
   186  		if envelope.Version.major() > currentDistTraceVersion.major() {
   187  			return nil, ErrUnsupportedPayloadVersion{
   188  				version: envelope.Version.major(),
   189  			}
   190  		}
   191  		if err := json.Unmarshal(envelope.Data, &payload); nil != err {
   192  			return nil, ErrPayloadParse{err: err}
   193  		}
   194  	case Payload:
   195  		payload = v
   196  	default:
   197  		// Could be a shim payload (if the app is not yet connected).
   198  		return nil, nil
   199  	}
   200  	// Ensure that we don't have a reference to the input payload: we don't
   201  	// want to change it, it could be used multiple times.
   202  	alloc := new(Payload)
   203  	*alloc = payload
   204  
   205  	return alloc, nil
   206  }