github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/node/outbound.go (about)

     1  package node
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"math/rand/v2"
     7  	"net"
     8  	"net/http"
     9  	"time"
    10  
    11  	"github.com/Asutorufa/yuhaiin/pkg/net/netapi"
    12  	"github.com/Asutorufa/yuhaiin/pkg/net/proxy/direct"
    13  	"github.com/Asutorufa/yuhaiin/pkg/net/proxy/drop"
    14  	"github.com/Asutorufa/yuhaiin/pkg/protos/config/bypass"
    15  	"github.com/Asutorufa/yuhaiin/pkg/protos/node"
    16  	"github.com/Asutorufa/yuhaiin/pkg/protos/node/point"
    17  	pt "github.com/Asutorufa/yuhaiin/pkg/protos/node/tag"
    18  	"github.com/Asutorufa/yuhaiin/pkg/utils/jsondb"
    19  	"github.com/Asutorufa/yuhaiin/pkg/utils/lru"
    20  )
    21  
    22  type outbound struct {
    23  	manager *manager
    24  	db      *jsondb.DB[*node.Node]
    25  
    26  	lruCache *lru.LRU[string, netapi.Proxy]
    27  }
    28  
    29  func NewOutbound(db *jsondb.DB[*node.Node], mamanager *manager) *outbound {
    30  	return &outbound{
    31  		manager:  mamanager,
    32  		db:       db,
    33  		lruCache: lru.New(lru.WithCapacity[string, netapi.Proxy](200)),
    34  	}
    35  }
    36  
    37  type TagKey struct{}
    38  
    39  func (TagKey) String() string { return "Tag" }
    40  
    41  func (o *outbound) getNowPoint(p *point.Point) *point.Point {
    42  	pp, ok := o.manager.GetNodeByName(p.Group, p.Name)
    43  	if ok {
    44  		return pp
    45  	}
    46  
    47  	return p
    48  }
    49  
    50  func (o *outbound) GetDialer(p *point.Point) (netapi.Proxy, error) {
    51  	if p.Hash == "" {
    52  		return point.Dialer(p)
    53  	}
    54  
    55  	var err error
    56  	r, ok := o.lruCache.Load(p.Hash)
    57  	if !ok {
    58  		r, err = point.Dialer(p)
    59  		if err != nil {
    60  			return nil, err
    61  		}
    62  
    63  		o.lruCache.Add(p.Hash, r)
    64  	}
    65  
    66  	return r, nil
    67  }
    68  
    69  type HashKey struct{}
    70  
    71  func (HashKey) String() string { return "Hash" }
    72  
    73  func (o *outbound) Get(ctx context.Context, network string, str string, tag string) (netapi.Proxy, error) {
    74  	if tag != "" {
    75  		netapi.StoreFromContext(ctx).Add(TagKey{}, tag)
    76  		if hash := o.tagConn(tag); hash != "" {
    77  			p := o.GetDialerByHash(ctx, hash)
    78  			if p != nil {
    79  				return p, nil
    80  			}
    81  		}
    82  	}
    83  
    84  	switch str {
    85  	case bypass.Mode_direct.String():
    86  		return direct.Default, nil
    87  	case bypass.Mode_block.String():
    88  		return drop.Drop, nil
    89  	}
    90  
    91  	if len(network) < 3 {
    92  		return nil, fmt.Errorf("invalid network: %s", network)
    93  	}
    94  
    95  	var point *point.Point
    96  	switch network[:3] {
    97  	case "tcp":
    98  		point = o.getNowPoint(o.db.Data.Tcp)
    99  	case "udp":
   100  		point = o.getNowPoint(o.db.Data.Udp)
   101  	default:
   102  		return nil, fmt.Errorf("invalid network: %s", network)
   103  	}
   104  
   105  	p, err := o.GetDialer(point)
   106  	if err != nil {
   107  		return nil, err
   108  	}
   109  
   110  	netapi.StoreFromContext(ctx).Add(HashKey{}, point.Hash)
   111  	return p, nil
   112  }
   113  
   114  func (o *outbound) GetDialerByHash(ctx context.Context, hash string) netapi.Proxy {
   115  	v, ok := o.lruCache.Load(hash)
   116  	if !ok {
   117  		p, ok := o.manager.GetNode(hash)
   118  		if !ok {
   119  			return nil
   120  		}
   121  
   122  		var err error
   123  		v, err = point.Dialer(p)
   124  		if err != nil {
   125  			return nil
   126  		}
   127  
   128  		o.lruCache.Add(hash, v)
   129  	}
   130  
   131  	netapi.StoreFromContext(ctx).Add(HashKey{}, hash)
   132  	return v
   133  }
   134  
   135  func (o *outbound) tagConn(tag string) string {
   136  _retry:
   137  	t, ok := o.manager.ExistTag(tag)
   138  	if !ok || len(t.Hash) <= 0 {
   139  		return ""
   140  	}
   141  
   142  	if t.Type == pt.TagType_mirror {
   143  		if tag == t.Hash[0] {
   144  			return ""
   145  		}
   146  		tag = t.Hash[0]
   147  		goto _retry
   148  	}
   149  
   150  	hash := t.Hash[rand.IntN(len(t.Hash))]
   151  
   152  	return hash
   153  }
   154  
   155  func (o *outbound) Do(req *http.Request) (*http.Response, error) {
   156  	f, err := o.Get(req.Context(), "tcp", bypass.Mode_proxy.String(), "")
   157  	if err != nil {
   158  		return nil, err
   159  	}
   160  
   161  	c := &http.Client{
   162  		Timeout: time.Minute * 2,
   163  		Transport: &http.Transport{
   164  			DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
   165  				ad, err := netapi.ParseAddress(netapi.PaseNetwork(network), addr)
   166  				if err != nil {
   167  					return nil, fmt.Errorf("parse address failed: %w", err)
   168  				}
   169  
   170  				return f.Conn(ctx, ad)
   171  			},
   172  		},
   173  	}
   174  
   175  	r, err := c.Do(req)
   176  	if err == nil {
   177  		return r, nil
   178  	}
   179  
   180  	f = direct.Default
   181  
   182  	return c.Do(req)
   183  }