github.com/fumiama/go-registry@v0.2.7/cmd.go (about)

     1  package registry
     2  
     3  import (
     4  	"crypto/md5"
     5  	"encoding/binary"
     6  	"errors"
     7  	"io"
     8  	"unsafe"
     9  
    10  	base14 "github.com/fumiama/go-base16384"
    11  	tea "github.com/fumiama/gofastTEA"
    12  )
    13  
    14  const (
    15  	CMDGET uint8 = iota
    16  	CMDCAT
    17  	CMDMD5
    18  	CMDACK
    19  	CMDEND
    20  	CMDSET
    21  	CMDDEL
    22  	CMDDAT
    23  )
    24  
    25  const (
    26  	ACKNONE uint8 = iota<<4 + 3
    27  	ACKSUCC
    28  	ACKDATA
    29  	ACKNULL
    30  	ACKNEQU
    31  	ACKERRO
    32  )
    33  
    34  var (
    35  	ErrMd5Mismatch = errors.New("cmd: md5 mismatch")
    36  )
    37  
    38  type CmdPacket struct {
    39  	io.ReaderFrom
    40  	t    *tea.TEA
    41  	Data []byte
    42  	rawCmdPacket
    43  }
    44  
    45  type rawCmdPacket struct {
    46  	cmd uint8
    47  	len uint8
    48  	md5 [16]byte
    49  	raw [255]byte // raw will expand to len
    50  }
    51  
    52  //go:nosplit
    53  func NewCmdPacket(cmd uint8, data []byte, t *tea.TEA) (c *CmdPacket) {
    54  	c = pool.Get().(*CmdPacket)
    55  	c.t = t
    56  	c.Data = data
    57  	c.cmd = cmd
    58  	c.md5 = md5.Sum(data)
    59  	return
    60  }
    61  
    62  //go:nosplit
    63  func ParseCmdPacket(data []byte, t *tea.TEA) (c *CmdPacket) {
    64  	if len(data) < 1+1+16 {
    65  		return nil
    66  	}
    67  	if len(data)-1-1-16 < int(data[1]) {
    68  		return nil
    69  	}
    70  	r := (*rawCmdPacket)(*(*unsafe.Pointer)(unsafe.Pointer(&data)))
    71  	c = pool.Get().(*CmdPacket)
    72  	c.t = t
    73  	c.cmd = r.cmd
    74  	c.len = r.len
    75  	c.md5 = r.md5
    76  	copy(c.raw[:], data[1+1+16:])
    77  	return c
    78  }
    79  
    80  //go:nosplit
    81  func ReadCmdPacket(f io.Reader, t *tea.TEA) (c *CmdPacket, err error) {
    82  	c = pool.Get().(*CmdPacket)
    83  	buf := (*[1 + 1 + 16 + 255]byte)(unsafe.Pointer(&c.rawCmdPacket))
    84  	_, err = io.ReadFull(f, buf[:1+1+16])
    85  	if err != nil {
    86  		c.Put()
    87  		return nil, err
    88  	}
    89  	_, err = io.ReadFull(f, c.raw[:c.len])
    90  	if err != nil {
    91  		c.Put()
    92  		return nil, err
    93  	}
    94  	return
    95  }
    96  
    97  //go:nosplit
    98  func (c *CmdPacket) Refresh(cmd uint8, data []byte, t *tea.TEA) {
    99  	c.t = t
   100  	c.Data = data
   101  	c.cmd = cmd
   102  	c.md5 = md5.Sum(data)
   103  }
   104  
   105  //go:nosplit
   106  func (c *CmdPacket) ReadFrom(f io.Reader) (n int64, err error) {
   107  	if c.cmd > 0 {
   108  		err = io.EOF
   109  		return
   110  	}
   111  	buf := (*[1 + 1 + 16 + 255]byte)(unsafe.Pointer(&c.rawCmdPacket))
   112  	cnt, err := io.ReadFull(f, buf[:1+1+16])
   113  	if err != nil {
   114  		return int64(cnt), err
   115  	}
   116  	cnt, err = io.ReadFull(f, c.raw[:c.len])
   117  	cnt += 1 + 1 + 16
   118  	if err != nil {
   119  		return int64(cnt), err
   120  	}
   121  	return
   122  }
   123  
   124  // Write should not be used due to the full-copy of buf
   125  func (c *CmdPacket) Write(buf []byte) (n int, err error) {
   126  	oldlen := len(c.Data)
   127  	c.Data = append(c.Data, buf...)
   128  	if len(c.Data) < 1+1+16 {
   129  		return len(buf), nil
   130  	}
   131  	if len(c.Data) < 1+1+16+int(c.len) {
   132  		return len(buf), nil
   133  	}
   134  	r := (*rawCmdPacket)(*(*unsafe.Pointer)(unsafe.Pointer(&c.Data)))
   135  	c.cmd = r.cmd
   136  	c.len = r.len
   137  	c.md5 = r.md5
   138  	copy(c.raw[:], r.raw[:c.len])
   139  	c.Data = c.Data[1+1+16+int(c.len):]
   140  	return 1 + 1 + 16 + int(c.len) - oldlen, nil
   141  }
   142  
   143  //go:nosplit
   144  func (c *CmdPacket) Encrypt(seq uint8) (raw []byte) {
   145  	setseq(c.t, seq)
   146  	c.len = uint8(c.t.EncryptLittleEndianTo(c.Data, sumtable, c.raw[:]))
   147  	(*slice)(unsafe.Pointer(&raw)).data = unsafe.Pointer(&c.rawCmdPacket)
   148  	(*slice)(unsafe.Pointer(&raw)).len = 1 + 1 + 16 + int(c.len)
   149  	(*slice)(unsafe.Pointer(&raw)).cap = 1 + 1 + 16 + 255
   150  	return
   151  }
   152  
   153  //go:nosplit
   154  func (c *CmdPacket) Decrypt(seq uint8) error {
   155  	setseq(c.t, seq)
   156  	d := c.t.DecryptLittleEndian(c.raw[:c.len], sumtable)
   157  	if d != nil && c.md5 == md5.Sum(d) {
   158  		c.Data = d
   159  		return nil
   160  	}
   161  	return ErrMd5Mismatch
   162  }
   163  
   164  //go:nosplit
   165  func (c *CmdPacket) Put() {
   166  	c.cmd = 0
   167  	c.Data = nil
   168  	pool.Put(c)
   169  }
   170  
   171  //go:nosplit
   172  func setseq(t *tea.TEA, seq uint8) {
   173  	*(*uint8)(unsafe.Add(unsafe.Pointer(t), 15)) = seq
   174  }
   175  
   176  // randuint32 returns a lock free uint32 value.
   177  //
   178  //go:linkname randuint32 runtime.fastrand
   179  func randuint32() uint32
   180  
   181  //go:nosplit
   182  func fill() []byte {
   183  	var b [8]byte
   184  	binary.LittleEndian.PutUint32(b[:4], randuint32())
   185  	binary.LittleEndian.PutUint32(b[4:8], randuint32())
   186  	return base14.Encode(b[:7])
   187  }
   188  
   189  // TEA encoding sumtable
   190  var sumtable = [0x10]uint32{
   191  	0x9e3579b9,
   192  	0x3c6ef172,
   193  	0xd2a66d2b,
   194  	0x78dd36e4,
   195  	0x17e5609d,
   196  	0xb54fda56,
   197  	0x5384560f,
   198  	0xf1bb77c8,
   199  	0x8ff24781,
   200  	0x2e4ac13a,
   201  	0xcc653af3,
   202  	0x6a9964ac,
   203  	0x08d12965,
   204  	0xa708081e,
   205  	0x451221d7,
   206  	0xe37793d0,
   207  }