github.com/hellofresh/janus@v0.0.0-20230925145208-ce8de8183c67/pkg/proxy/definition_test.go (about)

     1  package proxy
     2  
     3  import (
     4  	"encoding/json"
     5  	"testing"
     6  	"time"
     7  
     8  	"github.com/hellofresh/janus/pkg/middleware"
     9  	"github.com/stretchr/testify/assert"
    10  	"github.com/stretchr/testify/require"
    11  )
    12  
    13  func TestDefinition(t *testing.T) {
    14  	t.Parallel()
    15  
    16  	tests := []struct {
    17  		scenario string
    18  		function func(*testing.T)
    19  	}{
    20  		{
    21  			scenario: "new definitions",
    22  			function: testNewDefinitions,
    23  		},
    24  		{
    25  			scenario: "successful validation",
    26  			function: testSuccessfulValidation,
    27  		},
    28  		{
    29  			scenario: "empty listen path validation",
    30  			function: testEmptyListenPathValidation,
    31  		},
    32  		{
    33  			scenario: "invalid target url validation",
    34  			function: testInvalidTargetURLValidation,
    35  		},
    36  		{
    37  			scenario: "is balancer defined",
    38  			function: testIsBalancerDefined,
    39  		},
    40  		{
    41  			scenario: "add middleware",
    42  			function: testAddMiddlewares,
    43  		},
    44  		{
    45  			scenario: "marshal forwarding_timeouts to json",
    46  			function: testMarshalForwardingTimeoutsToJSON,
    47  		},
    48  		{
    49  			scenario: "unmarshal forwarding_timeouts from json",
    50  			function: testUnmarshalForwardingTimeoutsFromJSON,
    51  		},
    52  	}
    53  
    54  	for _, test := range tests {
    55  		t.Run(test.scenario, func(t *testing.T) {
    56  			test.function(t)
    57  		})
    58  	}
    59  }
    60  
    61  func testNewDefinitions(t *testing.T) {
    62  	definition := NewDefinition()
    63  
    64  	assert.Equal(t, []string{"GET"}, definition.Methods)
    65  	assert.NotNil(t, definition)
    66  }
    67  
    68  func testSuccessfulValidation(t *testing.T) {
    69  	definition := Definition{
    70  		ListenPath: "/*",
    71  		Upstreams: &Upstreams{
    72  			Balancing: "roundrobin",
    73  			Targets: Targets{
    74  				{Target: "http://test.com"},
    75  			},
    76  		},
    77  	}
    78  	isValid, err := definition.Validate()
    79  
    80  	assert.NoError(t, err)
    81  	assert.True(t, isValid)
    82  }
    83  
    84  func testEmptyListenPathValidation(t *testing.T) {
    85  	definition := Definition{}
    86  	isValid, err := definition.Validate()
    87  
    88  	assert.Error(t, err)
    89  	assert.False(t, isValid)
    90  }
    91  
    92  func testInvalidTargetURLValidation(t *testing.T) {
    93  	definition := Definition{
    94  		ListenPath: " ",
    95  		Upstreams: &Upstreams{
    96  			Balancing: "roundrobin",
    97  			Targets: Targets{
    98  				{Target: "wrong"},
    99  			},
   100  		},
   101  	}
   102  	isValid, err := definition.Validate()
   103  
   104  	assert.Error(t, err)
   105  	assert.False(t, isValid)
   106  }
   107  
   108  func testIsBalancerDefined(t *testing.T) {
   109  	definition := NewDefinition()
   110  	assert.False(t, definition.IsBalancerDefined())
   111  
   112  	target := &Target{Target: "http://localhost:8080/api-name"}
   113  	definition.Upstreams.Targets = append(definition.Upstreams.Targets, target)
   114  	assert.True(t, definition.IsBalancerDefined())
   115  	assert.Len(t, definition.Upstreams.Targets.ToBalancerTargets(), 1)
   116  }
   117  
   118  func testAddMiddlewares(t *testing.T) {
   119  	routerDefinition := NewRouterDefinition(NewDefinition())
   120  	routerDefinition.AddMiddleware(middleware.NewLogger().Handler)
   121  
   122  	assert.Len(t, routerDefinition.Middleware(), 1)
   123  }
   124  
   125  func testMarshalForwardingTimeoutsToJSON(t *testing.T) {
   126  	definition := Definition{
   127  		ListenPath: "/*",
   128  		Upstreams: &Upstreams{
   129  			Balancing: "roundrobin",
   130  			Targets: Targets{
   131  				{Target: "http://test.com"},
   132  			},
   133  		},
   134  		ForwardingTimeouts: ForwardingTimeouts{
   135  			DialTimeout:           Duration(30 * time.Second),
   136  			ResponseHeaderTimeout: Duration(31 * time.Second),
   137  		},
   138  	}
   139  	jsonDefinition, err := json.Marshal(&definition)
   140  	require.NoError(t, err)
   141  	assert.Contains(t, string(jsonDefinition), `"dial_timeout":"30s"`)
   142  	assert.Contains(t, string(jsonDefinition), `"response_header_timeout":"31s"`)
   143  }
   144  
   145  func testUnmarshalForwardingTimeoutsFromJSON(t *testing.T) {
   146  	rawDefinition := []byte(`
   147    {
   148      "preserve_host":false,
   149      "listen_path":"/example/*",
   150      "upstreams":{
   151        "balancing":"roundrobin",
   152        "targets":[
   153          {
   154            "target":"http://localhost:9089/hello-world"
   155          }
   156        ]
   157      },
   158      "strip_path":false,
   159      "append_path":false,
   160      "methods":[
   161        "GET"
   162      ],
   163      "forwarding_timeouts": {
   164        "dial_timeout": "30s",
   165        "response_header_timeout": "31s"
   166      }
   167    }
   168  `)
   169  	definition := NewDefinition()
   170  	err := json.Unmarshal(rawDefinition, &definition)
   171  	require.NoError(t, err)
   172  
   173  	assert.Equal(t, 30*time.Second, time.Duration(definition.ForwardingTimeouts.DialTimeout))
   174  	assert.Equal(t, 31*time.Second, time.Duration(definition.ForwardingTimeouts.ResponseHeaderTimeout))
   175  }