github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/components/inbound/inbound.go (about) 1 package inbound 2 3 import ( 4 "context" 5 "errors" 6 7 "github.com/Asutorufa/yuhaiin/pkg/log" 8 "github.com/Asutorufa/yuhaiin/pkg/net/netapi" 9 pc "github.com/Asutorufa/yuhaiin/pkg/protos/config" 10 pl "github.com/Asutorufa/yuhaiin/pkg/protos/config/listener" 11 "github.com/Asutorufa/yuhaiin/pkg/utils/syncmap" 12 "google.golang.org/protobuf/proto" 13 ) 14 15 type key struct { 16 name string 17 old bool 18 } 19 20 type entry struct { 21 config *pl.Inbound 22 server netapi.Accepter 23 } 24 25 type listener struct { 26 store syncmap.SyncMap[key, entry] 27 28 handler *handler 29 30 ctx context.Context 31 close context.CancelFunc 32 33 tcpChannel chan *netapi.StreamMeta 34 udpChannel chan *netapi.Packet 35 36 hijackDNS bool 37 fakeip bool 38 } 39 40 func NewListener(dnsHandler netapi.DNSServer, dialer netapi.Proxy) *listener { 41 ctx, cancel := context.WithCancel(context.Background()) 42 43 l := &listener{ 44 handler: NewHandler(dialer, dnsHandler), 45 ctx: ctx, 46 close: cancel, 47 tcpChannel: make(chan *netapi.StreamMeta, 100), 48 udpChannel: make(chan *netapi.Packet, 100), 49 50 hijackDNS: true, 51 fakeip: true, 52 } 53 54 go l.tcp() 55 go l.udp() 56 57 return l 58 } 59 60 func (l *listener) tcp() { 61 for { 62 select { 63 case <-l.ctx.Done(): 64 return 65 case stream := <-l.tcpChannel: 66 if stream.Address.Port().Port() == 53 && l.hijackDNS { 67 err := l.handler.dnsHandler.HandleTCP(l.ctx, stream.Src) 68 _ = stream.Src.Close() 69 if err != nil { 70 if errors.Is(err, netapi.ErrBlocked) { 71 log.Debug("blocked", "msg", err) 72 } else { 73 log.Error("tcp server handle DnsHijacking failed", "err", err) 74 } 75 } 76 continue 77 } 78 79 l.handler.Stream(l.ctx, stream) 80 } 81 } 82 } 83 84 func (l *listener) udp() { 85 for { 86 select { 87 case <-l.ctx.Done(): 88 return 89 case packet := <-l.udpChannel: 90 if packet.Dst.Port().Port() == 53 && l.hijackDNS { 91 go func() { 92 ctx := l.ctx 93 if l.fakeip { 94 ctx = context.WithValue(ctx, 95 netapi.ForceFakeIP{}, true) 96 } 97 98 err := l.handler.dnsHandler.Do(ctx, packet.Payload, func(b []byte) error { 99 _, err := packet.WriteBack(b, packet.Dst) 100 return err 101 }) 102 if err != nil { 103 if errors.Is(err, netapi.ErrBlocked) { 104 log.Debug("blocked", "msg", err) 105 } else { 106 log.Error("udp server handle DnsHijacking failed", "err", err) 107 } 108 } 109 }() 110 111 continue 112 } 113 114 l.handler.Packet(l.ctx, packet) 115 } 116 } 117 } 118 119 func (l *listener) Update(current *pc.Setting) { 120 // l.hijackDNS = current.Server.HijackDns 121 l.fakeip = current.Server.HijackDnsFakeip 122 // l.handler.sniffyEnabled = current.GetBypass().GetSniffy() 123 124 l.store.Range(func(key key, v entry) bool { 125 var z interface{ GetEnabled() bool } 126 var ok bool 127 if key.old { 128 z, ok = current.Server.Servers[key.name] 129 } else { 130 z, ok = current.Server.Inbounds[key.name] 131 } 132 133 if !ok || !z.GetEnabled() { 134 v.server.Close() 135 l.store.Delete(key) 136 } 137 138 return true 139 }) 140 141 for k, v := range current.Server.Servers { 142 l.start(key{k, true}, v.ToInbound()) 143 } 144 145 for k, v := range current.Server.Inbounds { 146 l.start(key{k, false}, v) 147 } 148 } 149 150 func (l *listener) start(key key, config *pl.Inbound) { 151 if config == nil { 152 return 153 } 154 155 v, ok := l.store.Load(key) 156 if ok { 157 if proto.Equal(v.config, config) { 158 return 159 } 160 v.server.Close() 161 l.store.Delete(key) 162 } 163 164 if !config.GetEnabled() { 165 log.Debug("server disabled", "name", key) 166 return 167 } 168 169 server, err := pl.Listen(config) 170 if err != nil { 171 log.Error("start server failed", "name", key, "err", err) 172 return 173 } 174 175 go func() { 176 for { 177 stream, err := server.AcceptStream() 178 if err != nil { 179 log.Error("accept stream failed", "err", err) 180 return 181 } 182 183 select { 184 case <-l.ctx.Done(): 185 return 186 case l.tcpChannel <- stream: 187 } 188 } 189 }() 190 191 go func() { 192 for { 193 packet, err := server.AcceptPacket() 194 if err != nil { 195 log.Error("accept packet failed", "err", err) 196 return 197 } 198 199 select { 200 case <-l.ctx.Done(): 201 return 202 case l.udpChannel <- packet: 203 } 204 } 205 }() 206 207 l.store.Store(key, entry{config, server}) 208 } 209 210 func (l *listener) Close() error { 211 l.close() 212 l.store.Range(func(key key, value entry) bool { 213 log.Info("start close server", "name", key) 214 defer log.Info("closed server", "name", key) 215 value.server.Close() 216 l.store.Delete(key) 217 return true 218 }) 219 return l.handler.Close() 220 }