github.com/nxtrace/NTrace-core@v1.3.1-0.20240513132635-39169291e8c9/trace/udp.go (about) 1 package trace 2 3 import ( 4 "log" 5 "net" 6 "sync" 7 "time" 8 9 "github.com/google/gopacket" 10 "github.com/google/gopacket/layers" 11 "github.com/nxtrace/NTrace-core/util" 12 "golang.org/x/net/context" 13 "golang.org/x/net/icmp" 14 "golang.org/x/net/ipv4" 15 "golang.org/x/sync/semaphore" 16 ) 17 18 type UDPTracer struct { 19 Config 20 wg sync.WaitGroup 21 res Result 22 ctx context.Context 23 inflightRequest map[int]chan Hop 24 inflightRequestLock sync.Mutex 25 26 icmp net.PacketConn 27 28 final int 29 finalLock sync.Mutex 30 31 sem *semaphore.Weighted 32 fetchLock sync.Mutex 33 } 34 35 func (t *UDPTracer) Execute() (*Result, error) { 36 if len(t.res.Hops) > 0 { 37 return &t.res, ErrTracerouteExecuted 38 } 39 40 var err error 41 t.icmp, err = icmp.ListenPacket("ip4:icmp", t.SrcAddr) 42 if err != nil { 43 return &t.res, err 44 } 45 defer t.icmp.Close() 46 47 var cancel context.CancelFunc 48 t.ctx, cancel = context.WithCancel(context.Background()) 49 defer cancel() 50 t.inflightRequest = make(map[int]chan Hop) 51 t.final = -1 52 53 go t.listenICMP() 54 55 t.sem = semaphore.NewWeighted(int64(t.ParallelRequests)) 56 for ttl := 1; ttl <= t.MaxHops; ttl++ { 57 // 如果到达最终跳,则退出 58 if t.final != -1 && ttl > t.final { 59 break 60 } 61 for i := 0; i < t.NumMeasurements; i++ { 62 t.wg.Add(1) 63 go t.send(ttl) 64 <-time.After(time.Millisecond * time.Duration(t.Config.PacketInterval)) 65 } 66 if t.RealtimePrinter != nil { 67 // 对于实时模式,应该按照TTL进行并发请求 68 t.wg.Wait() 69 t.RealtimePrinter(&t.res, ttl-1) 70 } 71 <-time.After(time.Millisecond * time.Duration(t.Config.TTLInterval)) 72 } 73 go func() { 74 if t.AsyncPrinter != nil { 75 for { 76 t.AsyncPrinter(&t.res) 77 time.Sleep(200 * time.Millisecond) 78 } 79 } 80 }() 81 // 如果是表格模式,则一次性并发请求 82 if t.AsyncPrinter != nil { 83 t.wg.Wait() 84 } 85 t.res.reduce(t.final) 86 87 return &t.res, nil 88 } 89 90 func (t *UDPTracer) listenICMP() { 91 lc := NewPacketListener(t.icmp, t.ctx) 92 go lc.Start() 93 for { 94 select { 95 case <-t.ctx.Done(): 96 return 97 case msg := <-lc.Messages: 98 if msg.N == nil { 99 continue 100 } 101 rm, err := icmp.ParseMessage(1, msg.Msg[:*msg.N]) 102 if err != nil { 103 log.Println(err) 104 continue 105 } 106 switch rm.Type { 107 case ipv4.ICMPTypeTimeExceeded: 108 t.handleICMPMessage(msg, rm.Body.(*icmp.TimeExceeded).Data) 109 case ipv4.ICMPTypeDestinationUnreachable: 110 t.handleICMPMessage(msg, rm.Body.(*icmp.DstUnreach).Data) 111 default: 112 // log.Println("received icmp message of unknown type", rm.Type) 113 } 114 } 115 } 116 117 } 118 119 func (t *UDPTracer) handleICMPMessage(msg ReceivedMessage, data []byte) { 120 header, err := util.GetICMPResponsePayload(data) 121 if err != nil { 122 return 123 } 124 srcPort := util.GetUDPSrcPort(header) 125 t.inflightRequestLock.Lock() 126 defer t.inflightRequestLock.Unlock() 127 ch, ok := t.inflightRequest[int(srcPort)] 128 if !ok { 129 return 130 } 131 ch <- Hop{ 132 Success: true, 133 Address: msg.Peer, 134 } 135 } 136 137 func (t *UDPTracer) getUDPConn(try int) (net.IP, int, net.PacketConn) { 138 srcIP, _ := util.LocalIPPort(t.DestIP) 139 140 var ipString string 141 if srcIP == nil { 142 ipString = "" 143 } else { 144 ipString = srcIP.String() 145 } 146 147 udpConn, err := net.ListenPacket("udp", ipString+":0") 148 if err != nil { 149 if try > 3 { 150 log.Fatal(err) 151 } 152 return t.getUDPConn(try + 1) 153 } 154 return srcIP, udpConn.LocalAddr().(*net.UDPAddr).Port, udpConn 155 } 156 157 func (t *UDPTracer) send(ttl int) error { 158 err := t.sem.Acquire(context.Background(), 1) 159 if err != nil { 160 return err 161 } 162 defer t.sem.Release(1) 163 164 defer t.wg.Done() 165 if t.final != -1 && ttl > t.final { 166 return nil 167 } 168 169 srcIP, srcPort, udpConn := t.getUDPConn(0) 170 171 var payload []byte 172 if t.Quic { 173 payload = GenerateQuicPayloadWithRandomIds() 174 } else { 175 ipHeader := &layers.IPv4{ 176 SrcIP: srcIP, 177 DstIP: t.DestIP, 178 Protocol: layers.IPProtocolTCP, 179 TTL: uint8(ttl), 180 } 181 182 udpHeader := &layers.UDP{ 183 SrcPort: layers.UDPPort(srcPort), 184 DstPort: layers.UDPPort(t.DestPort), 185 } 186 _ = udpHeader.SetNetworkLayerForChecksum(ipHeader) 187 buf := gopacket.NewSerializeBuffer() 188 opts := gopacket.SerializeOptions{ 189 ComputeChecksums: true, 190 FixLengths: true, 191 } 192 193 desiredPayloadSize := t.Config.PktSize 194 payload := make([]byte, desiredPayloadSize) 195 copy(buf.Bytes(), payload) 196 197 if err := gopacket.SerializeLayers(buf, opts, udpHeader); err != nil { 198 return err 199 } 200 201 payload = buf.Bytes() 202 } 203 204 err = ipv4.NewPacketConn(udpConn).SetTTL(ttl) 205 if err != nil { 206 return err 207 } 208 209 start := time.Now() 210 if _, err := udpConn.WriteTo(payload, &net.UDPAddr{IP: t.DestIP, Port: t.DestPort}); err != nil { 211 return err 212 } 213 214 // 在对inflightRequest进行写操作的时候应该加锁保护,以免多个goroutine协程试图同时写入造成panic 215 t.inflightRequestLock.Lock() 216 hopCh := make(chan Hop) 217 t.inflightRequest[srcPort] = hopCh 218 t.inflightRequestLock.Unlock() 219 defer func() { 220 t.inflightRequestLock.Lock() 221 close(hopCh) 222 delete(t.inflightRequest, srcPort) 223 t.inflightRequestLock.Unlock() 224 }() 225 226 go func() { 227 reply := make([]byte, 1500) 228 _, peer, err := udpConn.ReadFrom(reply) 229 if err != nil { 230 // probably because we closed the connection 231 return 232 } 233 hopCh <- Hop{ 234 Success: true, 235 Address: peer, 236 } 237 }() 238 239 select { 240 case <-t.ctx.Done(): 241 return nil 242 case h := <-hopCh: 243 rtt := time.Since(start) 244 if t.final != -1 && ttl > t.final { 245 return nil 246 } 247 248 if addr, ok := h.Address.(*net.IPAddr); ok && addr.IP.Equal(t.DestIP) { 249 t.finalLock.Lock() 250 if t.final == -1 || ttl < t.final { 251 t.final = ttl 252 } 253 t.finalLock.Unlock() 254 } else if addr, ok := h.Address.(*net.UDPAddr); ok && addr.IP.Equal(t.DestIP) { 255 t.finalLock.Lock() 256 if t.final == -1 || ttl < t.final { 257 t.final = ttl 258 } 259 t.finalLock.Unlock() 260 } 261 262 h.TTL = ttl 263 h.RTT = rtt 264 265 t.fetchLock.Lock() 266 defer t.fetchLock.Unlock() 267 err := h.fetchIPData(t.Config) 268 if err != nil { 269 return err 270 } 271 272 t.res.add(h) 273 274 case <-time.After(t.Timeout): 275 if t.final != -1 && ttl > t.final { 276 return nil 277 } 278 279 t.res.add(Hop{ 280 Success: false, 281 Address: nil, 282 TTL: ttl, 283 RTT: 0, 284 Error: ErrHopLimitTimeout, 285 }) 286 } 287 288 return nil 289 }