github.com/erda-project/erda-infra@v1.0.10-0.20240327085753-f3a249292aeb/providers/remote-forward/server/provider.go (about) 1 // Copyright (c) 2021 Terminus, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package server 16 17 import ( 18 "errors" 19 "fmt" 20 "io" 21 "net" 22 "reflect" 23 "time" 24 25 "github.com/erda-project/erda-infra/base/logs" 26 "github.com/erda-project/erda-infra/base/servicehub" 27 forward "github.com/erda-project/erda-infra/providers/remote-forward" 28 yamux "github.com/hashicorp/yamux" 29 ) 30 31 type ( 32 // Handshaker . 33 Handshaker func(req *forward.RequestHeader, resp *forward.ResponseHeader) error 34 // Interface . 35 Interface interface { 36 AddHandshaker(h Handshaker) 37 } 38 ) 39 40 var _ (Interface) = (*provider)(nil) 41 42 type config struct { 43 Addr string `file:"addr"` 44 Token string `file:"token"` 45 } 46 47 type provider struct { 48 Cfg *config 49 Log logs.Logger 50 ln net.Listener 51 handshaker []Handshaker 52 } 53 54 func (p *provider) Init(ctx servicehub.Context) error { 55 ln, err := net.Listen("tcp", p.Cfg.Addr) 56 if err != nil { 57 return err 58 } 59 p.ln = ln 60 p.Log.Infof("forward server listen at %s", p.Cfg.Addr) 61 return nil 62 } 63 64 func (p *provider) AddHandshaker(h Handshaker) { 65 p.handshaker = append(p.handshaker, h) 66 } 67 68 func (p *provider) Start() error { 69 ln := p.ln 70 defer ln.Close() 71 for { 72 conn, err := ln.Accept() 73 if err != nil { 74 if errors.Is(err, net.ErrClosed) { 75 return nil 76 } 77 return err 78 } 79 go p.handleConn(conn) 80 } 81 } 82 83 func (p *provider) Close() error { 84 if p.ln != nil { 85 ln := p.ln 86 p.ln = nil 87 err := ln.Close() 88 if !errors.Is(err, net.ErrClosed) { 89 return err 90 } 91 } 92 return nil 93 } 94 95 func (p *provider) handleConn(conn net.Conn) { 96 defer conn.Close() 97 req, err := p.handshake(conn) 98 if err != nil { 99 p.responseError(conn, err) 100 return 101 } 102 if req == nil { 103 return 104 } 105 resp := &forward.ResponseHeader{Values: make(map[string]interface{})} 106 for _, h := range p.handshaker { 107 err := h(req, resp) 108 if err != nil { 109 p.responseError(conn, err) 110 return 111 } 112 } 113 114 ln, err := net.Listen("tcp", req.ShadowAddr) 115 if err != nil { 116 p.responseError(conn, err) 117 return 118 } 119 defer ln.Close() 120 121 session, err := yamux.Client(conn, nil) 122 if err != nil { 123 p.responseError(conn, err) 124 return 125 } 126 defer session.Close() 127 128 resp.ShadowAddr = ln.Addr().String() 129 p.Log.Infof("%q shadow address listen at %s", req.Name, resp.ShadowAddr) 130 if err := p.responseOK(conn, resp); err != nil { 131 return 132 } 133 134 go func() { 135 <-session.CloseChan() 136 ln.Close() 137 }() 138 for { 139 source, err := ln.Accept() 140 if err != nil { 141 if !errors.Is(err, net.ErrClosed) { 142 p.Log.Errorf("accept error: %s", err) 143 } 144 return 145 } 146 go func() { 147 defer source.Close() 148 target, err := session.Open() 149 if err != nil { 150 if !errors.Is(err, net.ErrClosed) { 151 p.Log.Errorf("failed to open connect in session: %s", err) 152 } 153 return 154 } 155 defer target.Close() 156 forward.Pipe(p.Log, target, source) 157 }() 158 } 159 } 160 161 func (p *provider) handshake(conn net.Conn) (header *forward.RequestHeader, err error) { 162 defer func() { 163 if err != nil && errors.Is(err, net.ErrClosed) { 164 header, err = nil, nil 165 } else if err != nil { 166 err = fmt.Errorf("handshake error: %w", err) 167 } 168 }() 169 err = conn.SetDeadline(time.Now().Add(forward.HandshakeTimeout)) 170 if err != nil { 171 return nil, err 172 } 173 header, err = forward.DecodeRequestHeader(conn) 174 if err != nil { 175 return nil, err 176 } 177 if header.Version != forward.ProtocolVersion { 178 return nil, fmt.Errorf("not support version %q", header.Version) 179 } 180 if header.Token != p.Cfg.Token { 181 return nil, fmt.Errorf("invalid token") 182 } 183 err = conn.SetDeadline(time.Time{}) 184 if err != nil { 185 return nil, err 186 } 187 return header, nil 188 } 189 190 func (p *provider) responseError(conn net.Conn, err error) error { 191 if errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) { 192 return err 193 } 194 err = forward.EncodeResponseHeader(conn, &forward.ResponseHeader{Error: err.Error()}) 195 if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.EOF) { 196 p.Log.Errorf("failed to encode response: %s", err) 197 } 198 return err 199 } 200 201 func (p *provider) responseOK(conn net.Conn, resp *forward.ResponseHeader) error { 202 resp.Error = "" 203 err := forward.EncodeResponseHeader(conn, resp) 204 if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.EOF) { 205 p.Log.Errorf("failed to encode response: %s", err) 206 } 207 return err 208 } 209 210 func init() { 211 servicehub.Register("remote-forward-server", &servicehub.Spec{ 212 Services: []string{"remote-forward-server"}, 213 Types: []reflect.Type{reflect.TypeOf((*Interface)(nil)).Elem()}, 214 ConfigFunc: func() interface{} { return &config{} }, 215 Creator: func() servicehub.Provider { return &provider{} }, 216 }) 217 }