github.com/zmap/zcrypto@v0.0.0-20240512203510-0fef58d9a9db/tls/tls_heartbeat.go (about)

     1  // Copyright 2015 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package tls
     6  
     7  import (
     8  	"errors"
     9  )
    10  
    11  const (
    12  	// Record Type
    13  	recordTypeHeartbeat recordType = 24
    14  
    15  	// Extension Number
    16  	extensionHeartbeat uint16 = 15
    17  
    18  	// Heartbeat Mode
    19  	heartbeatModePeerAllowed    uint8 = 1
    20  	heartbeatModePeerNotAllowed uint8 = 2
    21  
    22  	// Heartbeat Message Types
    23  	heartbeatTypeRequest  uint8 = 1
    24  	heartbeatTypeResponse uint8 = 2
    25  )
    26  
    27  var (
    28  	HeartbleedError = errors.New("Error after Heartbleed")
    29  )
    30  
    31  type Heartbleed struct {
    32  	HeartbeatEnabled bool `json:"heartbeat_enabled"`
    33  	Vulnerable       bool `json:"heartbleed_vulnerable"`
    34  }
    35  
    36  type heartbleedMessage struct {
    37  	raw []byte
    38  }
    39  
    40  func (m *heartbleedMessage) marshal() []byte {
    41  	x := make([]byte, 3)
    42  	x[0] = 1
    43  	x[1] = byte(0x00)
    44  	x[2] = byte(0x00)
    45  	m.raw = x
    46  	return x
    47  }
    48  
    49  func (c *Conn) CheckHeartbleed(b []byte) (n int, err error) {
    50  	if err = c.Handshake(); err != nil {
    51  		return
    52  	}
    53  	if !c.heartbeat {
    54  		return
    55  	}
    56  	c.in.Lock()
    57  	defer c.in.Unlock()
    58  
    59  	hb := heartbleedMessage{}
    60  	hb.marshal()
    61  
    62  	if _, err = c.writeRecord(recordTypeHeartbeat, hb.raw); err != nil {
    63  		return 0, err
    64  	}
    65  
    66  	if err = c.readRecord(recordTypeHeartbeat); err != nil {
    67  		return 0, HeartbleedError
    68  	}
    69  	if c.in.err != nil {
    70  		return 0, HeartbleedError
    71  	}
    72  	n, err = c.input.Read(b)
    73  	if c.input.off >= len(c.input.data) {
    74  		c.in.freeBlock(c.input)
    75  		c.input = nil
    76  	}
    77  
    78  	if n != 0 {
    79  		return n, HeartbleedError
    80  	}
    81  	if err != nil {
    82  		return 0, HeartbleedError
    83  	}
    84  	return 0, HeartbleedError
    85  }
    86  
    87  func (c *Conn) GetHeartbleedLog() *Heartbleed {
    88  	return c.heartbleedLog
    89  }