github.com/fumiama/go-registry@v0.2.7/reg.go (about)

     1  package registry
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/md5"
     6  	"errors"
     7  	"io"
     8  	"net"
     9  	"os"
    10  	"strconv"
    11  	"sync"
    12  	"time"
    13  
    14  	spb "github.com/fumiama/go-simple-protobuf"
    15  	tea "github.com/fumiama/gofastTEA"
    16  )
    17  
    18  var (
    19  	ErrGetKeyTooLong    = errors.New("reg: get key too long")
    20  	ErrDecAck           = errors.New("reg: decrypt ack error")
    21  	ErrInternalServer   = errors.New("reg: internal server error")
    22  	ErrPermissionDenied = errors.New("reg: permission denied")
    23  	ErrSetKeyTooLong    = errors.New("reg: set key too long")
    24  	ErrSetValTooLong    = errors.New("reg: set val too long")
    25  	ErrUnknownAck       = errors.New("reg: unknown ack error")
    26  	ErrNoSuchKey        = errors.New("reg: no such key")
    27  	ErrRawDataTooLong   = errors.New("reg: raw data too long")
    28  	ErrMd5NotEqual      = errors.New("reg: md5 not equal")
    29  	ErrInvalidCatData   = errors.New("reg: invalid cat data")
    30  	ErrNilStorData      = errors.New("reg: nil stor data")
    31  )
    32  
    33  type Regedit struct {
    34  	mu   sync.Mutex
    35  	conn net.Conn
    36  	addr string
    37  	stor string
    38  	tp   tea.TEA
    39  	ts   *tea.TEA
    40  	seq  byte
    41  }
    42  
    43  func NewRegedit(addr, stor, pwd, sps string) *Regedit {
    44  	var tp, ts [16]byte
    45  	if len(pwd) > 15 {
    46  		pwd = pwd[:15]
    47  	}
    48  	if len(sps) > 15 {
    49  		sps = sps[:15]
    50  	}
    51  	copy(tp[:], pwd)
    52  	copy(ts[:], sps)
    53  	s := tea.NewTeaCipherLittleEndian(ts[:])
    54  	if stor != "" {
    55  		f, err := os.Open(stor)
    56  		if err != nil {
    57  			f, err = os.Create(stor)
    58  			if err != nil {
    59  				panic(err)
    60  			}
    61  		}
    62  		_ = f.Close()
    63  	}
    64  	return &Regedit{addr: addr, stor: stor, tp: tea.NewTeaCipherLittleEndian(tp[:]), ts: &s}
    65  }
    66  
    67  func NewRegReader(addr, stor, pwd string) *Regedit {
    68  	var tp [16]byte
    69  	if len(pwd) > 15 {
    70  		pwd = pwd[:15]
    71  	}
    72  	copy(tp[:], pwd)
    73  	if stor != "" {
    74  		f, err := os.Open(stor)
    75  		if err != nil {
    76  			f, err = os.Create(stor)
    77  			if err != nil {
    78  				panic(err)
    79  			}
    80  		}
    81  		_ = f.Close()
    82  	}
    83  	return &Regedit{addr: addr, stor: stor, tp: tea.NewTeaCipherLittleEndian(tp[:])}
    84  }
    85  
    86  func (r *Regedit) Connect() (err error) {
    87  	r.mu.Lock()
    88  	if r.conn == nil {
    89  		r.conn, err = net.Dial("tcp", r.addr)
    90  	}
    91  	r.mu.Unlock()
    92  	return
    93  }
    94  
    95  func (r *Regedit) ConnectIn(timeout time.Duration) (err error) {
    96  	r.mu.Lock()
    97  	if r.conn == nil {
    98  		r.conn, err = net.DialTimeout("tcp", r.addr, timeout)
    99  	}
   100  	r.mu.Unlock()
   101  	return
   102  }
   103  
   104  func (r *Regedit) Close() (err error) {
   105  	r.mu.Lock()
   106  	defer r.mu.Unlock()
   107  	if r.conn != nil {
   108  		p := NewCmdPacket(CMDEND, fill(), &r.tp)
   109  		r.conn.Write(p.Encrypt(r.seq))
   110  		p.Put()
   111  		r.seq = 0
   112  		err = r.conn.Close()
   113  		r.conn = nil
   114  		return
   115  	}
   116  	return
   117  }
   118  
   119  func (r *Regedit) Get(key string) (string, error) {
   120  	if len(key) > 127 {
   121  		return "", ErrGetKeyTooLong
   122  	}
   123  	p := NewCmdPacket(CMDGET, StringToBytes(key), &r.tp)
   124  	defer p.Put()
   125  	r.mu.Lock()
   126  	r.conn.Write(p.Encrypt(r.seq))
   127  	r.seq++
   128  	err := r.ack(p)
   129  	if err != nil {
   130  		r.mu.Unlock()
   131  		return "", err
   132  	}
   133  	err = p.Decrypt(r.seq)
   134  	r.seq++
   135  	r.mu.Unlock()
   136  	if err != nil {
   137  		return "", ErrDecAck
   138  	}
   139  	a := string(p.Data)
   140  	if a == "erro" && p.cmd == ACKERRO {
   141  		return "", ErrInternalServer
   142  	}
   143  	if a == "null" && p.cmd == ACKNULL {
   144  		return "", ErrNoSuchKey
   145  	}
   146  	return a, nil
   147  }
   148  
   149  func (r *Regedit) Cat() (*Storage, error) {
   150  	p := NewCmdPacket(CMDCAT, fill(), &r.tp)
   151  	defer p.Put()
   152  	r.mu.Lock()
   153  	r.conn.Write(p.Encrypt(r.seq))
   154  	r.seq++
   155  	seq := r.seq
   156  	r.seq++
   157  	r.mu.Unlock()
   158  	var buf [64]byte
   159  	i := 0
   160  	for {
   161  		_, err := r.conn.Read(buf[i : i+1])
   162  		if err != nil {
   163  			return nil, err
   164  		}
   165  		if buf[i] == '$' {
   166  			break
   167  		}
   168  		i++
   169  		if i >= 64 {
   170  			return nil, ErrRawDataTooLong
   171  		}
   172  	}
   173  	n, err := strconv.ParseUint(BytesToString(buf[:i]), 10, 64)
   174  	if err != nil {
   175  		return nil, err
   176  	}
   177  	data := make([]byte, n)
   178  	_, err = io.ReadFull(r.conn, data)
   179  	if err != nil {
   180  		return nil, err
   181  	}
   182  	setseq(&r.tp, seq)
   183  	data = r.tp.DecryptLittleEndian(data, sumtable)
   184  	s := new(Storage)
   185  	s.m = make(map[string]string, 256)
   186  	s.Md5 = md5.Sum(data)
   187  	rd := bytes.NewReader(data)
   188  	for i = 0; i < len(data); {
   189  		sp, err := spb.NewSimplePB(rd)
   190  		if err == io.EOF {
   191  			break
   192  		}
   193  		if err != nil {
   194  			return nil, err
   195  		}
   196  		if len(sp.Target) <= 1 {
   197  			return nil, ErrInvalidCatData
   198  		}
   199  		s.m[BytesToString(sp.Target[0])] = BytesToString(sp.Target[1])
   200  		i += int(sp.RealLen)
   201  	}
   202  	f, err := os.Create(r.stor)
   203  	if err != nil {
   204  		return s, err
   205  	}
   206  	defer f.Close()
   207  	_, err = f.Write(data)
   208  	return s, err
   209  }
   210  
   211  func (r *Regedit) Load() (*Storage, error) {
   212  	data, err := os.ReadFile(r.stor)
   213  	if err != nil {
   214  		return nil, err
   215  	}
   216  	if len(data) == 0 {
   217  		return nil, ErrNilStorData
   218  	}
   219  	s := new(Storage)
   220  	s.m = make(map[string]string, 256)
   221  	s.Md5 = md5.Sum(data)
   222  	rd := bytes.NewReader(data)
   223  	for i := 0; i < len(data); {
   224  		sp, err := spb.NewSimplePB(rd)
   225  		if err == io.EOF {
   226  			break
   227  		}
   228  		if err != nil {
   229  			return nil, err
   230  		}
   231  		if len(sp.Target) <= 1 {
   232  			return nil, ErrInvalidCatData
   233  		}
   234  		s.m[BytesToString(sp.Target[0])] = BytesToString(sp.Target[1])
   235  		i += int(sp.RealLen)
   236  	}
   237  	f, err := os.Create(r.stor)
   238  	if err != nil {
   239  		return s, err
   240  	}
   241  	defer f.Close()
   242  	_, err = f.Write(data)
   243  	return s, err
   244  }
   245  
   246  func (r *Regedit) IsMd5Equal(m [md5.Size]byte) (bool, error) {
   247  	p := NewCmdPacket(CMDMD5, m[:], &r.tp)
   248  	defer p.Put()
   249  	r.mu.Lock()
   250  	r.conn.Write(p.Encrypt(r.seq))
   251  	r.seq++
   252  	err := r.ack(p)
   253  	if err != nil {
   254  		r.mu.Unlock()
   255  		return false, err
   256  	}
   257  	err = p.Decrypt(r.seq)
   258  	r.seq++
   259  	r.mu.Unlock()
   260  	if err != nil {
   261  		return false, ErrDecAck
   262  	}
   263  	a := string(p.Data)
   264  	if a == "erro" && p.cmd == ACKERRO {
   265  		return false, ErrInternalServer
   266  	}
   267  	if a == "nequ" && p.cmd == ACKNEQU {
   268  		return false, ErrNoSuchKey
   269  	}
   270  	if a == "null" && p.cmd == ACKNULL {
   271  		return true, nil
   272  	}
   273  	return false, ErrUnknownAck
   274  }
   275  
   276  func (r *Regedit) Set(key, value string) error {
   277  	if r.ts == nil {
   278  		return ErrPermissionDenied
   279  	}
   280  	if len(key) > 127 {
   281  		return ErrSetKeyTooLong
   282  	}
   283  	if len(value) > 127 {
   284  		return ErrSetValTooLong
   285  	}
   286  	p := NewCmdPacket(CMDSET, StringToBytes(key), r.ts)
   287  	defer p.Put()
   288  	r.mu.Lock()
   289  	defer r.mu.Unlock()
   290  	r.conn.Write(p.Encrypt(r.seq))
   291  	r.seq++
   292  	ack := NewCmdPacket(CMDACK, nil, &r.tp)
   293  	defer ack.Put()
   294  	err := r.ack(ack)
   295  	if err != nil {
   296  		return err
   297  	}
   298  	err = ack.Decrypt(r.seq)
   299  	r.seq++
   300  	if err != nil {
   301  		return ErrDecAck
   302  	}
   303  	a := BytesToString(ack.Data)
   304  	if a == "erro" || ack.cmd == ACKERRO {
   305  		return ErrInternalServer
   306  	}
   307  	if a != "data" && ack.cmd != ACKDATA {
   308  		return ErrUnknownAck
   309  	}
   310  	p.Refresh(CMDDAT, StringToBytes(value), r.ts)
   311  	r.conn.Write(p.Encrypt(r.seq))
   312  	r.seq++
   313  	err = r.ack(ack)
   314  	if err != nil {
   315  		return err
   316  	}
   317  	err = ack.Decrypt(r.seq)
   318  	r.seq++
   319  	if err != nil {
   320  		return ErrDecAck
   321  	}
   322  	a = BytesToString(ack.Data)
   323  	if a == "erro" || ack.cmd == ACKERRO {
   324  		return ErrInternalServer
   325  	}
   326  	if a != "succ" && ack.cmd != ACKSUCC {
   327  		return ErrUnknownAck
   328  	}
   329  	return nil
   330  }
   331  
   332  func (r *Regedit) Del(key string) error {
   333  	if r.ts == nil {
   334  		return ErrPermissionDenied
   335  	}
   336  	if len(key) > 127 {
   337  		return ErrGetKeyTooLong
   338  	}
   339  	p := NewCmdPacket(CMDDEL, StringToBytes(key), r.ts)
   340  	defer p.Put()
   341  	r.mu.Lock()
   342  	r.conn.Write(p.Encrypt(r.seq))
   343  	r.seq++
   344  	ack := NewCmdPacket(CMDACK, nil, &r.tp)
   345  	defer ack.Put()
   346  	err := r.ack(ack)
   347  	if err != nil {
   348  		r.mu.Unlock()
   349  		return err
   350  	}
   351  	err = ack.Decrypt(r.seq)
   352  	r.seq++
   353  	r.mu.Unlock()
   354  	if err != nil {
   355  		return ErrDecAck
   356  	}
   357  	a := BytesToString(ack.Data)
   358  	if a == "erro" || ack.cmd == ACKERRO {
   359  		return ErrInternalServer
   360  	}
   361  	if a == "null" || ack.cmd == ACKNULL {
   362  		return ErrNoSuchKey
   363  	}
   364  	if a != "succ" && ack.cmd != ACKSUCC {
   365  		return ErrUnknownAck
   366  	}
   367  	return nil
   368  }
   369  
   370  func (r *Regedit) ack(c *CmdPacket) error {
   371  	c.cmd = 0
   372  	_, err := io.Copy(c, r.conn)
   373  	return err
   374  }