github.com/storacha/go-ucanto@v0.7.2/transport/http/channel_test.go (about) 1 package http 2 3 import ( 4 "context" 5 "net/http" 6 "net/http/httptest" 7 "net/url" 8 "testing" 9 10 "go.opentelemetry.io/otel" 11 "go.opentelemetry.io/otel/propagation" 12 "go.opentelemetry.io/otel/trace" 13 ) 14 15 func TestChannelPropagatesTraceContext(t *testing.T) { 16 const ( 17 requestTraceIDHex = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" 18 requestSpanIDHex = "bbbbbbbbbbbbbbbb" 19 responseTraceIDHex = "cccccccccccccccccccccccccccccccc" 20 responseSpanIDHex = "dddddddddddddddd" 21 responseTrace = "00-" + responseTraceIDHex + "-" + responseSpanIDHex + "-01" 22 expectedRequest = "00-" + requestTraceIDHex + "-" + requestSpanIDHex + "-01" 23 ) 24 25 var seenRequestTrace string 26 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 27 seenRequestTrace = r.Header.Get("traceparent") 28 w.Header().Set("traceparent", responseTrace) 29 w.WriteHeader(http.StatusOK) 30 })) 31 t.Cleanup(server.Close) 32 33 endpoint, err := url.Parse(server.URL) 34 if err != nil { 35 t.Fatalf("parsing server URL: %v", err) 36 } 37 38 channel := NewChannel(endpoint, WithClient(server.Client())) 39 40 restoreProp := setTraceContextPropagator() 41 t.Cleanup(restoreProp) 42 43 ctx := context.Background() 44 ctx = trace.ContextWithSpanContext(ctx, newSpanContext(t, requestTraceIDHex, requestSpanIDHex)) 45 46 res, err := channel.Request(ctx, NewRequest(http.NoBody, nil)) 47 if err != nil { 48 t.Fatalf("request failed: %v", err) 49 } 50 t.Cleanup(func() { res.Body().Close() }) 51 52 if seenRequestTrace != expectedRequest { 53 t.Fatalf("expected traceparent %q, got %q", expectedRequest, seenRequestTrace) 54 } 55 56 responseCtx, ok := res.(*Response) 57 if !ok { 58 t.Fatalf("expected *Response, got %T", res) 59 } 60 sc := trace.SpanContextFromContext(responseCtx.Context()) 61 expectedTraceID := mustTraceIDFromHex(t, responseTraceIDHex) 62 if sc.TraceID() != expectedTraceID { 63 t.Fatalf("expected response trace ID %s, got %s", expectedTraceID, sc.TraceID()) 64 } 65 expectedSpanID := mustSpanIDFromHex(t, responseSpanIDHex) 66 if sc.SpanID() != expectedSpanID { 67 t.Fatalf("expected response span ID %s, got %s", expectedSpanID, sc.SpanID()) 68 } 69 } 70 71 func newSpanContext(t *testing.T, traceIDHex, spanIDHex string) trace.SpanContext { 72 t.Helper() 73 traceID := mustTraceIDFromHex(t, traceIDHex) 74 spanID := mustSpanIDFromHex(t, spanIDHex) 75 return trace.NewSpanContext(trace.SpanContextConfig{ 76 TraceID: traceID, 77 SpanID: spanID, 78 TraceFlags: trace.FlagsSampled, 79 }) 80 } 81 82 func mustTraceIDFromHex(t *testing.T, hex string) trace.TraceID { 83 t.Helper() 84 traceID, err := trace.TraceIDFromHex(hex) 85 if err != nil { 86 t.Fatalf("parsing trace ID: %v", err) 87 } 88 return traceID 89 } 90 91 func mustSpanIDFromHex(t *testing.T, hex string) trace.SpanID { 92 t.Helper() 93 spanID, err := trace.SpanIDFromHex(hex) 94 if err != nil { 95 t.Fatalf("parsing span ID: %v", err) 96 } 97 return spanID 98 } 99 100 func setTraceContextPropagator() func() { 101 prev := otel.GetTextMapPropagator() 102 otel.SetTextMapPropagator(propagation.TraceContext{}) 103 return func() { 104 otel.SetTextMapPropagator(prev) 105 } 106 }