github.com/erda-project/erda-infra@v1.0.10-0.20240327085753-f3a249292aeb/providers/remote-forward/server/provider.go (about)

     1  // Copyright (c) 2021 Terminus, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package server
    16  
    17  import (
    18  	"errors"
    19  	"fmt"
    20  	"io"
    21  	"net"
    22  	"reflect"
    23  	"time"
    24  
    25  	"github.com/erda-project/erda-infra/base/logs"
    26  	"github.com/erda-project/erda-infra/base/servicehub"
    27  	forward "github.com/erda-project/erda-infra/providers/remote-forward"
    28  	yamux "github.com/hashicorp/yamux"
    29  )
    30  
    31  type (
    32  	// Handshaker .
    33  	Handshaker func(req *forward.RequestHeader, resp *forward.ResponseHeader) error
    34  	// Interface .
    35  	Interface interface {
    36  		AddHandshaker(h Handshaker)
    37  	}
    38  )
    39  
    40  var _ (Interface) = (*provider)(nil)
    41  
    42  type config struct {
    43  	Addr  string `file:"addr"`
    44  	Token string `file:"token"`
    45  }
    46  
    47  type provider struct {
    48  	Cfg        *config
    49  	Log        logs.Logger
    50  	ln         net.Listener
    51  	handshaker []Handshaker
    52  }
    53  
    54  func (p *provider) Init(ctx servicehub.Context) error {
    55  	ln, err := net.Listen("tcp", p.Cfg.Addr)
    56  	if err != nil {
    57  		return err
    58  	}
    59  	p.ln = ln
    60  	p.Log.Infof("forward server listen at %s", p.Cfg.Addr)
    61  	return nil
    62  }
    63  
    64  func (p *provider) AddHandshaker(h Handshaker) {
    65  	p.handshaker = append(p.handshaker, h)
    66  }
    67  
    68  func (p *provider) Start() error {
    69  	ln := p.ln
    70  	defer ln.Close()
    71  	for {
    72  		conn, err := ln.Accept()
    73  		if err != nil {
    74  			if errors.Is(err, net.ErrClosed) {
    75  				return nil
    76  			}
    77  			return err
    78  		}
    79  		go p.handleConn(conn)
    80  	}
    81  }
    82  
    83  func (p *provider) Close() error {
    84  	if p.ln != nil {
    85  		ln := p.ln
    86  		p.ln = nil
    87  		err := ln.Close()
    88  		if !errors.Is(err, net.ErrClosed) {
    89  			return err
    90  		}
    91  	}
    92  	return nil
    93  }
    94  
    95  func (p *provider) handleConn(conn net.Conn) {
    96  	defer conn.Close()
    97  	req, err := p.handshake(conn)
    98  	if err != nil {
    99  		p.responseError(conn, err)
   100  		return
   101  	}
   102  	if req == nil {
   103  		return
   104  	}
   105  	resp := &forward.ResponseHeader{Values: make(map[string]interface{})}
   106  	for _, h := range p.handshaker {
   107  		err := h(req, resp)
   108  		if err != nil {
   109  			p.responseError(conn, err)
   110  			return
   111  		}
   112  	}
   113  
   114  	ln, err := net.Listen("tcp", req.ShadowAddr)
   115  	if err != nil {
   116  		p.responseError(conn, err)
   117  		return
   118  	}
   119  	defer ln.Close()
   120  
   121  	session, err := yamux.Client(conn, nil)
   122  	if err != nil {
   123  		p.responseError(conn, err)
   124  		return
   125  	}
   126  	defer session.Close()
   127  
   128  	resp.ShadowAddr = ln.Addr().String()
   129  	p.Log.Infof("%q shadow address listen at %s", req.Name, resp.ShadowAddr)
   130  	if err := p.responseOK(conn, resp); err != nil {
   131  		return
   132  	}
   133  
   134  	go func() {
   135  		<-session.CloseChan()
   136  		ln.Close()
   137  	}()
   138  	for {
   139  		source, err := ln.Accept()
   140  		if err != nil {
   141  			if !errors.Is(err, net.ErrClosed) {
   142  				p.Log.Errorf("accept error: %s", err)
   143  			}
   144  			return
   145  		}
   146  		go func() {
   147  			defer source.Close()
   148  			target, err := session.Open()
   149  			if err != nil {
   150  				if !errors.Is(err, net.ErrClosed) {
   151  					p.Log.Errorf("failed to open connect in session: %s", err)
   152  				}
   153  				return
   154  			}
   155  			defer target.Close()
   156  			forward.Pipe(p.Log, target, source)
   157  		}()
   158  	}
   159  }
   160  
   161  func (p *provider) handshake(conn net.Conn) (header *forward.RequestHeader, err error) {
   162  	defer func() {
   163  		if err != nil && errors.Is(err, net.ErrClosed) {
   164  			header, err = nil, nil
   165  		} else if err != nil {
   166  			err = fmt.Errorf("handshake error: %w", err)
   167  		}
   168  	}()
   169  	err = conn.SetDeadline(time.Now().Add(forward.HandshakeTimeout))
   170  	if err != nil {
   171  		return nil, err
   172  	}
   173  	header, err = forward.DecodeRequestHeader(conn)
   174  	if err != nil {
   175  		return nil, err
   176  	}
   177  	if header.Version != forward.ProtocolVersion {
   178  		return nil, fmt.Errorf("not support version %q", header.Version)
   179  	}
   180  	if header.Token != p.Cfg.Token {
   181  		return nil, fmt.Errorf("invalid token")
   182  	}
   183  	err = conn.SetDeadline(time.Time{})
   184  	if err != nil {
   185  		return nil, err
   186  	}
   187  	return header, nil
   188  }
   189  
   190  func (p *provider) responseError(conn net.Conn, err error) error {
   191  	if errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) {
   192  		return err
   193  	}
   194  	err = forward.EncodeResponseHeader(conn, &forward.ResponseHeader{Error: err.Error()})
   195  	if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.EOF) {
   196  		p.Log.Errorf("failed to encode response: %s", err)
   197  	}
   198  	return err
   199  }
   200  
   201  func (p *provider) responseOK(conn net.Conn, resp *forward.ResponseHeader) error {
   202  	resp.Error = ""
   203  	err := forward.EncodeResponseHeader(conn, resp)
   204  	if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.EOF) {
   205  		p.Log.Errorf("failed to encode response: %s", err)
   206  	}
   207  	return err
   208  }
   209  
   210  func init() {
   211  	servicehub.Register("remote-forward-server", &servicehub.Spec{
   212  		Services:   []string{"remote-forward-server"},
   213  		Types:      []reflect.Type{reflect.TypeOf((*Interface)(nil)).Elem()},
   214  		ConfigFunc: func() interface{} { return &config{} },
   215  		Creator:    func() servicehub.Provider { return &provider{} },
   216  	})
   217  }