github.com/AntonOrnatskyi/goproxy@v0.0.0-20190205095733-4526a9fa18b4/sdk/android-ios/dns.go (about) 1 package proxy 2 3 import ( 4 "crypto/md5" 5 "encoding/hex" 6 "fmt" 7 logger "log" 8 "net" 9 "runtime/debug" 10 "time" 11 12 "golang.org/x/net/proxy" 13 14 "github.com/miekg/dns" 15 gocache "github.com/pmylund/go-cache" 16 "github.com/AntonOrnatskyi/goproxy/core/lib/kcpcfg" 17 services "github.com/AntonOrnatskyi/goproxy/services" 18 ) 19 20 type DNSArgs struct { 21 ParentServiceType *string 22 ParentType *string 23 Parent *string 24 ParentAuth *string 25 ParentKey *string 26 ParentCompress *bool 27 KCP kcpcfg.KCPConfigArgs 28 CertFile *string 29 KeyFile *string 30 CaCertFile *string 31 Local *string 32 Timeout *int 33 RemoteDNSAddress *string 34 DNSTTL *int 35 CacheFile *string 36 LocalSocks5Port *string 37 } 38 type DNS struct { 39 cfg DNSArgs 40 log *logger.Logger 41 cache *gocache.Cache 42 exitSig chan bool 43 serviceKey string 44 dialer proxy.Dialer 45 } 46 47 func NewDNS() services.Service { 48 return &DNS{ 49 cfg: DNSArgs{}, 50 exitSig: make(chan bool, 1), 51 serviceKey: "dns-service-" + fmt.Sprintf("%d", time.Now().UnixNano()), 52 } 53 } 54 func (s *DNS) CheckArgs() (err error) { 55 return 56 } 57 func (s *DNS) InitService() (err error) { 58 s.cache = gocache.New(time.Second*time.Duration(*s.cfg.DNSTTL), time.Second*60) 59 s.cache.LoadFile(*s.cfg.CacheFile) 60 go func() { 61 for { 62 select { 63 case <-s.exitSig: 64 return 65 case <-time.After(time.Second * 300): 66 s.cache.DeleteExpired() 67 s.cache.SaveFile(*s.cfg.CacheFile) 68 } 69 } 70 }() 71 go func() { 72 defer func() { 73 if e := recover(); e != nil { 74 fmt.Printf("crashed, err: %s\nstack:%s", e, string(debug.Stack())) 75 } 76 }() 77 for { 78 select { 79 case <-s.exitSig: 80 return 81 case <-time.After(time.Second * 60): 82 err := s.cache.SaveFile(*s.cfg.CacheFile) 83 if err == nil { 84 //s.log.Printf("cache saved: %s", *s.cfg.CacheFile) 85 } else { 86 s.log.Printf("cache save failed: %s, %s", *s.cfg.CacheFile, err) 87 } 88 } 89 } 90 }() 91 s.dialer, err = proxy.SOCKS5("tcp", *s.cfg.Parent, 92 nil, 93 &net.Dialer{ 94 Timeout: 5 * time.Second, 95 KeepAlive: 2 * time.Second, 96 }, 97 ) 98 if err != nil { 99 return 100 } 101 102 sdkArgs := fmt.Sprintf("sps -S %s -T %s -P %s -C %s -K %s -i %d -p 127.0.0.1:%s --disable-http", 103 *s.cfg.ParentServiceType, 104 *s.cfg.ParentType, 105 *s.cfg.Parent, 106 *s.cfg.CertFile, 107 *s.cfg.KeyFile, 108 *s.cfg.Timeout, 109 *s.cfg.LocalSocks5Port, 110 ) 111 if *s.cfg.ParentKey != "" { 112 sdkArgs += " -Z " + *s.cfg.ParentKey 113 } 114 if *s.cfg.ParentAuth != "" { 115 sdkArgs += " -A " + *s.cfg.ParentAuth 116 } 117 if *s.cfg.CaCertFile != "" { 118 sdkArgs += " --ca " + *s.cfg.CaCertFile 119 } 120 if *s.cfg.ParentCompress { 121 sdkArgs += " -M" 122 } 123 s.log.Printf("start sps with : %s", sdkArgs) 124 errStr := Start(s.serviceKey, sdkArgs) 125 if errStr != "" { 126 err = fmt.Errorf("start sps service fail,%s", errStr) 127 } 128 return 129 } 130 func (s *DNS) StopService() { 131 defer func() { 132 e := recover() 133 if e != nil { 134 s.log.Printf("stop dns service crashed,%s", e) 135 } else { 136 s.log.Printf("service dns stopped") 137 } 138 }() 139 Stop(s.serviceKey) 140 s.cache.Flush() 141 s.exitSig <- true 142 } 143 func (s *DNS) Start(args interface{}, log *logger.Logger) (err error) { 144 s.log = log 145 s.cfg = args.(DNSArgs) 146 if err = s.CheckArgs(); err != nil { 147 return 148 } 149 if err = s.InitService(); err != nil { 150 return 151 } 152 dns.HandleFunc(".", s.callback) 153 go func() { 154 defer func() { 155 if e := recover(); e != nil { 156 fmt.Printf("crashed, err: %s\nstack:%s", e, string(debug.Stack())) 157 } 158 }() 159 log.Printf("dns server on udp %s", *s.cfg.Local) 160 err := dns.ListenAndServe(*s.cfg.Local, "udp", nil) 161 if err != nil { 162 log.Printf("dns listen error: %s", err) 163 } 164 }() 165 return 166 } 167 168 func (s *DNS) Clean() { 169 s.StopService() 170 } 171 func (s *DNS) callback(w dns.ResponseWriter, req *dns.Msg) { 172 defer func() { 173 if err := recover(); err != nil { 174 s.log.Printf("dns handler crashed with err : %s \nstack: %s", err, string(debug.Stack())) 175 } 176 }() 177 var ( 178 key string 179 m *dns.Msg 180 err error 181 data []byte 182 id uint16 183 query []string 184 questions []dns.Question 185 ) 186 if req.MsgHdr.Response == true { 187 return 188 } 189 query = make([]string, len(req.Question)) 190 for i, q := range req.Question { 191 if q.Qtype != dns.TypeAAAA { 192 questions = append(questions, q) 193 } 194 query[i] = fmt.Sprintf("(%s %s %s)", q.Name, dns.ClassToString[q.Qclass], dns.TypeToString[q.Qtype]) 195 } 196 197 if len(questions) == 0 { 198 return 199 } 200 201 req.Question = questions 202 id = req.Id 203 req.Id = 0 204 key = s.toMd5(req.String()) 205 req.Id = id 206 if reply, ok := s.cache.Get(key); ok { 207 data, _ = reply.([]byte) 208 } 209 if data != nil && len(data) > 0 { 210 m = &dns.Msg{} 211 m.Unpack(data) 212 m.Id = id 213 err = w.WriteMsg(m) 214 s.log.Printf("id: %5d cache: HIT %v", id, query) 215 return 216 217 } else { 218 s.log.Printf("id: %5d cache: MISS %v", id, query) 219 } 220 221 s.log.Printf("id: %5d resolve: %v %s", id, query, *s.cfg.RemoteDNSAddress) 222 223 rawConn, err := s.dialer.Dial("tcp", *s.cfg.RemoteDNSAddress) 224 if err != nil { 225 s.log.Printf("dail to %s fail,%s", *s.cfg.RemoteDNSAddress, err) 226 return 227 } 228 defer rawConn.Close() 229 co := new(dns.Conn) 230 co.Conn = rawConn 231 defer co.Close() 232 if err = co.WriteMsg(req); err != nil { 233 s.log.Printf("write dns query fail,%s", err) 234 return 235 } 236 m, err = co.ReadMsg() 237 if err == nil && m.Id != req.Id { 238 s.log.Printf("id: %5d mismath", id) 239 return 240 } 241 if err != nil || len(m.Answer) == 0 { 242 s.log.Printf("dns query fail,%s", err) 243 return 244 } 245 data, err = m.Pack() 246 if err != nil { 247 s.log.Printf("dns query fail,%s", err) 248 return 249 } 250 251 _, err = w.Write(data) 252 if err != nil { 253 s.log.Printf("dns query fail,%s", err) 254 return 255 } 256 m.Id = 0 257 data, _ = m.Pack() 258 ttl := 0 259 if len(m.Answer) > 0 { 260 if *s.cfg.DNSTTL > 0 { 261 ttl = *s.cfg.DNSTTL 262 } else { 263 ttl = int(m.Answer[0].Header().Ttl) 264 if ttl < 0 { 265 ttl = *s.cfg.DNSTTL 266 } 267 } 268 } 269 s.cache.Set(key, data, time.Second*time.Duration(ttl)) 270 m.Id = id 271 s.log.Printf("id: %5d cache: CACHED %v TTL %v", id, query, ttl) 272 } 273 func (s *DNS) toMd5(data string) string { 274 m := md5.New() 275 m.Write([]byte(data)) 276 return hex.EncodeToString(m.Sum(nil)) 277 }