gitlab.com/gitlab-org/labkit@v1.21.0/tracing/env_extractor_test.go (about)

     1  package tracing
     2  
     3  import (
     4  	"context"
     5  	"os"
     6  	"reflect"
     7  	"testing"
     8  
     9  	"github.com/stretchr/testify/require"
    10  	"gitlab.com/gitlab-org/labkit/correlation"
    11  )
    12  
    13  func TestExtractFromEnv(t *testing.T) {
    14  	tests := []struct {
    15  		name              string
    16  		ctx               context.Context
    17  		opts              []ExtractFromEnvOption
    18  		additionalEnv     map[string]string
    19  		wantCorrelationID string
    20  	}{
    21  		{
    22  			name: "no_options",
    23  			ctx:  context.Background(),
    24  			opts: []ExtractFromEnvOption{},
    25  		},
    26  		{
    27  			name: "pass_correlation_id",
    28  			ctx:  context.Background(),
    29  			opts: []ExtractFromEnvOption{},
    30  			additionalEnv: map[string]string{
    31  				envCorrelationIDKey: "abc123",
    32  			},
    33  			wantCorrelationID: "abc123",
    34  		},
    35  	}
    36  	for _, tt := range tests {
    37  		t.Run(tt.name, func(t *testing.T) {
    38  			resetEnvironment := addAdditionalEnv(tt.additionalEnv)
    39  			defer resetEnvironment()
    40  
    41  			ctx, finished := ExtractFromEnv(tt.ctx, tt.opts...)
    42  			require.NotNil(t, ctx, "ctx is nil")
    43  			require.NotNil(t, finished, "finished is nil")
    44  			gotCorrelationID := correlation.ExtractFromContext(ctx)
    45  			require.Equal(t, tt.wantCorrelationID, gotCorrelationID)
    46  			defer finished()
    47  		})
    48  	}
    49  }
    50  
    51  func Test_environAsMap(t *testing.T) {
    52  	tests := []struct {
    53  		name string
    54  		env  []string
    55  		want map[string]string
    56  	}{
    57  		{
    58  			name: "trivial",
    59  			env:  nil,
    60  			want: map[string]string{},
    61  		},
    62  	}
    63  	for _, tt := range tests {
    64  		t.Run(tt.name, func(t *testing.T) {
    65  			if got := environAsMap(tt.env); !reflect.DeepEqual(got, tt.want) {
    66  				t.Errorf("environAsMap() = %v, want %v", got, tt.want)
    67  			}
    68  		})
    69  	}
    70  }
    71  
    72  // addAdditionalEnv will configure additional environment values
    73  // and return a deferrable function to reset the environment to
    74  // it's original state after the test.
    75  func addAdditionalEnv(envMap map[string]string) func() {
    76  	prevValues := map[string]string{}
    77  	unsetValues := []string{}
    78  	for k, v := range envMap {
    79  		value, exists := os.LookupEnv(k)
    80  		if exists {
    81  			prevValues[k] = value
    82  		} else {
    83  			unsetValues = append(unsetValues, k)
    84  		}
    85  		os.Setenv(k, v)
    86  	}
    87  
    88  	return func() {
    89  		for k, v := range prevValues {
    90  			os.Setenv(k, v)
    91  		}
    92  
    93  		for _, k := range unsetValues {
    94  			os.Unsetenv(k)
    95  		}
    96  	}
    97  }