github.com/nxtrace/NTrace-core@v1.3.1-0.20240513132635-39169291e8c9/trace/udp.go (about)

     1  package trace
     2  
     3  import (
     4  	"log"
     5  	"net"
     6  	"sync"
     7  	"time"
     8  
     9  	"github.com/google/gopacket"
    10  	"github.com/google/gopacket/layers"
    11  	"github.com/nxtrace/NTrace-core/util"
    12  	"golang.org/x/net/context"
    13  	"golang.org/x/net/icmp"
    14  	"golang.org/x/net/ipv4"
    15  	"golang.org/x/sync/semaphore"
    16  )
    17  
    18  type UDPTracer struct {
    19  	Config
    20  	wg                  sync.WaitGroup
    21  	res                 Result
    22  	ctx                 context.Context
    23  	inflightRequest     map[int]chan Hop
    24  	inflightRequestLock sync.Mutex
    25  
    26  	icmp net.PacketConn
    27  
    28  	final     int
    29  	finalLock sync.Mutex
    30  
    31  	sem       *semaphore.Weighted
    32  	fetchLock sync.Mutex
    33  }
    34  
    35  func (t *UDPTracer) Execute() (*Result, error) {
    36  	if len(t.res.Hops) > 0 {
    37  		return &t.res, ErrTracerouteExecuted
    38  	}
    39  
    40  	var err error
    41  	t.icmp, err = icmp.ListenPacket("ip4:icmp", t.SrcAddr)
    42  	if err != nil {
    43  		return &t.res, err
    44  	}
    45  	defer t.icmp.Close()
    46  
    47  	var cancel context.CancelFunc
    48  	t.ctx, cancel = context.WithCancel(context.Background())
    49  	defer cancel()
    50  	t.inflightRequest = make(map[int]chan Hop)
    51  	t.final = -1
    52  
    53  	go t.listenICMP()
    54  
    55  	t.sem = semaphore.NewWeighted(int64(t.ParallelRequests))
    56  	for ttl := 1; ttl <= t.MaxHops; ttl++ {
    57  		// 如果到达最终跳,则退出
    58  		if t.final != -1 && ttl > t.final {
    59  			break
    60  		}
    61  		for i := 0; i < t.NumMeasurements; i++ {
    62  			t.wg.Add(1)
    63  			go t.send(ttl)
    64  			<-time.After(time.Millisecond * time.Duration(t.Config.PacketInterval))
    65  		}
    66  		if t.RealtimePrinter != nil {
    67  			// 对于实时模式,应该按照TTL进行并发请求
    68  			t.wg.Wait()
    69  			t.RealtimePrinter(&t.res, ttl-1)
    70  		}
    71  		<-time.After(time.Millisecond * time.Duration(t.Config.TTLInterval))
    72  	}
    73  	go func() {
    74  		if t.AsyncPrinter != nil {
    75  			for {
    76  				t.AsyncPrinter(&t.res)
    77  				time.Sleep(200 * time.Millisecond)
    78  			}
    79  		}
    80  	}()
    81  	// 如果是表格模式,则一次性并发请求
    82  	if t.AsyncPrinter != nil {
    83  		t.wg.Wait()
    84  	}
    85  	t.res.reduce(t.final)
    86  
    87  	return &t.res, nil
    88  }
    89  
    90  func (t *UDPTracer) listenICMP() {
    91  	lc := NewPacketListener(t.icmp, t.ctx)
    92  	go lc.Start()
    93  	for {
    94  		select {
    95  		case <-t.ctx.Done():
    96  			return
    97  		case msg := <-lc.Messages:
    98  			if msg.N == nil {
    99  				continue
   100  			}
   101  			rm, err := icmp.ParseMessage(1, msg.Msg[:*msg.N])
   102  			if err != nil {
   103  				log.Println(err)
   104  				continue
   105  			}
   106  			switch rm.Type {
   107  			case ipv4.ICMPTypeTimeExceeded:
   108  				t.handleICMPMessage(msg, rm.Body.(*icmp.TimeExceeded).Data)
   109  			case ipv4.ICMPTypeDestinationUnreachable:
   110  				t.handleICMPMessage(msg, rm.Body.(*icmp.DstUnreach).Data)
   111  			default:
   112  				// log.Println("received icmp message of unknown type", rm.Type)
   113  			}
   114  		}
   115  	}
   116  
   117  }
   118  
   119  func (t *UDPTracer) handleICMPMessage(msg ReceivedMessage, data []byte) {
   120  	header, err := util.GetICMPResponsePayload(data)
   121  	if err != nil {
   122  		return
   123  	}
   124  	srcPort := util.GetUDPSrcPort(header)
   125  	t.inflightRequestLock.Lock()
   126  	defer t.inflightRequestLock.Unlock()
   127  	ch, ok := t.inflightRequest[int(srcPort)]
   128  	if !ok {
   129  		return
   130  	}
   131  	ch <- Hop{
   132  		Success: true,
   133  		Address: msg.Peer,
   134  	}
   135  }
   136  
   137  func (t *UDPTracer) getUDPConn(try int) (net.IP, int, net.PacketConn) {
   138  	srcIP, _ := util.LocalIPPort(t.DestIP)
   139  
   140  	var ipString string
   141  	if srcIP == nil {
   142  		ipString = ""
   143  	} else {
   144  		ipString = srcIP.String()
   145  	}
   146  
   147  	udpConn, err := net.ListenPacket("udp", ipString+":0")
   148  	if err != nil {
   149  		if try > 3 {
   150  			log.Fatal(err)
   151  		}
   152  		return t.getUDPConn(try + 1)
   153  	}
   154  	return srcIP, udpConn.LocalAddr().(*net.UDPAddr).Port, udpConn
   155  }
   156  
   157  func (t *UDPTracer) send(ttl int) error {
   158  	err := t.sem.Acquire(context.Background(), 1)
   159  	if err != nil {
   160  		return err
   161  	}
   162  	defer t.sem.Release(1)
   163  
   164  	defer t.wg.Done()
   165  	if t.final != -1 && ttl > t.final {
   166  		return nil
   167  	}
   168  
   169  	srcIP, srcPort, udpConn := t.getUDPConn(0)
   170  
   171  	var payload []byte
   172  	if t.Quic {
   173  		payload = GenerateQuicPayloadWithRandomIds()
   174  	} else {
   175  		ipHeader := &layers.IPv4{
   176  			SrcIP:    srcIP,
   177  			DstIP:    t.DestIP,
   178  			Protocol: layers.IPProtocolTCP,
   179  			TTL:      uint8(ttl),
   180  		}
   181  
   182  		udpHeader := &layers.UDP{
   183  			SrcPort: layers.UDPPort(srcPort),
   184  			DstPort: layers.UDPPort(t.DestPort),
   185  		}
   186  		_ = udpHeader.SetNetworkLayerForChecksum(ipHeader)
   187  		buf := gopacket.NewSerializeBuffer()
   188  		opts := gopacket.SerializeOptions{
   189  			ComputeChecksums: true,
   190  			FixLengths:       true,
   191  		}
   192  
   193  		desiredPayloadSize := t.Config.PktSize
   194  		payload := make([]byte, desiredPayloadSize)
   195  		copy(buf.Bytes(), payload)
   196  
   197  		if err := gopacket.SerializeLayers(buf, opts, udpHeader); err != nil {
   198  			return err
   199  		}
   200  
   201  		payload = buf.Bytes()
   202  	}
   203  
   204  	err = ipv4.NewPacketConn(udpConn).SetTTL(ttl)
   205  	if err != nil {
   206  		return err
   207  	}
   208  
   209  	start := time.Now()
   210  	if _, err := udpConn.WriteTo(payload, &net.UDPAddr{IP: t.DestIP, Port: t.DestPort}); err != nil {
   211  		return err
   212  	}
   213  
   214  	// 在对inflightRequest进行写操作的时候应该加锁保护,以免多个goroutine协程试图同时写入造成panic
   215  	t.inflightRequestLock.Lock()
   216  	hopCh := make(chan Hop)
   217  	t.inflightRequest[srcPort] = hopCh
   218  	t.inflightRequestLock.Unlock()
   219  	defer func() {
   220  		t.inflightRequestLock.Lock()
   221  		close(hopCh)
   222  		delete(t.inflightRequest, srcPort)
   223  		t.inflightRequestLock.Unlock()
   224  	}()
   225  
   226  	go func() {
   227  		reply := make([]byte, 1500)
   228  		_, peer, err := udpConn.ReadFrom(reply)
   229  		if err != nil {
   230  			// probably because we closed the connection
   231  			return
   232  		}
   233  		hopCh <- Hop{
   234  			Success: true,
   235  			Address: peer,
   236  		}
   237  	}()
   238  
   239  	select {
   240  	case <-t.ctx.Done():
   241  		return nil
   242  	case h := <-hopCh:
   243  		rtt := time.Since(start)
   244  		if t.final != -1 && ttl > t.final {
   245  			return nil
   246  		}
   247  
   248  		if addr, ok := h.Address.(*net.IPAddr); ok && addr.IP.Equal(t.DestIP) {
   249  			t.finalLock.Lock()
   250  			if t.final == -1 || ttl < t.final {
   251  				t.final = ttl
   252  			}
   253  			t.finalLock.Unlock()
   254  		} else if addr, ok := h.Address.(*net.UDPAddr); ok && addr.IP.Equal(t.DestIP) {
   255  			t.finalLock.Lock()
   256  			if t.final == -1 || ttl < t.final {
   257  				t.final = ttl
   258  			}
   259  			t.finalLock.Unlock()
   260  		}
   261  
   262  		h.TTL = ttl
   263  		h.RTT = rtt
   264  
   265  		t.fetchLock.Lock()
   266  		defer t.fetchLock.Unlock()
   267  		err := h.fetchIPData(t.Config)
   268  		if err != nil {
   269  			return err
   270  		}
   271  
   272  		t.res.add(h)
   273  
   274  	case <-time.After(t.Timeout):
   275  		if t.final != -1 && ttl > t.final {
   276  			return nil
   277  		}
   278  
   279  		t.res.add(Hop{
   280  			Success: false,
   281  			Address: nil,
   282  			TTL:     ttl,
   283  			RTT:     0,
   284  			Error:   ErrHopLimitTimeout,
   285  		})
   286  	}
   287  
   288  	return nil
   289  }