github.com/ncruces/go-sqlite3@v0.15.1-0.20240520133447-53eef1510ff0/vfs/adiantum/hbsh.go (about)

     1  package adiantum
     2  
     3  import (
     4  	"encoding/binary"
     5  	"encoding/hex"
     6  	"io"
     7  
     8  	"github.com/ncruces/go-sqlite3"
     9  	"github.com/ncruces/go-sqlite3/internal/util"
    10  	"github.com/ncruces/go-sqlite3/vfs"
    11  	"lukechampine.com/adiantum/hbsh"
    12  )
    13  
    14  type hbshVFS struct {
    15  	vfs.VFS
    16  	hbsh HBSHCreator
    17  }
    18  
    19  func (h *hbshVFS) Open(name string, flags vfs.OpenFlag) (vfs.File, vfs.OpenFlag, error) {
    20  	return nil, 0, sqlite3.CANTOPEN
    21  }
    22  
    23  func (h *hbshVFS) OpenFilename(name *vfs.Filename, flags vfs.OpenFlag) (file vfs.File, _ vfs.OpenFlag, err error) {
    24  	if hf, ok := h.VFS.(vfs.VFSFilename); ok {
    25  		file, flags, err = hf.OpenFilename(name, flags)
    26  	} else {
    27  		file, flags, err = h.VFS.Open(name.String(), flags)
    28  	}
    29  
    30  	// Encrypt everything except super journals and memory files.
    31  	if err != nil || flags&(vfs.OPEN_SUPER_JOURNAL|vfs.OPEN_MEMORY) != 0 {
    32  		return file, flags, err
    33  	}
    34  
    35  	var hbsh *hbsh.HBSH
    36  	if f, ok := name.DatabaseFile().(*hbshFile); ok {
    37  		hbsh = f.hbsh
    38  	} else {
    39  		var key []byte
    40  		if params := name.URIParameters(); name == nil {
    41  			key = h.hbsh.KDF("") // Temporary files get a random key.
    42  		} else if t, ok := params["key"]; ok {
    43  			key = []byte(t[0])
    44  		} else if t, ok := params["hexkey"]; ok {
    45  			key, _ = hex.DecodeString(t[0])
    46  		} else if t, ok := params["textkey"]; ok {
    47  			key = h.hbsh.KDF(t[0])
    48  		} else if flags&vfs.OPEN_MAIN_DB != 0 {
    49  			// Main datatabases may have their key specified as a PRAGMA.
    50  			return &hbshFile{File: file, reset: h.hbsh}, flags, nil
    51  		}
    52  		hbsh = h.hbsh.HBSH(key)
    53  	}
    54  
    55  	if hbsh == nil {
    56  		return nil, flags, sqlite3.CANTOPEN
    57  	}
    58  	return &hbshFile{File: file, hbsh: hbsh, reset: h.hbsh}, flags, nil
    59  }
    60  
    61  const (
    62  	tweakSize = 8
    63  	blockSize = 4096
    64  )
    65  
    66  type hbshFile struct {
    67  	vfs.File
    68  	hbsh  *hbsh.HBSH
    69  	reset HBSHCreator
    70  	tweak [tweakSize]byte
    71  	block [blockSize]byte
    72  }
    73  
    74  func (h *hbshFile) Pragma(name string, value string) (string, error) {
    75  	var key []byte
    76  	switch name {
    77  	case "key":
    78  		key = []byte(value)
    79  	case "hexkey":
    80  		key, _ = hex.DecodeString(value)
    81  	case "textkey":
    82  		key = h.reset.KDF(value)
    83  	default:
    84  		if f, ok := h.File.(vfs.FilePragma); ok {
    85  			return f.Pragma(name, value)
    86  		}
    87  		return "", sqlite3.NOTFOUND
    88  	}
    89  
    90  	if h.hbsh = h.reset.HBSH(key); h.hbsh != nil {
    91  		return "ok", nil
    92  	}
    93  	return "", sqlite3.CANTOPEN
    94  }
    95  
    96  func (h *hbshFile) ReadAt(p []byte, off int64) (n int, err error) {
    97  	if h.hbsh == nil {
    98  		// Only OPEN_MAIN_DB can have a missing key.
    99  		if off == 0 && len(p) == 100 {
   100  			// SQLite is trying to read the header of a database file.
   101  			// Pretend the file is empty so the key may specified as a PRAGMA.
   102  			return 0, io.EOF
   103  		}
   104  		return 0, sqlite3.CANTOPEN
   105  	}
   106  
   107  	min := (off) &^ (blockSize - 1)                                   // round down
   108  	max := (off + int64(len(p)) + (blockSize - 1)) &^ (blockSize - 1) // round up
   109  
   110  	// Read one block at a time.
   111  	for ; min < max; min += blockSize {
   112  		m, err := h.File.ReadAt(h.block[:], min)
   113  		if m != blockSize {
   114  			return n, err
   115  		}
   116  
   117  		binary.LittleEndian.PutUint64(h.tweak[:], uint64(min))
   118  		data := h.hbsh.Decrypt(h.block[:], h.tweak[:])
   119  
   120  		if off > min {
   121  			data = data[off-min:]
   122  		}
   123  		n += copy(p[n:], data)
   124  	}
   125  
   126  	if n != len(p) {
   127  		panic(util.AssertErr())
   128  	}
   129  	return n, nil
   130  }
   131  
   132  func (h *hbshFile) WriteAt(p []byte, off int64) (n int, err error) {
   133  	if h.hbsh == nil {
   134  		return 0, sqlite3.READONLY
   135  	}
   136  
   137  	min := (off) &^ (blockSize - 1)                                   // round down
   138  	max := (off + int64(len(p)) + (blockSize - 1)) &^ (blockSize - 1) // round up
   139  
   140  	// Write one block at a time.
   141  	for ; min < max; min += blockSize {
   142  		binary.LittleEndian.PutUint64(h.tweak[:], uint64(min))
   143  		data := h.block[:]
   144  
   145  		if off > min || len(p[n:]) < blockSize {
   146  			// Partial block write: read-update-write.
   147  			m, err := h.File.ReadAt(h.block[:], min)
   148  			if m != blockSize {
   149  				if err != io.EOF {
   150  					return n, err
   151  				}
   152  				// Writing past the EOF.
   153  				// We're either appending an entirely new block,
   154  				// or the final block was only partially written.
   155  				// A partially written block can't be decrypted,
   156  				// and is as good as corrupt.
   157  				// Either way, zero pad the file to the next block size.
   158  				clear(data)
   159  			} else {
   160  				data = h.hbsh.Decrypt(h.block[:], h.tweak[:])
   161  			}
   162  			if off > min {
   163  				data = data[off-min:]
   164  			}
   165  		}
   166  
   167  		t := copy(data, p[n:])
   168  		h.hbsh.Encrypt(h.block[:], h.tweak[:])
   169  
   170  		m, err := h.File.WriteAt(h.block[:], min)
   171  		if m != blockSize {
   172  			return n, err
   173  		}
   174  		n += t
   175  	}
   176  
   177  	if n != len(p) {
   178  		panic(util.AssertErr())
   179  	}
   180  	return n, nil
   181  }
   182  
   183  func (h *hbshFile) Truncate(size int64) error {
   184  	size = (size + (blockSize - 1)) &^ (blockSize - 1) // round up
   185  	return h.File.Truncate(size)
   186  }
   187  
   188  func (h *hbshFile) SectorSize() int {
   189  	return lcm(h.File.SectorSize(), blockSize)
   190  }
   191  
   192  func (h *hbshFile) DeviceCharacteristics() vfs.DeviceCharacteristic {
   193  	return h.File.DeviceCharacteristics() & (0 |
   194  		// The only safe flags are these:
   195  		vfs.IOCAP_UNDELETABLE_WHEN_OPEN |
   196  		vfs.IOCAP_IMMUTABLE |
   197  		vfs.IOCAP_BATCH_ATOMIC)
   198  }
   199  
   200  // Wrap optional methods.
   201  
   202  func (h *hbshFile) SharedMemory() vfs.SharedMemory {
   203  	if f, ok := h.File.(vfs.FileSharedMemory); ok {
   204  		return f.SharedMemory()
   205  	}
   206  	return nil
   207  }
   208  
   209  func (h *hbshFile) ChunkSize(size int) {
   210  	if f, ok := h.File.(vfs.FileChunkSize); ok {
   211  		size = (size + (blockSize - 1)) &^ (blockSize - 1) // round up
   212  		f.ChunkSize(size)
   213  	}
   214  }
   215  
   216  func (h *hbshFile) SizeHint(size int64) error {
   217  	if f, ok := h.File.(vfs.FileSizeHint); ok {
   218  		size = (size + (blockSize - 1)) &^ (blockSize - 1) // round up
   219  		return f.SizeHint(size)
   220  	}
   221  	return sqlite3.NOTFOUND
   222  }
   223  
   224  func (h *hbshFile) HasMoved() (bool, error) {
   225  	if f, ok := h.File.(vfs.FileHasMoved); ok {
   226  		return f.HasMoved()
   227  	}
   228  	return false, sqlite3.NOTFOUND
   229  }
   230  
   231  func (h *hbshFile) Overwrite() error {
   232  	if f, ok := h.File.(vfs.FileOverwrite); ok {
   233  		return f.Overwrite()
   234  	}
   235  	return sqlite3.NOTFOUND
   236  }
   237  
   238  func (h *hbshFile) CommitPhaseTwo() error {
   239  	if f, ok := h.File.(vfs.FileCommitPhaseTwo); ok {
   240  		return f.CommitPhaseTwo()
   241  	}
   242  	return sqlite3.NOTFOUND
   243  }
   244  
   245  func (h *hbshFile) BeginAtomicWrite() error {
   246  	if f, ok := h.File.(vfs.FileBatchAtomicWrite); ok {
   247  		return f.BeginAtomicWrite()
   248  	}
   249  	return sqlite3.NOTFOUND
   250  }
   251  
   252  func (h *hbshFile) CommitAtomicWrite() error {
   253  	if f, ok := h.File.(vfs.FileBatchAtomicWrite); ok {
   254  		return f.CommitAtomicWrite()
   255  	}
   256  	return sqlite3.NOTFOUND
   257  }
   258  
   259  func (h *hbshFile) RollbackAtomicWrite() error {
   260  	if f, ok := h.File.(vfs.FileBatchAtomicWrite); ok {
   261  		return f.RollbackAtomicWrite()
   262  	}
   263  	return sqlite3.NOTFOUND
   264  }
   265  
   266  func (h *hbshFile) CheckpointDone() error {
   267  	if f, ok := h.File.(vfs.FileCheckpoint); ok {
   268  		return f.CheckpointDone()
   269  	}
   270  	return sqlite3.NOTFOUND
   271  }
   272  
   273  func (h *hbshFile) CheckpointStart() error {
   274  	if f, ok := h.File.(vfs.FileCheckpoint); ok {
   275  		return f.CheckpointStart()
   276  	}
   277  	return sqlite3.NOTFOUND
   278  }