github.com/erda-project/erda-infra@v1.0.9/providers/remote-forward/protocol_test.go (about) 1 // Copyright (c) 2021 Terminus, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package forward 16 17 import ( 18 "bytes" 19 "io" 20 "math" 21 "testing" 22 ) 23 24 func TestRequestHeader(t *testing.T) { 25 tests := []struct { 26 name string 27 h *RequestHeader 28 tail string 29 want *RequestHeader 30 wantErr bool 31 }{ 32 { 33 h: &RequestHeader{ 34 Version: ProtocolVersion, 35 Name: "test-name", 36 Token: "test-token", 37 ShadowAddr: "test-addr", 38 }, 39 tail: "test", 40 want: &RequestHeader{ 41 Version: ProtocolVersion, 42 Name: "test-name", 43 Token: "test-token", 44 ShadowAddr: "test-addr", 45 }, 46 }, 47 { 48 h: &RequestHeader{ 49 Version: math.MaxUint32, 50 Name: "test2-name", 51 Token: "test2-token", 52 ShadowAddr: "test2-addr", 53 }, 54 tail: "test", 55 want: &RequestHeader{ 56 Version: math.MaxUint32, 57 Name: "test2-name", 58 Token: "test2-token", 59 ShadowAddr: "test2-addr", 60 }, 61 }, 62 } 63 for _, tt := range tests { 64 t.Run(tt.name, func(t *testing.T) { 65 w := &bytes.Buffer{} 66 if err := EncodeRequestHeader(w, tt.h); (err != nil) != tt.wantErr { 67 t.Errorf("EncodeRequestHeader() error = %v, wantErr %v", err, tt.wantErr) 68 return 69 } else if (err != nil) && tt.wantErr { 70 return 71 } 72 w.WriteString(tt.tail) 73 74 r := bytes.NewReader(w.Bytes()) 75 h, err := DecodeRequestHeader(r) 76 if (err != nil) != tt.wantErr { 77 t.Errorf("DecodeRequestHeader() error = %v, wantErr %v", err, tt.wantErr) 78 return 79 } else if (err != nil) && tt.wantErr { 80 return 81 } 82 if *h != *tt.h { 83 t.Errorf("EncodeRequestHeader() != DecodeRequestHeader()") 84 return 85 } 86 byts, _ := io.ReadAll(r) 87 if tt.tail != string(byts) { 88 t.Errorf("EncodeRequestHeader() tail bytes != DecodeRequestHeader() tail bytes") 89 return 90 } 91 92 if *h != *tt.want { 93 t.Errorf("DecodeRequestHeader() got %v, want %v", *h, *tt.want) 94 return 95 } 96 }) 97 } 98 }