package sqlite3 import ( "database/sql" "fmt" "io" nurl "net/url" "strconv" "strings" "go.uber.org/atomic" "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/database" "github.com/hashicorp/go-multierror" _ "github.com/mattn/go-sqlite3" ) func init() { database.Register("sqlite3", &Sqlite{}) } var DefaultMigrationsTable = "schema_migrations" var ( ErrDatabaseDirty = fmt.Errorf("database is dirty") ErrNilConfig = fmt.Errorf("no config") ErrNoDatabaseName = fmt.Errorf("no database name") ) type Config struct { MigrationsTable string DatabaseName string NoTxWrap bool } type Sqlite struct { db *sql.DB isLocked atomic.Bool config *Config } func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { if config == nil { return nil, ErrNilConfig } if err := instance.Ping(); err != nil { return nil, err } if len(config.MigrationsTable) == 0 { config.MigrationsTable = DefaultMigrationsTable } mx := &Sqlite{ db: instance, config: config, } if err := mx.ensureVersionTable(); err != nil { return nil, err } return mx, nil } // ensureVersionTable checks if versions table exists and, if not, creates it. // Note that this function locks the database, which deviates from the usual // convention of "caller locks" in the Sqlite type. func (m *Sqlite) ensureVersionTable() (err error) { if err = m.Lock(); err != nil { return err } defer func() { if e := m.Unlock(); e != nil { if err == nil { err = e } else { err = multierror.Append(err, e) } } }() query := fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s (version uint64,dirty bool); CREATE UNIQUE INDEX IF NOT EXISTS version_unique ON %s (version); `, m.config.MigrationsTable, m.config.MigrationsTable) if _, err := m.db.Exec(query); err != nil { return err } return nil } func (m *Sqlite) Open(url string) (database.Driver, error) { purl, err := nurl.Parse(url) if err != nil { return nil, err } dbfile := strings.Replace(migrate.FilterCustomQuery(purl).String(), "sqlite3://", "", 1) db, err := sql.Open("sqlite3", dbfile) if err != nil { return nil, err } qv := purl.Query() migrationsTable := qv.Get("x-migrations-table") if len(migrationsTable) == 0 { migrationsTable = DefaultMigrationsTable } noTxWrap := false if v := qv.Get("x-no-tx-wrap"); v != "" { noTxWrap, err = strconv.ParseBool(v) if err != nil { return nil, fmt.Errorf("x-no-tx-wrap: %s", err) } } mx, err := WithInstance(db, &Config{ DatabaseName: purl.Path, MigrationsTable: migrationsTable, NoTxWrap: noTxWrap, }) if err != nil { return nil, err } return mx, nil } func (m *Sqlite) Close() error { return m.db.Close() } func (m *Sqlite) Drop() (err error) { query := `SELECT name FROM sqlite_master WHERE type = 'table';` tables, err := m.db.Query(query) if err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } defer func() { if errClose := tables.Close(); errClose != nil { err = multierror.Append(err, errClose) } }() tableNames := make([]string, 0) for tables.Next() { var tableName string if err := tables.Scan(&tableName); err != nil { return err } if len(tableName) > 0 { tableNames = append(tableNames, tableName) } } if err := tables.Err(); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } if len(tableNames) > 0 { for _, t := range tableNames { query := "DROP TABLE " + t err = m.executeQuery(query) if err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } } query := "VACUUM" _, err = m.db.Query(query) if err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } } return nil } func (m *Sqlite) Lock() error { if !m.isLocked.CAS(false, true) { return database.ErrLocked } return nil } func (m *Sqlite) Unlock() error { if !m.isLocked.CAS(true, false) { return database.ErrNotLocked } return nil } func (m *Sqlite) Run(migration io.Reader) error { migr, err := io.ReadAll(migration) if err != nil { return err } query := string(migr[:]) if m.config.NoTxWrap { return m.executeQueryNoTx(query) } return m.executeQuery(query) } func (m *Sqlite) executeQuery(query string) error { tx, err := m.db.Begin() if err != nil { return &database.Error{OrigErr: err, Err: "transaction start failed"} } if _, err := tx.Exec(query); err != nil { if errRollback := tx.Rollback(); errRollback != nil { err = multierror.Append(err, errRollback) } return &database.Error{OrigErr: err, Query: []byte(query)} } if err := tx.Commit(); err != nil { return &database.Error{OrigErr: err, Err: "transaction commit failed"} } return nil } func (m *Sqlite) executeQueryNoTx(query string) error { if _, err := m.db.Exec(query); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } return nil } func (m *Sqlite) SetVersion(version int, dirty bool) error { tx, err := m.db.Begin() if err != nil { return &database.Error{OrigErr: err, Err: "transaction start failed"} } query := "DELETE FROM " + m.config.MigrationsTable if _, err := tx.Exec(query); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } // Also re-write the schema version for nil dirty versions to prevent // empty schema version for failed down migration on the first migration // See: https://github.com/golang-migrate/migrate/issues/330 if version >= 0 || (version == database.NilVersion && dirty) { query := fmt.Sprintf(`INSERT INTO %s (version, dirty) VALUES (?, ?)`, m.config.MigrationsTable) if _, err := tx.Exec(query, version, dirty); err != nil { if errRollback := tx.Rollback(); errRollback != nil { err = multierror.Append(err, errRollback) } return &database.Error{OrigErr: err, Query: []byte(query)} } } if err := tx.Commit(); err != nil { return &database.Error{OrigErr: err, Err: "transaction commit failed"} } return nil } func (m *Sqlite) Version() (version int, dirty bool, err error) { query := "SELECT version, dirty FROM " + m.config.MigrationsTable + " LIMIT 1" err = m.db.QueryRow(query).Scan(&version, &dirty) if err != nil { return database.NilVersion, false, nil } return version, dirty, nil }