github.com/ipfans/trojan-go@v0.11.0/proxy/custom/custom.go (about)

     1  package custom
     2  
     3  import (
     4  	"context"
     5  	"strings"
     6  
     7  	"gopkg.in/yaml.v3"
     8  
     9  	"github.com/ipfans/trojan-go/common"
    10  	"github.com/ipfans/trojan-go/config"
    11  	"github.com/ipfans/trojan-go/proxy"
    12  	"github.com/ipfans/trojan-go/tunnel"
    13  )
    14  
    15  func convert(i interface{}) interface{} {
    16  	switch x := i.(type) {
    17  	case map[interface{}]interface{}:
    18  		m2 := map[string]interface{}{}
    19  		for k, v := range x {
    20  			m2[k.(string)] = convert(v)
    21  		}
    22  		return m2
    23  	case []interface{}:
    24  		for i, v := range x {
    25  			x[i] = convert(v)
    26  		}
    27  	}
    28  	return i
    29  }
    30  
    31  func buildNodes(ctx context.Context, nodeConfigList []NodeConfig) (map[string]*proxy.Node, error) {
    32  	nodes := make(map[string]*proxy.Node)
    33  	for _, nodeCfg := range nodeConfigList {
    34  		nodeCfg.Protocol = strings.ToUpper(nodeCfg.Protocol)
    35  		if _, err := tunnel.GetTunnel(nodeCfg.Protocol); err != nil {
    36  			return nil, common.NewError("invalid protocol name:" + nodeCfg.Protocol)
    37  		}
    38  		data, err := yaml.Marshal(nodeCfg.Config)
    39  		common.Must(err)
    40  		nodeContext, err := config.WithYAMLConfig(ctx, data)
    41  		if err != nil {
    42  			return nil, common.NewError("failed to parse config data for " + nodeCfg.Tag + " with protocol" + nodeCfg.Protocol).Base(err)
    43  		}
    44  		node := &proxy.Node{
    45  			Name:    nodeCfg.Protocol,
    46  			Next:    make(map[string]*proxy.Node),
    47  			Context: nodeContext,
    48  		}
    49  		nodes[nodeCfg.Tag] = node
    50  	}
    51  	return nodes, nil
    52  }
    53  
    54  func init() {
    55  	proxy.RegisterProxyCreator(Name, func(ctx context.Context) (*proxy.Proxy, error) {
    56  		cfg := config.FromContext(ctx, Name).(*Config)
    57  
    58  		ctx, cancel := context.WithCancel(ctx)
    59  		success := false
    60  		defer func() {
    61  			if !success {
    62  				cancel()
    63  			}
    64  		}()
    65  		// inbound
    66  		nodes, err := buildNodes(ctx, cfg.Inbound.Node)
    67  		if err != nil {
    68  			return nil, err
    69  		}
    70  
    71  		var root *proxy.Node
    72  		// build server tree
    73  		for _, path := range cfg.Inbound.Path {
    74  			var lastNode *proxy.Node
    75  			for _, tag := range path {
    76  				if _, found := nodes[tag]; !found {
    77  					return nil, common.NewError("invalid node tag: " + tag)
    78  				}
    79  				if lastNode == nil {
    80  					if root == nil {
    81  						lastNode = nodes[tag]
    82  						root = lastNode
    83  						t, err := tunnel.GetTunnel(root.Name)
    84  						if err != nil {
    85  							return nil, common.NewError("failed to find root tunnel").Base(err)
    86  						}
    87  						s, err := t.NewServer(root.Context, nil)
    88  						if err != nil {
    89  							return nil, common.NewError("failed to init root server").Base(err)
    90  						}
    91  						root.Server = s
    92  					} else {
    93  						lastNode = root
    94  					}
    95  				} else {
    96  					lastNode = lastNode.LinkNextNode(nodes[tag])
    97  				}
    98  			}
    99  			lastNode.IsEndpoint = true
   100  		}
   101  
   102  		servers := proxy.FindAllEndpoints(root)
   103  
   104  		if len(cfg.Outbound.Path) != 1 {
   105  			return nil, common.NewError("there must be only 1 path for outbound protocol stack")
   106  		}
   107  
   108  		// outbound
   109  		nodes, err = buildNodes(ctx, cfg.Outbound.Node)
   110  		if err != nil {
   111  			return nil, err
   112  		}
   113  
   114  		// build client stack
   115  		var client tunnel.Client
   116  		for _, tag := range cfg.Outbound.Path[0] {
   117  			if _, found := nodes[tag]; !found {
   118  				return nil, common.NewError("invalid node tag: " + tag)
   119  			}
   120  			t, err := tunnel.GetTunnel(nodes[tag].Name)
   121  			if err != nil {
   122  				return nil, common.NewError("invalid tunnel name").Base(err)
   123  			}
   124  			client, err = t.NewClient(nodes[tag].Context, client)
   125  			if err != nil {
   126  				return nil, common.NewError("failed to create client").Base(err)
   127  			}
   128  		}
   129  
   130  		success = true
   131  		return proxy.NewProxy(ctx, cancel, servers, client), nil
   132  	})
   133  }