github.com/ergo-services/ergo@v1.999.224/apps/cloud/handshake.go (about)

     1  package cloud
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/sha256"
     6  	"encoding/binary"
     7  	"fmt"
     8  	"hash"
     9  	"io"
    10  	"net"
    11  	"time"
    12  
    13  	"github.com/ergo-services/ergo/etf"
    14  	"github.com/ergo-services/ergo/lib"
    15  	"github.com/ergo-services/ergo/node"
    16  )
    17  
    18  const (
    19  	defaultHandshakeTimeout = 5 * time.Second
    20  	clusterNameLengthMax    = 128
    21  )
    22  
    23  type Handshake struct {
    24  	node.Handshake
    25  	nodename string
    26  	creation uint32
    27  	options  node.Cloud
    28  	flags    node.Flags
    29  }
    30  
    31  type handshakeDetails struct {
    32  	cookieHash   []byte
    33  	digestRemote []byte
    34  	details      node.HandshakeDetails
    35  	mapName      string
    36  	hash         hash.Hash
    37  }
    38  
    39  func createHandshake(options node.Cloud) (node.HandshakeInterface, error) {
    40  	if options.Timeout == 0 {
    41  		options.Timeout = defaultHandshakeTimeout
    42  	}
    43  
    44  	if err := RegisterTypes(); err != nil {
    45  		return nil, err
    46  	}
    47  
    48  	return &Handshake{
    49  		options: options,
    50  	}, nil
    51  }
    52  
    53  func (ch *Handshake) Init(nodename string, creation uint32, flags node.Flags) error {
    54  	if flags.EnableProxy == false {
    55  		s := "proxy feature must be enabled for the cloud connection"
    56  		lib.Warning(s)
    57  		return fmt.Errorf(s)
    58  	}
    59  	if ch.options.Cluster == "" {
    60  		s := "option Cloud.Cluster can not be empty"
    61  		lib.Warning(s)
    62  		return fmt.Errorf(s)
    63  	}
    64  	if len(ch.options.Cluster) > clusterNameLengthMax {
    65  		s := "option Cloud.Cluster has too long name"
    66  		lib.Warning(s)
    67  		return fmt.Errorf(s)
    68  	}
    69  	ch.nodename = nodename
    70  	ch.creation = creation
    71  	ch.flags = flags
    72  	if ch.options.Flags.Enable == false {
    73  		return nil
    74  	}
    75  
    76  	ch.flags.EnableRemoteSpawn = ch.options.Flags.EnableRemoteSpawn
    77  	return nil
    78  }
    79  
    80  func (ch *Handshake) Start(remote net.Addr, conn lib.NetReadWriter, tls bool, cookie string) (node.HandshakeDetails, error) {
    81  	hash := sha256.New()
    82  	handshake := &handshakeDetails{
    83  		cookieHash: hash.Sum([]byte(cookie)),
    84  		hash:       hash,
    85  	}
    86  	handshake.details.Flags = ch.flags
    87  
    88  	ch.sendV1Auth(conn)
    89  
    90  	// define timeout for the handshaking
    91  	timer := time.NewTimer(ch.options.Timeout)
    92  	defer timer.Stop()
    93  
    94  	b := lib.TakeBuffer()
    95  	defer lib.ReleaseBuffer(b)
    96  
    97  	asyncReadChannel := make(chan error, 2)
    98  	asyncRead := func() {
    99  		_, err := b.ReadDataFrom(conn, 1024)
   100  		asyncReadChannel <- err
   101  	}
   102  
   103  	expectingBytes := 4
   104  	await := []byte{ProtoHandshakeV1AuthReply, ProtoHandshakeV1Error}
   105  	rest := []byte{}
   106  
   107  	for {
   108  		go asyncRead()
   109  		select {
   110  		case <-timer.C:
   111  			return handshake.details, fmt.Errorf("timeout")
   112  		case err := <-asyncReadChannel:
   113  			if err != nil {
   114  				return handshake.details, err
   115  			}
   116  
   117  			if b.Len() < expectingBytes {
   118  				continue
   119  			}
   120  
   121  			if b.B[0] != ProtoHandshakeV1 {
   122  				return handshake.details, fmt.Errorf("malformed handshake proto")
   123  			}
   124  
   125  			l := int(binary.BigEndian.Uint16(b.B[2:4]))
   126  			buffer := b.B[4 : l+4]
   127  
   128  			if len(buffer) != l {
   129  				return handshake.details, fmt.Errorf("malformed handshake (wrong packet length)")
   130  			}
   131  
   132  			// check if we got correct message type regarding to 'await' value
   133  			if bytes.Count(await, b.B[1:2]) == 0 {
   134  				return handshake.details, fmt.Errorf("malformed handshake sequence")
   135  			}
   136  
   137  			await, rest, err = ch.handle(conn, b.B[1], buffer, handshake)
   138  			if err != nil {
   139  				return handshake.details, err
   140  			}
   141  
   142  			if await == nil && rest != nil {
   143  				// handshaked with some extra data. keep them for the Proto handler
   144  				handshake.details.Buffer = lib.TakeBuffer()
   145  				handshake.details.Buffer.Set(rest)
   146  			}
   147  
   148  			b.Reset()
   149  		}
   150  
   151  		if await == nil {
   152  			// handshaked
   153  			break
   154  		}
   155  	}
   156  
   157  	return handshake.details, nil
   158  }
   159  
   160  func (ch *Handshake) handle(socket io.Writer, messageType byte, buffer []byte, details *handshakeDetails) ([]byte, []byte, error) {
   161  	switch messageType {
   162  	case ProtoHandshakeV1AuthReply:
   163  		if err := ch.handleV1AuthReply(buffer, details); err != nil {
   164  			return nil, nil, err
   165  		}
   166  		if err := ch.sendV1Challenge(socket, details); err != nil {
   167  			return nil, nil, err
   168  		}
   169  		return []byte{ProtoHandshakeV1ChallengeAccept, ProtoHandshakeV1Error}, nil, nil
   170  
   171  	case ProtoHandshakeV1ChallengeAccept:
   172  		rest, err := ch.handleV1ChallegeAccept(buffer, details)
   173  		if err != nil {
   174  			return nil, nil, err
   175  		}
   176  		return nil, rest, err
   177  
   178  	case ProtoHandshakeV1Error:
   179  		return nil, nil, ch.handleV1Error(buffer)
   180  
   181  	default:
   182  		return nil, nil, fmt.Errorf("unknown message type")
   183  	}
   184  }
   185  
   186  func (ch *Handshake) sendV1Auth(socket io.Writer) error {
   187  	b := lib.TakeBuffer()
   188  	defer lib.ReleaseBuffer(b)
   189  
   190  	message := MessageHandshakeV1Auth{
   191  		Node:     ch.nodename,
   192  		Cluster:  ch.options.Cluster,
   193  		Creation: ch.creation,
   194  		Flags:    ch.options.Flags,
   195  	}
   196  	b.Allocate(1 + 1 + 2)
   197  	b.B[0] = ProtoHandshakeV1
   198  	b.B[1] = ProtoHandshakeV1Auth
   199  	if err := etf.Encode(message, b, etf.EncodeOptions{}); err != nil {
   200  		return err
   201  	}
   202  	binary.BigEndian.PutUint16(b.B[2:4], uint16(b.Len()-4))
   203  	if err := b.WriteDataTo(socket); err != nil {
   204  		return err
   205  	}
   206  
   207  	return nil
   208  }
   209  
   210  func (ch *Handshake) sendV1Challenge(socket io.Writer, handshake *handshakeDetails) error {
   211  	b := lib.TakeBuffer()
   212  	defer lib.ReleaseBuffer(b)
   213  
   214  	digest := GenDigest(handshake.hash, []byte(ch.nodename), handshake.digestRemote, handshake.cookieHash)
   215  	message := MessageHandshakeV1Challenge{
   216  		Digest: digest,
   217  	}
   218  	b.Allocate(1 + 1 + 2)
   219  	b.B[0] = ProtoHandshakeV1
   220  	b.B[1] = ProtoHandshakeV1Challenge
   221  	if err := etf.Encode(message, b, etf.EncodeOptions{}); err != nil {
   222  		return err
   223  	}
   224  	binary.BigEndian.PutUint16(b.B[2:4], uint16(b.Len()-4))
   225  	if err := b.WriteDataTo(socket); err != nil {
   226  		return err
   227  	}
   228  
   229  	return nil
   230  
   231  }
   232  
   233  func (ch *Handshake) handleV1AuthReply(buffer []byte, handshake *handshakeDetails) error {
   234  	m, _, err := etf.Decode(buffer, nil, etf.DecodeOptions{})
   235  	if err != nil {
   236  		return fmt.Errorf("malformed MessageHandshakeV1AuthReply message: %s", err)
   237  	}
   238  	message, ok := m.(MessageHandshakeV1AuthReply)
   239  	if ok == false {
   240  		return fmt.Errorf("malformed MessageHandshakeV1AuthReply message: %#v", m)
   241  	}
   242  
   243  	digest := GenDigest(handshake.hash, []byte(message.Node), []byte(ch.options.Cluster), handshake.cookieHash)
   244  	if bytes.Compare(message.Digest, digest) != 0 {
   245  		return fmt.Errorf("authorization failed")
   246  	}
   247  	handshake.digestRemote = digest
   248  	handshake.details.Name = message.Node
   249  	handshake.details.Creation = message.Creation
   250  
   251  	return nil
   252  }
   253  
   254  func (ch *Handshake) handleV1ChallegeAccept(buffer []byte, handshake *handshakeDetails) ([]byte, error) {
   255  	m, rest, err := etf.Decode(buffer, nil, etf.DecodeOptions{})
   256  	if err != nil {
   257  		return nil, fmt.Errorf("malformed MessageHandshakeV1ChallengeAccept message: %s", err)
   258  	}
   259  	message, ok := m.(MessageHandshakeV1ChallengeAccept)
   260  	if ok == false {
   261  		return nil, fmt.Errorf("malformed MessageHandshakeV1ChallengeAccept message: %#v", m)
   262  	}
   263  
   264  	mapping := etf.NewAtomMapping()
   265  	mapping.In[etf.Atom(message.Node)] = etf.Atom(ch.nodename)
   266  	mapping.Out[etf.Atom(ch.nodename)] = etf.Atom(message.Node)
   267  	handshake.details.AtomMapping = mapping
   268  	handshake.mapName = message.Node
   269  	return rest, nil
   270  }
   271  
   272  func (ch *Handshake) handleV1Error(buffer []byte) error {
   273  	m, _, err := etf.Decode(buffer, nil, etf.DecodeOptions{})
   274  	if err != nil {
   275  		return fmt.Errorf("malformed MessageHandshakeV1Error message: %s", err)
   276  	}
   277  	message, ok := m.(MessageHandshakeV1Error)
   278  	if ok == false {
   279  		return fmt.Errorf("malformed MessageHandshakeV1Error message: %#v", m)
   280  	}
   281  	return fmt.Errorf(message.Reason)
   282  }