gitlab.com/gitlab-org/labkit@v1.21.0/tracing/env_extractor_test.go (about) 1 package tracing 2 3 import ( 4 "context" 5 "os" 6 "reflect" 7 "testing" 8 9 "github.com/stretchr/testify/require" 10 "gitlab.com/gitlab-org/labkit/correlation" 11 ) 12 13 func TestExtractFromEnv(t *testing.T) { 14 tests := []struct { 15 name string 16 ctx context.Context 17 opts []ExtractFromEnvOption 18 additionalEnv map[string]string 19 wantCorrelationID string 20 }{ 21 { 22 name: "no_options", 23 ctx: context.Background(), 24 opts: []ExtractFromEnvOption{}, 25 }, 26 { 27 name: "pass_correlation_id", 28 ctx: context.Background(), 29 opts: []ExtractFromEnvOption{}, 30 additionalEnv: map[string]string{ 31 envCorrelationIDKey: "abc123", 32 }, 33 wantCorrelationID: "abc123", 34 }, 35 } 36 for _, tt := range tests { 37 t.Run(tt.name, func(t *testing.T) { 38 resetEnvironment := addAdditionalEnv(tt.additionalEnv) 39 defer resetEnvironment() 40 41 ctx, finished := ExtractFromEnv(tt.ctx, tt.opts...) 42 require.NotNil(t, ctx, "ctx is nil") 43 require.NotNil(t, finished, "finished is nil") 44 gotCorrelationID := correlation.ExtractFromContext(ctx) 45 require.Equal(t, tt.wantCorrelationID, gotCorrelationID) 46 defer finished() 47 }) 48 } 49 } 50 51 func Test_environAsMap(t *testing.T) { 52 tests := []struct { 53 name string 54 env []string 55 want map[string]string 56 }{ 57 { 58 name: "trivial", 59 env: nil, 60 want: map[string]string{}, 61 }, 62 } 63 for _, tt := range tests { 64 t.Run(tt.name, func(t *testing.T) { 65 if got := environAsMap(tt.env); !reflect.DeepEqual(got, tt.want) { 66 t.Errorf("environAsMap() = %v, want %v", got, tt.want) 67 } 68 }) 69 } 70 } 71 72 // addAdditionalEnv will configure additional environment values 73 // and return a deferrable function to reset the environment to 74 // it's original state after the test. 75 func addAdditionalEnv(envMap map[string]string) func() { 76 prevValues := map[string]string{} 77 unsetValues := []string{} 78 for k, v := range envMap { 79 value, exists := os.LookupEnv(k) 80 if exists { 81 prevValues[k] = value 82 } else { 83 unsetValues = append(unsetValues, k) 84 } 85 os.Setenv(k, v) 86 } 87 88 return func() { 89 for k, v := range prevValues { 90 os.Setenv(k, v) 91 } 92 93 for _, k := range unsetValues { 94 os.Unsetenv(k) 95 } 96 } 97 }