github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/components/inbound/inbound.go (about)

     1  package inbound
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  
     7  	"github.com/Asutorufa/yuhaiin/pkg/log"
     8  	"github.com/Asutorufa/yuhaiin/pkg/net/netapi"
     9  	pc "github.com/Asutorufa/yuhaiin/pkg/protos/config"
    10  	pl "github.com/Asutorufa/yuhaiin/pkg/protos/config/listener"
    11  	"github.com/Asutorufa/yuhaiin/pkg/utils/syncmap"
    12  	"google.golang.org/protobuf/proto"
    13  )
    14  
    15  type key struct {
    16  	name string
    17  	old  bool
    18  }
    19  
    20  type entry struct {
    21  	config *pl.Inbound
    22  	server netapi.Accepter
    23  }
    24  
    25  type listener struct {
    26  	store syncmap.SyncMap[key, entry]
    27  
    28  	handler *handler
    29  
    30  	ctx   context.Context
    31  	close context.CancelFunc
    32  
    33  	tcpChannel chan *netapi.StreamMeta
    34  	udpChannel chan *netapi.Packet
    35  
    36  	hijackDNS bool
    37  	fakeip    bool
    38  }
    39  
    40  func NewListener(dnsHandler netapi.DNSServer, dialer netapi.Proxy) *listener {
    41  	ctx, cancel := context.WithCancel(context.Background())
    42  
    43  	l := &listener{
    44  		handler:    NewHandler(dialer, dnsHandler),
    45  		ctx:        ctx,
    46  		close:      cancel,
    47  		tcpChannel: make(chan *netapi.StreamMeta, 100),
    48  		udpChannel: make(chan *netapi.Packet, 100),
    49  
    50  		hijackDNS: true,
    51  		fakeip:    true,
    52  	}
    53  
    54  	go l.tcp()
    55  	go l.udp()
    56  
    57  	return l
    58  }
    59  
    60  func (l *listener) tcp() {
    61  	for {
    62  		select {
    63  		case <-l.ctx.Done():
    64  			return
    65  		case stream := <-l.tcpChannel:
    66  			if stream.Address.Port().Port() == 53 && l.hijackDNS {
    67  				err := l.handler.dnsHandler.HandleTCP(l.ctx, stream.Src)
    68  				_ = stream.Src.Close()
    69  				if err != nil {
    70  					if errors.Is(err, netapi.ErrBlocked) {
    71  						log.Debug("blocked", "msg", err)
    72  					} else {
    73  						log.Error("tcp server handle DnsHijacking failed", "err", err)
    74  					}
    75  				}
    76  				continue
    77  			}
    78  
    79  			l.handler.Stream(l.ctx, stream)
    80  		}
    81  	}
    82  }
    83  
    84  func (l *listener) udp() {
    85  	for {
    86  		select {
    87  		case <-l.ctx.Done():
    88  			return
    89  		case packet := <-l.udpChannel:
    90  			if packet.Dst.Port().Port() == 53 && l.hijackDNS {
    91  				go func() {
    92  					ctx := l.ctx
    93  					if l.fakeip {
    94  						ctx = context.WithValue(ctx,
    95  							netapi.ForceFakeIP{}, true)
    96  					}
    97  
    98  					err := l.handler.dnsHandler.Do(ctx, packet.Payload, func(b []byte) error {
    99  						_, err := packet.WriteBack(b, packet.Dst)
   100  						return err
   101  					})
   102  					if err != nil {
   103  						if errors.Is(err, netapi.ErrBlocked) {
   104  							log.Debug("blocked", "msg", err)
   105  						} else {
   106  							log.Error("udp server handle DnsHijacking failed", "err", err)
   107  						}
   108  					}
   109  				}()
   110  
   111  				continue
   112  			}
   113  
   114  			l.handler.Packet(l.ctx, packet)
   115  		}
   116  	}
   117  }
   118  
   119  func (l *listener) Update(current *pc.Setting) {
   120  	// l.hijackDNS = current.Server.HijackDns
   121  	l.fakeip = current.Server.HijackDnsFakeip
   122  	// l.handler.sniffyEnabled = current.GetBypass().GetSniffy()
   123  
   124  	l.store.Range(func(key key, v entry) bool {
   125  		var z interface{ GetEnabled() bool }
   126  		var ok bool
   127  		if key.old {
   128  			z, ok = current.Server.Servers[key.name]
   129  		} else {
   130  			z, ok = current.Server.Inbounds[key.name]
   131  		}
   132  
   133  		if !ok || !z.GetEnabled() {
   134  			v.server.Close()
   135  			l.store.Delete(key)
   136  		}
   137  
   138  		return true
   139  	})
   140  
   141  	for k, v := range current.Server.Servers {
   142  		l.start(key{k, true}, v.ToInbound())
   143  	}
   144  
   145  	for k, v := range current.Server.Inbounds {
   146  		l.start(key{k, false}, v)
   147  	}
   148  }
   149  
   150  func (l *listener) start(key key, config *pl.Inbound) {
   151  	if config == nil {
   152  		return
   153  	}
   154  
   155  	v, ok := l.store.Load(key)
   156  	if ok {
   157  		if proto.Equal(v.config, config) {
   158  			return
   159  		}
   160  		v.server.Close()
   161  		l.store.Delete(key)
   162  	}
   163  
   164  	if !config.GetEnabled() {
   165  		log.Debug("server disabled", "name", key)
   166  		return
   167  	}
   168  
   169  	server, err := pl.Listen(config)
   170  	if err != nil {
   171  		log.Error("start server failed", "name", key, "err", err)
   172  		return
   173  	}
   174  
   175  	go func() {
   176  		for {
   177  			stream, err := server.AcceptStream()
   178  			if err != nil {
   179  				log.Error("accept stream failed", "err", err)
   180  				return
   181  			}
   182  
   183  			select {
   184  			case <-l.ctx.Done():
   185  				return
   186  			case l.tcpChannel <- stream:
   187  			}
   188  		}
   189  	}()
   190  
   191  	go func() {
   192  		for {
   193  			packet, err := server.AcceptPacket()
   194  			if err != nil {
   195  				log.Error("accept packet failed", "err", err)
   196  				return
   197  			}
   198  
   199  			select {
   200  			case <-l.ctx.Done():
   201  				return
   202  			case l.udpChannel <- packet:
   203  			}
   204  		}
   205  	}()
   206  
   207  	l.store.Store(key, entry{config, server})
   208  }
   209  
   210  func (l *listener) Close() error {
   211  	l.close()
   212  	l.store.Range(func(key key, value entry) bool {
   213  		log.Info("start close server", "name", key)
   214  		defer log.Info("closed server", "name", key)
   215  		value.server.Close()
   216  		l.store.Delete(key)
   217  		return true
   218  	})
   219  	return l.handler.Close()
   220  }