github.com/iDigitalFlame/xmt@v0.5.4/data/crypto/cbk.go (about)

     1  // Copyright (C) 2020 - 2023 iDigitalFlame
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU General Public License as published by
     5  // the Free Software Foundation, either version 3 of the License, or
     6  // any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU General Public License
    14  // along with this program.  If not, see <https://www.gnu.org/licenses/>.
    15  //
    16  
    17  package crypto
    18  
    19  import (
    20  	"crypto/rand"
    21  	"io"
    22  	"sync"
    23  
    24  	"github.com/iDigitalFlame/xmt/util/xerr"
    25  )
    26  
    27  const size = 128
    28  
    29  var (
    30  	chains = sync.Pool{
    31  		New: func() interface{} {
    32  			var b [size + 1]byte
    33  			return &b
    34  		},
    35  	}
    36  	tables = sync.Pool{
    37  		New: func() interface{} {
    38  			var b [size + 1][256]byte
    39  			return &b
    40  		},
    41  	}
    42  )
    43  
    44  // CBK is the representation of the CBK Cipher.
    45  // CBK is a block based cipher that allows for a variable size index in encoding.
    46  type CBK struct {
    47  	// Random Source to use for data generation from keys.
    48  	// This source MUST be repeatable.
    49  	Source     source
    50  	buf        []byte
    51  	pos, total int
    52  
    53  	A, B  byte
    54  	C, D  byte
    55  	index uint8
    56  }
    57  type source interface {
    58  	Seed(int64)
    59  	Int63() int64
    60  }
    61  
    62  // NewCBK returns a new CBK Cipher with the D value specified. The other A, B and
    63  // C values are randomly generated at runtime.
    64  func NewCBK(d int) CBK {
    65  	c, _ := NewCBKEx(d, size, nil)
    66  	return c
    67  }
    68  
    69  // Reset resets the encryption keys and sets them to new random bytes.
    70  func (e *CBK) Reset() error {
    71  	if _, err := rand.Read(e.buf[0:3]); err != nil {
    72  		return err
    73  	}
    74  	_ = e.buf[2]
    75  	e.A, e.B, e.C = e.buf[0], e.buf[1], e.buf[2]
    76  	if e.pos, e.index = 0, 0; e.A == 0 {
    77  		e.A = 1
    78  	}
    79  	return nil
    80  }
    81  
    82  // BlockSize returns the cipher's block BlockSize.
    83  func (e *CBK) BlockSize() int {
    84  	return len(e.buf) - 1
    85  }
    86  
    87  // Shuffle will switch around the bytes in the array based on the Cipher bytes.
    88  func (e *CBK) Shuffle(b []byte) {
    89  	if len(b) > 1 {
    90  		b[0] += e.A
    91  	}
    92  	for i := byte(0); i < byte(len(b)); i++ {
    93  		switch {
    94  		case i%e.A == 0:
    95  			b[i] += e.A - i
    96  		case e.C%i == 0:
    97  			b[i] += e.B - e.D
    98  		case i == e.D:
    99  			b[i] -= e.A + i
   100  		default:
   101  			if i%2 == 0 {
   102  				b[i] += e.B / 3
   103  			} else {
   104  				b[i] += e.C / 5
   105  			}
   106  		}
   107  	}
   108  }
   109  
   110  // Deshuffle will reverse the switch around the bytes in the array based on the
   111  // Cipher bytes.
   112  func (e *CBK) Deshuffle(b []byte) {
   113  	if len(b) > 1 {
   114  		b[0] -= e.A
   115  	}
   116  	for i := byte(0); i < byte(len(b)); i++ {
   117  		switch {
   118  		case i%e.A == 0:
   119  			b[i] -= e.A - i
   120  		case e.C%i == 0:
   121  			b[i] -= e.B - e.D
   122  		case i == e.D:
   123  			b[i] += e.A + i
   124  		default:
   125  			if i%2 == 0 {
   126  				b[i] -= e.B / 3
   127  			} else {
   128  				b[i] -= e.C / 5
   129  			}
   130  		}
   131  	}
   132  }
   133  func (e *CBK) adjust(i uint16) uint16 {
   134  	if e.Source != nil {
   135  		return uint16(e.Source.Int63() * int64(i+1))
   136  	}
   137  	if n := ((uint16(e.A) ^ uint16(e.B)) - uint16(e.C)) * (i + 1); n > 1 {
   138  		return n
   139  	}
   140  	return 1
   141  }
   142  
   143  // Encrypt encrypts the first block in src into dst. Dst and src must overlap entirely
   144  // or not at all.
   145  func (e *CBK) Encrypt(dst, src []byte) {
   146  	copy(dst, src)
   147  	e.Shuffle(dst)
   148  	e.scramble(dst, true)
   149  }
   150  
   151  // Decrypt decrypts the first block in src into dst. Dst and src must overlap entirely
   152  // or not at all.
   153  func (e *CBK) Decrypt(dst, src []byte) {
   154  	copy(dst, src)
   155  	e.scramble(dst, false)
   156  	e.Deshuffle(dst)
   157  }
   158  
   159  // Flush pushes the remaining bytes stored into the buffer into the supplies Writer.
   160  func (e *CBK) Flush(w io.Writer) error {
   161  	_, err := e.flushOutput(w)
   162  	return err
   163  }
   164  func (e *CBK) scramble(b []byte, d bool) {
   165  	var (
   166  		o    = chains.Get().(*[size + 1]byte)
   167  		x    = e.adjust(uint16(e.A*e.B) + uint16(e.D))
   168  		y    = e.adjust(uint16((e.C-e.D)*e.A) + x + e.adjust(uint16(e.index)))
   169  		z    = e.adjust(uint16(byte(x*y) + e.B - e.D*e.index))
   170  		i    int8
   171  		g, h byte
   172  	)
   173  	if d {
   174  		i = 5
   175  	}
   176  	for (i < 6 && !d) || (i >= 0 && d) {
   177  		g = (byte(z*y) + e.blockIndex(true, uint16(e.D*e.A)+uint16(i)+x, uint16(e.D)+uint16(e.index))) % 8
   178  		h = (byte(y) - e.blockIndex(false, y+uint16(e.D)+uint16(e.index*uint8(i+1)), uint16(e.D)+x+uint16(byte(uint16(i)*z)*e.A))) % 8
   179  		if g != h {
   180  			if !d {
   181  				b[h], b[g] = (b[g]&0xF)<<4|(b[h]&0xF), (b[g]>>4)<<4|((b[h]>>4)&0xF)
   182  				b[h+1], b[g+1] = (b[g+1]&0xF)<<4|(b[h+1]&0xF), (b[g+1]>>4)<<4|((b[h+1]>>4)&0xF)
   183  			}
   184  			copy((*o)[0:2], b[g*2:(g*2)+2])
   185  			copy(b[g*2:], b[h*2:(h*2)+2])
   186  			copy(b[h*2:], (*o)[0:2])
   187  			if d {
   188  				b[h], b[g] = (b[g]&0xF)<<4|(b[h]&0xF), (b[g]>>4)<<4|((b[h]>>4)&0xF)
   189  				b[h+1], b[g+1] = (b[g+1]&0xF)<<4|(b[h+1]&0xF), (b[g+1]>>4)<<4|((b[h+1]>>4)&0xF)
   190  			}
   191  		}
   192  		if d {
   193  			i--
   194  		} else {
   195  			i++
   196  		}
   197  	}
   198  	clear(o, nil)
   199  }
   200  func (e *CBK) cipherTable(b *[size + 1]byte) {
   201  	(*b)[0] = byte(uint16(e.index+1)*uint16(e.D+1) + e.adjust(uint16(e.D)))
   202  	for i := byte(1); i < byte(len(*b))-1; i++ {
   203  		switch {
   204  		case i <= 6:
   205  			if i%2 == 0 {
   206  				(*b)[i] = byte(uint16(e.index) - uint16(e.A) + uint16(e.B-(i-e.C)) + uint16(i) - e.adjust(uint16(e.A)))
   207  			} else {
   208  				(*b)[i] = byte(uint16(e.index) - uint16(e.A) + uint16(e.B-(i-3)) + uint16(i) - e.adjust(uint16(e.A)))
   209  			}
   210  		case i > 6 && i <= 11:
   211  			(*b)[i] = byte(uint16(e.C) - uint16(e.B) + uint16((e.index+1)*i) + e.adjust(uint16(e.C)))
   212  		case i > 11:
   213  			(*b)[i] = byte(e.adjust(uint16(e.B+e.C)) + uint16(e.D) - uint16(len(*b)-1) - uint16(e.D) + uint16(e.A-e.C))
   214  		}
   215  	}
   216  	(*b)[len(*b)-1] = byte(e.adjust(uint16(e.B+e.C)) + uint16(e.index) - uint16(len(*b)-1) - uint16(e.D) + uint16(e.A-e.C))
   217  }
   218  func (e *CBK) readInput(r io.Reader) (int, error) {
   219  	n, err := io.ReadFull(r, e.buf)
   220  	if n <= 0 {
   221  		if e.total = 0; err == nil {
   222  			return 0, io.EOF
   223  		}
   224  		return 0, err
   225  	}
   226  	if n != len(e.buf) {
   227  		return 0, io.ErrUnexpectedEOF
   228  	}
   229  	if e.index++; e.index > 30 {
   230  		e.index = 0
   231  	}
   232  	var (
   233  		t = chains.Get().(*[size + 1]byte)
   234  		c = tables.Get().(*[size + 1][256]byte)
   235  	)
   236  	e.cipherTable(t)
   237  	e.Deshuffle(e.buf)
   238  	e.scramble(e.buf, true)
   239  	for x := range *c {
   240  		for z := range (*c)[x] {
   241  			(*c)[x][(*t)[x]&0xFF] = byte(z)
   242  			(*t)[x]++
   243  		}
   244  	}
   245  	for i := range e.buf {
   246  		e.buf[i] = (*c)[i&0xF][e.buf[i]&0xFF]
   247  	}
   248  	e.total, e.pos = int(e.buf[len(e.buf)-1]), 0
   249  	if clear(t, c); e.total == 0 {
   250  		return 0, io.EOF
   251  	}
   252  	if e.total > len(e.buf)-1 {
   253  		return n, io.ErrShortBuffer
   254  	}
   255  	return n, err
   256  }
   257  func (e *CBK) blockIndex(a bool, t, i uint16) byte {
   258  	switch v := t % 8; {
   259  	case v == 0 && a:
   260  		return byte((((t+1)*(1+i+uint16(e.A)*t) + t + 5) / 3) + 4 + (5 * t) + (i / 5))
   261  	case v == 1 && a:
   262  		return byte((t / 5) + i + ((i + 1) * 7) + ((1 + t) * 3) + (i / 2) + t)
   263  	case v == 2 && a:
   264  		return byte((((3+t+uint16(e.B+e.C))/4+1)+i)/2 + (3 * t) + (t / 5) + i + 3)
   265  	case v == 3 && a:
   266  		return byte(((t / 2) * 3) + 7 + ((t + i) * 3) - 2 + ((t * (i + 5 + uint16(e.D))) * 3))
   267  	case v == 4 && a:
   268  		return byte((((i*6)+2)/5)*3 + ((4 * i) / 5) + 3 + (t / 4))
   269  	case v == 5 && a:
   270  		return byte((((t*3)/5)+(5+i))*3 + (t * (2 - uint16(e.A*e.D))) + (i / (t + 1)) + (6 + t))
   271  	case v == 6 && a:
   272  		return byte((((((i + 5) / 3) * 7) + 3 + uint16(e.B)) / (t + 1)) + 3 + (t/(i+1))*3)
   273  	case v == 7 && a:
   274  		return byte(((((t / (i + 1) * 2) + 5) / 4) + 10) + (3 * t) + ((i / 2) + (t * 3)) + 4)
   275  	case v == 0 && !a:
   276  		return byte((((3/(2+i) + 3) / (t + 1)) * 9) + 6 - uint16(e.A*e.C) + i)
   277  	case v == 1 && !a:
   278  		return byte(((((4*i)/3 + (t * 2)) / 3) + 8) / 3)
   279  	case v == 2 && !a:
   280  		return byte((((9 + i + uint16(e.A*e.D)) / 4) + (t / 2) + (2*i + 1 + uint16(e.D))) / (((i + 3) / (5 + t)) + 6))
   281  	case v == 3 && !a:
   282  		return byte(((((4+(t-5)/2)/6)+3)*2)*((5+i)/3) + 4)
   283  	case v == 4 && !a:
   284  		return byte((((((t/3)/(3+i) + uint16(e.C)) / 9) * 2) + 8) + (5+i)/(3+t))
   285  	case v == 5 && !a:
   286  		return byte(((i * 4) + (t / 3) - uint16(e.A*byte(1+t)) + (6 / (1 + i))) + (6 / (3 + t)) + (i * 3))
   287  	case v == 6 && !a:
   288  		return byte((((((t*9)/6)+(i*3)/9)*5 + i) - uint16(e.D*byte(i))) + (t+2)/4)
   289  	case v == 7 && !a:
   290  		return byte((((((i/3)*7)+3-uint16(e.B))*5 + t) * (t + 3) / 7) + uint16(e.D*e.B))
   291  	}
   292  	return 0
   293  }
   294  func (e *CBK) flushOutput(w io.Writer) (int, error) {
   295  	if e.pos == 0 {
   296  		return 0, nil
   297  	}
   298  	if e.index++; e.index > 30 {
   299  		e.index = 0
   300  	}
   301  	e.buf[e.total] = byte(e.pos)
   302  	var (
   303  		t = chains.Get().(*[size + 1]byte)
   304  		c = tables.Get().(*[size + 1][256]byte)
   305  	)
   306  	e.cipherTable(t)
   307  	for x := range *c {
   308  		for z := range (*c)[x] {
   309  			(*c)[x][z] = (*t)[x]
   310  			(*t)[x]++
   311  		}
   312  	}
   313  	for i := range e.buf {
   314  		e.buf[i] = (*c)[i&0xF][e.buf[i]&0xFF]
   315  	}
   316  	e.scramble(e.buf, false)
   317  	e.Shuffle(e.buf)
   318  	e.pos = 0
   319  	n, err := w.Write(e.buf)
   320  	clear(t, c)
   321  	return n, err
   322  }
   323  
   324  // NewCBKSource returns a new CBK Cipher with the A, B, C, D, BlockSize values
   325  // specified.
   326  func NewCBKSource(a, b, c, d, sz byte) (CBK, error) {
   327  	switch sz {
   328  	case 0:
   329  		sz = size
   330  	case 16, 32, 64, 128:
   331  	default:
   332  		return CBK{}, xerr.Sub("block size must be a power of two between 16 and 128", 0x28)
   333  	}
   334  	if a == 0 {
   335  		a = 1
   336  	}
   337  	return CBK{A: a, B: b, C: c, D: d, buf: make([]byte, sz+1), total: -1}, nil
   338  }
   339  func clear(b *[size + 1]byte, z *[size + 1][256]byte) {
   340  	for i := range *b {
   341  		(*b)[i] = 0
   342  	}
   343  	if chains.Put(b); z != nil {
   344  		tables.Put(z)
   345  	}
   346  }
   347  
   348  // NewCBKEx returns a new CBK Cipher with the D value, BlockSize and Entropy source
   349  // specified. The other A, B and C values are randomly generated at runtime.
   350  func NewCBKEx(d int, sz int, src source) (CBK, error) {
   351  	switch sz {
   352  	case 0:
   353  		sz = size
   354  	case 16, 32, 64, 128:
   355  	default:
   356  		return CBK{}, xerr.Sub("block size must be a power of two between 16 and 128", 0x28)
   357  	}
   358  	c := CBK{D: byte(d), buf: make([]byte, sz+1), total: -1, Source: src}
   359  	c.Reset()
   360  	return c, nil
   361  }
   362  
   363  // Read reads the contents of the Reader to the byte array after decrypting with
   364  // this Cipher.
   365  func (e *CBK) Read(r io.Reader, b []byte) (int, error) {
   366  	if e.buf == nil {
   367  		e.buf = make([]byte, size+1)
   368  	}
   369  	if e.total-e.pos > len(b) {
   370  		if e.pos+len(b) > len(e.buf) {
   371  			return 0, io.ErrShortBuffer
   372  		}
   373  		u := copy(b, e.buf[e.pos:e.pos+len(b)])
   374  		e.pos += len(b)
   375  		return u, nil
   376  	}
   377  	if e.pos >= e.total {
   378  		if o, err := e.readInput(r); err != nil && (err != io.EOF || o == 0) {
   379  			return o, err
   380  		}
   381  	}
   382  	var n int
   383  	for i := 0; n < len(b) && e.pos < e.total && e.total < len(e.buf); n += i {
   384  		if e.total <= 0 {
   385  			return n, io.EOF
   386  		}
   387  		i = copy(b[n:], e.buf[e.pos:e.total])
   388  		if e.pos += i; e.pos >= e.total && e.total >= len(e.buf)-1 {
   389  			if _, err := e.readInput(r); err != nil && err != io.EOF {
   390  				return n, err
   391  			}
   392  		}
   393  	}
   394  	if e.total > len(e.buf) {
   395  		return n, io.EOF
   396  	}
   397  	return n, nil
   398  }
   399  
   400  // Write writes the contents of the byte array to the Writer after encrypting with
   401  // this Cipher.
   402  func (e *CBK) Write(w io.Writer, b []byte) (int, error) {
   403  	if e.buf == nil {
   404  		e.buf = make([]byte, size+1)
   405  	} else if e.total == -1 {
   406  		e.total = len(e.buf) - 1
   407  	}
   408  	var n, i int
   409  	for n < len(b) {
   410  		if e.pos >= e.total {
   411  			if _, err := e.flushOutput(w); err != nil {
   412  				return n, err
   413  			}
   414  		}
   415  		i = copy(e.buf[e.pos:e.total], b[n:])
   416  		e.pos += i
   417  		n += i
   418  	}
   419  	if e.pos < e.total {
   420  		return n, nil
   421  	}
   422  	o, err := e.flushOutput(w)
   423  	if o < e.total {
   424  		return n - (e.total - o), err
   425  	}
   426  	return n, err
   427  }