github.com/yaling888/clash@v1.53.0/transport/trojan/trojan_test.go (about) 1 package trojan 2 3 import ( 4 "bytes" 5 "crypto/rand" 6 "io" 7 "net" 8 "net/netip" 9 "reflect" 10 "testing" 11 ) 12 13 type fakeConn struct { 14 net.Conn 15 rw *bytes.Buffer 16 } 17 18 func (f *fakeConn) Read(b []byte) (n int, err error) { 19 return f.rw.Read(b) 20 } 21 22 func (f *fakeConn) Write(b []byte) (n int, err error) { 23 return f.rw.Write(b) 24 } 25 26 func TestPacketConn_ReadFrom(t *testing.T) { 27 srcS := make([]byte, 64*1025) 28 srcL := make([]byte, 7*1024) 29 _, _ = rand.Read(srcS) 30 _, _ = rand.Read(srcL) 31 32 addr := &net.UDPAddr{ 33 IP: net.ParseIP("127.0.0.1"), 34 Port: 443, 35 } 36 37 type fields struct { 38 Conn net.Conn 39 } 40 type args struct { 41 src []byte 42 buf []byte 43 addr net.Addr 44 } 45 tests := []struct { 46 name string 47 fields fields 48 args args 49 wantN int 50 wantAddr netip.AddrPort 51 wantErr bool 52 }{ 53 { 54 name: "smallBuffer", 55 fields: fields{ 56 Conn: &fakeConn{ 57 rw: &bytes.Buffer{}, 58 }, 59 }, 60 args: args{ 61 src: srcS, 62 buf: make([]byte, 1024), 63 addr: addr, 64 }, 65 wantN: len(srcS), 66 wantAddr: addr.AddrPort(), 67 wantErr: false, 68 }, 69 { 70 name: "largeBuffer", 71 fields: fields{ 72 Conn: &fakeConn{ 73 rw: &bytes.Buffer{}, 74 }, 75 }, 76 args: args{ 77 src: srcL, 78 buf: make([]byte, 32*1024), 79 addr: addr, 80 }, 81 wantN: len(srcL), 82 wantAddr: addr.AddrPort(), 83 wantErr: false, 84 }, 85 } 86 for _, tt := range tests { 87 t.Run(tt.name, func(t *testing.T) { 88 pc := &PacketConn{ 89 Conn: tt.fields.Conn, 90 } 91 gotN, err := pc.WriteTo(tt.args.src, tt.args.addr) 92 if (err != nil) != tt.wantErr { 93 t.Errorf("WriteTo() error = %v, wantErr %v", err, tt.wantErr) 94 return 95 } 96 if gotN != tt.wantN { 97 t.Errorf("WriteTo() gotN = %v, want %v", gotN, tt.wantN) 98 } 99 100 buf := tt.args.buf 101 dst := make([]byte, 0, 64*1024) 102 for { 103 n, gotAddr, err1 := pc.ReadFrom(buf) 104 if err1 != nil { 105 if err1 == io.EOF { 106 break 107 } else if !tt.wantErr { 108 t.Errorf("ReadFrom() error = %v, wantErr %v", err1, tt.wantErr) 109 return 110 } 111 } 112 if !reflect.DeepEqual(gotAddr.(*net.UDPAddr).AddrPort(), tt.wantAddr) { 113 t.Errorf("ReadFrom() gotAddr = %v, want %v", gotAddr, tt.wantAddr) 114 } 115 dst = append(dst, buf[:n]...) 116 } 117 118 if len(dst) != tt.wantN { 119 t.Errorf("ReadFrom() read data doesn't match write data, gotN = %v, want %v", len(dst), tt.wantN) 120 return 121 } 122 123 if !reflect.DeepEqual(dst, tt.args.src) { 124 t.Errorf("ReadFrom() read data doesn't match write data") 125 } 126 }) 127 } 128 }