github.com/TeaOSLab/EdgeNode@v1.3.8/internal/nodes/listener_base.go (about) 1 package nodes 2 3 import ( 4 "crypto/tls" 5 "errors" 6 "github.com/TeaOSLab/EdgeCommon/pkg/configutils" 7 "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" 8 "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs" 9 "github.com/TeaOSLab/EdgeNode/internal/remotelogs" 10 "github.com/TeaOSLab/EdgeNode/internal/utils" 11 "github.com/iwind/TeaGo/types" 12 "net" 13 ) 14 15 type BaseListener struct { 16 Group *serverconfigs.ServerAddressGroup 17 18 countActiveConnections int64 // 当前活跃的连接数 19 } 20 21 // Init 初始化 22 func (this *BaseListener) Init() { 23 } 24 25 // Reset 清除既有配置 26 func (this *BaseListener) Reset() { 27 28 } 29 30 // CountActiveConnections 获取当前活跃连接数 31 func (this *BaseListener) CountActiveConnections() int { 32 return types.Int(this.countActiveConnections) 33 } 34 35 // 构造TLS配置 36 func (this *BaseListener) buildTLSConfig() *tls.Config { 37 return &tls.Config{ 38 Certificates: nil, 39 GetConfigForClient: func(clientInfo *tls.ClientHelloInfo) (config *tls.Config, e error) { 40 // 指纹信息 41 var fingerprint = this.calculateFingerprint(clientInfo) 42 if len(fingerprint) > 0 && clientInfo.Conn != nil { 43 clientConn, ok := clientInfo.Conn.(ClientConnInterface) 44 if ok { 45 clientConn.SetFingerprint(fingerprint) 46 } 47 } 48 49 tlsPolicy, _, err := this.matchSSL(this.helloServerNames(clientInfo)) 50 if err != nil { 51 return nil, err 52 } 53 54 if tlsPolicy == nil { 55 return nil, nil 56 } 57 58 tlsPolicy.CheckOCSP() 59 60 return tlsPolicy.TLSConfig(), nil 61 }, 62 GetCertificate: func(clientInfo *tls.ClientHelloInfo) (certificate *tls.Certificate, e error) { 63 // 指纹信息 64 var fingerprint = this.calculateFingerprint(clientInfo) 65 if len(fingerprint) > 0 && clientInfo.Conn != nil { 66 clientConn, ok := clientInfo.Conn.(ClientConnInterface) 67 if ok { 68 clientConn.SetFingerprint(fingerprint) 69 } 70 } 71 72 tlsPolicy, cert, err := this.matchSSL(this.helloServerNames(clientInfo)) 73 if err != nil { 74 return nil, err 75 } 76 if cert == nil { 77 return nil, errors.New("no ssl certs found for '" + clientInfo.ServerName + "'") 78 } 79 80 tlsPolicy.CheckOCSP() 81 82 return cert, nil 83 }, 84 } 85 } 86 87 // 根据域名匹配证书 88 func (this *BaseListener) matchSSL(domains []string) (*sslconfigs.SSLPolicy, *tls.Certificate, error) { 89 var group = this.Group 90 91 if group == nil { 92 return nil, nil, errors.New("no configure found") 93 } 94 95 var globalServerConfig *serverconfigs.GlobalServerConfig 96 if sharedNodeConfig != nil { 97 globalServerConfig = sharedNodeConfig.GlobalServerConfig 98 } 99 100 // 如果域名为空,则取第一个 101 // 通常域名为空是因为是直接通过IP访问的 102 if len(domains) == 0 { 103 if group.IsHTTPS() && globalServerConfig != nil && globalServerConfig.HTTPAll.MatchDomainStrictly { 104 return nil, nil, errors.New("no tls server name matched") 105 } 106 107 firstServer := group.FirstTLSServer() 108 if firstServer == nil { 109 return nil, nil, errors.New("no tls server available") 110 } 111 sslConfig := firstServer.SSLPolicy() 112 113 if sslConfig != nil { 114 return sslConfig, sslConfig.FirstCert(), nil 115 116 } 117 return nil, nil, errors.New("no tls server name found") 118 } 119 var firstDomain = domains[0] 120 121 // 通过网站域名配置匹配 122 var server *serverconfigs.ServerConfig 123 var matchedDomain string 124 for _, domain := range domains { 125 server, _ = this.findNamedServer(domain, true) 126 if server != nil { 127 matchedDomain = domain 128 break 129 } 130 } 131 if server == nil { 132 server, _ = this.findNamedServer(firstDomain, false) 133 if server != nil { 134 matchedDomain = firstDomain 135 } 136 } 137 138 if server == nil { 139 // 找不到或者此时的服务没有配置证书,需要搜索所有的Server,通过SSL证书内容中的DNSName匹配 140 // 此功能仅为了兼容以往版本(v1.0.4),不应该作为常态启用 141 if globalServerConfig != nil && globalServerConfig.HTTPAll.MatchCertFromAllServers { 142 for _, searchingServer := range group.Servers() { 143 if searchingServer.SSLPolicy() == nil || !searchingServer.SSLPolicy().IsOn { 144 continue 145 } 146 cert, ok := searchingServer.SSLPolicy().MatchDomain(firstDomain) 147 if ok { 148 return searchingServer.SSLPolicy(), cert, nil 149 } 150 } 151 } 152 153 return nil, nil, errors.New("no server found for '" + firstDomain + "'") 154 } 155 if server.SSLPolicy() == nil || !server.SSLPolicy().IsOn { 156 // 找不到或者此时的服务没有配置证书,需要搜索所有的Server,通过SSL证书内容中的DNSName匹配 157 // 此功能仅为了兼容以往版本(v1.0.4),不应该作为常态启用 158 if globalServerConfig != nil && globalServerConfig.HTTPAll.MatchCertFromAllServers { 159 for _, searchingServer := range group.Servers() { 160 if searchingServer.SSLPolicy() == nil || !searchingServer.SSLPolicy().IsOn { 161 continue 162 } 163 cert, ok := searchingServer.SSLPolicy().MatchDomain(matchedDomain) 164 if ok { 165 return searchingServer.SSLPolicy(), cert, nil 166 } 167 } 168 } 169 170 return nil, nil, errors.New("no cert found for '" + matchedDomain + "'") 171 } 172 173 // 证书是否匹配 174 var sslConfig = server.SSLPolicy() 175 cert, ok := sslConfig.MatchDomain(matchedDomain) 176 if ok { 177 return sslConfig, cert, nil 178 } 179 180 if len(sslConfig.Certs) == 0 { 181 remotelogs.ServerError(server.Id, "BASE_LISTENER", "no ssl certs found for '"+matchedDomain+"', server id: "+types.String(server.Id), "", nil) 182 } 183 184 return sslConfig, sslConfig.FirstCert(), nil 185 } 186 187 // 根据域名来查找匹配的域名 188 func (this *BaseListener) findNamedServer(name string, exactly bool) (serverConfig *serverconfigs.ServerConfig, serverName string) { 189 serverConfig, serverName = this.findNamedServerMatched(name) 190 if serverConfig != nil { 191 return 192 } 193 194 var globalServerConfig = sharedNodeConfig.GlobalServerConfig 195 var matchDomainStrictly = globalServerConfig != nil && globalServerConfig.HTTPAll.MatchDomainStrictly 196 197 if globalServerConfig != nil && 198 len(globalServerConfig.HTTPAll.DefaultDomain) > 0 && 199 (!matchDomainStrictly || configutils.MatchDomains(globalServerConfig.HTTPAll.AllowMismatchDomains, name) || (globalServerConfig.HTTPAll.AllowNodeIP && utils.IsWildIP(name))) { 200 if globalServerConfig.HTTPAll.AllowNodeIP && 201 globalServerConfig.HTTPAll.NodeIPShowPage && 202 utils.IsWildIP(name) { 203 return 204 } else { 205 var defaultDomain = globalServerConfig.HTTPAll.DefaultDomain 206 serverConfig, serverName = this.findNamedServerMatched(defaultDomain) 207 if serverConfig != nil { 208 return 209 } 210 } 211 } 212 213 if matchDomainStrictly && !configutils.MatchDomains(globalServerConfig.HTTPAll.AllowMismatchDomains, name) && (!globalServerConfig.HTTPAll.AllowNodeIP || (!utils.IsWildIP(name) || globalServerConfig.HTTPAll.NodeIPShowPage)) { 214 return 215 } 216 217 if !exactly { 218 // 如果没有找到,则匹配到第一个 219 var group = this.Group 220 var currentServers = group.Servers() 221 var countServers = len(currentServers) 222 if countServers == 0 { 223 return nil, "" 224 } 225 return currentServers[0], name 226 } 227 228 return 229 } 230 231 // 严格查找域名 232 func (this *BaseListener) findNamedServerMatched(name string) (serverConfig *serverconfigs.ServerConfig, serverName string) { 233 var group = this.Group 234 if group == nil { 235 return nil, "" 236 } 237 238 server := group.MatchServerName(name) 239 if server != nil { 240 return server, name 241 } 242 243 // 是否严格匹配域名 244 var matchDomainStrictly = sharedNodeConfig.GlobalServerConfig != nil && sharedNodeConfig.GlobalServerConfig.HTTPAll.MatchDomainStrictly 245 246 // 如果只有一个server,则默认为这个 247 var currentServers = group.Servers() 248 var countServers = len(currentServers) 249 if countServers == 1 && !matchDomainStrictly { 250 return currentServers[0], name 251 } 252 253 return nil, name 254 } 255 256 // 从Hello信息中获取服务名称 257 func (this *BaseListener) helloServerNames(clientInfo *tls.ClientHelloInfo) (serverNames []string) { 258 if len(clientInfo.ServerName) != 0 { 259 serverNames = append(serverNames, clientInfo.ServerName) 260 return 261 } 262 263 if clientInfo.Conn != nil { 264 var localAddr = clientInfo.Conn.LocalAddr() 265 if localAddr != nil { 266 tcpAddr, ok := localAddr.(*net.TCPAddr) 267 if ok { 268 serverNames = append(serverNames, tcpAddr.IP.String()) 269 } 270 } 271 } 272 273 serverNames = append(serverNames, sharedNodeConfig.IPAddresses...) 274 275 return 276 }