github.com/ncruces/go-sqlite3@v0.15.1-0.20240520133447-53eef1510ff0/vfs/memdb/memdb.go (about)

     1  package memdb
     2  
     3  import (
     4  	"io"
     5  	"runtime"
     6  	"sync"
     7  	"time"
     8  
     9  	"github.com/ncruces/go-sqlite3"
    10  	"github.com/ncruces/go-sqlite3/vfs"
    11  )
    12  
    13  // Must be a multiple of 64K (the largest page size).
    14  const sectorSize = 65536
    15  
    16  type memVFS struct{}
    17  
    18  func (memVFS) Open(name string, flags vfs.OpenFlag) (vfs.File, vfs.OpenFlag, error) {
    19  	// For simplicity, we do not support reading or writing data
    20  	// across "sector" boundaries.
    21  	//
    22  	// This is not a problem for most SQLite file types:
    23  	// - databases, which only do page aligned reads/writes;
    24  	// - temp journals, as used by the sorter, which does the same:
    25  	//   https://github.com/sqlite/sqlite/blob/b74eb0/src/vdbesort.c#L409-L412
    26  	//
    27  	// We refuse to open all other file types,
    28  	// but returning OPEN_MEMORY means SQLite won't ask us to.
    29  	const types = vfs.OPEN_MAIN_DB |
    30  		vfs.OPEN_TEMP_DB |
    31  		vfs.OPEN_TEMP_JOURNAL
    32  	if flags&types == 0 {
    33  		return nil, flags, sqlite3.CANTOPEN
    34  	}
    35  
    36  	// A shared database has a name that begins with "/".
    37  	shared := len(name) > 1 && name[0] == '/'
    38  
    39  	var db *memDB
    40  	if shared {
    41  		name = name[1:]
    42  		memoryMtx.Lock()
    43  		defer memoryMtx.Unlock()
    44  		db = memoryDBs[name]
    45  	}
    46  	if db == nil {
    47  		if flags&vfs.OPEN_CREATE == 0 {
    48  			return nil, flags, sqlite3.CANTOPEN
    49  		}
    50  		db = &memDB{name: name}
    51  	}
    52  	if shared {
    53  		db.refs++ // +checklocksforce: memoryMtx is held
    54  		memoryDBs[name] = db
    55  	}
    56  
    57  	return &memFile{
    58  		memDB:    db,
    59  		readOnly: flags&vfs.OPEN_READONLY != 0,
    60  	}, flags | vfs.OPEN_MEMORY, nil
    61  }
    62  
    63  func (memVFS) Delete(name string, dirSync bool) error {
    64  	return sqlite3.IOERR_DELETE
    65  }
    66  
    67  func (memVFS) Access(name string, flag vfs.AccessFlag) (bool, error) {
    68  	return false, nil
    69  }
    70  
    71  func (memVFS) FullPathname(name string) (string, error) {
    72  	return name, nil
    73  }
    74  
    75  type memDB struct {
    76  	name string
    77  
    78  	// +checklocks:lockMtx
    79  	pending *memFile
    80  	// +checklocks:lockMtx
    81  	reserved *memFile
    82  
    83  	// +checklocks:dataMtx
    84  	data []*[sectorSize]byte
    85  
    86  	// +checklocks:dataMtx
    87  	size int64
    88  
    89  	// +checklocks:lockMtx
    90  	shared int
    91  
    92  	// +checklocks:memoryMtx
    93  	refs int
    94  
    95  	lockMtx sync.Mutex
    96  	dataMtx sync.RWMutex
    97  }
    98  
    99  func (m *memDB) release() {
   100  	memoryMtx.Lock()
   101  	defer memoryMtx.Unlock()
   102  	if m.refs--; m.refs == 0 && m == memoryDBs[m.name] {
   103  		delete(memoryDBs, m.name)
   104  	}
   105  }
   106  
   107  type memFile struct {
   108  	*memDB
   109  	lock     vfs.LockLevel
   110  	readOnly bool
   111  }
   112  
   113  var (
   114  	// Ensure these interfaces are implemented:
   115  	_ vfs.FileLockState = &memFile{}
   116  	_ vfs.FileSizeHint  = &memFile{}
   117  )
   118  
   119  func (m *memFile) Close() error {
   120  	m.release()
   121  	return m.Unlock(vfs.LOCK_NONE)
   122  }
   123  
   124  func (m *memFile) ReadAt(b []byte, off int64) (n int, err error) {
   125  	m.dataMtx.RLock()
   126  	defer m.dataMtx.RUnlock()
   127  
   128  	if off >= m.size {
   129  		return 0, io.EOF
   130  	}
   131  
   132  	base := off / sectorSize
   133  	rest := off % sectorSize
   134  	have := int64(sectorSize)
   135  	if base == int64(len(m.data))-1 {
   136  		have = modRoundUp(m.size, sectorSize)
   137  	}
   138  	n = copy(b, (*m.data[base])[rest:have])
   139  	if n < len(b) {
   140  		// Assume reads are page aligned.
   141  		return 0, io.ErrNoProgress
   142  	}
   143  	return n, nil
   144  }
   145  
   146  func (m *memFile) WriteAt(b []byte, off int64) (n int, err error) {
   147  	m.dataMtx.Lock()
   148  	defer m.dataMtx.Unlock()
   149  
   150  	base := off / sectorSize
   151  	rest := off % sectorSize
   152  	for base >= int64(len(m.data)) {
   153  		m.data = append(m.data, new([sectorSize]byte))
   154  	}
   155  	n = copy((*m.data[base])[rest:], b)
   156  	if n < len(b) {
   157  		// Assume writes are page aligned.
   158  		return n, io.ErrShortWrite
   159  	}
   160  	if size := off + int64(len(b)); size > m.size {
   161  		m.size = size
   162  	}
   163  	return n, nil
   164  }
   165  
   166  func (m *memFile) Truncate(size int64) error {
   167  	m.dataMtx.Lock()
   168  	defer m.dataMtx.Unlock()
   169  	return m.truncate(size)
   170  }
   171  
   172  // +checklocks:m.dataMtx
   173  func (m *memFile) truncate(size int64) error {
   174  	if size < m.size {
   175  		base := size / sectorSize
   176  		rest := size % sectorSize
   177  		if rest != 0 {
   178  			clear((*m.data[base])[rest:])
   179  		}
   180  	}
   181  	sectors := divRoundUp(size, sectorSize)
   182  	for sectors > int64(len(m.data)) {
   183  		m.data = append(m.data, new([sectorSize]byte))
   184  	}
   185  	clear(m.data[sectors:])
   186  	m.data = m.data[:sectors]
   187  	m.size = size
   188  	return nil
   189  }
   190  
   191  func (m *memFile) Sync(flag vfs.SyncFlag) error {
   192  	return nil
   193  }
   194  
   195  func (m *memFile) Size() (int64, error) {
   196  	m.dataMtx.RLock()
   197  	defer m.dataMtx.RUnlock()
   198  	return m.size, nil
   199  }
   200  
   201  const spinWait = 25 * time.Microsecond
   202  
   203  func (m *memFile) Lock(lock vfs.LockLevel) error {
   204  	if m.lock >= lock {
   205  		return nil
   206  	}
   207  
   208  	if m.readOnly && lock >= vfs.LOCK_RESERVED {
   209  		return sqlite3.IOERR_LOCK
   210  	}
   211  
   212  	m.lockMtx.Lock()
   213  	defer m.lockMtx.Unlock()
   214  
   215  	switch lock {
   216  	case vfs.LOCK_SHARED:
   217  		if m.pending != nil {
   218  			return sqlite3.BUSY
   219  		}
   220  		m.shared++
   221  
   222  	case vfs.LOCK_RESERVED:
   223  		if m.reserved != nil {
   224  			return sqlite3.BUSY
   225  		}
   226  		m.reserved = m
   227  
   228  	case vfs.LOCK_EXCLUSIVE:
   229  		if m.lock < vfs.LOCK_PENDING {
   230  			if m.pending != nil {
   231  				return sqlite3.BUSY
   232  			}
   233  			m.lock = vfs.LOCK_PENDING
   234  			m.pending = m
   235  		}
   236  
   237  		for before := time.Now(); m.shared > 1; {
   238  			if time.Since(before) > spinWait {
   239  				return sqlite3.BUSY
   240  			}
   241  			m.lockMtx.Unlock()
   242  			runtime.Gosched()
   243  			m.lockMtx.Lock()
   244  		}
   245  	}
   246  
   247  	m.lock = lock
   248  	return nil
   249  }
   250  
   251  func (m *memFile) Unlock(lock vfs.LockLevel) error {
   252  	if m.lock <= lock {
   253  		return nil
   254  	}
   255  
   256  	m.lockMtx.Lock()
   257  	defer m.lockMtx.Unlock()
   258  
   259  	if m.pending == m {
   260  		m.pending = nil
   261  	}
   262  	if m.reserved == m {
   263  		m.reserved = nil
   264  	}
   265  	if lock < vfs.LOCK_SHARED {
   266  		m.shared--
   267  	}
   268  	m.lock = lock
   269  	return nil
   270  }
   271  
   272  func (m *memFile) CheckReservedLock() (bool, error) {
   273  	if m.lock >= vfs.LOCK_RESERVED {
   274  		return true, nil
   275  	}
   276  	m.lockMtx.Lock()
   277  	defer m.lockMtx.Unlock()
   278  	return m.reserved != nil, nil
   279  }
   280  
   281  func (m *memFile) SectorSize() int {
   282  	return sectorSize
   283  }
   284  
   285  func (m *memFile) DeviceCharacteristics() vfs.DeviceCharacteristic {
   286  	return vfs.IOCAP_ATOMIC |
   287  		vfs.IOCAP_SEQUENTIAL |
   288  		vfs.IOCAP_SAFE_APPEND |
   289  		vfs.IOCAP_POWERSAFE_OVERWRITE
   290  }
   291  
   292  func (m *memFile) SizeHint(size int64) error {
   293  	m.dataMtx.Lock()
   294  	defer m.dataMtx.Unlock()
   295  	if size > m.size {
   296  		return m.truncate(size)
   297  	}
   298  	return nil
   299  }
   300  
   301  func (m *memFile) LockState() vfs.LockLevel {
   302  	return m.lock
   303  }
   304  
   305  func divRoundUp(a, b int64) int64 {
   306  	return (a + b - 1) / b
   307  }
   308  
   309  func modRoundUp(a, b int64) int64 {
   310  	return b - (b-a%b)%b
   311  }