github.com/nxtrace/NTrace-core@v1.3.1-0.20240513132635-39169291e8c9/trace/tcp_ipv4.go (about)

     1  package trace
     2  
     3  import (
     4  	"log"
     5  	"math"
     6  	"math/rand"
     7  	"net"
     8  	"sync"
     9  	"time"
    10  
    11  	"github.com/google/gopacket"
    12  	"github.com/google/gopacket/layers"
    13  	"github.com/nxtrace/NTrace-core/util"
    14  	"golang.org/x/net/context"
    15  	"golang.org/x/net/icmp"
    16  	"golang.org/x/net/ipv4"
    17  	"golang.org/x/sync/semaphore"
    18  )
    19  
    20  type TCPTracer struct {
    21  	Config
    22  	wg                  sync.WaitGroup
    23  	res                 Result
    24  	ctx                 context.Context
    25  	inflightRequest     map[int]chan Hop
    26  	inflightRequestLock sync.Mutex
    27  	SrcIP               net.IP
    28  	icmp                net.PacketConn
    29  	tcp                 net.PacketConn
    30  
    31  	final     int
    32  	finalLock sync.Mutex
    33  
    34  	sem       *semaphore.Weighted
    35  	fetchLock sync.Mutex
    36  }
    37  
    38  func (t *TCPTracer) Execute() (*Result, error) {
    39  	if len(t.res.Hops) > 0 {
    40  		return &t.res, ErrTracerouteExecuted
    41  	}
    42  
    43  	t.SrcIP, _ = util.LocalIPPort(t.DestIP)
    44  
    45  	var err error
    46  	if t.SrcAddr != "" {
    47  		t.tcp, err = net.ListenPacket("ip4:tcp", t.SrcAddr)
    48  	} else {
    49  		t.tcp, err = net.ListenPacket("ip4:tcp", t.SrcIP.String())
    50  	}
    51  
    52  	if err != nil {
    53  		return nil, err
    54  	}
    55  	t.icmp, err = icmp.ListenPacket("ip4:icmp", t.SrcAddr)
    56  	if err != nil {
    57  		return &t.res, err
    58  	}
    59  	defer t.icmp.Close()
    60  
    61  	var cancel context.CancelFunc
    62  	t.ctx, cancel = context.WithCancel(context.Background())
    63  	defer cancel()
    64  	t.inflightRequestLock.Lock()
    65  	t.inflightRequest = make(map[int]chan Hop)
    66  	t.inflightRequestLock.Unlock()
    67  
    68  	t.final = -1
    69  
    70  	go t.listenICMP()
    71  	go t.listenTCP()
    72  
    73  	t.sem = semaphore.NewWeighted(int64(t.ParallelRequests))
    74  
    75  	for ttl := t.BeginHop; ttl <= t.MaxHops; ttl++ {
    76  		// 如果到达最终跳,则退出
    77  		if t.final != -1 && ttl > t.final {
    78  			break
    79  		}
    80  		for i := 0; i < t.NumMeasurements; i++ {
    81  			t.wg.Add(1)
    82  			go t.send(ttl)
    83  			<-time.After(time.Millisecond * time.Duration(t.Config.PacketInterval))
    84  		}
    85  		if t.RealtimePrinter != nil {
    86  			// 对于实时模式,应该按照TTL进行并发请求
    87  			t.wg.Wait()
    88  			t.RealtimePrinter(&t.res, ttl-1)
    89  		}
    90  
    91  		<-time.After(time.Millisecond * time.Duration(t.Config.TTLInterval))
    92  	}
    93  	go func() {
    94  		if t.AsyncPrinter != nil {
    95  			for {
    96  				t.AsyncPrinter(&t.res)
    97  				time.Sleep(200 * time.Millisecond)
    98  			}
    99  		}
   100  
   101  	}()
   102  
   103  	// 如果是表格模式,则一次性并发请求
   104  	if t.RealtimePrinter == nil {
   105  		t.wg.Wait()
   106  	}
   107  	t.res.reduce(t.final)
   108  
   109  	return &t.res, nil
   110  }
   111  
   112  func (t *TCPTracer) listenICMP() {
   113  	lc := NewPacketListener(t.icmp, t.ctx)
   114  	go lc.Start()
   115  	for {
   116  		select {
   117  		case <-t.ctx.Done():
   118  			return
   119  		case msg := <-lc.Messages:
   120  			if msg.N == nil {
   121  				continue
   122  			}
   123  			dstip := net.IP(msg.Msg[24:28])
   124  			if dstip.Equal(t.DestIP) {
   125  				rm, err := icmp.ParseMessage(1, msg.Msg[:*msg.N])
   126  				if err != nil {
   127  					log.Println(err)
   128  					continue
   129  				}
   130  				switch rm.Type {
   131  				case ipv4.ICMPTypeTimeExceeded:
   132  					t.handleICMPMessage(msg, rm.Body.(*icmp.TimeExceeded).Data)
   133  				case ipv4.ICMPTypeDestinationUnreachable:
   134  					t.handleICMPMessage(msg, rm.Body.(*icmp.DstUnreach).Data)
   135  				default:
   136  					//log.Println("received icmp message of unknown type", rm.Type)
   137  				}
   138  			}
   139  		}
   140  	}
   141  
   142  }
   143  
   144  // @title    listenTCP
   145  // @description   监听TCP的响应数据包
   146  func (t *TCPTracer) listenTCP() {
   147  	lc := NewPacketListener(t.tcp, t.ctx)
   148  	go lc.Start()
   149  
   150  	for {
   151  		select {
   152  		case <-t.ctx.Done():
   153  			return
   154  		case msg := <-lc.Messages:
   155  			if msg.N == nil {
   156  				continue
   157  			}
   158  			if msg.Peer.String() != t.DestIP.String() {
   159  				continue
   160  			}
   161  
   162  			// 解包
   163  			packet := gopacket.NewPacket(msg.Msg[:*msg.N], layers.LayerTypeTCP, gopacket.Default)
   164  			// 从包中获取TCP layer信息
   165  			if tcpLayer := packet.Layer(layers.LayerTypeTCP); tcpLayer != nil {
   166  				tcp, _ := tcpLayer.(*layers.TCP)
   167  				// 取得目标主机的Sequence Number
   168  				t.inflightRequestLock.Lock()
   169  				if ch, ok := t.inflightRequest[int(tcp.Ack-1)]; ok {
   170  					// 最后一跳
   171  					ch <- Hop{
   172  						Success: true,
   173  						Address: msg.Peer,
   174  					}
   175  				}
   176  				t.inflightRequestLock.Unlock()
   177  			}
   178  		}
   179  	}
   180  }
   181  
   182  func (t *TCPTracer) handleICMPMessage(msg ReceivedMessage, data []byte) {
   183  	header, err := util.GetICMPResponsePayload(data)
   184  	if err != nil {
   185  		return
   186  	}
   187  	sequenceNumber := util.GetTCPSeq(header)
   188  	t.inflightRequestLock.Lock()
   189  	defer t.inflightRequestLock.Unlock()
   190  	ch, ok := t.inflightRequest[int(sequenceNumber)]
   191  	if !ok {
   192  		return
   193  	}
   194  	ch <- Hop{
   195  		Success: true,
   196  		Address: msg.Peer,
   197  	}
   198  
   199  }
   200  
   201  func (t *TCPTracer) send(ttl int) error {
   202  	err := t.sem.Acquire(context.Background(), 1)
   203  	if err != nil {
   204  		return err
   205  	}
   206  	defer t.sem.Release(1)
   207  
   208  	defer t.wg.Done()
   209  	if t.final != -1 && ttl > t.final {
   210  		return nil
   211  	}
   212  	// 随机种子
   213  	r := rand.New(rand.NewSource(time.Now().UnixNano()))
   214  	_, srcPort := util.LocalIPPort(t.DestIP)
   215  	ipHeader := &layers.IPv4{
   216  		SrcIP:    t.SrcIP,
   217  		DstIP:    t.DestIP,
   218  		Protocol: layers.IPProtocolTCP,
   219  		TTL:      uint8(ttl),
   220  	}
   221  	// 使用Uint16兼容32位系统,防止在rand的时候因使用int32而溢出
   222  	sequenceNumber := uint32(r.Intn(math.MaxUint16))
   223  	tcpHeader := &layers.TCP{
   224  		SrcPort: layers.TCPPort(srcPort),
   225  		DstPort: layers.TCPPort(t.DestPort),
   226  		Seq:     sequenceNumber,
   227  		SYN:     true,
   228  		Window:  14600,
   229  	}
   230  	_ = tcpHeader.SetNetworkLayerForChecksum(ipHeader)
   231  
   232  	buf := gopacket.NewSerializeBuffer()
   233  	opts := gopacket.SerializeOptions{
   234  		ComputeChecksums: true,
   235  		FixLengths:       true,
   236  	}
   237  
   238  	desiredPayloadSize := t.Config.PktSize
   239  	payload := make([]byte, desiredPayloadSize)
   240  	copy(buf.Bytes(), payload)
   241  
   242  	if err := gopacket.SerializeLayers(buf, opts, tcpHeader); err != nil {
   243  		return err
   244  	}
   245  
   246  	err = ipv4.NewPacketConn(t.tcp).SetTTL(ttl)
   247  	if err != nil {
   248  		return err
   249  	}
   250  
   251  	start := time.Now()
   252  	if _, err := t.tcp.WriteTo(buf.Bytes(), &net.IPAddr{IP: t.DestIP}); err != nil {
   253  		return err
   254  	}
   255  	t.inflightRequestLock.Lock()
   256  	hopCh := make(chan Hop)
   257  	t.inflightRequest[int(sequenceNumber)] = hopCh
   258  	t.inflightRequestLock.Unlock()
   259  	/*
   260  		// 这里属于 2个Sender,N个Receiver的情况,在哪里关闭Channel都容易导致Panic
   261  		defer func() {
   262  			t.inflightRequestLock.Lock()
   263  			close(hopCh)
   264  			delete(t.inflightRequest, srcPort)
   265  			t.inflightRequestLock.Unlock()
   266  		}()
   267  	*/
   268  	select {
   269  	case <-t.ctx.Done():
   270  		return nil
   271  	case h := <-hopCh:
   272  		rtt := time.Since(start)
   273  		if t.final != -1 && ttl > t.final {
   274  			return nil
   275  		}
   276  
   277  		if addr, ok := h.Address.(*net.IPAddr); ok && addr.IP.Equal(t.DestIP) {
   278  			t.finalLock.Lock()
   279  			if t.final == -1 || ttl < t.final {
   280  				t.final = ttl
   281  			}
   282  			t.finalLock.Unlock()
   283  		} else if addr, ok := h.Address.(*net.TCPAddr); ok && addr.IP.Equal(t.DestIP) {
   284  			t.finalLock.Lock()
   285  			if t.final == -1 || ttl < t.final {
   286  				t.final = ttl
   287  			}
   288  			t.finalLock.Unlock()
   289  		}
   290  
   291  		h.TTL = ttl
   292  		h.RTT = rtt
   293  
   294  		t.fetchLock.Lock()
   295  		defer t.fetchLock.Unlock()
   296  		err := h.fetchIPData(t.Config)
   297  		if err != nil {
   298  			return err
   299  		}
   300  
   301  		t.res.add(h)
   302  
   303  	case <-time.After(t.Timeout):
   304  		if t.final != -1 && ttl > t.final {
   305  			return nil
   306  		}
   307  
   308  		t.res.add(Hop{
   309  			Success: false,
   310  			Address: nil,
   311  			TTL:     ttl,
   312  			RTT:     0,
   313  			Error:   ErrHopLimitTimeout,
   314  		})
   315  	}
   316  
   317  	return nil
   318  }