github.com/GuanceCloud/cliutils@v1.1.21/dialtesting/traceroute_other.go (about)

     1  // Unless explicitly stated otherwise all files in this repository are licensed
     2  // under the MIT License.
     3  // This product includes software developed at Guance Cloud (https://www.guance.com/).
     4  // Copyright 2021-present Guance, Inc.
     5  
     6  //go:build !windows
     7  // +build !windows
     8  
     9  package dialtesting
    10  
    11  import (
    12  	"fmt"
    13  	"math/rand"
    14  	"net"
    15  	"strconv"
    16  	"sync"
    17  	"sync/atomic"
    18  	"syscall"
    19  	"time"
    20  
    21  	"golang.org/x/net/icmp"
    22  	"golang.org/x/net/ipv4"
    23  	"golang.org/x/net/ipv6"
    24  )
    25  
    26  type receivePacket struct {
    27  	from           *net.IPAddr
    28  	packetRecvTime time.Time
    29  	buf            []byte
    30  }
    31  
    32  // Traceroute specified host with max hops and timeout.
    33  type Traceroute struct {
    34  	Host    string
    35  	Hops    int
    36  	Retry   int
    37  	Timeout time.Duration
    38  
    39  	routes           []*Route
    40  	response         chan *Response
    41  	stopCh           chan interface{}
    42  	packetCh         chan *Packet
    43  	receivePacketsCh chan *receivePacket
    44  	id               uint32
    45  }
    46  
    47  // init config: hops, retry, timeout should not be greater than the max value.
    48  func (t *Traceroute) init() {
    49  	if t.Hops <= 0 {
    50  		t.Hops = 30
    51  	} else if t.Hops > MaxHops {
    52  		t.Hops = MaxHops
    53  	}
    54  
    55  	if t.Retry <= 0 {
    56  		t.Retry = 3
    57  	} else if t.Retry > MaxRetry {
    58  		t.Retry = MaxRetry
    59  	}
    60  
    61  	if t.Timeout <= 0 {
    62  		t.Timeout = 1 * time.Second
    63  	} else if t.Timeout > MaxTimeout {
    64  		t.Timeout = MaxTimeout
    65  	}
    66  
    67  	t.routes = make([]*Route, 0)
    68  
    69  	t.response = make(chan *Response)
    70  	t.stopCh = make(chan interface{})
    71  	t.packetCh = make(chan *Packet)
    72  	t.receivePacketsCh = make(chan *receivePacket, 5000)
    73  
    74  	t.id = t.getRandomID()
    75  }
    76  
    77  // getRandomID generate random id, max 60000.
    78  func (t *Traceroute) getRandomID() uint32 {
    79  	rand.Seed(time.Now().UnixNano())
    80  	return uint32(rand.Intn(60000)) //nolint:gosec
    81  }
    82  
    83  func (t *Traceroute) Run() error {
    84  	var runError error
    85  	ips, err := net.LookupIP(t.Host)
    86  	if err != nil {
    87  		return err
    88  	}
    89  
    90  	t.init()
    91  
    92  	if len(ips) == 0 {
    93  		return fmt.Errorf("invalid host: %s", t.Host)
    94  	}
    95  	ip := ips[0]
    96  
    97  	var wg sync.WaitGroup
    98  	wg.Add(2)
    99  	go func() {
   100  		defer wg.Done()
   101  		if err := t.startTrace(ip); err != nil {
   102  			runError = fmt.Errorf("start trace error: %w", err)
   103  		}
   104  	}()
   105  
   106  	go func() {
   107  		defer wg.Done()
   108  		if err := t.listenICMP(); err != nil {
   109  			runError = fmt.Errorf("listen icmp error: %w", err)
   110  		}
   111  	}()
   112  	wg.Wait()
   113  	return runError
   114  }
   115  
   116  func (t *Traceroute) startTrace(ip net.IP) error {
   117  	var icmpResponse *Response
   118  
   119  	defer close(t.stopCh)
   120  
   121  	for i := 1; i <= t.Hops; i++ {
   122  		isReply := false
   123  		routeItems := []*RouteItem{}
   124  		responseTimes := []float64{}
   125  		var minCost, maxCost time.Duration
   126  		var failed int
   127  		for j := 0; j < t.Retry; j++ {
   128  			if err := t.sendICMP(ip, i); err != nil {
   129  				return err
   130  			}
   131  			icmpResponse = <-t.response
   132  			routeItem := &RouteItem{
   133  				IP:           icmpResponse.From.String(),
   134  				ResponseTime: float64(icmpResponse.ResponseTime.Microseconds()),
   135  			}
   136  
   137  			if icmpResponse.fail {
   138  				routeItem.IP = "*"
   139  				failed++
   140  			} else {
   141  				if icmpResponse.From.String() == ip.String() {
   142  					isReply = true
   143  				}
   144  
   145  				if icmpResponse.ResponseTime > 0 {
   146  					if minCost == 0 || minCost > icmpResponse.ResponseTime {
   147  						minCost = icmpResponse.ResponseTime
   148  					}
   149  
   150  					if maxCost == 0 || maxCost < icmpResponse.ResponseTime {
   151  						maxCost = icmpResponse.ResponseTime
   152  					}
   153  
   154  					responseTimes = append(responseTimes, float64(icmpResponse.ResponseTime.Microseconds()))
   155  				}
   156  			}
   157  
   158  			routeItems = append(routeItems, routeItem)
   159  		}
   160  
   161  		loss, _ := strconv.ParseFloat(fmt.Sprintf("%.2f", float64(failed)*100/float64(t.Retry)), 64)
   162  
   163  		route := &Route{
   164  			Total:   t.Retry,
   165  			Failed:  failed,
   166  			Loss:    loss,
   167  			MinCost: float64(minCost.Microseconds()),
   168  			AvgCost: mean(responseTimes),
   169  			MaxCost: float64(maxCost.Microseconds()),
   170  			StdCost: std(responseTimes),
   171  			Items:   routeItems,
   172  		}
   173  		t.routes = append(t.routes, route)
   174  
   175  		if isReply {
   176  			return nil
   177  		}
   178  	}
   179  
   180  	return nil
   181  }
   182  
   183  func (t *Traceroute) dealPacket() {
   184  	for {
   185  		select {
   186  		case <-t.stopCh:
   187  			return
   188  		case packet, ok := <-t.packetCh:
   189  			if ok {
   190  				for {
   191  					p := <-t.receivePacketsCh
   192  					if p.packetRecvTime.Sub(packet.startTime) > t.Timeout {
   193  						t.response <- &Response{fail: true}
   194  						break
   195  					}
   196  					if p.from == nil || p.from.IP == nil || len(p.buf) == 0 {
   197  						continue
   198  					}
   199  					msg, err := icmp.ParseMessage(1, p.buf)
   200  					if err != nil {
   201  						continue
   202  					}
   203  
   204  					if msg.Type == ipv4.ICMPTypeEchoReply {
   205  						echo := msg.Body.(*icmp.Echo)
   206  
   207  						if echo.ID != packet.ID {
   208  							continue
   209  						}
   210  					} else {
   211  						icmpData := t.getReplyData(msg)
   212  						if len(icmpData) < ipv4.HeaderLen {
   213  							continue
   214  						}
   215  
   216  						var packetID int
   217  
   218  						func() {
   219  							switch icmpData[0] >> 4 {
   220  							case ipv4.Version:
   221  								header, err := ipv4.ParseHeader(icmpData)
   222  								if err != nil {
   223  									return
   224  								}
   225  								packetID = header.ID
   226  							case ipv6.Version:
   227  								header, err := ipv6.ParseHeader(icmpData)
   228  								if err != nil {
   229  									return
   230  								}
   231  
   232  								packetID = header.FlowLabel
   233  							}
   234  						}()
   235  						if packetID != packet.ID {
   236  							continue
   237  						}
   238  					}
   239  
   240  					t.response <- &Response{From: p.from.IP, ResponseTime: p.packetRecvTime.Sub(packet.startTime)}
   241  					break
   242  				}
   243  			}
   244  		}
   245  	}
   246  }
   247  
   248  func (t *Traceroute) listenICMP() error {
   249  	var addr *net.IPAddr
   250  	conn, err := net.ListenIP("ip4:icmp", addr)
   251  	if err != nil {
   252  		return err
   253  	}
   254  
   255  	defer func() {
   256  		if err := conn.Close(); err != nil {
   257  			_ = err // pass
   258  		}
   259  	}()
   260  
   261  	go t.dealPacket()
   262  
   263  	for {
   264  		select {
   265  		case <-t.stopCh:
   266  			return nil
   267  		default:
   268  		}
   269  
   270  		buf := make([]byte, 1500)
   271  		deadLine := time.Now().Add(time.Second)
   272  
   273  		if t.Timeout > 0 && t.Timeout < 10*time.Second { // max 10s
   274  			deadLine = time.Now().Add(t.Timeout)
   275  		}
   276  
   277  		_ = conn.SetDeadline(deadLine)
   278  
   279  		n, from, _ := conn.ReadFromIP(buf)
   280  		t.receivePacketsCh <- &receivePacket{
   281  			from:           from,
   282  			packetRecvTime: time.Now(),
   283  			buf:            buf[:n],
   284  		}
   285  	}
   286  }
   287  
   288  func (t *Traceroute) getReplyData(msg *icmp.Message) []byte {
   289  	switch b := msg.Body.(type) {
   290  	case *icmp.TimeExceeded:
   291  		return b.Data
   292  	case *icmp.DstUnreach:
   293  		return b.Data
   294  	case *icmp.ParamProb:
   295  		return b.Data
   296  	}
   297  
   298  	return nil
   299  }
   300  
   301  func (t *Traceroute) sendICMP(ip net.IP, ttl int) error {
   302  	if ip.To4() == nil {
   303  		return fmt.Errorf("support ip version 4 only")
   304  	}
   305  	id := uint16(atomic.AddUint32(&t.id, 1))
   306  
   307  	dst := net.ParseIP(ip.String())
   308  	echoBody := &icmp.Echo{
   309  		ID:  int(id),
   310  		Seq: int(id),
   311  	}
   312  	msg := icmp.Message{
   313  		Type: ipv4.ICMPTypeEcho,
   314  		Body: echoBody,
   315  	}
   316  
   317  	p, err := msg.Marshal(nil)
   318  	if err != nil {
   319  		return err
   320  	}
   321  
   322  	ipHeader := &ipv4.Header{
   323  		Version:  ipv4.Version,
   324  		Len:      ipv4.HeaderLen,
   325  		TotalLen: ipv4.HeaderLen + len(p),
   326  		TOS:      16,
   327  		ID:       int(id),
   328  		Dst:      dst,
   329  		Protocol: 1,
   330  		TTL:      ttl,
   331  	}
   332  
   333  	buf, err := ipHeader.Marshal()
   334  	if err != nil {
   335  		return err
   336  	}
   337  
   338  	buf = append(buf, p...)
   339  
   340  	conn, err := net.ListenIP("ip4:icmp", nil)
   341  	if err != nil {
   342  		return err
   343  	}
   344  	defer func() {
   345  		if err := conn.Close(); err != nil {
   346  			_ = err // pass
   347  		}
   348  	}()
   349  
   350  	raw, err := conn.SyscallConn()
   351  	if err != nil {
   352  		return err
   353  	}
   354  
   355  	_ = raw.Control(func(fd uintptr) {
   356  		err = syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
   357  	})
   358  
   359  	if err != nil {
   360  		return err
   361  	}
   362  
   363  	t.packetCh <- &Packet{ID: echoBody.ID, Dst: ipHeader.Dst, startTime: time.Now()}
   364  
   365  	_, err = conn.WriteToIP(buf, &net.IPAddr{IP: dst})
   366  
   367  	if err != nil {
   368  		return err
   369  	}
   370  
   371  	return nil
   372  }
   373  
   374  func TracerouteIP(ip string, opt *TracerouteOption) (routes []*Route, err error) {
   375  	defaultTimeout := 30 * time.Millisecond
   376  	if opt == nil {
   377  		opt = &TracerouteOption{
   378  			Hops:    30,
   379  			Retry:   2,
   380  			timeout: defaultTimeout,
   381  		}
   382  	} else {
   383  		if timeout, err := time.ParseDuration(opt.Timeout); err != nil {
   384  			opt.timeout = defaultTimeout
   385  		} else {
   386  			opt.timeout = timeout
   387  		}
   388  	}
   389  
   390  	traceroute := Traceroute{
   391  		Host:    ip,
   392  		Hops:    opt.Hops,
   393  		Retry:   opt.Retry,
   394  		Timeout: opt.timeout,
   395  	}
   396  
   397  	err = traceroute.Run()
   398  
   399  	if err != nil {
   400  		return
   401  	}
   402  
   403  	routes = traceroute.routes
   404  
   405  	return routes, err
   406  }