github.com/yaling888/clash@v1.53.0/transport/trojan/trojan_test.go (about)

     1  package trojan
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/rand"
     6  	"io"
     7  	"net"
     8  	"net/netip"
     9  	"reflect"
    10  	"testing"
    11  )
    12  
    13  type fakeConn struct {
    14  	net.Conn
    15  	rw *bytes.Buffer
    16  }
    17  
    18  func (f *fakeConn) Read(b []byte) (n int, err error) {
    19  	return f.rw.Read(b)
    20  }
    21  
    22  func (f *fakeConn) Write(b []byte) (n int, err error) {
    23  	return f.rw.Write(b)
    24  }
    25  
    26  func TestPacketConn_ReadFrom(t *testing.T) {
    27  	srcS := make([]byte, 64*1025)
    28  	srcL := make([]byte, 7*1024)
    29  	_, _ = rand.Read(srcS)
    30  	_, _ = rand.Read(srcL)
    31  
    32  	addr := &net.UDPAddr{
    33  		IP:   net.ParseIP("127.0.0.1"),
    34  		Port: 443,
    35  	}
    36  
    37  	type fields struct {
    38  		Conn net.Conn
    39  	}
    40  	type args struct {
    41  		src  []byte
    42  		buf  []byte
    43  		addr net.Addr
    44  	}
    45  	tests := []struct {
    46  		name     string
    47  		fields   fields
    48  		args     args
    49  		wantN    int
    50  		wantAddr netip.AddrPort
    51  		wantErr  bool
    52  	}{
    53  		{
    54  			name: "smallBuffer",
    55  			fields: fields{
    56  				Conn: &fakeConn{
    57  					rw: &bytes.Buffer{},
    58  				},
    59  			},
    60  			args: args{
    61  				src:  srcS,
    62  				buf:  make([]byte, 1024),
    63  				addr: addr,
    64  			},
    65  			wantN:    len(srcS),
    66  			wantAddr: addr.AddrPort(),
    67  			wantErr:  false,
    68  		},
    69  		{
    70  			name: "largeBuffer",
    71  			fields: fields{
    72  				Conn: &fakeConn{
    73  					rw: &bytes.Buffer{},
    74  				},
    75  			},
    76  			args: args{
    77  				src:  srcL,
    78  				buf:  make([]byte, 32*1024),
    79  				addr: addr,
    80  			},
    81  			wantN:    len(srcL),
    82  			wantAddr: addr.AddrPort(),
    83  			wantErr:  false,
    84  		},
    85  	}
    86  	for _, tt := range tests {
    87  		t.Run(tt.name, func(t *testing.T) {
    88  			pc := &PacketConn{
    89  				Conn: tt.fields.Conn,
    90  			}
    91  			gotN, err := pc.WriteTo(tt.args.src, tt.args.addr)
    92  			if (err != nil) != tt.wantErr {
    93  				t.Errorf("WriteTo() error = %v, wantErr %v", err, tt.wantErr)
    94  				return
    95  			}
    96  			if gotN != tt.wantN {
    97  				t.Errorf("WriteTo() gotN = %v, want %v", gotN, tt.wantN)
    98  			}
    99  
   100  			buf := tt.args.buf
   101  			dst := make([]byte, 0, 64*1024)
   102  			for {
   103  				n, gotAddr, err1 := pc.ReadFrom(buf)
   104  				if err1 != nil {
   105  					if err1 == io.EOF {
   106  						break
   107  					} else if !tt.wantErr {
   108  						t.Errorf("ReadFrom() error = %v, wantErr %v", err1, tt.wantErr)
   109  						return
   110  					}
   111  				}
   112  				if !reflect.DeepEqual(gotAddr.(*net.UDPAddr).AddrPort(), tt.wantAddr) {
   113  					t.Errorf("ReadFrom() gotAddr = %v, want %v", gotAddr, tt.wantAddr)
   114  				}
   115  				dst = append(dst, buf[:n]...)
   116  			}
   117  
   118  			if len(dst) != tt.wantN {
   119  				t.Errorf("ReadFrom() read data doesn't match write data, gotN = %v, want %v", len(dst), tt.wantN)
   120  				return
   121  			}
   122  
   123  			if !reflect.DeepEqual(dst, tt.args.src) {
   124  				t.Errorf("ReadFrom() read data doesn't match write data")
   125  			}
   126  		})
   127  	}
   128  }