github.com/nxtrace/NTrace-core@v1.3.1-0.20240513132635-39169291e8c9/trace/icmp_ipv4.go (about)

     1  package trace
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"errors"
     7  	"fmt"
     8  	"log"
     9  	"net"
    10  	"os"
    11  	"strconv"
    12  	"sync"
    13  	"time"
    14  
    15  	"golang.org/x/net/context"
    16  	"golang.org/x/net/icmp"
    17  	"golang.org/x/net/ipv4"
    18  
    19  	"github.com/nxtrace/NTrace-core/trace/internal"
    20  )
    21  
    22  type ICMPTracer struct {
    23  	Config
    24  	wg                    sync.WaitGroup
    25  	res                   Result
    26  	ctx                   context.Context
    27  	inflightRequest       map[int]chan Hop
    28  	inflightRequestRWLock sync.RWMutex
    29  	icmpListen            net.PacketConn
    30  	final                 int
    31  	finalLock             sync.Mutex
    32  	fetchLock             sync.Mutex
    33  }
    34  
    35  var psize = 52
    36  
    37  func (t *ICMPTracer) PrintFunc() {
    38  	defer t.wg.Done()
    39  	var ttl = t.Config.BeginHop - 1
    40  	for {
    41  		if t.AsyncPrinter != nil {
    42  			t.AsyncPrinter(&t.res)
    43  		}
    44  		// 接收的时候检查一下是不是 3 跳都齐了
    45  		if len(t.res.Hops)-1 > ttl {
    46  			if len(t.res.Hops[ttl]) == t.NumMeasurements {
    47  				if t.RealtimePrinter != nil {
    48  					t.RealtimePrinter(&t.res, ttl)
    49  				}
    50  				ttl++
    51  
    52  				if ttl == t.final-1 || ttl >= t.MaxHops-1 {
    53  					return
    54  				}
    55  			}
    56  		}
    57  		<-time.After(200 * time.Millisecond)
    58  	}
    59  }
    60  
    61  func (t *ICMPTracer) Execute() (*Result, error) {
    62  	t.inflightRequestRWLock.Lock()
    63  	t.inflightRequest = make(map[int]chan Hop)
    64  	t.inflightRequestRWLock.Unlock()
    65  
    66  	if len(t.res.Hops) > 0 {
    67  		return &t.res, ErrTracerouteExecuted
    68  	}
    69  
    70  	var err error
    71  
    72  	t.icmpListen, err = internal.ListenICMP("ip4:1", t.SrcAddr)
    73  	if err != nil {
    74  		return &t.res, err
    75  	}
    76  	defer t.icmpListen.Close()
    77  
    78  	var cancel context.CancelFunc
    79  	t.ctx, cancel = context.WithCancel(context.Background())
    80  	defer cancel()
    81  	t.final = -1
    82  
    83  	go t.listenICMP()
    84  	t.wg.Add(1)
    85  	go t.PrintFunc()
    86  	for ttl := t.BeginHop; ttl <= t.MaxHops; ttl++ {
    87  		t.inflightRequestRWLock.Lock()
    88  		t.inflightRequest[ttl] = make(chan Hop, t.NumMeasurements)
    89  		t.inflightRequestRWLock.Unlock()
    90  		if t.final != -1 && ttl > t.final {
    91  			break
    92  		}
    93  		for i := 0; i < t.NumMeasurements; i++ {
    94  			t.wg.Add(1)
    95  			go t.send(ttl)
    96  			<-time.After(time.Millisecond * time.Duration(t.Config.PacketInterval))
    97  		}
    98  		<-time.After(time.Millisecond * time.Duration(t.Config.TTLInterval))
    99  	}
   100  
   101  	t.wg.Wait()
   102  	t.res.reduce(t.final)
   103  	if t.final != -1 {
   104  		if t.RealtimePrinter != nil {
   105  			t.RealtimePrinter(&t.res, t.final-1)
   106  		}
   107  	} else {
   108  		for i := 0; i < t.NumMeasurements; i++ {
   109  			t.res.add(Hop{
   110  				Success: false,
   111  				Address: nil,
   112  				TTL:     30,
   113  				RTT:     0,
   114  				Error:   ErrHopLimitTimeout,
   115  			})
   116  		}
   117  		if t.RealtimePrinter != nil {
   118  			t.RealtimePrinter(&t.res, t.MaxHops-1)
   119  		}
   120  	}
   121  	return &t.res, nil
   122  }
   123  
   124  func (t *ICMPTracer) listenICMP() {
   125  	lc := NewPacketListener(t.icmpListen, t.ctx)
   126  	psize = t.Config.PktSize
   127  	go lc.Start()
   128  	for {
   129  		select {
   130  		case <-t.ctx.Done():
   131  			return
   132  		case msg := <-lc.Messages:
   133  			if msg.N == nil {
   134  				continue
   135  			}
   136  			// log.Println(msg.Msg)
   137  			if msg.Msg[0] == 0 {
   138  				rm, err := icmp.ParseMessage(1, msg.Msg[:*msg.N])
   139  				if err != nil {
   140  					log.Println(err)
   141  					continue
   142  				}
   143  				echoReply := rm.Body.(*icmp.Echo)
   144  				ttl := echoReply.Seq // This is the TTL value
   145  				if ttl > 100 {
   146  					continue
   147  				}
   148  				if msg.Peer.String() == t.DestIP.String() {
   149  					t.handleICMPMessage(msg, 1, rm.Body.(*icmp.Echo).Data, ttl)
   150  				}
   151  				continue
   152  			}
   153  			ttl := int64(binary.BigEndian.Uint16(msg.Msg[34:36]))
   154  			packetId := strconv.FormatInt(int64(binary.BigEndian.Uint16(msg.Msg[32:34])), 2)
   155  			if processId, _, err := reverseID(packetId); err == nil {
   156  				if processId == int64(os.Getpid()&0x7f) {
   157  					dstip := net.IP(msg.Msg[24:28])
   158  					if dstip.Equal(t.DestIP) || dstip.Equal(net.IPv4zero) {
   159  						// 匹配再继续解析包,否则直接丢弃
   160  						rm, err := icmp.ParseMessage(1, msg.Msg[:*msg.N])
   161  						if err != nil {
   162  							log.Println(err)
   163  							continue
   164  						}
   165  
   166  						switch rm.Type {
   167  						case ipv4.ICMPTypeTimeExceeded:
   168  							t.handleICMPMessage(msg, 0, rm.Body.(*icmp.TimeExceeded).Data, int(ttl))
   169  						case ipv4.ICMPTypeEchoReply:
   170  							t.handleICMPMessage(msg, 1, rm.Body.(*icmp.Echo).Data, int(ttl))
   171  						//unreachable
   172  						case ipv4.ICMPTypeDestinationUnreachable:
   173  							t.handleICMPMessage(msg, 2, rm.Body.(*icmp.DstUnreach).Data, int(ttl))
   174  						default:
   175  							// log.Println("received icmp message of unknown type", rm.Type)
   176  						}
   177  					}
   178  				}
   179  			}
   180  		}
   181  	}
   182  
   183  }
   184  
   185  func (t *ICMPTracer) handleICMPMessage(msg ReceivedMessage, icmpType int8, data []byte, ttl int) {
   186  	if icmpType == 2 {
   187  		if t.DestIP.String() != msg.Peer.String() {
   188  			return
   189  		}
   190  	}
   191  
   192  	t.inflightRequestRWLock.RLock()
   193  	defer t.inflightRequestRWLock.RUnlock()
   194  
   195  	mpls := extractMPLS(msg, data)
   196  	if _, ok := t.inflightRequest[ttl]; ok {
   197  		t.inflightRequest[ttl] <- Hop{
   198  			Success: true,
   199  			Address: msg.Peer,
   200  			MPLS:    mpls,
   201  		}
   202  	}
   203  }
   204  
   205  func gernerateID(ttlInt int) int {
   206  	const IdFixedHeader = "10"
   207  	var processID = fmt.Sprintf("%07b", os.Getpid()&0x7f) //取进程ID的前7位
   208  	var ttl = fmt.Sprintf("%06b", ttlInt)                 //取TTL的后6位
   209  
   210  	var parity int
   211  	id := IdFixedHeader + processID + ttl
   212  	for _, c := range id {
   213  		if c == '1' {
   214  			parity++
   215  		}
   216  	}
   217  	if parity%2 == 0 {
   218  		id += "1"
   219  	} else {
   220  		id += "0"
   221  	}
   222  
   223  	res, _ := strconv.ParseInt(id, 2, 32)
   224  	return int(res)
   225  }
   226  
   227  func reverseID(id string) (int64, int64, error) {
   228  	if len(id) < 16 {
   229  		return 0, 0, errors.New("err")
   230  	}
   231  	ttl, err := strconv.ParseInt(id[9:15], 2, 32)
   232  	if err != nil {
   233  		return 0, 0, err
   234  	}
   235  	//process ID
   236  	processID, _ := strconv.ParseInt(id[2:9], 2, 32)
   237  
   238  	parity := 0
   239  	for i := 0; i < len(id)-1; i++ {
   240  		if id[i] == '1' {
   241  			parity++
   242  		}
   243  	}
   244  
   245  	if parity%2 == 1 {
   246  		if id[len(id)-1] == '0' {
   247  			// fmt.Println("Parity check passed.")
   248  		} else {
   249  			// fmt.Println("Parity check failed.")
   250  			return 0, 0, errors.New("err")
   251  		}
   252  	} else {
   253  		if id[len(id)-1] == '1' {
   254  			// fmt.Println("Parity check passed.")
   255  		} else {
   256  			// fmt.Println("Parity check failed.")
   257  			return 0, 0, errors.New("err")
   258  		}
   259  	}
   260  	return processID, ttl, nil
   261  }
   262  
   263  func (t *ICMPTracer) send(ttl int) error {
   264  
   265  	defer t.wg.Done()
   266  	if t.final != -1 && ttl > t.final {
   267  		return nil
   268  	}
   269  
   270  	//id := gernerateID(ttl)
   271  	id := gernerateID(0)
   272  	// log.Println("发送的", id)
   273  
   274  	//data := []byte{byte(ttl)}
   275  	data := []byte{byte(0)}
   276  	data = append(data, bytes.Repeat([]byte{1}, t.Config.PktSize-5)...)
   277  	data = append(data, 0x00, 0x00, 0x4f, 0xff)
   278  
   279  	icmpHeader := icmp.Message{
   280  		Type: ipv4.ICMPTypeEcho, Code: 0,
   281  		Body: &icmp.Echo{
   282  			ID: id,
   283  			//Data: []byte("HELLO-R-U-THERE"),
   284  			Data: data,
   285  			Seq:  ttl,
   286  		},
   287  	}
   288  
   289  	err := ipv4.NewPacketConn(t.icmpListen).SetTTL(ttl)
   290  	if err != nil {
   291  		return err
   292  	}
   293  
   294  	wb, err := icmpHeader.Marshal(nil)
   295  	if err != nil {
   296  		log.Fatal(err)
   297  	}
   298  
   299  	start := time.Now()
   300  	if _, err := t.icmpListen.WriteTo(wb, &net.IPAddr{IP: t.DestIP}); err != nil {
   301  		log.Fatal(err)
   302  	}
   303  	if err := t.icmpListen.SetReadDeadline(time.Now().Add(3 * time.Second)); err != nil {
   304  		log.Fatal(err)
   305  	}
   306  	select {
   307  	case <-t.ctx.Done():
   308  		return nil
   309  	case h := <-t.inflightRequest[ttl]:
   310  		rtt := time.Since(start)
   311  		if t.final != -1 && ttl > t.final {
   312  			return nil
   313  		}
   314  		if addr, ok := h.Address.(*net.IPAddr); ok && addr.IP.Equal(t.DestIP) {
   315  			t.finalLock.Lock()
   316  			if t.final == -1 || ttl < t.final {
   317  
   318  				t.final = ttl
   319  			}
   320  			t.finalLock.Unlock()
   321  		} else if addr, ok := h.Address.(*net.TCPAddr); ok && addr.IP.Equal(t.DestIP) {
   322  			t.finalLock.Lock()
   323  			if t.final == -1 || ttl < t.final {
   324  				t.final = ttl
   325  			}
   326  			t.finalLock.Unlock()
   327  		}
   328  
   329  		h.TTL = ttl
   330  		h.RTT = rtt
   331  
   332  		t.fetchLock.Lock()
   333  		defer t.fetchLock.Unlock()
   334  		err := h.fetchIPData(t.Config)
   335  		if err != nil {
   336  			return err
   337  		}
   338  
   339  		t.res.add(h)
   340  	case <-time.After(t.Timeout):
   341  		if t.final != -1 && ttl > t.final {
   342  			return nil
   343  		}
   344  
   345  		t.res.add(Hop{
   346  			Success: false,
   347  			Address: nil,
   348  			TTL:     ttl,
   349  			RTT:     0,
   350  			Error:   ErrHopLimitTimeout,
   351  		})
   352  
   353  	}
   354  
   355  	return nil
   356  }