github.com/storacha/go-ucanto@v0.7.2/transport/http/channel_test.go (about)

     1  package http
     2  
     3  import (
     4  	"context"
     5  	"net/http"
     6  	"net/http/httptest"
     7  	"net/url"
     8  	"testing"
     9  
    10  	"go.opentelemetry.io/otel"
    11  	"go.opentelemetry.io/otel/propagation"
    12  	"go.opentelemetry.io/otel/trace"
    13  )
    14  
    15  func TestChannelPropagatesTraceContext(t *testing.T) {
    16  	const (
    17  		requestTraceIDHex  = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
    18  		requestSpanIDHex   = "bbbbbbbbbbbbbbbb"
    19  		responseTraceIDHex = "cccccccccccccccccccccccccccccccc"
    20  		responseSpanIDHex  = "dddddddddddddddd"
    21  		responseTrace      = "00-" + responseTraceIDHex + "-" + responseSpanIDHex + "-01"
    22  		expectedRequest    = "00-" + requestTraceIDHex + "-" + requestSpanIDHex + "-01"
    23  	)
    24  
    25  	var seenRequestTrace string
    26  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    27  		seenRequestTrace = r.Header.Get("traceparent")
    28  		w.Header().Set("traceparent", responseTrace)
    29  		w.WriteHeader(http.StatusOK)
    30  	}))
    31  	t.Cleanup(server.Close)
    32  
    33  	endpoint, err := url.Parse(server.URL)
    34  	if err != nil {
    35  		t.Fatalf("parsing server URL: %v", err)
    36  	}
    37  
    38  	channel := NewChannel(endpoint, WithClient(server.Client()))
    39  
    40  	restoreProp := setTraceContextPropagator()
    41  	t.Cleanup(restoreProp)
    42  
    43  	ctx := context.Background()
    44  	ctx = trace.ContextWithSpanContext(ctx, newSpanContext(t, requestTraceIDHex, requestSpanIDHex))
    45  
    46  	res, err := channel.Request(ctx, NewRequest(http.NoBody, nil))
    47  	if err != nil {
    48  		t.Fatalf("request failed: %v", err)
    49  	}
    50  	t.Cleanup(func() { res.Body().Close() })
    51  
    52  	if seenRequestTrace != expectedRequest {
    53  		t.Fatalf("expected traceparent %q, got %q", expectedRequest, seenRequestTrace)
    54  	}
    55  
    56  	responseCtx, ok := res.(*Response)
    57  	if !ok {
    58  		t.Fatalf("expected *Response, got %T", res)
    59  	}
    60  	sc := trace.SpanContextFromContext(responseCtx.Context())
    61  	expectedTraceID := mustTraceIDFromHex(t, responseTraceIDHex)
    62  	if sc.TraceID() != expectedTraceID {
    63  		t.Fatalf("expected response trace ID %s, got %s", expectedTraceID, sc.TraceID())
    64  	}
    65  	expectedSpanID := mustSpanIDFromHex(t, responseSpanIDHex)
    66  	if sc.SpanID() != expectedSpanID {
    67  		t.Fatalf("expected response span ID %s, got %s", expectedSpanID, sc.SpanID())
    68  	}
    69  }
    70  
    71  func newSpanContext(t *testing.T, traceIDHex, spanIDHex string) trace.SpanContext {
    72  	t.Helper()
    73  	traceID := mustTraceIDFromHex(t, traceIDHex)
    74  	spanID := mustSpanIDFromHex(t, spanIDHex)
    75  	return trace.NewSpanContext(trace.SpanContextConfig{
    76  		TraceID:    traceID,
    77  		SpanID:     spanID,
    78  		TraceFlags: trace.FlagsSampled,
    79  	})
    80  }
    81  
    82  func mustTraceIDFromHex(t *testing.T, hex string) trace.TraceID {
    83  	t.Helper()
    84  	traceID, err := trace.TraceIDFromHex(hex)
    85  	if err != nil {
    86  		t.Fatalf("parsing trace ID: %v", err)
    87  	}
    88  	return traceID
    89  }
    90  
    91  func mustSpanIDFromHex(t *testing.T, hex string) trace.SpanID {
    92  	t.Helper()
    93  	spanID, err := trace.SpanIDFromHex(hex)
    94  	if err != nil {
    95  		t.Fatalf("parsing span ID: %v", err)
    96  	}
    97  	return spanID
    98  }
    99  
   100  func setTraceContextPropagator() func() {
   101  	prev := otel.GetTextMapPropagator()
   102  	otel.SetTextMapPropagator(propagation.TraceContext{})
   103  	return func() {
   104  		otel.SetTextMapPropagator(prev)
   105  	}
   106  }