github.com/GuanceCloud/cliutils@v1.1.21/dialtesting/traceroute_other.go (about) 1 // Unless explicitly stated otherwise all files in this repository are licensed 2 // under the MIT License. 3 // This product includes software developed at Guance Cloud (https://www.guance.com/). 4 // Copyright 2021-present Guance, Inc. 5 6 //go:build !windows 7 // +build !windows 8 9 package dialtesting 10 11 import ( 12 "fmt" 13 "math/rand" 14 "net" 15 "strconv" 16 "sync" 17 "sync/atomic" 18 "syscall" 19 "time" 20 21 "golang.org/x/net/icmp" 22 "golang.org/x/net/ipv4" 23 "golang.org/x/net/ipv6" 24 ) 25 26 type receivePacket struct { 27 from *net.IPAddr 28 packetRecvTime time.Time 29 buf []byte 30 } 31 32 // Traceroute specified host with max hops and timeout. 33 type Traceroute struct { 34 Host string 35 Hops int 36 Retry int 37 Timeout time.Duration 38 39 routes []*Route 40 response chan *Response 41 stopCh chan interface{} 42 packetCh chan *Packet 43 receivePacketsCh chan *receivePacket 44 id uint32 45 } 46 47 // init config: hops, retry, timeout should not be greater than the max value. 48 func (t *Traceroute) init() { 49 if t.Hops <= 0 { 50 t.Hops = 30 51 } else if t.Hops > MaxHops { 52 t.Hops = MaxHops 53 } 54 55 if t.Retry <= 0 { 56 t.Retry = 3 57 } else if t.Retry > MaxRetry { 58 t.Retry = MaxRetry 59 } 60 61 if t.Timeout <= 0 { 62 t.Timeout = 1 * time.Second 63 } else if t.Timeout > MaxTimeout { 64 t.Timeout = MaxTimeout 65 } 66 67 t.routes = make([]*Route, 0) 68 69 t.response = make(chan *Response) 70 t.stopCh = make(chan interface{}) 71 t.packetCh = make(chan *Packet) 72 t.receivePacketsCh = make(chan *receivePacket, 5000) 73 74 t.id = t.getRandomID() 75 } 76 77 // getRandomID generate random id, max 60000. 78 func (t *Traceroute) getRandomID() uint32 { 79 rand.Seed(time.Now().UnixNano()) 80 return uint32(rand.Intn(60000)) //nolint:gosec 81 } 82 83 func (t *Traceroute) Run() error { 84 var runError error 85 ips, err := net.LookupIP(t.Host) 86 if err != nil { 87 return err 88 } 89 90 t.init() 91 92 if len(ips) == 0 { 93 return fmt.Errorf("invalid host: %s", t.Host) 94 } 95 ip := ips[0] 96 97 var wg sync.WaitGroup 98 wg.Add(2) 99 go func() { 100 defer wg.Done() 101 if err := t.startTrace(ip); err != nil { 102 runError = fmt.Errorf("start trace error: %w", err) 103 } 104 }() 105 106 go func() { 107 defer wg.Done() 108 if err := t.listenICMP(); err != nil { 109 runError = fmt.Errorf("listen icmp error: %w", err) 110 } 111 }() 112 wg.Wait() 113 return runError 114 } 115 116 func (t *Traceroute) startTrace(ip net.IP) error { 117 var icmpResponse *Response 118 119 defer close(t.stopCh) 120 121 for i := 1; i <= t.Hops; i++ { 122 isReply := false 123 routeItems := []*RouteItem{} 124 responseTimes := []float64{} 125 var minCost, maxCost time.Duration 126 var failed int 127 for j := 0; j < t.Retry; j++ { 128 if err := t.sendICMP(ip, i); err != nil { 129 return err 130 } 131 icmpResponse = <-t.response 132 routeItem := &RouteItem{ 133 IP: icmpResponse.From.String(), 134 ResponseTime: float64(icmpResponse.ResponseTime.Microseconds()), 135 } 136 137 if icmpResponse.fail { 138 routeItem.IP = "*" 139 failed++ 140 } else { 141 if icmpResponse.From.String() == ip.String() { 142 isReply = true 143 } 144 145 if icmpResponse.ResponseTime > 0 { 146 if minCost == 0 || minCost > icmpResponse.ResponseTime { 147 minCost = icmpResponse.ResponseTime 148 } 149 150 if maxCost == 0 || maxCost < icmpResponse.ResponseTime { 151 maxCost = icmpResponse.ResponseTime 152 } 153 154 responseTimes = append(responseTimes, float64(icmpResponse.ResponseTime.Microseconds())) 155 } 156 } 157 158 routeItems = append(routeItems, routeItem) 159 } 160 161 loss, _ := strconv.ParseFloat(fmt.Sprintf("%.2f", float64(failed)*100/float64(t.Retry)), 64) 162 163 route := &Route{ 164 Total: t.Retry, 165 Failed: failed, 166 Loss: loss, 167 MinCost: float64(minCost.Microseconds()), 168 AvgCost: mean(responseTimes), 169 MaxCost: float64(maxCost.Microseconds()), 170 StdCost: std(responseTimes), 171 Items: routeItems, 172 } 173 t.routes = append(t.routes, route) 174 175 if isReply { 176 return nil 177 } 178 } 179 180 return nil 181 } 182 183 func (t *Traceroute) dealPacket() { 184 for { 185 select { 186 case <-t.stopCh: 187 return 188 case packet, ok := <-t.packetCh: 189 if ok { 190 for { 191 p := <-t.receivePacketsCh 192 if p.packetRecvTime.Sub(packet.startTime) > t.Timeout { 193 t.response <- &Response{fail: true} 194 break 195 } 196 if p.from == nil || p.from.IP == nil || len(p.buf) == 0 { 197 continue 198 } 199 msg, err := icmp.ParseMessage(1, p.buf) 200 if err != nil { 201 continue 202 } 203 204 if msg.Type == ipv4.ICMPTypeEchoReply { 205 echo := msg.Body.(*icmp.Echo) 206 207 if echo.ID != packet.ID { 208 continue 209 } 210 } else { 211 icmpData := t.getReplyData(msg) 212 if len(icmpData) < ipv4.HeaderLen { 213 continue 214 } 215 216 var packetID int 217 218 func() { 219 switch icmpData[0] >> 4 { 220 case ipv4.Version: 221 header, err := ipv4.ParseHeader(icmpData) 222 if err != nil { 223 return 224 } 225 packetID = header.ID 226 case ipv6.Version: 227 header, err := ipv6.ParseHeader(icmpData) 228 if err != nil { 229 return 230 } 231 232 packetID = header.FlowLabel 233 } 234 }() 235 if packetID != packet.ID { 236 continue 237 } 238 } 239 240 t.response <- &Response{From: p.from.IP, ResponseTime: p.packetRecvTime.Sub(packet.startTime)} 241 break 242 } 243 } 244 } 245 } 246 } 247 248 func (t *Traceroute) listenICMP() error { 249 var addr *net.IPAddr 250 conn, err := net.ListenIP("ip4:icmp", addr) 251 if err != nil { 252 return err 253 } 254 255 defer func() { 256 if err := conn.Close(); err != nil { 257 _ = err // pass 258 } 259 }() 260 261 go t.dealPacket() 262 263 for { 264 select { 265 case <-t.stopCh: 266 return nil 267 default: 268 } 269 270 buf := make([]byte, 1500) 271 deadLine := time.Now().Add(time.Second) 272 273 if t.Timeout > 0 && t.Timeout < 10*time.Second { // max 10s 274 deadLine = time.Now().Add(t.Timeout) 275 } 276 277 _ = conn.SetDeadline(deadLine) 278 279 n, from, _ := conn.ReadFromIP(buf) 280 t.receivePacketsCh <- &receivePacket{ 281 from: from, 282 packetRecvTime: time.Now(), 283 buf: buf[:n], 284 } 285 } 286 } 287 288 func (t *Traceroute) getReplyData(msg *icmp.Message) []byte { 289 switch b := msg.Body.(type) { 290 case *icmp.TimeExceeded: 291 return b.Data 292 case *icmp.DstUnreach: 293 return b.Data 294 case *icmp.ParamProb: 295 return b.Data 296 } 297 298 return nil 299 } 300 301 func (t *Traceroute) sendICMP(ip net.IP, ttl int) error { 302 if ip.To4() == nil { 303 return fmt.Errorf("support ip version 4 only") 304 } 305 id := uint16(atomic.AddUint32(&t.id, 1)) 306 307 dst := net.ParseIP(ip.String()) 308 echoBody := &icmp.Echo{ 309 ID: int(id), 310 Seq: int(id), 311 } 312 msg := icmp.Message{ 313 Type: ipv4.ICMPTypeEcho, 314 Body: echoBody, 315 } 316 317 p, err := msg.Marshal(nil) 318 if err != nil { 319 return err 320 } 321 322 ipHeader := &ipv4.Header{ 323 Version: ipv4.Version, 324 Len: ipv4.HeaderLen, 325 TotalLen: ipv4.HeaderLen + len(p), 326 TOS: 16, 327 ID: int(id), 328 Dst: dst, 329 Protocol: 1, 330 TTL: ttl, 331 } 332 333 buf, err := ipHeader.Marshal() 334 if err != nil { 335 return err 336 } 337 338 buf = append(buf, p...) 339 340 conn, err := net.ListenIP("ip4:icmp", nil) 341 if err != nil { 342 return err 343 } 344 defer func() { 345 if err := conn.Close(); err != nil { 346 _ = err // pass 347 } 348 }() 349 350 raw, err := conn.SyscallConn() 351 if err != nil { 352 return err 353 } 354 355 _ = raw.Control(func(fd uintptr) { 356 err = syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1) 357 }) 358 359 if err != nil { 360 return err 361 } 362 363 t.packetCh <- &Packet{ID: echoBody.ID, Dst: ipHeader.Dst, startTime: time.Now()} 364 365 _, err = conn.WriteToIP(buf, &net.IPAddr{IP: dst}) 366 367 if err != nil { 368 return err 369 } 370 371 return nil 372 } 373 374 func TracerouteIP(ip string, opt *TracerouteOption) (routes []*Route, err error) { 375 defaultTimeout := 30 * time.Millisecond 376 if opt == nil { 377 opt = &TracerouteOption{ 378 Hops: 30, 379 Retry: 2, 380 timeout: defaultTimeout, 381 } 382 } else { 383 if timeout, err := time.ParseDuration(opt.Timeout); err != nil { 384 opt.timeout = defaultTimeout 385 } else { 386 opt.timeout = timeout 387 } 388 } 389 390 traceroute := Traceroute{ 391 Host: ip, 392 Hops: opt.Hops, 393 Retry: opt.Retry, 394 Timeout: opt.timeout, 395 } 396 397 err = traceroute.Run() 398 399 if err != nil { 400 return 401 } 402 403 routes = traceroute.routes 404 405 return routes, err 406 }