github.com/ipfans/trojan-go@v0.11.0/test/util/util.go (about) 1 package util 2 3 import ( 4 "bytes" 5 "crypto/rand" 6 "fmt" 7 "net" 8 "sync" 9 10 "github.com/ipfans/trojan-go/common" 11 ) 12 13 // CheckConn checks if two netConn were connected and work properly 14 func CheckConn(a net.Conn, b net.Conn) bool { 15 payload1 := make([]byte, 1024) 16 payload2 := make([]byte, 1024) 17 18 result1 := make([]byte, 1024) 19 result2 := make([]byte, 1024) 20 21 rand.Reader.Read(payload1) 22 rand.Reader.Read(payload2) 23 24 wg := sync.WaitGroup{} 25 wg.Add(2) 26 27 go func() { 28 a.Write(payload1) 29 a.Read(result2) 30 wg.Done() 31 }() 32 33 go func() { 34 b.Read(result1) 35 b.Write(payload2) 36 wg.Done() 37 }() 38 39 wg.Wait() 40 41 return bytes.Equal(payload1, result1) && bytes.Equal(payload2, result2) 42 } 43 44 // CheckPacketOverConn checks if two PacketConn streaming over a connection work properly 45 func CheckPacketOverConn(a, b net.PacketConn) bool { 46 port := common.PickPort("tcp", "127.0.0.1") 47 addr := &net.UDPAddr{ 48 IP: net.ParseIP("127.0.0.1"), 49 Port: port, 50 } 51 52 payload1 := make([]byte, 1024) 53 payload2 := make([]byte, 1024) 54 55 result1 := make([]byte, 1024) 56 result2 := make([]byte, 1024) 57 58 rand.Reader.Read(payload1) 59 rand.Reader.Read(payload2) 60 61 common.Must2(a.WriteTo(payload1, addr)) 62 _, addr1, err := b.ReadFrom(result1) 63 common.Must(err) 64 if addr1.String() != addr.String() { 65 return false 66 } 67 68 common.Must2(a.WriteTo(payload2, addr)) 69 _, addr2, err := b.ReadFrom(result2) 70 common.Must(err) 71 if addr2.String() != addr.String() { 72 return false 73 } 74 75 return bytes.Equal(payload1, result1) && bytes.Equal(payload2, result2) 76 } 77 78 func CheckPacket(a, b net.PacketConn) bool { 79 payload1 := make([]byte, 1024) 80 payload2 := make([]byte, 1024) 81 82 result1 := make([]byte, 1024) 83 result2 := make([]byte, 1024) 84 85 rand.Reader.Read(payload1) 86 rand.Reader.Read(payload2) 87 88 _, err := a.WriteTo(payload1, b.LocalAddr()) 89 common.Must(err) 90 _, _, err = b.ReadFrom(result1) 91 common.Must(err) 92 93 _, err = b.WriteTo(payload2, a.LocalAddr()) 94 common.Must(err) 95 _, _, err = a.ReadFrom(result2) 96 common.Must(err) 97 98 return bytes.Equal(payload1, result1) && bytes.Equal(payload2, result2) 99 } 100 101 func GetTestAddr() string { 102 port := common.PickPort("tcp", "127.0.0.1") 103 return fmt.Sprintf("127.0.0.1:%d", port) 104 }