github.com/ergo-services/ergo@v1.999.224/apps/cloud/handshake.go (about) 1 package cloud 2 3 import ( 4 "bytes" 5 "crypto/sha256" 6 "encoding/binary" 7 "fmt" 8 "hash" 9 "io" 10 "net" 11 "time" 12 13 "github.com/ergo-services/ergo/etf" 14 "github.com/ergo-services/ergo/lib" 15 "github.com/ergo-services/ergo/node" 16 ) 17 18 const ( 19 defaultHandshakeTimeout = 5 * time.Second 20 clusterNameLengthMax = 128 21 ) 22 23 type Handshake struct { 24 node.Handshake 25 nodename string 26 creation uint32 27 options node.Cloud 28 flags node.Flags 29 } 30 31 type handshakeDetails struct { 32 cookieHash []byte 33 digestRemote []byte 34 details node.HandshakeDetails 35 mapName string 36 hash hash.Hash 37 } 38 39 func createHandshake(options node.Cloud) (node.HandshakeInterface, error) { 40 if options.Timeout == 0 { 41 options.Timeout = defaultHandshakeTimeout 42 } 43 44 if err := RegisterTypes(); err != nil { 45 return nil, err 46 } 47 48 return &Handshake{ 49 options: options, 50 }, nil 51 } 52 53 func (ch *Handshake) Init(nodename string, creation uint32, flags node.Flags) error { 54 if flags.EnableProxy == false { 55 s := "proxy feature must be enabled for the cloud connection" 56 lib.Warning(s) 57 return fmt.Errorf(s) 58 } 59 if ch.options.Cluster == "" { 60 s := "option Cloud.Cluster can not be empty" 61 lib.Warning(s) 62 return fmt.Errorf(s) 63 } 64 if len(ch.options.Cluster) > clusterNameLengthMax { 65 s := "option Cloud.Cluster has too long name" 66 lib.Warning(s) 67 return fmt.Errorf(s) 68 } 69 ch.nodename = nodename 70 ch.creation = creation 71 ch.flags = flags 72 if ch.options.Flags.Enable == false { 73 return nil 74 } 75 76 ch.flags.EnableRemoteSpawn = ch.options.Flags.EnableRemoteSpawn 77 return nil 78 } 79 80 func (ch *Handshake) Start(remote net.Addr, conn lib.NetReadWriter, tls bool, cookie string) (node.HandshakeDetails, error) { 81 hash := sha256.New() 82 handshake := &handshakeDetails{ 83 cookieHash: hash.Sum([]byte(cookie)), 84 hash: hash, 85 } 86 handshake.details.Flags = ch.flags 87 88 ch.sendV1Auth(conn) 89 90 // define timeout for the handshaking 91 timer := time.NewTimer(ch.options.Timeout) 92 defer timer.Stop() 93 94 b := lib.TakeBuffer() 95 defer lib.ReleaseBuffer(b) 96 97 asyncReadChannel := make(chan error, 2) 98 asyncRead := func() { 99 _, err := b.ReadDataFrom(conn, 1024) 100 asyncReadChannel <- err 101 } 102 103 expectingBytes := 4 104 await := []byte{ProtoHandshakeV1AuthReply, ProtoHandshakeV1Error} 105 rest := []byte{} 106 107 for { 108 go asyncRead() 109 select { 110 case <-timer.C: 111 return handshake.details, fmt.Errorf("timeout") 112 case err := <-asyncReadChannel: 113 if err != nil { 114 return handshake.details, err 115 } 116 117 if b.Len() < expectingBytes { 118 continue 119 } 120 121 if b.B[0] != ProtoHandshakeV1 { 122 return handshake.details, fmt.Errorf("malformed handshake proto") 123 } 124 125 l := int(binary.BigEndian.Uint16(b.B[2:4])) 126 buffer := b.B[4 : l+4] 127 128 if len(buffer) != l { 129 return handshake.details, fmt.Errorf("malformed handshake (wrong packet length)") 130 } 131 132 // check if we got correct message type regarding to 'await' value 133 if bytes.Count(await, b.B[1:2]) == 0 { 134 return handshake.details, fmt.Errorf("malformed handshake sequence") 135 } 136 137 await, rest, err = ch.handle(conn, b.B[1], buffer, handshake) 138 if err != nil { 139 return handshake.details, err 140 } 141 142 if await == nil && rest != nil { 143 // handshaked with some extra data. keep them for the Proto handler 144 handshake.details.Buffer = lib.TakeBuffer() 145 handshake.details.Buffer.Set(rest) 146 } 147 148 b.Reset() 149 } 150 151 if await == nil { 152 // handshaked 153 break 154 } 155 } 156 157 return handshake.details, nil 158 } 159 160 func (ch *Handshake) handle(socket io.Writer, messageType byte, buffer []byte, details *handshakeDetails) ([]byte, []byte, error) { 161 switch messageType { 162 case ProtoHandshakeV1AuthReply: 163 if err := ch.handleV1AuthReply(buffer, details); err != nil { 164 return nil, nil, err 165 } 166 if err := ch.sendV1Challenge(socket, details); err != nil { 167 return nil, nil, err 168 } 169 return []byte{ProtoHandshakeV1ChallengeAccept, ProtoHandshakeV1Error}, nil, nil 170 171 case ProtoHandshakeV1ChallengeAccept: 172 rest, err := ch.handleV1ChallegeAccept(buffer, details) 173 if err != nil { 174 return nil, nil, err 175 } 176 return nil, rest, err 177 178 case ProtoHandshakeV1Error: 179 return nil, nil, ch.handleV1Error(buffer) 180 181 default: 182 return nil, nil, fmt.Errorf("unknown message type") 183 } 184 } 185 186 func (ch *Handshake) sendV1Auth(socket io.Writer) error { 187 b := lib.TakeBuffer() 188 defer lib.ReleaseBuffer(b) 189 190 message := MessageHandshakeV1Auth{ 191 Node: ch.nodename, 192 Cluster: ch.options.Cluster, 193 Creation: ch.creation, 194 Flags: ch.options.Flags, 195 } 196 b.Allocate(1 + 1 + 2) 197 b.B[0] = ProtoHandshakeV1 198 b.B[1] = ProtoHandshakeV1Auth 199 if err := etf.Encode(message, b, etf.EncodeOptions{}); err != nil { 200 return err 201 } 202 binary.BigEndian.PutUint16(b.B[2:4], uint16(b.Len()-4)) 203 if err := b.WriteDataTo(socket); err != nil { 204 return err 205 } 206 207 return nil 208 } 209 210 func (ch *Handshake) sendV1Challenge(socket io.Writer, handshake *handshakeDetails) error { 211 b := lib.TakeBuffer() 212 defer lib.ReleaseBuffer(b) 213 214 digest := GenDigest(handshake.hash, []byte(ch.nodename), handshake.digestRemote, handshake.cookieHash) 215 message := MessageHandshakeV1Challenge{ 216 Digest: digest, 217 } 218 b.Allocate(1 + 1 + 2) 219 b.B[0] = ProtoHandshakeV1 220 b.B[1] = ProtoHandshakeV1Challenge 221 if err := etf.Encode(message, b, etf.EncodeOptions{}); err != nil { 222 return err 223 } 224 binary.BigEndian.PutUint16(b.B[2:4], uint16(b.Len()-4)) 225 if err := b.WriteDataTo(socket); err != nil { 226 return err 227 } 228 229 return nil 230 231 } 232 233 func (ch *Handshake) handleV1AuthReply(buffer []byte, handshake *handshakeDetails) error { 234 m, _, err := etf.Decode(buffer, nil, etf.DecodeOptions{}) 235 if err != nil { 236 return fmt.Errorf("malformed MessageHandshakeV1AuthReply message: %s", err) 237 } 238 message, ok := m.(MessageHandshakeV1AuthReply) 239 if ok == false { 240 return fmt.Errorf("malformed MessageHandshakeV1AuthReply message: %#v", m) 241 } 242 243 digest := GenDigest(handshake.hash, []byte(message.Node), []byte(ch.options.Cluster), handshake.cookieHash) 244 if bytes.Compare(message.Digest, digest) != 0 { 245 return fmt.Errorf("authorization failed") 246 } 247 handshake.digestRemote = digest 248 handshake.details.Name = message.Node 249 handshake.details.Creation = message.Creation 250 251 return nil 252 } 253 254 func (ch *Handshake) handleV1ChallegeAccept(buffer []byte, handshake *handshakeDetails) ([]byte, error) { 255 m, rest, err := etf.Decode(buffer, nil, etf.DecodeOptions{}) 256 if err != nil { 257 return nil, fmt.Errorf("malformed MessageHandshakeV1ChallengeAccept message: %s", err) 258 } 259 message, ok := m.(MessageHandshakeV1ChallengeAccept) 260 if ok == false { 261 return nil, fmt.Errorf("malformed MessageHandshakeV1ChallengeAccept message: %#v", m) 262 } 263 264 mapping := etf.NewAtomMapping() 265 mapping.In[etf.Atom(message.Node)] = etf.Atom(ch.nodename) 266 mapping.Out[etf.Atom(ch.nodename)] = etf.Atom(message.Node) 267 handshake.details.AtomMapping = mapping 268 handshake.mapName = message.Node 269 return rest, nil 270 } 271 272 func (ch *Handshake) handleV1Error(buffer []byte) error { 273 m, _, err := etf.Decode(buffer, nil, etf.DecodeOptions{}) 274 if err != nil { 275 return fmt.Errorf("malformed MessageHandshakeV1Error message: %s", err) 276 } 277 message, ok := m.(MessageHandshakeV1Error) 278 if ok == false { 279 return fmt.Errorf("malformed MessageHandshakeV1Error message: %#v", m) 280 } 281 return fmt.Errorf(message.Reason) 282 }