github.com/nxtrace/NTrace-core@v1.3.1-0.20240513132635-39169291e8c9/trace/tcp_ipv4.go (about) 1 package trace 2 3 import ( 4 "log" 5 "math" 6 "math/rand" 7 "net" 8 "sync" 9 "time" 10 11 "github.com/google/gopacket" 12 "github.com/google/gopacket/layers" 13 "github.com/nxtrace/NTrace-core/util" 14 "golang.org/x/net/context" 15 "golang.org/x/net/icmp" 16 "golang.org/x/net/ipv4" 17 "golang.org/x/sync/semaphore" 18 ) 19 20 type TCPTracer struct { 21 Config 22 wg sync.WaitGroup 23 res Result 24 ctx context.Context 25 inflightRequest map[int]chan Hop 26 inflightRequestLock sync.Mutex 27 SrcIP net.IP 28 icmp net.PacketConn 29 tcp net.PacketConn 30 31 final int 32 finalLock sync.Mutex 33 34 sem *semaphore.Weighted 35 fetchLock sync.Mutex 36 } 37 38 func (t *TCPTracer) Execute() (*Result, error) { 39 if len(t.res.Hops) > 0 { 40 return &t.res, ErrTracerouteExecuted 41 } 42 43 t.SrcIP, _ = util.LocalIPPort(t.DestIP) 44 45 var err error 46 if t.SrcAddr != "" { 47 t.tcp, err = net.ListenPacket("ip4:tcp", t.SrcAddr) 48 } else { 49 t.tcp, err = net.ListenPacket("ip4:tcp", t.SrcIP.String()) 50 } 51 52 if err != nil { 53 return nil, err 54 } 55 t.icmp, err = icmp.ListenPacket("ip4:icmp", t.SrcAddr) 56 if err != nil { 57 return &t.res, err 58 } 59 defer t.icmp.Close() 60 61 var cancel context.CancelFunc 62 t.ctx, cancel = context.WithCancel(context.Background()) 63 defer cancel() 64 t.inflightRequestLock.Lock() 65 t.inflightRequest = make(map[int]chan Hop) 66 t.inflightRequestLock.Unlock() 67 68 t.final = -1 69 70 go t.listenICMP() 71 go t.listenTCP() 72 73 t.sem = semaphore.NewWeighted(int64(t.ParallelRequests)) 74 75 for ttl := t.BeginHop; ttl <= t.MaxHops; ttl++ { 76 // 如果到达最终跳,则退出 77 if t.final != -1 && ttl > t.final { 78 break 79 } 80 for i := 0; i < t.NumMeasurements; i++ { 81 t.wg.Add(1) 82 go t.send(ttl) 83 <-time.After(time.Millisecond * time.Duration(t.Config.PacketInterval)) 84 } 85 if t.RealtimePrinter != nil { 86 // 对于实时模式,应该按照TTL进行并发请求 87 t.wg.Wait() 88 t.RealtimePrinter(&t.res, ttl-1) 89 } 90 91 <-time.After(time.Millisecond * time.Duration(t.Config.TTLInterval)) 92 } 93 go func() { 94 if t.AsyncPrinter != nil { 95 for { 96 t.AsyncPrinter(&t.res) 97 time.Sleep(200 * time.Millisecond) 98 } 99 } 100 101 }() 102 103 // 如果是表格模式,则一次性并发请求 104 if t.RealtimePrinter == nil { 105 t.wg.Wait() 106 } 107 t.res.reduce(t.final) 108 109 return &t.res, nil 110 } 111 112 func (t *TCPTracer) listenICMP() { 113 lc := NewPacketListener(t.icmp, t.ctx) 114 go lc.Start() 115 for { 116 select { 117 case <-t.ctx.Done(): 118 return 119 case msg := <-lc.Messages: 120 if msg.N == nil { 121 continue 122 } 123 dstip := net.IP(msg.Msg[24:28]) 124 if dstip.Equal(t.DestIP) { 125 rm, err := icmp.ParseMessage(1, msg.Msg[:*msg.N]) 126 if err != nil { 127 log.Println(err) 128 continue 129 } 130 switch rm.Type { 131 case ipv4.ICMPTypeTimeExceeded: 132 t.handleICMPMessage(msg, rm.Body.(*icmp.TimeExceeded).Data) 133 case ipv4.ICMPTypeDestinationUnreachable: 134 t.handleICMPMessage(msg, rm.Body.(*icmp.DstUnreach).Data) 135 default: 136 //log.Println("received icmp message of unknown type", rm.Type) 137 } 138 } 139 } 140 } 141 142 } 143 144 // @title listenTCP 145 // @description 监听TCP的响应数据包 146 func (t *TCPTracer) listenTCP() { 147 lc := NewPacketListener(t.tcp, t.ctx) 148 go lc.Start() 149 150 for { 151 select { 152 case <-t.ctx.Done(): 153 return 154 case msg := <-lc.Messages: 155 if msg.N == nil { 156 continue 157 } 158 if msg.Peer.String() != t.DestIP.String() { 159 continue 160 } 161 162 // 解包 163 packet := gopacket.NewPacket(msg.Msg[:*msg.N], layers.LayerTypeTCP, gopacket.Default) 164 // 从包中获取TCP layer信息 165 if tcpLayer := packet.Layer(layers.LayerTypeTCP); tcpLayer != nil { 166 tcp, _ := tcpLayer.(*layers.TCP) 167 // 取得目标主机的Sequence Number 168 t.inflightRequestLock.Lock() 169 if ch, ok := t.inflightRequest[int(tcp.Ack-1)]; ok { 170 // 最后一跳 171 ch <- Hop{ 172 Success: true, 173 Address: msg.Peer, 174 } 175 } 176 t.inflightRequestLock.Unlock() 177 } 178 } 179 } 180 } 181 182 func (t *TCPTracer) handleICMPMessage(msg ReceivedMessage, data []byte) { 183 header, err := util.GetICMPResponsePayload(data) 184 if err != nil { 185 return 186 } 187 sequenceNumber := util.GetTCPSeq(header) 188 t.inflightRequestLock.Lock() 189 defer t.inflightRequestLock.Unlock() 190 ch, ok := t.inflightRequest[int(sequenceNumber)] 191 if !ok { 192 return 193 } 194 ch <- Hop{ 195 Success: true, 196 Address: msg.Peer, 197 } 198 199 } 200 201 func (t *TCPTracer) send(ttl int) error { 202 err := t.sem.Acquire(context.Background(), 1) 203 if err != nil { 204 return err 205 } 206 defer t.sem.Release(1) 207 208 defer t.wg.Done() 209 if t.final != -1 && ttl > t.final { 210 return nil 211 } 212 // 随机种子 213 r := rand.New(rand.NewSource(time.Now().UnixNano())) 214 _, srcPort := util.LocalIPPort(t.DestIP) 215 ipHeader := &layers.IPv4{ 216 SrcIP: t.SrcIP, 217 DstIP: t.DestIP, 218 Protocol: layers.IPProtocolTCP, 219 TTL: uint8(ttl), 220 } 221 // 使用Uint16兼容32位系统,防止在rand的时候因使用int32而溢出 222 sequenceNumber := uint32(r.Intn(math.MaxUint16)) 223 tcpHeader := &layers.TCP{ 224 SrcPort: layers.TCPPort(srcPort), 225 DstPort: layers.TCPPort(t.DestPort), 226 Seq: sequenceNumber, 227 SYN: true, 228 Window: 14600, 229 } 230 _ = tcpHeader.SetNetworkLayerForChecksum(ipHeader) 231 232 buf := gopacket.NewSerializeBuffer() 233 opts := gopacket.SerializeOptions{ 234 ComputeChecksums: true, 235 FixLengths: true, 236 } 237 238 desiredPayloadSize := t.Config.PktSize 239 payload := make([]byte, desiredPayloadSize) 240 copy(buf.Bytes(), payload) 241 242 if err := gopacket.SerializeLayers(buf, opts, tcpHeader); err != nil { 243 return err 244 } 245 246 err = ipv4.NewPacketConn(t.tcp).SetTTL(ttl) 247 if err != nil { 248 return err 249 } 250 251 start := time.Now() 252 if _, err := t.tcp.WriteTo(buf.Bytes(), &net.IPAddr{IP: t.DestIP}); err != nil { 253 return err 254 } 255 t.inflightRequestLock.Lock() 256 hopCh := make(chan Hop) 257 t.inflightRequest[int(sequenceNumber)] = hopCh 258 t.inflightRequestLock.Unlock() 259 /* 260 // 这里属于 2个Sender,N个Receiver的情况,在哪里关闭Channel都容易导致Panic 261 defer func() { 262 t.inflightRequestLock.Lock() 263 close(hopCh) 264 delete(t.inflightRequest, srcPort) 265 t.inflightRequestLock.Unlock() 266 }() 267 */ 268 select { 269 case <-t.ctx.Done(): 270 return nil 271 case h := <-hopCh: 272 rtt := time.Since(start) 273 if t.final != -1 && ttl > t.final { 274 return nil 275 } 276 277 if addr, ok := h.Address.(*net.IPAddr); ok && addr.IP.Equal(t.DestIP) { 278 t.finalLock.Lock() 279 if t.final == -1 || ttl < t.final { 280 t.final = ttl 281 } 282 t.finalLock.Unlock() 283 } else if addr, ok := h.Address.(*net.TCPAddr); ok && addr.IP.Equal(t.DestIP) { 284 t.finalLock.Lock() 285 if t.final == -1 || ttl < t.final { 286 t.final = ttl 287 } 288 t.finalLock.Unlock() 289 } 290 291 h.TTL = ttl 292 h.RTT = rtt 293 294 t.fetchLock.Lock() 295 defer t.fetchLock.Unlock() 296 err := h.fetchIPData(t.Config) 297 if err != nil { 298 return err 299 } 300 301 t.res.add(h) 302 303 case <-time.After(t.Timeout): 304 if t.final != -1 && ttl > t.final { 305 return nil 306 } 307 308 t.res.add(Hop{ 309 Success: false, 310 Address: nil, 311 TTL: ttl, 312 RTT: 0, 313 Error: ErrHopLimitTimeout, 314 }) 315 } 316 317 return nil 318 }