github.com/cilium/cilium@v1.16.2/pkg/dial/dialer_test.go (about)

     1  // SPDX-License-Identifier: Apache-2.0
     2  // Copyright Authors of Cilium
     3  
     4  package dial
     5  
     6  import (
     7  	"context"
     8  	"errors"
     9  	"net"
    10  	"testing"
    11  
    12  	"github.com/sirupsen/logrus"
    13  	"github.com/stretchr/testify/assert"
    14  )
    15  
    16  type mockResolver struct{ before, after string }
    17  
    18  func (mr mockResolver) Resolve(_ context.Context, host string) (string, error) {
    19  	if host != mr.before {
    20  		return "", errors.New("unknown translation")
    21  	}
    22  
    23  	return mr.after, nil
    24  }
    25  
    26  func TestNewContextDialer(t *testing.T) {
    27  	tests := []struct {
    28  		hostport  string
    29  		expected  string
    30  		assertErr assert.ErrorAssertionFunc
    31  	}{
    32  		{
    33  			hostport:  "foo.bar",
    34  			assertErr: assert.Error,
    35  		},
    36  		{
    37  			hostport:  "[fd00::9999]:8080",
    38  			expected:  "[fd00::9999]:8080",
    39  			assertErr: assert.NoError,
    40  		},
    41  		{
    42  			hostport:  "foo.bar:9090",
    43  			expected:  "foo.bar:9090",
    44  			assertErr: assert.NoError,
    45  		},
    46  		{
    47  			hostport:  "resolve.foo:8888",
    48  			expected:  "1.2.3.4:8888",
    49  			assertErr: assert.NoError,
    50  		},
    51  		{
    52  			hostport:  "resolve.bar:9999",
    53  			expected:  "[fd00::8888]:9999",
    54  			assertErr: assert.NoError,
    55  		},
    56  		{
    57  			hostport:  "resolve.baz:9898",
    58  			expected:  "qux.fred:9898",
    59  			assertErr: assert.NoError,
    60  		},
    61  	}
    62  
    63  	ctx := context.Background()
    64  	var expected string
    65  
    66  	upstream := func(uctx context.Context, address string) (net.Conn, error) {
    67  		assert.Equal(t, ctx, uctx, "context not propagated correctly")
    68  		assert.Equal(t, expected, address, "address not translated correctly")
    69  		return nil, nil
    70  	}
    71  
    72  	dialer := newContextDialer(
    73  		logrus.New(),
    74  		upstream,
    75  		mockResolver{"resolve.foo", "1.2.3.4"},
    76  		mockResolver{"resolve.bar", "fd00::8888"},
    77  		mockResolver{"resolve.baz", "qux.fred"},
    78  	)
    79  
    80  	for _, tt := range tests {
    81  		expected = tt.expected
    82  		_, err := dialer(ctx, tt.hostport)
    83  		tt.assertErr(t, err, "Got incorrect error for address %q", tt.hostport)
    84  	}
    85  }