github.com/status-im/status-go@v1.1.0/protocol/common/message_segmentation.go (about) 1 package common 2 3 import ( 4 "bytes" 5 "math" 6 "time" 7 8 "github.com/golang/protobuf/proto" 9 "github.com/jinzhu/copier" 10 "github.com/klauspost/reedsolomon" 11 "github.com/pkg/errors" 12 "go.uber.org/zap" 13 14 "github.com/status-im/status-go/eth-node/crypto" 15 "github.com/status-im/status-go/eth-node/types" 16 "github.com/status-im/status-go/protocol/protobuf" 17 v1protocol "github.com/status-im/status-go/protocol/v1" 18 ) 19 20 var ErrMessageSegmentsIncomplete = errors.New("message segments incomplete") 21 var ErrMessageSegmentsAlreadyCompleted = errors.New("message segments already completed") 22 var ErrMessageSegmentsInvalidCount = errors.New("invalid segments count") 23 var ErrMessageSegmentsHashMismatch = errors.New("hash of entire payload does not match") 24 var ErrMessageSegmentsInvalidParity = errors.New("invalid parity segments") 25 26 const ( 27 segmentsParityRate = 0.125 28 segmentsReedsolomonMaxCount = 256 29 ) 30 31 type SegmentMessage struct { 32 *protobuf.SegmentMessage 33 } 34 35 func (s *SegmentMessage) IsValid() bool { 36 return s.SegmentsCount >= 2 || s.ParitySegmentsCount > 0 37 } 38 39 func (s *SegmentMessage) IsParityMessage() bool { 40 return s.SegmentsCount == 0 && s.ParitySegmentsCount > 0 41 } 42 43 func (s *MessageSender) segmentMessage(newMessage *types.NewMessage) ([]*types.NewMessage, error) { 44 // We set the max message size to 3/4 of the allowed message size, to leave 45 // room for segment message metadata. 46 newMessages, err := segmentMessage(newMessage, int(s.transport.MaxMessageSize()/4*3)) 47 s.logger.Debug("message segmented", zap.Int("segments", len(newMessages))) 48 return newMessages, err 49 } 50 51 func replicateMessageWithNewPayload(message *types.NewMessage, payload []byte) (*types.NewMessage, error) { 52 copy := &types.NewMessage{} 53 err := copier.Copy(copy, message) 54 if err != nil { 55 return nil, err 56 } 57 58 copy.Payload = payload 59 copy.PowTarget = calculatePoW(payload) 60 return copy, nil 61 } 62 63 // Segments message into smaller chunks if the size exceeds segmentSize. 64 func segmentMessage(newMessage *types.NewMessage, segmentSize int) ([]*types.NewMessage, error) { 65 if len(newMessage.Payload) <= segmentSize { 66 return []*types.NewMessage{newMessage}, nil 67 } 68 69 entireMessageHash := crypto.Keccak256(newMessage.Payload) 70 entirePayloadSize := len(newMessage.Payload) 71 72 segmentsCount := int(math.Ceil(float64(entirePayloadSize) / float64(segmentSize))) 73 paritySegmentsCount := int(math.Floor(float64(segmentsCount) * segmentsParityRate)) 74 75 segmentPayloads := make([][]byte, segmentsCount+paritySegmentsCount) 76 segmentMessages := make([]*types.NewMessage, segmentsCount) 77 78 for start, index := 0, 0; start < entirePayloadSize; start += segmentSize { 79 end := start + segmentSize 80 if end > entirePayloadSize { 81 end = entirePayloadSize 82 } 83 84 segmentPayload := newMessage.Payload[start:end] 85 segmentWithMetadata := &protobuf.SegmentMessage{ 86 EntireMessageHash: entireMessageHash, 87 Index: uint32(index), 88 SegmentsCount: uint32(segmentsCount), 89 Payload: segmentPayload, 90 } 91 marshaledSegmentWithMetadata, err := proto.Marshal(segmentWithMetadata) 92 if err != nil { 93 return nil, err 94 } 95 segmentMessage, err := replicateMessageWithNewPayload(newMessage, marshaledSegmentWithMetadata) 96 if err != nil { 97 return nil, err 98 } 99 100 segmentPayloads[index] = segmentPayload 101 segmentMessages[index] = segmentMessage 102 index++ 103 } 104 105 // Skip reedsolomon if the combined total of data and parity segments exceeds the predefined limit of segmentsReedsolomonMaxCount. 106 // Exceeding this limit necessitates shard sizes to be multiples of 64, which are incompatible with clients that do not support forward error correction. 107 if paritySegmentsCount == 0 || segmentsCount+paritySegmentsCount > segmentsReedsolomonMaxCount { 108 return segmentMessages, nil 109 } 110 111 enc, err := reedsolomon.New(segmentsCount, paritySegmentsCount) 112 if err != nil { 113 return nil, err 114 } 115 116 // Align the size of the last segment payload. 117 lastSegmentPayload := segmentPayloads[segmentsCount-1] 118 segmentPayloads[segmentsCount-1] = make([]byte, segmentSize) 119 copy(segmentPayloads[segmentsCount-1], lastSegmentPayload) 120 121 // Make space for parity data. 122 for i := segmentsCount; i < segmentsCount+paritySegmentsCount; i++ { 123 segmentPayloads[i] = make([]byte, segmentSize) 124 } 125 126 err = enc.Encode(segmentPayloads) 127 if err != nil { 128 return nil, err 129 } 130 131 // Create parity messages. 132 for i, index := segmentsCount, 0; i < segmentsCount+paritySegmentsCount; i++ { 133 segmentWithMetadata := &protobuf.SegmentMessage{ 134 EntireMessageHash: entireMessageHash, 135 SegmentsCount: 0, // indicates parity message 136 ParitySegmentIndex: uint32(index), 137 ParitySegmentsCount: uint32(paritySegmentsCount), 138 Payload: segmentPayloads[i], 139 } 140 marshaledSegmentWithMetadata, err := proto.Marshal(segmentWithMetadata) 141 if err != nil { 142 return nil, err 143 } 144 segmentMessage, err := replicateMessageWithNewPayload(newMessage, marshaledSegmentWithMetadata) 145 if err != nil { 146 return nil, err 147 } 148 149 segmentMessages = append(segmentMessages, segmentMessage) 150 index++ 151 } 152 153 return segmentMessages, nil 154 } 155 156 // SegmentationLayerV1 reconstructs the message only when all segments have been successfully retrieved. 157 // It lacks the capability to perform forward error correction. 158 // Kept to test forward compatibility. 159 func (s *MessageSender) handleSegmentationLayerV1(message *v1protocol.StatusMessage) error { 160 logger := s.logger.With(zap.String("site", "handleSegmentationLayerV1")).With(zap.String("hash", types.HexBytes(message.TransportLayer.Hash).String())) 161 162 segmentMessage := &SegmentMessage{ 163 SegmentMessage: &protobuf.SegmentMessage{}, 164 } 165 err := proto.Unmarshal(message.TransportLayer.Payload, segmentMessage.SegmentMessage) 166 if err != nil { 167 return errors.Wrap(err, "failed to unmarshal SegmentMessage") 168 } 169 170 logger.Debug("handling message segment", zap.String("EntireMessageHash", types.HexBytes(segmentMessage.EntireMessageHash).String()), 171 zap.Uint32("Index", segmentMessage.Index), zap.Uint32("SegmentsCount", segmentMessage.SegmentsCount)) 172 173 alreadyCompleted, err := s.persistence.IsMessageAlreadyCompleted(segmentMessage.EntireMessageHash) 174 if err != nil { 175 return err 176 } 177 if alreadyCompleted { 178 return ErrMessageSegmentsAlreadyCompleted 179 } 180 181 if segmentMessage.SegmentsCount < 2 { 182 return ErrMessageSegmentsInvalidCount 183 } 184 185 err = s.persistence.SaveMessageSegment(segmentMessage, message.TransportLayer.SigPubKey, time.Now().Unix()) 186 if err != nil { 187 return err 188 } 189 190 segments, err := s.persistence.GetMessageSegments(segmentMessage.EntireMessageHash, message.TransportLayer.SigPubKey) 191 if err != nil { 192 return err 193 } 194 195 if len(segments) != int(segmentMessage.SegmentsCount) { 196 return ErrMessageSegmentsIncomplete 197 } 198 199 // Combine payload 200 var entirePayload bytes.Buffer 201 for _, segment := range segments { 202 _, err := entirePayload.Write(segment.Payload) 203 if err != nil { 204 return errors.Wrap(err, "failed to write segment payload") 205 } 206 } 207 208 // Sanity check 209 entirePayloadHash := crypto.Keccak256(entirePayload.Bytes()) 210 if !bytes.Equal(entirePayloadHash, segmentMessage.EntireMessageHash) { 211 return ErrMessageSegmentsHashMismatch 212 } 213 214 err = s.persistence.CompleteMessageSegments(segmentMessage.EntireMessageHash, message.TransportLayer.SigPubKey, time.Now().Unix()) 215 if err != nil { 216 return err 217 } 218 219 message.TransportLayer.Payload = entirePayload.Bytes() 220 221 return nil 222 } 223 224 // SegmentationLayerV2 is capable of reconstructing the message from both complete and partial sets of data segments. 225 // It has capability to perform forward error correction. 226 func (s *MessageSender) handleSegmentationLayerV2(message *v1protocol.StatusMessage) error { 227 logger := s.logger.With(zap.String("site", "handleSegmentationLayerV2")).With(zap.String("hash", types.HexBytes(message.TransportLayer.Hash).String())) 228 229 segmentMessage := &SegmentMessage{ 230 SegmentMessage: &protobuf.SegmentMessage{}, 231 } 232 err := proto.Unmarshal(message.TransportLayer.Payload, segmentMessage.SegmentMessage) 233 if err != nil { 234 return errors.Wrap(err, "failed to unmarshal SegmentMessage") 235 } 236 237 logger.Debug("handling message segment", 238 zap.String("EntireMessageHash", types.HexBytes(segmentMessage.EntireMessageHash).String()), 239 zap.Uint32("Index", segmentMessage.Index), 240 zap.Uint32("SegmentsCount", segmentMessage.SegmentsCount), 241 zap.Uint32("ParitySegmentIndex", segmentMessage.ParitySegmentIndex), 242 zap.Uint32("ParitySegmentsCount", segmentMessage.ParitySegmentsCount)) 243 244 alreadyCompleted, err := s.persistence.IsMessageAlreadyCompleted(segmentMessage.EntireMessageHash) 245 if err != nil { 246 return err 247 } 248 if alreadyCompleted { 249 return ErrMessageSegmentsAlreadyCompleted 250 } 251 252 if !segmentMessage.IsValid() { 253 return ErrMessageSegmentsInvalidCount 254 } 255 256 err = s.persistence.SaveMessageSegment(segmentMessage, message.TransportLayer.SigPubKey, time.Now().Unix()) 257 if err != nil { 258 return err 259 } 260 261 segments, err := s.persistence.GetMessageSegments(segmentMessage.EntireMessageHash, message.TransportLayer.SigPubKey) 262 if err != nil { 263 return err 264 } 265 266 if len(segments) == 0 { 267 return errors.New("unexpected state: no segments found after save operation") // This should theoretically never occur. 268 } 269 270 firstSegmentMessage := segments[0] 271 lastSegmentMessage := segments[len(segments)-1] 272 273 // First segment message must not be a parity message. 274 if firstSegmentMessage.IsParityMessage() || len(segments) != int(firstSegmentMessage.SegmentsCount) { 275 return ErrMessageSegmentsIncomplete 276 } 277 278 payloads := make([][]byte, firstSegmentMessage.SegmentsCount+lastSegmentMessage.ParitySegmentsCount) 279 payloadSize := len(firstSegmentMessage.Payload) 280 281 restoreUsingParityData := lastSegmentMessage.IsParityMessage() 282 if !restoreUsingParityData { 283 for i, segment := range segments { 284 payloads[i] = segment.Payload 285 } 286 } else { 287 enc, err := reedsolomon.New(int(firstSegmentMessage.SegmentsCount), int(lastSegmentMessage.ParitySegmentsCount)) 288 if err != nil { 289 return err 290 } 291 292 var lastNonParitySegmentPayload []byte 293 for _, segment := range segments { 294 if !segment.IsParityMessage() { 295 if segment.Index == firstSegmentMessage.SegmentsCount-1 { 296 // Ensure last segment is aligned to payload size, as it is required by reedsolomon. 297 payloads[segment.Index] = make([]byte, payloadSize) 298 copy(payloads[segment.Index], segment.Payload) 299 lastNonParitySegmentPayload = segment.Payload 300 } else { 301 payloads[segment.Index] = segment.Payload 302 } 303 } else { 304 payloads[firstSegmentMessage.SegmentsCount+segment.ParitySegmentIndex] = segment.Payload 305 } 306 } 307 308 err = enc.Reconstruct(payloads) 309 if err != nil { 310 return err 311 } 312 313 ok, err := enc.Verify(payloads) 314 if err != nil { 315 return err 316 } 317 if !ok { 318 return ErrMessageSegmentsInvalidParity 319 } 320 321 if lastNonParitySegmentPayload != nil { 322 payloads[firstSegmentMessage.SegmentsCount-1] = lastNonParitySegmentPayload // Bring back last segment with original length. 323 } 324 } 325 326 // Combine payload. 327 var entirePayload bytes.Buffer 328 for i := 0; i < int(firstSegmentMessage.SegmentsCount); i++ { 329 _, err := entirePayload.Write(payloads[i]) 330 if err != nil { 331 return errors.Wrap(err, "failed to write segment payload") 332 } 333 } 334 335 // Sanity check. 336 entirePayloadHash := crypto.Keccak256(entirePayload.Bytes()) 337 if !bytes.Equal(entirePayloadHash, segmentMessage.EntireMessageHash) { 338 return ErrMessageSegmentsHashMismatch 339 } 340 341 err = s.persistence.CompleteMessageSegments(segmentMessage.EntireMessageHash, message.TransportLayer.SigPubKey, time.Now().Unix()) 342 if err != nil { 343 return err 344 } 345 346 message.TransportLayer.Payload = entirePayload.Bytes() 347 348 return nil 349 } 350 351 func (s *MessageSender) CleanupSegments() error { 352 monthAgo := time.Now().AddDate(0, -1, 0).Unix() 353 354 err := s.persistence.RemoveMessageSegmentsOlderThan(monthAgo) 355 if err != nil { 356 return err 357 } 358 359 err = s.persistence.RemoveMessageSegmentsCompletedOlderThan(monthAgo) 360 if err != nil { 361 return err 362 } 363 364 return nil 365 }