github.com/ergo-services/ergo@v1.999.224/proto/dist/handshake.go (about)

     1  package dist
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/md5"
     6  	"encoding/binary"
     7  	"fmt"
     8  	"math/rand"
     9  	"net"
    10  	"time"
    11  
    12  	"github.com/ergo-services/ergo/lib"
    13  	"github.com/ergo-services/ergo/node"
    14  )
    15  
    16  const (
    17  	HandshakeVersion5 node.HandshakeVersion = 5
    18  	HandshakeVersion6 node.HandshakeVersion = 6
    19  
    20  	DefaultHandshakeVersion = HandshakeVersion5
    21  	DefaultHandshakeTimeout = 5 * time.Second
    22  
    23  	// distribution flags are defined here https://erlang.org/doc/apps/erts/erl_dist_protocol.html#distribution-flags
    24  	flagPublished          nodeFlagId = 0x1
    25  	flagAtomCache          nodeFlagId = 0x2
    26  	flagExtendedReferences nodeFlagId = 0x4
    27  	flagDistMonitor        nodeFlagId = 0x8
    28  	flagFunTags            nodeFlagId = 0x10
    29  	flagDistMonitorName    nodeFlagId = 0x20
    30  	flagHiddenAtomCache    nodeFlagId = 0x40
    31  	flagNewFunTags         nodeFlagId = 0x80
    32  	flagExtendedPidsPorts  nodeFlagId = 0x100
    33  	flagExportPtrTag       nodeFlagId = 0x200
    34  	flagBitBinaries        nodeFlagId = 0x400
    35  	flagNewFloats          nodeFlagId = 0x800
    36  	flagUnicodeIO          nodeFlagId = 0x1000
    37  	flagDistHdrAtomCache   nodeFlagId = 0x2000
    38  	flagSmallAtomTags      nodeFlagId = 0x4000
    39  	//	flagCompressed                   = 0x8000 // erlang uses this flag for the internal purposes
    40  	flagUTF8Atoms         nodeFlagId = 0x10000
    41  	flagMapTag            nodeFlagId = 0x20000
    42  	flagBigCreation       nodeFlagId = 0x40000
    43  	flagSendSender        nodeFlagId = 0x80000 // since OTP.21 enable replacement for SEND (distProtoSEND by distProtoSEND_SENDER)
    44  	flagBigSeqTraceLabels            = 0x100000
    45  	flagExitPayload       nodeFlagId = 0x400000 // since OTP.22 enable replacement for EXIT, EXIT2, MONITOR_P_EXIT
    46  	flagFragments         nodeFlagId = 0x800000
    47  	flagHandshake23       nodeFlagId = 0x1000000 // new connection setup handshake (version 6) introduced in OTP 23
    48  	flagUnlinkID          nodeFlagId = 0x2000000
    49  	// for 64bit flags
    50  	flagSpawn  nodeFlagId = 1 << 32
    51  	flagNameMe nodeFlagId = 1 << 33
    52  	flagV4NC   nodeFlagId = 1 << 34
    53  	flagAlias  nodeFlagId = 1 << 35
    54  
    55  	// ergo flags
    56  	flagCompression = 1 << 63
    57  	flagProxy       = 1 << 62
    58  )
    59  
    60  type nodeFlagId uint64
    61  type nodeFlags nodeFlagId
    62  
    63  func (nf nodeFlags) toUint32() uint32 {
    64  	return uint32(nf)
    65  }
    66  
    67  func (nf nodeFlags) toUint64() uint64 {
    68  	return uint64(nf)
    69  }
    70  
    71  func (nf nodeFlags) isSet(f nodeFlagId) bool {
    72  	return (uint64(nf) & uint64(f)) != 0
    73  }
    74  
    75  func toNodeFlags(f ...nodeFlagId) nodeFlags {
    76  	var flags uint64
    77  	for _, v := range f {
    78  		flags |= uint64(v)
    79  	}
    80  	return nodeFlags(flags)
    81  }
    82  
    83  // DistHandshake implements Erlang handshake
    84  type DistHandshake struct {
    85  	node.Handshake
    86  	nodename  string
    87  	flags     node.Flags
    88  	creation  uint32
    89  	challenge uint32
    90  	options   HandshakeOptions
    91  }
    92  
    93  type HandshakeOptions struct {
    94  	Timeout time.Duration
    95  	Version node.HandshakeVersion // 5 or 6
    96  }
    97  
    98  func CreateHandshake(options HandshakeOptions) node.HandshakeInterface {
    99  	// must be 5 or 6
   100  	if options.Version != HandshakeVersion5 && options.Version != HandshakeVersion6 {
   101  		options.Version = DefaultHandshakeVersion
   102  	}
   103  
   104  	if options.Timeout == 0 {
   105  		options.Timeout = DefaultHandshakeTimeout
   106  	}
   107  	return &DistHandshake{
   108  		options:   options,
   109  		challenge: rand.Uint32(),
   110  	}
   111  }
   112  
   113  // Init implements Handshake interface mothod
   114  func (dh *DistHandshake) Init(nodename string, creation uint32, flags node.Flags) error {
   115  	dh.nodename = nodename
   116  	dh.creation = creation
   117  	dh.flags = flags
   118  	return nil
   119  }
   120  
   121  func (dh *DistHandshake) Version() node.HandshakeVersion {
   122  	return dh.options.Version
   123  }
   124  
   125  func (dh *DistHandshake) Start(remote net.Addr, conn lib.NetReadWriter, tls bool, cookie string) (node.HandshakeDetails, error) {
   126  
   127  	var details node.HandshakeDetails
   128  
   129  	b := lib.TakeBuffer()
   130  	defer lib.ReleaseBuffer(b)
   131  
   132  	var await []byte
   133  
   134  	if dh.options.Version == HandshakeVersion5 {
   135  		dh.composeName(b, tls)
   136  		// the next message must be send_status 's' or send_challenge 'n' (for
   137  		// handshake version 5) or 'N' (for handshake version 6)
   138  		await = []byte{'s', 'n', 'N'}
   139  	} else {
   140  		dh.composeNameVersion6(b, tls)
   141  		await = []byte{'s', 'N'}
   142  	}
   143  	if e := b.WriteDataTo(conn); e != nil {
   144  		return details, e
   145  	}
   146  
   147  	// define timeout for the handshaking
   148  	timer := time.NewTimer(dh.options.Timeout)
   149  	defer timer.Stop()
   150  
   151  	asyncReadChannel := make(chan error, 2)
   152  	asyncRead := func() {
   153  		_, e := b.ReadDataFrom(conn, 512)
   154  		asyncReadChannel <- e
   155  	}
   156  
   157  	// http://erlang.org/doc/apps/erts/erl_dist_protocol.html#distribution-handshake
   158  	// Every message in the handshake starts with a 16-bit big-endian integer,
   159  	// which contains the message length (not counting the two initial bytes).
   160  	// In Erlang this corresponds to option {packet, 2} in gen_tcp(3). Notice
   161  	// that after the handshake, the distribution switches to 4 byte packet headers.
   162  	expectingBytes := 2
   163  	if tls {
   164  		// TLS connection has 4 bytes packet length header
   165  		expectingBytes = 4
   166  	}
   167  
   168  	for {
   169  		go asyncRead()
   170  
   171  		select {
   172  		case <-timer.C:
   173  			return details, fmt.Errorf("handshake timeout")
   174  
   175  		case e := <-asyncReadChannel:
   176  			if e != nil {
   177  				return details, e
   178  			}
   179  
   180  		next:
   181  			l := binary.BigEndian.Uint16(b.B[expectingBytes-2 : expectingBytes])
   182  			buffer := b.B[expectingBytes:]
   183  
   184  			if len(buffer) < int(l) {
   185  				return details, fmt.Errorf("malformed handshake (wrong packet length)")
   186  			}
   187  
   188  			// check if we got correct message type regarding to 'await' value
   189  			if bytes.Count(await, buffer[0:1]) == 0 {
   190  				return details, fmt.Errorf("malformed handshake (wrong response)")
   191  			}
   192  
   193  			switch buffer[0] {
   194  			case 'n':
   195  				// 'n' + 2 (version) + 4 (flags) + 4 (challenge) + name...
   196  				if len(b.B) < 12 {
   197  					return details, fmt.Errorf("malformed handshake ('n')")
   198  				}
   199  
   200  				challenge, err := dh.readChallenge(buffer[1:], &details)
   201  				if err != nil {
   202  					return details, err
   203  				}
   204  				b.Reset()
   205  
   206  				dh.composeChallengeReply(b, challenge, tls, cookie)
   207  
   208  				if e := b.WriteDataTo(conn); e != nil {
   209  					return details, e
   210  				}
   211  				// add 's' status for the case if we got it after 'n' or 'N' message
   212  				// yes, sometime it happens
   213  				await = []byte{'s', 'a'}
   214  
   215  			case 'N':
   216  				// Peer support version 6.
   217  
   218  				// The new challenge message format (version 6)
   219  				// 8 (flags) + 4 (Creation) + 2 (NameLen) + Name
   220  				if len(buffer) < 16 {
   221  					return details, fmt.Errorf("malformed handshake ('N' length)")
   222  				}
   223  
   224  				challenge, err := dh.readChallengeVersion6(buffer[1:], &details)
   225  				if err != nil {
   226  					return details, err
   227  				}
   228  				b.Reset()
   229  
   230  				if dh.options.Version == HandshakeVersion5 {
   231  					// upgrade handshake to version 6 by sending complement message
   232  					dh.composeComplement(b, tls)
   233  					if e := b.WriteDataTo(conn); e != nil {
   234  						return details, e
   235  					}
   236  				}
   237  
   238  				dh.composeChallengeReply(b, challenge, tls, cookie)
   239  
   240  				if e := b.WriteDataTo(conn); e != nil {
   241  					return details, e
   242  				}
   243  
   244  				// add 's' (send_status message) for the case if we got it after 'n' or 'N' message
   245  				await = []byte{'s', 'a'}
   246  
   247  			case 'a':
   248  				// 'a' + 16 (digest)
   249  				if len(buffer) < 17 {
   250  					return details, fmt.Errorf("malformed handshake ('a' length of digest)")
   251  				}
   252  
   253  				// 'a' + 16 (digest)
   254  				digest := genDigest(dh.challenge, cookie)
   255  				if bytes.Compare(buffer[1:17], digest) != 0 {
   256  					return details, fmt.Errorf("malformed handshake ('a' digest)")
   257  				}
   258  
   259  				// check if we got DIST packet with the final handshake data.
   260  				if len(buffer) > 17 {
   261  					details.Buffer = lib.TakeBuffer()
   262  					details.Buffer.Set(buffer[17:])
   263  				}
   264  
   265  				// handshaked
   266  				return details, nil
   267  
   268  			case 's':
   269  				if dh.readStatus(buffer[1:]) == false {
   270  					return details, fmt.Errorf("handshake negotiation failed")
   271  				}
   272  
   273  				await = []byte{'n', 'N'}
   274  				// "sok"
   275  				if len(buffer) > 4 {
   276  					b.B = b.B[expectingBytes+3:]
   277  					goto next
   278  				}
   279  				b.Reset()
   280  
   281  			default:
   282  				return details, fmt.Errorf("malformed handshake ('%c' digest)", buffer[0])
   283  			}
   284  
   285  		}
   286  
   287  	}
   288  
   289  }
   290  
   291  func (dh *DistHandshake) Accept(remote net.Addr, conn lib.NetReadWriter, tls bool, cookie string) (node.HandshakeDetails, error) {
   292  	var details node.HandshakeDetails
   293  
   294  	b := lib.TakeBuffer()
   295  	defer lib.ReleaseBuffer(b)
   296  
   297  	var await []byte
   298  
   299  	// define timeout for the handshaking
   300  	timer := time.NewTimer(dh.options.Timeout)
   301  	defer timer.Stop()
   302  
   303  	asyncReadChannel := make(chan error, 2)
   304  	asyncRead := func() {
   305  		_, e := b.ReadDataFrom(conn, 512)
   306  		asyncReadChannel <- e
   307  	}
   308  
   309  	// http://erlang.org/doc/apps/erts/erl_dist_protocol.html#distribution-handshake
   310  	// Every message in the handshake starts with a 16-bit big-endian integer,
   311  	// which contains the message length (not counting the two initial bytes).
   312  	// In Erlang this corresponds to option {packet, 2} in gen_tcp(3). Notice
   313  	// that after the handshake, the distribution switches to 4 byte packet headers.
   314  	expectingBytes := 2
   315  	if tls {
   316  		// TLS connection has 4 bytes packet length header
   317  		expectingBytes = 4
   318  	}
   319  
   320  	// the comming message must be 'receive_name' as an answer for the
   321  	// 'send_name' message request we just sent
   322  	await = []byte{'n', 'N'}
   323  
   324  	for {
   325  		go asyncRead()
   326  
   327  		select {
   328  		case <-timer.C:
   329  			return details, fmt.Errorf("handshake accept timeout")
   330  		case e := <-asyncReadChannel:
   331  			if e != nil {
   332  				return details, e
   333  			}
   334  
   335  			if b.Len() < expectingBytes+1 {
   336  				return details, fmt.Errorf("malformed handshake (too short packet)")
   337  			}
   338  
   339  		next:
   340  			l := binary.BigEndian.Uint16(b.B[expectingBytes-2 : expectingBytes])
   341  			buffer := b.B[expectingBytes:]
   342  
   343  			if len(buffer) < int(l) {
   344  				return details, fmt.Errorf("malformed handshake (wrong packet length)")
   345  			}
   346  
   347  			if bytes.Count(await, buffer[0:1]) == 0 {
   348  				return details, fmt.Errorf("malformed handshake (wrong response %d)", buffer[0])
   349  			}
   350  
   351  			switch buffer[0] {
   352  			case 'n':
   353  				if len(buffer) < 8 {
   354  					return details, fmt.Errorf("malformed handshake ('n' length)")
   355  				}
   356  
   357  				if err := dh.readName(buffer[1:], &details); err != nil {
   358  					return details, err
   359  				}
   360  				b.Reset()
   361  				dh.composeStatus(b, tls)
   362  				if e := b.WriteDataTo(conn); e != nil {
   363  					return details, fmt.Errorf("malformed handshake ('n' accept name)")
   364  				}
   365  
   366  				b.Reset()
   367  				if details.Version == 6 {
   368  					dh.composeChallengeVersion6(b, tls)
   369  					await = []byte{'s', 'r', 'c'}
   370  				} else {
   371  					dh.composeChallenge(b, tls)
   372  					await = []byte{'s', 'r'}
   373  				}
   374  				if e := b.WriteDataTo(conn); e != nil {
   375  					return details, e
   376  				}
   377  
   378  			case 'N':
   379  				// The new challenge message format (version 6)
   380  				// 8 (flags) + 4 (Creation) + 2 (NameLen) + Name
   381  				if len(buffer) < 16 {
   382  					return details, fmt.Errorf("malformed handshake ('N' length)")
   383  				}
   384  				if err := dh.readNameVersion6(buffer[1:], &details); err != nil {
   385  					return details, err
   386  				}
   387  				b.Reset()
   388  				dh.composeStatus(b, tls)
   389  				if e := b.WriteDataTo(conn); e != nil {
   390  					return details, fmt.Errorf("malformed handshake ('N' accept name)")
   391  				}
   392  
   393  				b.Reset()
   394  				dh.composeChallengeVersion6(b, tls)
   395  				if e := b.WriteDataTo(conn); e != nil {
   396  					return details, e
   397  				}
   398  
   399  				await = []byte{'s', 'r'}
   400  
   401  			case 'c':
   402  				if len(buffer) < 9 {
   403  					return details, fmt.Errorf("malformed handshake ('c' length)")
   404  				}
   405  				dh.readComplement(buffer[1:], &details)
   406  
   407  				await = []byte{'r'}
   408  
   409  				if len(buffer) > 9 {
   410  					b.B = b.B[expectingBytes+9:]
   411  					goto next
   412  				}
   413  				b.Reset()
   414  
   415  			case 'r':
   416  				if len(buffer) < 19 {
   417  					return details, fmt.Errorf("malformed handshake ('r' length)")
   418  				}
   419  
   420  				challenge, valid := dh.validateChallengeReply(buffer[1:], cookie)
   421  				if valid == false {
   422  					return details, fmt.Errorf("malformed handshake ('r' invalid reply)")
   423  				}
   424  				b.Reset()
   425  
   426  				dh.composeChallengeAck(b, challenge, tls, cookie)
   427  				if e := b.WriteDataTo(conn); e != nil {
   428  					return details, e
   429  				}
   430  
   431  				// handshaked
   432  
   433  				return details, nil
   434  
   435  			case 's':
   436  				if dh.readStatus(buffer[1:]) == false {
   437  					return details, fmt.Errorf("link status != ok")
   438  				}
   439  
   440  				await = []byte{'c', 'r'}
   441  				if len(buffer) > 4 {
   442  					b.B = b.B[expectingBytes+3:]
   443  					goto next
   444  				}
   445  				b.Reset()
   446  
   447  			default:
   448  				return details, fmt.Errorf("malformed handshake (unknown code %d)", b.B[0])
   449  			}
   450  
   451  		}
   452  
   453  	}
   454  }
   455  
   456  // private functions
   457  
   458  func (dh *DistHandshake) composeName(b *lib.Buffer, tls bool) {
   459  	flags := composeFlags(dh.flags)
   460  	version := uint16(dh.options.Version)
   461  	if tls {
   462  		b.Allocate(11)
   463  		dataLength := 7 + len(dh.nodename) // byte + uint16 + uint32 + len(dh.nodename)
   464  		binary.BigEndian.PutUint32(b.B[0:4], uint32(dataLength))
   465  		b.B[4] = 'n'
   466  		binary.BigEndian.PutUint16(b.B[5:7], version)           // uint16
   467  		binary.BigEndian.PutUint32(b.B[7:11], flags.toUint32()) // uint32
   468  		b.Append([]byte(dh.nodename))
   469  		return
   470  	}
   471  
   472  	b.Allocate(9)
   473  	dataLength := 7 + len(dh.nodename) // byte + uint16 + uint32 + len(dh.nodename)
   474  	binary.BigEndian.PutUint16(b.B[0:2], uint16(dataLength))
   475  	b.B[2] = 'n'
   476  	binary.BigEndian.PutUint16(b.B[3:5], version)          // uint16
   477  	binary.BigEndian.PutUint32(b.B[5:9], flags.toUint32()) // uint32
   478  	b.Append([]byte(dh.nodename))
   479  }
   480  
   481  func (dh *DistHandshake) composeNameVersion6(b *lib.Buffer, tls bool) {
   482  	flags := composeFlags(dh.flags)
   483  	creation := uint32(dh.creation)
   484  	if tls {
   485  		b.Allocate(19)
   486  		dataLength := 15 + len(dh.nodename) // 1 + 8 (flags) + 4 (creation) + 2 (len dh.nodename)
   487  		binary.BigEndian.PutUint32(b.B[0:4], uint32(dataLength))
   488  		b.B[4] = 'N'
   489  		binary.BigEndian.PutUint64(b.B[5:13], flags.toUint64())          // uint64
   490  		binary.BigEndian.PutUint32(b.B[13:17], creation)                 //uint32
   491  		binary.BigEndian.PutUint16(b.B[17:19], uint16(len(dh.nodename))) // uint16
   492  		b.Append([]byte(dh.nodename))
   493  		return
   494  	}
   495  
   496  	b.Allocate(17)
   497  	dataLength := 15 + len(dh.nodename) // 1 + 8 (flags) + 4 (creation) + 2 (len dh.nodename)
   498  	binary.BigEndian.PutUint16(b.B[0:2], uint16(dataLength))
   499  	b.B[2] = 'N'
   500  	binary.BigEndian.PutUint64(b.B[3:11], flags.toUint64())          // uint64
   501  	binary.BigEndian.PutUint32(b.B[11:15], creation)                 // uint32
   502  	binary.BigEndian.PutUint16(b.B[15:17], uint16(len(dh.nodename))) // uint16
   503  	b.Append([]byte(dh.nodename))
   504  }
   505  
   506  func (dh *DistHandshake) readName(b []byte, details *node.HandshakeDetails) error {
   507  	flags := nodeFlags(binary.BigEndian.Uint32(b[2:6]))
   508  	details.Flags = node.DefaultFlags()
   509  	details.Flags.EnableFragmentation = flags.isSet(flagFragments)
   510  	details.Flags.EnableBigCreation = flags.isSet(flagBigCreation)
   511  	details.Flags.EnableHeaderAtomCache = flags.isSet(flagDistHdrAtomCache)
   512  	details.Flags.EnableAlias = flags.isSet(flagAlias)
   513  	details.Flags.EnableRemoteSpawn = flags.isSet(flagSpawn)
   514  	details.Flags.EnableBigPidRef = flags.isSet(flagV4NC)
   515  	version := int(binary.BigEndian.Uint16(b[0:2]))
   516  	if version != 5 {
   517  		return fmt.Errorf("Malformed version for handshake 5")
   518  	}
   519  
   520  	details.Version = 5
   521  	if flags.isSet(flagHandshake23) {
   522  		details.Version = 6
   523  	}
   524  
   525  	// Erlang node limits the node name length to 256 characters (not bytes).
   526  	// I don't think anyone wants to use such a ridiculous name with a length > 250 bytes.
   527  	// Report an issue you really want to have a name longer that 255 bytes.
   528  	if len(b[6:]) > 255 {
   529  		return fmt.Errorf("Malformed node name")
   530  	}
   531  	details.Name = string(b[6:])
   532  
   533  	return nil
   534  }
   535  
   536  func (dh *DistHandshake) readNameVersion6(b []byte, details *node.HandshakeDetails) error {
   537  	details.Creation = binary.BigEndian.Uint32(b[8:12])
   538  
   539  	flags := nodeFlags(binary.BigEndian.Uint64(b[0:8]))
   540  	details.Flags = node.DefaultFlags()
   541  	details.Flags.EnableFragmentation = flags.isSet(flagFragments)
   542  	details.Flags.EnableBigCreation = flags.isSet(flagBigCreation)
   543  	details.Flags.EnableHeaderAtomCache = flags.isSet(flagDistHdrAtomCache)
   544  	details.Flags.EnableAlias = flags.isSet(flagAlias)
   545  	details.Flags.EnableRemoteSpawn = flags.isSet(flagSpawn)
   546  	details.Flags.EnableBigPidRef = flags.isSet(flagV4NC)
   547  	details.Flags.EnableCompression = flags.isSet(flagCompression)
   548  	details.Flags.EnableProxy = flags.isSet(flagProxy)
   549  
   550  	// see my prev comment about name len
   551  	nameLen := int(binary.BigEndian.Uint16(b[12:14]))
   552  	if nameLen > 255 {
   553  		return fmt.Errorf("Malformed node name")
   554  	}
   555  	nodename := string(b[14 : 14+nameLen])
   556  	details.Name = nodename
   557  
   558  	return nil
   559  }
   560  
   561  func (dh *DistHandshake) composeStatus(b *lib.Buffer, tls bool) {
   562  	// there are few options for the status: ok, ok_simultaneous, nok, not_allowed, alive
   563  	// More details here: https://erlang.org/doc/apps/erts/erl_dist_protocol.html#the-handshake-in-detail
   564  	// support "ok" only, in any other cases link will be just closed
   565  
   566  	if tls {
   567  		b.Allocate(4)
   568  		dataLength := 3 // 's' + "ok"
   569  		binary.BigEndian.PutUint32(b.B[0:4], uint32(dataLength))
   570  		b.Append([]byte("sok"))
   571  		return
   572  	}
   573  
   574  	b.Allocate(2)
   575  	dataLength := 3 // 's' + "ok"
   576  	binary.BigEndian.PutUint16(b.B[0:2], uint16(dataLength))
   577  	b.Append([]byte("sok"))
   578  
   579  }
   580  
   581  func (dh *DistHandshake) readStatus(msg []byte) bool {
   582  	if string(msg[:2]) == "ok" {
   583  		return true
   584  	}
   585  
   586  	return false
   587  }
   588  
   589  func (dh *DistHandshake) composeChallenge(b *lib.Buffer, tls bool) {
   590  	flags := composeFlags(dh.flags)
   591  	if tls {
   592  		b.Allocate(15)
   593  		dataLength := uint32(11 + len(dh.nodename))
   594  		binary.BigEndian.PutUint32(b.B[0:4], dataLength)
   595  		b.B[4] = 'n'
   596  
   597  		//https://www.erlang.org/doc/apps/erts/erl_dist_protocol.html#distribution-handshake
   598  		// The Version is a 16-bit big endian integer and must always have the value 5
   599  		binary.BigEndian.PutUint16(b.B[5:7], 5) // uint16
   600  
   601  		binary.BigEndian.PutUint32(b.B[7:11], flags.toUint32()) // uint32
   602  		binary.BigEndian.PutUint32(b.B[11:15], dh.challenge)    // uint32
   603  		b.Append([]byte(dh.nodename))
   604  		return
   605  	}
   606  
   607  	b.Allocate(13)
   608  	dataLength := 11 + len(dh.nodename)
   609  	binary.BigEndian.PutUint16(b.B[0:2], uint16(dataLength))
   610  	b.B[2] = 'n'
   611  	//https://www.erlang.org/doc/apps/erts/erl_dist_protocol.html#distribution-handshake
   612  	// The Version is a 16-bit big endian integer and must always have the value 5
   613  	binary.BigEndian.PutUint16(b.B[3:5], 5)                // uint16
   614  	binary.BigEndian.PutUint32(b.B[5:9], flags.toUint32()) // uint32
   615  	binary.BigEndian.PutUint32(b.B[9:13], dh.challenge)    // uint32
   616  	b.Append([]byte(dh.nodename))
   617  }
   618  
   619  func (dh *DistHandshake) composeChallengeVersion6(b *lib.Buffer, tls bool) {
   620  
   621  	flags := composeFlags(dh.flags)
   622  	if tls {
   623  		// 1 ('N') + 8 (flags) + 4 (chalange) + 4 (creation) + 2 (len(dh.nodename))
   624  		b.Allocate(23)
   625  		dataLength := 19 + len(dh.nodename)
   626  		binary.BigEndian.PutUint32(b.B[0:4], uint32(dataLength))
   627  		b.B[4] = 'N'
   628  		binary.BigEndian.PutUint64(b.B[5:13], uint64(flags))             // uint64
   629  		binary.BigEndian.PutUint32(b.B[13:17], dh.challenge)             // uint32
   630  		binary.BigEndian.PutUint32(b.B[17:21], dh.creation)              // uint32
   631  		binary.BigEndian.PutUint16(b.B[21:23], uint16(len(dh.nodename))) // uint16
   632  		b.Append([]byte(dh.nodename))
   633  		return
   634  	}
   635  
   636  	// 1 ('N') + 8 (flags) + 4 (chalange) + 4 (creation) + 2 (len(dh.nodename))
   637  	b.Allocate(21)
   638  	dataLength := 19 + len(dh.nodename)
   639  	binary.BigEndian.PutUint16(b.B[0:2], uint16(dataLength))
   640  	b.B[2] = 'N'
   641  	binary.BigEndian.PutUint64(b.B[3:11], uint64(flags))             // uint64
   642  	binary.BigEndian.PutUint32(b.B[11:15], dh.challenge)             // uint32
   643  	binary.BigEndian.PutUint32(b.B[15:19], dh.creation)              // uint32
   644  	binary.BigEndian.PutUint16(b.B[19:21], uint16(len(dh.nodename))) // uint16
   645  	b.Append([]byte(dh.nodename))
   646  }
   647  
   648  func (dh *DistHandshake) readChallenge(msg []byte, details *node.HandshakeDetails) (uint32, error) {
   649  	var challenge uint32
   650  	if len(msg) < 15 {
   651  		return challenge, fmt.Errorf("malformed handshake challenge")
   652  	}
   653  	flags := nodeFlags(binary.BigEndian.Uint32(msg[2:6]))
   654  	details.Flags = node.DefaultFlags()
   655  	details.Flags.EnableFragmentation = flags.isSet(flagFragments)
   656  	details.Flags.EnableBigCreation = flags.isSet(flagBigCreation)
   657  	details.Flags.EnableHeaderAtomCache = flags.isSet(flagDistHdrAtomCache)
   658  	details.Flags.EnableAlias = flags.isSet(flagAlias)
   659  	details.Flags.EnableRemoteSpawn = flags.isSet(flagSpawn)
   660  	details.Flags.EnableBigPidRef = flags.isSet(flagV4NC)
   661  
   662  	version := binary.BigEndian.Uint16(msg[0:2])
   663  	if version != uint16(HandshakeVersion5) {
   664  		return challenge, fmt.Errorf("malformed handshake version %d", version)
   665  	}
   666  	details.Version = int(version)
   667  
   668  	if flags.isSet(flagHandshake23) {
   669  		// remote peer does support version 6
   670  		details.Version = 6
   671  	}
   672  
   673  	details.Name = string(msg[10:])
   674  	challenge = binary.BigEndian.Uint32(msg[6:10])
   675  	return challenge, nil
   676  }
   677  
   678  func (dh *DistHandshake) readChallengeVersion6(msg []byte, details *node.HandshakeDetails) (uint32, error) {
   679  	var challenge uint32
   680  	flags := nodeFlags(binary.BigEndian.Uint64(msg[0:8]))
   681  	details.Flags = node.DefaultFlags()
   682  	details.Flags.EnableFragmentation = flags.isSet(flagFragments)
   683  	details.Flags.EnableBigCreation = flags.isSet(flagBigCreation)
   684  	details.Flags.EnableHeaderAtomCache = flags.isSet(flagDistHdrAtomCache)
   685  	details.Flags.EnableAlias = flags.isSet(flagAlias)
   686  	details.Flags.EnableRemoteSpawn = flags.isSet(flagSpawn)
   687  	details.Flags.EnableBigPidRef = flags.isSet(flagV4NC)
   688  	details.Flags.EnableCompression = flags.isSet(flagCompression)
   689  	details.Flags.EnableProxy = flags.isSet(flagProxy)
   690  
   691  	details.Creation = binary.BigEndian.Uint32(msg[12:16])
   692  	details.Version = 6
   693  
   694  	challenge = binary.BigEndian.Uint32(msg[8:12])
   695  
   696  	lenName := int(binary.BigEndian.Uint16(msg[16:18]))
   697  	details.Name = string(msg[18 : 18+lenName])
   698  
   699  	return challenge, nil
   700  }
   701  
   702  func (dh *DistHandshake) readComplement(msg []byte, details *node.HandshakeDetails) {
   703  	flags := nodeFlags(uint64(binary.BigEndian.Uint32(msg[0:4])) << 32)
   704  
   705  	details.Flags.EnableCompression = flags.isSet(flagCompression)
   706  	details.Flags.EnableProxy = flags.isSet(flagProxy)
   707  	details.Creation = binary.BigEndian.Uint32(msg[4:8])
   708  }
   709  
   710  func (dh *DistHandshake) validateChallengeReply(b []byte, cookie string) (uint32, bool) {
   711  	challenge := binary.BigEndian.Uint32(b[:4])
   712  	digestB := b[4:]
   713  
   714  	digestA := genDigest(dh.challenge, cookie)
   715  	return challenge, bytes.Equal(digestA[:], digestB)
   716  }
   717  
   718  func (dh *DistHandshake) composeChallengeAck(b *lib.Buffer, challenge uint32, tls bool, cookie string) {
   719  	if tls {
   720  		b.Allocate(5)
   721  		dataLength := uint32(17) // 'a' + 16 (digest)
   722  		binary.BigEndian.PutUint32(b.B[0:4], dataLength)
   723  		b.B[4] = 'a'
   724  		digest := genDigest(challenge, cookie)
   725  		b.Append(digest)
   726  		return
   727  	}
   728  
   729  	b.Allocate(3)
   730  	dataLength := uint16(17) // 'a' + 16 (digest)
   731  	binary.BigEndian.PutUint16(b.B[0:2], dataLength)
   732  	b.B[2] = 'a'
   733  	digest := genDigest(challenge, cookie)
   734  	b.Append(digest)
   735  }
   736  
   737  func (dh *DistHandshake) composeChallengeReply(b *lib.Buffer, challenge uint32, tls bool, cookie string) {
   738  	if tls {
   739  		digest := genDigest(challenge, cookie)
   740  		b.Allocate(9)
   741  		dataLength := 5 + len(digest) // 1 (byte) + 4 (challenge) + 16 (digest)
   742  		binary.BigEndian.PutUint32(b.B[0:4], uint32(dataLength))
   743  		b.B[4] = 'r'
   744  		binary.BigEndian.PutUint32(b.B[5:9], dh.challenge) // uint32
   745  		b.Append(digest)
   746  		return
   747  	}
   748  
   749  	b.Allocate(7)
   750  	digest := genDigest(challenge, cookie)
   751  	dataLength := 5 + len(digest) // 1 (byte) + 4 (challenge) + 16 (digest)
   752  	binary.BigEndian.PutUint16(b.B[0:2], uint16(dataLength))
   753  	b.B[2] = 'r'
   754  	binary.BigEndian.PutUint32(b.B[3:7], dh.challenge) // uint32
   755  	b.Append(digest)
   756  }
   757  
   758  func (dh *DistHandshake) composeComplement(b *lib.Buffer, tls bool) {
   759  	flags := composeFlags(dh.flags)
   760  	// cast must cast creation to int32 in order to follow the
   761  	// erlang's handshake. Ergo don't care of it.
   762  	node_flags := uint32(flags.toUint64() >> 32)
   763  	if tls {
   764  		b.Allocate(13)
   765  		dataLength := 9 // 1 + 4 (flag high) + 4 (creation)
   766  		binary.BigEndian.PutUint32(b.B[0:4], uint32(dataLength))
   767  		b.B[4] = 'c'
   768  		binary.BigEndian.PutUint32(b.B[5:9], node_flags)
   769  		binary.BigEndian.PutUint32(b.B[9:13], dh.creation)
   770  		return
   771  	}
   772  
   773  	dataLength := 9 // 1 + 4 (flag high) + 4 (creation)
   774  	b.Allocate(11)
   775  	binary.BigEndian.PutUint16(b.B[0:2], uint16(dataLength))
   776  	b.B[2] = 'c'
   777  	binary.BigEndian.PutUint32(b.B[3:7], node_flags)
   778  	binary.BigEndian.PutUint32(b.B[7:11], dh.creation)
   779  }
   780  
   781  func genDigest(challenge uint32, cookie string) []byte {
   782  	s := fmt.Sprintf("%s%d", cookie, challenge)
   783  	digest := md5.Sum([]byte(s))
   784  	return digest[:]
   785  }
   786  
   787  func composeFlags(flags node.Flags) nodeFlags {
   788  
   789  	// default flags
   790  	enabledFlags := []nodeFlagId{
   791  		flagPublished,
   792  		flagUnicodeIO,
   793  		flagDistMonitor,
   794  		flagNewFloats,
   795  		flagBitBinaries,
   796  		flagDistMonitorName,
   797  		flagExtendedPidsPorts,
   798  		flagExtendedReferences,
   799  		flagAtomCache,
   800  		flagHiddenAtomCache,
   801  		flagFunTags,
   802  		flagNewFunTags,
   803  		flagExportPtrTag,
   804  		flagSmallAtomTags,
   805  		flagUTF8Atoms,
   806  		flagMapTag,
   807  		flagHandshake23,
   808  	}
   809  
   810  	// optional flags
   811  	if flags.EnableHeaderAtomCache {
   812  		enabledFlags = append(enabledFlags, flagDistHdrAtomCache)
   813  	}
   814  	if flags.EnableFragmentation {
   815  		enabledFlags = append(enabledFlags, flagFragments)
   816  	}
   817  	if flags.EnableBigCreation {
   818  		enabledFlags = append(enabledFlags, flagBigCreation)
   819  	}
   820  	if flags.EnableAlias {
   821  		enabledFlags = append(enabledFlags, flagAlias)
   822  	}
   823  	if flags.EnableBigPidRef {
   824  		enabledFlags = append(enabledFlags, flagV4NC)
   825  	}
   826  	if flags.EnableRemoteSpawn {
   827  		enabledFlags = append(enabledFlags, flagSpawn)
   828  	}
   829  	if flags.EnableCompression {
   830  		enabledFlags = append(enabledFlags, flagCompression)
   831  	}
   832  	if flags.EnableProxy {
   833  		enabledFlags = append(enabledFlags, flagProxy)
   834  	}
   835  	return toNodeFlags(enabledFlags...)
   836  }