github.com/nxtrace/NTrace-core@v1.3.1-0.20240513132635-39169291e8c9/trace/tcp_ipv6.go (about)

     1  package trace
     2  
     3  import (
     4  	"encoding/binary"
     5  	"math"
     6  	"math/rand"
     7  	"net"
     8  	"sync"
     9  	"time"
    10  
    11  	"github.com/google/gopacket"
    12  	"github.com/google/gopacket/layers"
    13  	"github.com/nxtrace/NTrace-core/util"
    14  	"golang.org/x/net/context"
    15  	"golang.org/x/net/icmp"
    16  	"golang.org/x/net/ipv6"
    17  	"golang.org/x/sync/semaphore"
    18  )
    19  
    20  type TCPTracerv6 struct {
    21  	Config
    22  	wg                  sync.WaitGroup
    23  	res                 Result
    24  	ctx                 context.Context
    25  	inflightRequest     map[int]chan Hop
    26  	inflightRequestLock sync.Mutex
    27  	SrcIP               net.IP
    28  	icmp                net.PacketConn
    29  	tcp                 net.PacketConn
    30  
    31  	final     int
    32  	finalLock sync.Mutex
    33  
    34  	sem       *semaphore.Weighted
    35  	fetchLock sync.Mutex
    36  }
    37  
    38  func (t *TCPTracerv6) Execute() (*Result, error) {
    39  	if len(t.res.Hops) > 0 {
    40  		return &t.res, ErrTracerouteExecuted
    41  	}
    42  
    43  	t.SrcIP, _ = util.LocalIPPortv6(t.DestIP)
    44  	// log.Println(util.LocalIPPortv6(t.DestIP))
    45  	var err error
    46  	t.tcp, err = net.ListenPacket("ip6:tcp", t.SrcIP.String())
    47  	if err != nil {
    48  		return nil, err
    49  	}
    50  	t.icmp, err = icmp.ListenPacket("ip6:58", "::")
    51  	if err != nil {
    52  		return &t.res, err
    53  	}
    54  	defer t.icmp.Close()
    55  
    56  	var cancel context.CancelFunc
    57  	t.ctx, cancel = context.WithCancel(context.Background())
    58  	defer cancel()
    59  	t.inflightRequest = make(map[int]chan Hop)
    60  	t.final = -1
    61  
    62  	go t.listenICMP()
    63  	go t.listenTCP()
    64  
    65  	t.sem = semaphore.NewWeighted(int64(t.ParallelRequests))
    66  
    67  	for ttl := t.BeginHop; ttl <= t.MaxHops; ttl++ {
    68  		// 如果到达最终跳,则退出
    69  		if t.final != -1 && ttl > t.final {
    70  			break
    71  		}
    72  		for i := 0; i < t.NumMeasurements; i++ {
    73  			t.wg.Add(1)
    74  			go t.send(ttl)
    75  			<-time.After(time.Millisecond * time.Duration(t.Config.PacketInterval))
    76  		}
    77  		if t.RealtimePrinter != nil {
    78  			// 对于实时模式,应该按照TTL进行并发请求
    79  			t.wg.Wait()
    80  			t.RealtimePrinter(&t.res, ttl-1)
    81  		}
    82  		<-time.After(time.Millisecond * time.Duration(t.Config.TTLInterval))
    83  	}
    84  
    85  	go func() {
    86  		if t.AsyncPrinter != nil {
    87  			for {
    88  				t.AsyncPrinter(&t.res)
    89  				time.Sleep(200 * time.Millisecond)
    90  			}
    91  		}
    92  
    93  	}()
    94  
    95  	if t.RealtimePrinter == nil {
    96  		t.wg.Wait()
    97  	}
    98  	t.res.reduce(t.final)
    99  
   100  	return &t.res, nil
   101  }
   102  
   103  func (t *TCPTracerv6) listenICMP() {
   104  	lc := NewPacketListener(t.icmp, t.ctx)
   105  	go lc.Start()
   106  	for {
   107  		select {
   108  		case <-t.ctx.Done():
   109  			return
   110  		case msg := <-lc.Messages:
   111  			// log.Println(msg)
   112  
   113  			if msg.N == nil {
   114  				continue
   115  			}
   116  			rm, err := icmp.ParseMessage(58, msg.Msg[:*msg.N])
   117  			if err != nil {
   118  				// log.Println(err)
   119  				continue
   120  			}
   121  			switch rm.Type {
   122  			case ipv6.ICMPTypeTimeExceeded:
   123  				t.handleICMPMessage(msg)
   124  			case ipv6.ICMPTypeDestinationUnreachable:
   125  				t.handleICMPMessage(msg)
   126  			default:
   127  				//log.Println("received icmp message of unknown type", rm.Type)
   128  			}
   129  		}
   130  	}
   131  
   132  }
   133  
   134  // @title    listenTCP
   135  // @description   监听TCP的响应数据包
   136  func (t *TCPTracerv6) listenTCP() {
   137  	lc := NewPacketListener(t.tcp, t.ctx)
   138  	go lc.Start()
   139  
   140  	for {
   141  		select {
   142  		case <-t.ctx.Done():
   143  			return
   144  		case msg := <-lc.Messages:
   145  			// log.Println(msg)
   146  			// return
   147  			if msg.N == nil {
   148  				continue
   149  			}
   150  			if msg.Peer.String() != t.DestIP.String() {
   151  				continue
   152  			}
   153  
   154  			// 解包
   155  			packet := gopacket.NewPacket(msg.Msg[:*msg.N], layers.LayerTypeTCP, gopacket.Default)
   156  			// 从包中获取TCP layer信息
   157  			if tcpLayer := packet.Layer(layers.LayerTypeTCP); tcpLayer != nil {
   158  				tcp, _ := tcpLayer.(*layers.TCP)
   159  				// 取得目标主机的Sequence Number
   160  
   161  				if ch, ok := t.inflightRequest[int(tcp.Ack-1)]; ok {
   162  					// 最后一跳
   163  					ch <- Hop{
   164  						Success: true,
   165  						Address: msg.Peer,
   166  					}
   167  				}
   168  			}
   169  		}
   170  	}
   171  }
   172  
   173  func (t *TCPTracerv6) handleICMPMessage(msg ReceivedMessage) {
   174  	var sequenceNumber = binary.BigEndian.Uint32(msg.Msg[52:56])
   175  
   176  	t.inflightRequestLock.Lock()
   177  	defer t.inflightRequestLock.Unlock()
   178  	ch, ok := t.inflightRequest[int(sequenceNumber)]
   179  	if !ok {
   180  		return
   181  	}
   182  	// log.Println("发送数据", sequenceNumber)
   183  	ch <- Hop{
   184  		Success: true,
   185  		Address: msg.Peer,
   186  	}
   187  	// log.Println("发送成功")
   188  }
   189  
   190  func (t *TCPTracerv6) send(ttl int) error {
   191  	err := t.sem.Acquire(context.Background(), 1)
   192  	if err != nil {
   193  		return err
   194  	}
   195  	defer t.sem.Release(1)
   196  
   197  	defer t.wg.Done()
   198  	if t.final != -1 && ttl > t.final {
   199  		return nil
   200  	}
   201  	// 随机种子
   202  	r := rand.New(rand.NewSource(time.Now().UnixNano()))
   203  	_, srcPort := util.LocalIPPortv6(t.DestIP)
   204  	ipHeader := &layers.IPv6{
   205  		SrcIP:      t.SrcIP,
   206  		DstIP:      t.DestIP,
   207  		NextHeader: layers.IPProtocolTCP,
   208  		HopLimit:   uint8(ttl),
   209  	}
   210  	// 使用Uint16兼容32位系统,防止在rand的时候因使用int32而溢出
   211  	sequenceNumber := uint32(r.Intn(math.MaxUint16))
   212  
   213  	tcpHeader := &layers.TCP{
   214  		SrcPort: layers.TCPPort(srcPort),
   215  		DstPort: layers.TCPPort(t.DestPort),
   216  		Seq:     sequenceNumber,
   217  		SYN:     true,
   218  		Window:  14600,
   219  	}
   220  	_ = tcpHeader.SetNetworkLayerForChecksum(ipHeader)
   221  
   222  	buf := gopacket.NewSerializeBuffer()
   223  	opts := gopacket.SerializeOptions{
   224  		ComputeChecksums: true,
   225  		FixLengths:       true,
   226  	}
   227  
   228  	desiredPayloadSize := t.Config.PktSize
   229  	payload := make([]byte, desiredPayloadSize)
   230  	copy(buf.Bytes(), payload)
   231  
   232  	if err := gopacket.SerializeLayers(buf, opts, tcpHeader); err != nil {
   233  		return err
   234  	}
   235  
   236  	err = ipv6.NewPacketConn(t.tcp).SetHopLimit(ttl)
   237  	if err != nil {
   238  		return err
   239  	}
   240  	if err != nil {
   241  		return err
   242  	}
   243  
   244  	start := time.Now()
   245  	if _, err := t.tcp.WriteTo(buf.Bytes(), &net.IPAddr{IP: t.DestIP}); err != nil {
   246  		return err
   247  	}
   248  	// log.Println(ttl, sequenceNumber)
   249  	t.inflightRequestLock.Lock()
   250  	hopCh := make(chan Hop)
   251  	t.inflightRequest[int(sequenceNumber)] = hopCh
   252  	t.inflightRequestLock.Unlock()
   253  
   254  	select {
   255  	case <-t.ctx.Done():
   256  		return nil
   257  	case h := <-hopCh:
   258  		rtt := time.Since(start)
   259  		if t.final != -1 && ttl > t.final {
   260  			return nil
   261  		}
   262  
   263  		if addr, ok := h.Address.(*net.IPAddr); ok && addr.IP.Equal(t.DestIP) {
   264  			t.finalLock.Lock()
   265  			if t.final == -1 || ttl < t.final {
   266  				t.final = ttl
   267  			}
   268  			t.finalLock.Unlock()
   269  		} else if addr, ok := h.Address.(*net.TCPAddr); ok && addr.IP.Equal(t.DestIP) {
   270  			t.finalLock.Lock()
   271  			if t.final == -1 || ttl < t.final {
   272  				t.final = ttl
   273  			}
   274  			t.finalLock.Unlock()
   275  		}
   276  
   277  		h.TTL = ttl
   278  		h.RTT = rtt
   279  
   280  		t.fetchLock.Lock()
   281  		defer t.fetchLock.Unlock()
   282  		h.fetchIPData(t.Config)
   283  
   284  		t.res.add(h)
   285  
   286  	case <-time.After(t.Timeout):
   287  		if t.final != -1 && ttl > t.final {
   288  			return nil
   289  		}
   290  
   291  		t.res.add(Hop{
   292  			Success: false,
   293  			Address: nil,
   294  			TTL:     ttl,
   295  			RTT:     0,
   296  			Error:   ErrHopLimitTimeout,
   297  		})
   298  	}
   299  
   300  	return nil
   301  }