diff options
-rw-r--r-- | dbstore/dbstore.go | 17 | ||||
-rw-r--r-- | dbstore/dbstore_test.go | 9 |
2 files changed, 19 insertions, 7 deletions
diff --git a/dbstore/dbstore.go b/dbstore/dbstore.go index 0ab6fcf..11ad02c 100644 --- a/dbstore/dbstore.go +++ b/dbstore/dbstore.go @@ -13,22 +13,33 @@ import ( //go:embed sql var sqlFs embed.FS +// sqlite3 backing store type DbStore struct { db *db_sql.DB } -// open database -func Open(path string) (DbStore, error) { +// Open database. +// +// This function is called by Open(). It is a separate package-private +// function to make Open() easier to test. +func openFull(dbType, path string) (DbStore, error) { var r DbStore + // init db - if db, err := db_sql.Open("sqlite3", path); err != nil { + if db, err := db_sql.Open(dbType, path); err != nil { return r, err } else { + // save handle r.db = db return r, nil } } +// Open database +func Open(path string) (DbStore, error) { + return openFull("sqlite3", path) +} + // initialized database version const initDbVersion = 314159 diff --git a/dbstore/dbstore_test.go b/dbstore/dbstore_test.go index fe653dc..1e03d72 100644 --- a/dbstore/dbstore_test.go +++ b/dbstore/dbstore_test.go @@ -198,19 +198,20 @@ func seedTestDb(ctx context.Context, db DbStore) error { // TODO: seed with other data } -func TestOpen(t *testing.T) { +func TestOpenFull(t *testing.T) { tests := []struct { name string + dbType string path string exp bool } { - { "pass", "./testdata/test-open.db", true }, - // { "fail", "file://invalid/foobar", false }, + { "pass", "sqlite3", "./testdata/test-open.db", true }, + { "fail", "invalid driver", "file://invalid/foobar", false }, } for _, test := range(tests) { t.Run(test.name, func(t *testing.T) { - got, err := Open(test.path) + got, err := openFull(test.dbType, test.path) if test.exp && err != nil { t.Error(err) } else if !test.exp && err == nil { |