git.lukeshu.com/go/lowmemjson@v0.3.9-0.20230723050957-72f6d13f6fb2/internal/base64dec/base64.go (about)

     1  // Copyright (C) 2022-2023  Luke Shumaker <lukeshu@lukeshu.com>
     2  //
     3  // SPDX-License-Identifier: GPL-2.0-or-later
     4  
     5  package base64dec
     6  
     7  import (
     8  	"encoding/base64"
     9  	"io"
    10  	"strings"
    11  
    12  	"git.lukeshu.com/go/lowmemjson/internal/fastio"
    13  	"git.lukeshu.com/go/lowmemjson/internal/fastio/noescape"
    14  )
    15  
    16  type base64Decoder struct {
    17  	dst io.Writer
    18  
    19  	err    error
    20  	pos    int64
    21  	buf    [4]byte
    22  	bufLen int
    23  }
    24  
    25  func NewBase64Decoder(w io.Writer) interface {
    26  	io.WriteCloser
    27  	fastio.RuneWriter
    28  } {
    29  	return &base64Decoder{
    30  		dst: w,
    31  	}
    32  }
    33  
    34  func (dec *base64Decoder) decodeByte(b byte) (byte, bool) {
    35  	const alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
    36  	n := strings.IndexByte(alphabet, b)
    37  	if n < 0 {
    38  		return 0, false
    39  	}
    40  	dec.pos++
    41  	return byte(n), true
    42  }
    43  
    44  func (dec *base64Decoder) decodeTuple(a, b, c, d byte) error {
    45  	var decodedLen int
    46  	var encoded [4]byte
    47  	var ok bool
    48  
    49  	if a != '=' {
    50  		encoded[0], ok = dec.decodeByte(a)
    51  		if !ok {
    52  			return base64.CorruptInputError(dec.pos)
    53  		}
    54  		decodedLen++
    55  	}
    56  	if b != '=' {
    57  		encoded[1], ok = dec.decodeByte(b)
    58  		if !ok {
    59  			return base64.CorruptInputError(dec.pos)
    60  		}
    61  		// do NOT increment decodedLen here
    62  	}
    63  	if c != '=' {
    64  		encoded[2], ok = dec.decodeByte(c)
    65  		if !ok {
    66  			return base64.CorruptInputError(dec.pos)
    67  		}
    68  		decodedLen++
    69  	}
    70  	if d != '=' {
    71  		encoded[3], ok = dec.decodeByte(d)
    72  		if !ok {
    73  			return base64.CorruptInputError(dec.pos)
    74  		}
    75  		decodedLen++
    76  	}
    77  
    78  	val := 0 |
    79  		uint32(encoded[0])<<18 |
    80  		uint32(encoded[1])<<12 |
    81  		uint32(encoded[2])<<6 |
    82  		uint32(encoded[3])<<0
    83  	var decoded [3]byte
    84  	decoded[0] = byte(val >> 16)
    85  	decoded[1] = byte(val >> 8)
    86  	decoded[2] = byte(val >> 0)
    87  
    88  	_, err := noescape.Write(dec.dst, decoded[:decodedLen])
    89  	return err
    90  }
    91  
    92  func (dec *base64Decoder) Write(dat []byte) (int, error) {
    93  	if len(dat) == 0 {
    94  		return 0, nil
    95  	}
    96  	if dec.err != nil {
    97  		return 0, dec.err
    98  	}
    99  	var n int
   100  	if dec.bufLen > 0 {
   101  		n = copy(dec.buf[dec.bufLen:], dat)
   102  		dec.bufLen += n
   103  		if dec.bufLen < 4 {
   104  			return len(dat), nil
   105  		}
   106  		if err := dec.decodeTuple(dec.buf[0], dec.buf[1], dec.buf[2], dec.buf[3]); err != nil {
   107  			dec.err = err
   108  			return 0, dec.err
   109  		}
   110  	}
   111  	for ; n+3 < len(dat); n += 4 {
   112  		if err := dec.decodeTuple(dat[n], dat[n+1], dat[n+2], dat[n+3]); err != nil {
   113  			dec.err = err
   114  			return n, dec.err
   115  		}
   116  	}
   117  	dec.bufLen = copy(dec.buf[:], dat[n:])
   118  	return len(dat), nil
   119  }
   120  
   121  func (dec *base64Decoder) WriteRune(r rune) (int, error) {
   122  	return fastio.WriteRune(dec, r)
   123  }
   124  
   125  func (dec *base64Decoder) Close() error {
   126  	if dec.bufLen == 0 {
   127  		return nil
   128  	}
   129  	copy(dec.buf[:], "====")
   130  	return dec.decodeTuple(dec.buf[0], dec.buf[1], dec.buf[2], dec.buf[3])
   131  }