github.com/TeaOSLab/EdgeNode@v1.3.8/internal/nodes/listener_base.go (about)

     1  package nodes
     2  
     3  import (
     4  	"crypto/tls"
     5  	"errors"
     6  	"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
     7  	"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
     8  	"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs"
     9  	"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
    10  	"github.com/TeaOSLab/EdgeNode/internal/utils"
    11  	"github.com/iwind/TeaGo/types"
    12  	"net"
    13  )
    14  
    15  type BaseListener struct {
    16  	Group *serverconfigs.ServerAddressGroup
    17  
    18  	countActiveConnections int64 // 当前活跃的连接数
    19  }
    20  
    21  // Init 初始化
    22  func (this *BaseListener) Init() {
    23  }
    24  
    25  // Reset 清除既有配置
    26  func (this *BaseListener) Reset() {
    27  
    28  }
    29  
    30  // CountActiveConnections 获取当前活跃连接数
    31  func (this *BaseListener) CountActiveConnections() int {
    32  	return types.Int(this.countActiveConnections)
    33  }
    34  
    35  // 构造TLS配置
    36  func (this *BaseListener) buildTLSConfig() *tls.Config {
    37  	return &tls.Config{
    38  		Certificates: nil,
    39  		GetConfigForClient: func(clientInfo *tls.ClientHelloInfo) (config *tls.Config, e error) {
    40  			// 指纹信息
    41  			var fingerprint = this.calculateFingerprint(clientInfo)
    42  			if len(fingerprint) > 0 && clientInfo.Conn != nil {
    43  				clientConn, ok := clientInfo.Conn.(ClientConnInterface)
    44  				if ok {
    45  					clientConn.SetFingerprint(fingerprint)
    46  				}
    47  			}
    48  
    49  			tlsPolicy, _, err := this.matchSSL(this.helloServerNames(clientInfo))
    50  			if err != nil {
    51  				return nil, err
    52  			}
    53  
    54  			if tlsPolicy == nil {
    55  				return nil, nil
    56  			}
    57  
    58  			tlsPolicy.CheckOCSP()
    59  
    60  			return tlsPolicy.TLSConfig(), nil
    61  		},
    62  		GetCertificate: func(clientInfo *tls.ClientHelloInfo) (certificate *tls.Certificate, e error) {
    63  			// 指纹信息
    64  			var fingerprint = this.calculateFingerprint(clientInfo)
    65  			if len(fingerprint) > 0 && clientInfo.Conn != nil {
    66  				clientConn, ok := clientInfo.Conn.(ClientConnInterface)
    67  				if ok {
    68  					clientConn.SetFingerprint(fingerprint)
    69  				}
    70  			}
    71  
    72  			tlsPolicy, cert, err := this.matchSSL(this.helloServerNames(clientInfo))
    73  			if err != nil {
    74  				return nil, err
    75  			}
    76  			if cert == nil {
    77  				return nil, errors.New("no ssl certs found for '" + clientInfo.ServerName + "'")
    78  			}
    79  
    80  			tlsPolicy.CheckOCSP()
    81  
    82  			return cert, nil
    83  		},
    84  	}
    85  }
    86  
    87  // 根据域名匹配证书
    88  func (this *BaseListener) matchSSL(domains []string) (*sslconfigs.SSLPolicy, *tls.Certificate, error) {
    89  	var group = this.Group
    90  
    91  	if group == nil {
    92  		return nil, nil, errors.New("no configure found")
    93  	}
    94  
    95  	var globalServerConfig *serverconfigs.GlobalServerConfig
    96  	if sharedNodeConfig != nil {
    97  		globalServerConfig = sharedNodeConfig.GlobalServerConfig
    98  	}
    99  
   100  	// 如果域名为空,则取第一个
   101  	// 通常域名为空是因为是直接通过IP访问的
   102  	if len(domains) == 0 {
   103  		if group.IsHTTPS() && globalServerConfig != nil && globalServerConfig.HTTPAll.MatchDomainStrictly {
   104  			return nil, nil, errors.New("no tls server name matched")
   105  		}
   106  
   107  		firstServer := group.FirstTLSServer()
   108  		if firstServer == nil {
   109  			return nil, nil, errors.New("no tls server available")
   110  		}
   111  		sslConfig := firstServer.SSLPolicy()
   112  
   113  		if sslConfig != nil {
   114  			return sslConfig, sslConfig.FirstCert(), nil
   115  
   116  		}
   117  		return nil, nil, errors.New("no tls server name found")
   118  	}
   119  	var firstDomain = domains[0]
   120  
   121  	// 通过网站域名配置匹配
   122  	var server *serverconfigs.ServerConfig
   123  	var matchedDomain string
   124  	for _, domain := range domains {
   125  		server, _ = this.findNamedServer(domain, true)
   126  		if server != nil {
   127  			matchedDomain = domain
   128  			break
   129  		}
   130  	}
   131  	if server == nil {
   132  		server, _ = this.findNamedServer(firstDomain, false)
   133  		if server != nil {
   134  			matchedDomain = firstDomain
   135  		}
   136  	}
   137  
   138  	if server == nil {
   139  		// 找不到或者此时的服务没有配置证书,需要搜索所有的Server,通过SSL证书内容中的DNSName匹配
   140  		// 此功能仅为了兼容以往版本(v1.0.4),不应该作为常态启用
   141  		if globalServerConfig != nil && globalServerConfig.HTTPAll.MatchCertFromAllServers {
   142  			for _, searchingServer := range group.Servers() {
   143  				if searchingServer.SSLPolicy() == nil || !searchingServer.SSLPolicy().IsOn {
   144  					continue
   145  				}
   146  				cert, ok := searchingServer.SSLPolicy().MatchDomain(firstDomain)
   147  				if ok {
   148  					return searchingServer.SSLPolicy(), cert, nil
   149  				}
   150  			}
   151  		}
   152  
   153  		return nil, nil, errors.New("no server found for '" + firstDomain + "'")
   154  	}
   155  	if server.SSLPolicy() == nil || !server.SSLPolicy().IsOn {
   156  		// 找不到或者此时的服务没有配置证书,需要搜索所有的Server,通过SSL证书内容中的DNSName匹配
   157  		// 此功能仅为了兼容以往版本(v1.0.4),不应该作为常态启用
   158  		if globalServerConfig != nil && globalServerConfig.HTTPAll.MatchCertFromAllServers {
   159  			for _, searchingServer := range group.Servers() {
   160  				if searchingServer.SSLPolicy() == nil || !searchingServer.SSLPolicy().IsOn {
   161  					continue
   162  				}
   163  				cert, ok := searchingServer.SSLPolicy().MatchDomain(matchedDomain)
   164  				if ok {
   165  					return searchingServer.SSLPolicy(), cert, nil
   166  				}
   167  			}
   168  		}
   169  
   170  		return nil, nil, errors.New("no cert found for '" + matchedDomain + "'")
   171  	}
   172  
   173  	// 证书是否匹配
   174  	var sslConfig = server.SSLPolicy()
   175  	cert, ok := sslConfig.MatchDomain(matchedDomain)
   176  	if ok {
   177  		return sslConfig, cert, nil
   178  	}
   179  
   180  	if len(sslConfig.Certs) == 0 {
   181  		remotelogs.ServerError(server.Id, "BASE_LISTENER", "no ssl certs found for '"+matchedDomain+"', server id: "+types.String(server.Id), "", nil)
   182  	}
   183  
   184  	return sslConfig, sslConfig.FirstCert(), nil
   185  }
   186  
   187  // 根据域名来查找匹配的域名
   188  func (this *BaseListener) findNamedServer(name string, exactly bool) (serverConfig *serverconfigs.ServerConfig, serverName string) {
   189  	serverConfig, serverName = this.findNamedServerMatched(name)
   190  	if serverConfig != nil {
   191  		return
   192  	}
   193  
   194  	var globalServerConfig = sharedNodeConfig.GlobalServerConfig
   195  	var matchDomainStrictly = globalServerConfig != nil && globalServerConfig.HTTPAll.MatchDomainStrictly
   196  
   197  	if globalServerConfig != nil &&
   198  		len(globalServerConfig.HTTPAll.DefaultDomain) > 0 &&
   199  		(!matchDomainStrictly || configutils.MatchDomains(globalServerConfig.HTTPAll.AllowMismatchDomains, name) || (globalServerConfig.HTTPAll.AllowNodeIP && utils.IsWildIP(name))) {
   200  		if globalServerConfig.HTTPAll.AllowNodeIP &&
   201  			globalServerConfig.HTTPAll.NodeIPShowPage &&
   202  			utils.IsWildIP(name) {
   203  			return
   204  		} else {
   205  			var defaultDomain = globalServerConfig.HTTPAll.DefaultDomain
   206  			serverConfig, serverName = this.findNamedServerMatched(defaultDomain)
   207  			if serverConfig != nil {
   208  				return
   209  			}
   210  		}
   211  	}
   212  
   213  	if matchDomainStrictly && !configutils.MatchDomains(globalServerConfig.HTTPAll.AllowMismatchDomains, name) && (!globalServerConfig.HTTPAll.AllowNodeIP || (!utils.IsWildIP(name) || globalServerConfig.HTTPAll.NodeIPShowPage)) {
   214  		return
   215  	}
   216  
   217  	if !exactly {
   218  		// 如果没有找到,则匹配到第一个
   219  		var group = this.Group
   220  		var currentServers = group.Servers()
   221  		var countServers = len(currentServers)
   222  		if countServers == 0 {
   223  			return nil, ""
   224  		}
   225  		return currentServers[0], name
   226  	}
   227  
   228  	return
   229  }
   230  
   231  // 严格查找域名
   232  func (this *BaseListener) findNamedServerMatched(name string) (serverConfig *serverconfigs.ServerConfig, serverName string) {
   233  	var group = this.Group
   234  	if group == nil {
   235  		return nil, ""
   236  	}
   237  
   238  	server := group.MatchServerName(name)
   239  	if server != nil {
   240  		return server, name
   241  	}
   242  
   243  	// 是否严格匹配域名
   244  	var matchDomainStrictly = sharedNodeConfig.GlobalServerConfig != nil && sharedNodeConfig.GlobalServerConfig.HTTPAll.MatchDomainStrictly
   245  
   246  	// 如果只有一个server,则默认为这个
   247  	var currentServers = group.Servers()
   248  	var countServers = len(currentServers)
   249  	if countServers == 1 && !matchDomainStrictly {
   250  		return currentServers[0], name
   251  	}
   252  
   253  	return nil, name
   254  }
   255  
   256  // 从Hello信息中获取服务名称
   257  func (this *BaseListener) helloServerNames(clientInfo *tls.ClientHelloInfo) (serverNames []string) {
   258  	if len(clientInfo.ServerName) != 0 {
   259  		serverNames = append(serverNames, clientInfo.ServerName)
   260  		return
   261  	}
   262  
   263  	if clientInfo.Conn != nil {
   264  		var localAddr = clientInfo.Conn.LocalAddr()
   265  		if localAddr != nil {
   266  			tcpAddr, ok := localAddr.(*net.TCPAddr)
   267  			if ok {
   268  				serverNames = append(serverNames, tcpAddr.IP.String())
   269  			}
   270  		}
   271  	}
   272  
   273  	serverNames = append(serverNames, sharedNodeConfig.IPAddresses...)
   274  
   275  	return
   276  }