github.com/AntonOrnatskyi/goproxy@v0.0.0-20190205095733-4526a9fa18b4/sdk/android-ios/dns.go (about)

     1  package proxy
     2  
     3  import (
     4  	"crypto/md5"
     5  	"encoding/hex"
     6  	"fmt"
     7  	logger "log"
     8  	"net"
     9  	"runtime/debug"
    10  	"time"
    11  
    12  	"golang.org/x/net/proxy"
    13  
    14  	"github.com/miekg/dns"
    15  	gocache "github.com/pmylund/go-cache"
    16  	"github.com/AntonOrnatskyi/goproxy/core/lib/kcpcfg"
    17  	services "github.com/AntonOrnatskyi/goproxy/services"
    18  )
    19  
    20  type DNSArgs struct {
    21  	ParentServiceType *string
    22  	ParentType        *string
    23  	Parent            *string
    24  	ParentAuth        *string
    25  	ParentKey         *string
    26  	ParentCompress    *bool
    27  	KCP               kcpcfg.KCPConfigArgs
    28  	CertFile          *string
    29  	KeyFile           *string
    30  	CaCertFile        *string
    31  	Local             *string
    32  	Timeout           *int
    33  	RemoteDNSAddress  *string
    34  	DNSTTL            *int
    35  	CacheFile         *string
    36  	LocalSocks5Port   *string
    37  }
    38  type DNS struct {
    39  	cfg        DNSArgs
    40  	log        *logger.Logger
    41  	cache      *gocache.Cache
    42  	exitSig    chan bool
    43  	serviceKey string
    44  	dialer     proxy.Dialer
    45  }
    46  
    47  func NewDNS() services.Service {
    48  	return &DNS{
    49  		cfg:        DNSArgs{},
    50  		exitSig:    make(chan bool, 1),
    51  		serviceKey: "dns-service-" + fmt.Sprintf("%d", time.Now().UnixNano()),
    52  	}
    53  }
    54  func (s *DNS) CheckArgs() (err error) {
    55  	return
    56  }
    57  func (s *DNS) InitService() (err error) {
    58  	s.cache = gocache.New(time.Second*time.Duration(*s.cfg.DNSTTL), time.Second*60)
    59  	s.cache.LoadFile(*s.cfg.CacheFile)
    60  	go func() {
    61  		for {
    62  			select {
    63  			case <-s.exitSig:
    64  				return
    65  			case <-time.After(time.Second * 300):
    66  				s.cache.DeleteExpired()
    67  				s.cache.SaveFile(*s.cfg.CacheFile)
    68  			}
    69  		}
    70  	}()
    71  	go func() {
    72  		defer func() {
    73  			if e := recover(); e != nil {
    74  				fmt.Printf("crashed, err: %s\nstack:%s", e, string(debug.Stack()))
    75  			}
    76  		}()
    77  		for {
    78  			select {
    79  			case <-s.exitSig:
    80  				return
    81  			case <-time.After(time.Second * 60):
    82  				err := s.cache.SaveFile(*s.cfg.CacheFile)
    83  				if err == nil {
    84  					//s.log.Printf("cache saved: %s", *s.cfg.CacheFile)
    85  				} else {
    86  					s.log.Printf("cache save failed: %s, %s", *s.cfg.CacheFile, err)
    87  				}
    88  			}
    89  		}
    90  	}()
    91  	s.dialer, err = proxy.SOCKS5("tcp", *s.cfg.Parent,
    92  		nil,
    93  		&net.Dialer{
    94  			Timeout:   5 * time.Second,
    95  			KeepAlive: 2 * time.Second,
    96  		},
    97  	)
    98  	if err != nil {
    99  		return
   100  	}
   101  
   102  	sdkArgs := fmt.Sprintf("sps -S %s -T %s -P %s -C %s -K %s -i %d -p 127.0.0.1:%s --disable-http",
   103  		*s.cfg.ParentServiceType,
   104  		*s.cfg.ParentType,
   105  		*s.cfg.Parent,
   106  		*s.cfg.CertFile,
   107  		*s.cfg.KeyFile,
   108  		*s.cfg.Timeout,
   109  		*s.cfg.LocalSocks5Port,
   110  	)
   111  	if *s.cfg.ParentKey != "" {
   112  		sdkArgs += " -Z " + *s.cfg.ParentKey
   113  	}
   114  	if *s.cfg.ParentAuth != "" {
   115  		sdkArgs += " -A " + *s.cfg.ParentAuth
   116  	}
   117  	if *s.cfg.CaCertFile != "" {
   118  		sdkArgs += " --ca " + *s.cfg.CaCertFile
   119  	}
   120  	if *s.cfg.ParentCompress {
   121  		sdkArgs += " -M"
   122  	}
   123  	s.log.Printf("start sps with : %s", sdkArgs)
   124  	errStr := Start(s.serviceKey, sdkArgs)
   125  	if errStr != "" {
   126  		err = fmt.Errorf("start sps service fail,%s", errStr)
   127  	}
   128  	return
   129  }
   130  func (s *DNS) StopService() {
   131  	defer func() {
   132  		e := recover()
   133  		if e != nil {
   134  			s.log.Printf("stop dns service crashed,%s", e)
   135  		} else {
   136  			s.log.Printf("service dns stopped")
   137  		}
   138  	}()
   139  	Stop(s.serviceKey)
   140  	s.cache.Flush()
   141  	s.exitSig <- true
   142  }
   143  func (s *DNS) Start(args interface{}, log *logger.Logger) (err error) {
   144  	s.log = log
   145  	s.cfg = args.(DNSArgs)
   146  	if err = s.CheckArgs(); err != nil {
   147  		return
   148  	}
   149  	if err = s.InitService(); err != nil {
   150  		return
   151  	}
   152  	dns.HandleFunc(".", s.callback)
   153  	go func() {
   154  		defer func() {
   155  			if e := recover(); e != nil {
   156  				fmt.Printf("crashed, err: %s\nstack:%s", e, string(debug.Stack()))
   157  			}
   158  		}()
   159  		log.Printf("dns server on udp %s", *s.cfg.Local)
   160  		err := dns.ListenAndServe(*s.cfg.Local, "udp", nil)
   161  		if err != nil {
   162  			log.Printf("dns listen error: %s", err)
   163  		}
   164  	}()
   165  	return
   166  }
   167  
   168  func (s *DNS) Clean() {
   169  	s.StopService()
   170  }
   171  func (s *DNS) callback(w dns.ResponseWriter, req *dns.Msg) {
   172  	defer func() {
   173  		if err := recover(); err != nil {
   174  			s.log.Printf("dns handler crashed with err : %s \nstack: %s", err, string(debug.Stack()))
   175  		}
   176  	}()
   177  	var (
   178  		key       string
   179  		m         *dns.Msg
   180  		err       error
   181  		data      []byte
   182  		id        uint16
   183  		query     []string
   184  		questions []dns.Question
   185  	)
   186  	if req.MsgHdr.Response == true {
   187  		return
   188  	}
   189  	query = make([]string, len(req.Question))
   190  	for i, q := range req.Question {
   191  		if q.Qtype != dns.TypeAAAA {
   192  			questions = append(questions, q)
   193  		}
   194  		query[i] = fmt.Sprintf("(%s %s %s)", q.Name, dns.ClassToString[q.Qclass], dns.TypeToString[q.Qtype])
   195  	}
   196  
   197  	if len(questions) == 0 {
   198  		return
   199  	}
   200  
   201  	req.Question = questions
   202  	id = req.Id
   203  	req.Id = 0
   204  	key = s.toMd5(req.String())
   205  	req.Id = id
   206  	if reply, ok := s.cache.Get(key); ok {
   207  		data, _ = reply.([]byte)
   208  	}
   209  	if data != nil && len(data) > 0 {
   210  		m = &dns.Msg{}
   211  		m.Unpack(data)
   212  		m.Id = id
   213  		err = w.WriteMsg(m)
   214  		s.log.Printf("id: %5d cache: HIT %v", id, query)
   215  		return
   216  
   217  	} else {
   218  		s.log.Printf("id: %5d cache: MISS %v", id, query)
   219  	}
   220  
   221  	s.log.Printf("id: %5d resolve: %v %s", id, query, *s.cfg.RemoteDNSAddress)
   222  
   223  	rawConn, err := s.dialer.Dial("tcp", *s.cfg.RemoteDNSAddress)
   224  	if err != nil {
   225  		s.log.Printf("dail to %s fail,%s", *s.cfg.RemoteDNSAddress, err)
   226  		return
   227  	}
   228  	defer rawConn.Close()
   229  	co := new(dns.Conn)
   230  	co.Conn = rawConn
   231  	defer co.Close()
   232  	if err = co.WriteMsg(req); err != nil {
   233  		s.log.Printf("write dns query fail,%s", err)
   234  		return
   235  	}
   236  	m, err = co.ReadMsg()
   237  	if err == nil && m.Id != req.Id {
   238  		s.log.Printf("id: %5d mismath", id)
   239  		return
   240  	}
   241  	if err != nil || len(m.Answer) == 0 {
   242  		s.log.Printf("dns query fail,%s", err)
   243  		return
   244  	}
   245  	data, err = m.Pack()
   246  	if err != nil {
   247  		s.log.Printf("dns query fail,%s", err)
   248  		return
   249  	}
   250  
   251  	_, err = w.Write(data)
   252  	if err != nil {
   253  		s.log.Printf("dns query fail,%s", err)
   254  		return
   255  	}
   256  	m.Id = 0
   257  	data, _ = m.Pack()
   258  	ttl := 0
   259  	if len(m.Answer) > 0 {
   260  		if *s.cfg.DNSTTL > 0 {
   261  			ttl = *s.cfg.DNSTTL
   262  		} else {
   263  			ttl = int(m.Answer[0].Header().Ttl)
   264  			if ttl < 0 {
   265  				ttl = *s.cfg.DNSTTL
   266  			}
   267  		}
   268  	}
   269  	s.cache.Set(key, data, time.Second*time.Duration(ttl))
   270  	m.Id = id
   271  	s.log.Printf("id: %5d cache: CACHED %v TTL %v", id, query, ttl)
   272  }
   273  func (s *DNS) toMd5(data string) string {
   274  	m := md5.New()
   275  	m.Write([]byte(data))
   276  	return hex.EncodeToString(m.Sum(nil))
   277  }