diff options
-rw-r--r-- | dbstore/dbstore.go | 237 | ||||
-rw-r--r-- | dbstore/dbstore_test.go | 420 | ||||
-rw-r--r-- | dbstore/sql/test/junk.sql | 3 | ||||
-rw-r--r-- | dbstore/sqlfs.go | 38 | ||||
-rw-r--r-- | dbstore/tx.go | 105 | ||||
-rw-r--r-- | dbstore/tx_test.go | 140 |
6 files changed, 793 insertions, 150 deletions
diff --git a/dbstore/dbstore.go b/dbstore/dbstore.go index 25c3564..21a80a3 100644 --- a/dbstore/dbstore.go +++ b/dbstore/dbstore.go @@ -4,16 +4,12 @@ package dbstore import ( "context" db_sql "database/sql" - "embed" "fmt" _ "github.com/mattn/go-sqlite3" "github.com/pablotron/cvez/cpedict" "github.com/pablotron/cvez/cpematch" ) -//go:embed sql -var sqlFs embed.FS - // sqlite3 backing store type DbStore struct { db *db_sql.DB @@ -72,49 +68,68 @@ func (me DbStore) Init(ctx context.Context) error { } // read init query - sql, err := sqlFs.ReadFile("sql/init.sql") - if err != nil { + if sql, err := getQuery("init"); err != nil { + return err + } else { + // exec init query, return result + _, err = me.db.ExecContext(ctx, sql) return err } +} - // exec init query, return result - _, err = me.db.ExecContext(ctx, string(sql)) - return err +var addCpeDictionaryQueryIds = []string { + "cpe/insert", + "cpe/insert-title", + "cpe/insert-ref", } -// get single query from embedded filesystem -func getQuery(id string) (string, error) { - // read query - if data, err := sqlFs.ReadFile(fmt.Sprintf("sql/%s.sql", id)); err != nil { - return "", err - } else { - // return query - return string(data), nil - } +// create new transaction +func (me DbStore) Tx(ctx context.Context, queryIds []string) (Tx, error) { + return newTx(ctx, me.db, queryIds) } -// return query map -func getQueries(ids []string) (map[string]string, error) { - r := make(map[string]string) +// Execute query and invoke callback with each row of result. +func (me DbStore) Query( + ctx context.Context, + queryId string, + args []interface{}, + fn func(*db_sql.Rows) error, +) error { + // get query + sql, err := getQuery(queryId) + if err != nil { + return err + } - for _, id := range(ids) { - // read query - if sql, err := getQuery(id); err != nil { - return r, fmt.Errorf("%s: %s", id, err.Error()) - } else { - // save query - r[id] = sql + // exec query + rows, err := me.db.QueryContext(ctx, sql, args...) + if err != nil { + return err + } + + // walk results + for rows.Next() { + if err = fn(rows); err != nil { + return err } } - // return success - return r, nil -} + // close rows + // FIXME: is this correct? i am following the example from the + // database/sql documentation, but it is messy and it seems + // counterintuitive to close the row set and then do an additional + // test for iteration errors... + if err = rows.Close(); err != nil { + return err + } -var addCpeDictionaryQueryIds = []string { - "cpe/insert", - "cpe/insert-title", - "cpe/insert-ref", + // check for iteration errors + if err = rows.Err(); err != nil { + return err + } + + // return success + return nil } // import CPE dictionary @@ -124,33 +139,15 @@ func (me DbStore) AddCpeDictionary(ctx context.Context, dict cpedict.Dictionary) return err } - // build query map - queries, err := getQueries(addCpeDictionaryQueryIds) - if err != nil { - return err - } - - // begin context - tx, err := me.db.BeginTx(ctx, nil) + tx, err := me.Tx(ctx, addCpeDictionaryQueryIds) if err != nil { return err } - // build statements - sts := make(map[string]*db_sql.Stmt) - for id, sql := range(queries) { - if st, err := tx.PrepareContext(ctx, sql); err != nil { - return err - } else { - sts[id] = st - defer sts[id].Close() - } - } - // add items for _, item := range(dict.Items) { // add cpe - rs, err := sts["cpe/insert"].ExecContext(ctx, item.CpeUri, item.Cpe23Item.Name) + rs, err := tx.Exec(ctx, "cpe/insert", item.CpeUri, item.Cpe23Item.Name) if err != nil { return err } @@ -163,7 +160,7 @@ func (me DbStore) AddCpeDictionary(ctx context.Context, dict cpedict.Dictionary) // add titles for _, title := range(item.Titles) { - _, err := sts["cpe/insert-title"].ExecContext(ctx, id, title.Lang, title.Text) + _, err := tx.Exec(ctx, "cpe/insert-title", id, title.Lang, title.Text) if err != nil { return err } @@ -171,7 +168,7 @@ func (me DbStore) AddCpeDictionary(ctx context.Context, dict cpedict.Dictionary) // add refs for _, ref := range(item.References) { - _, err := sts["cpe/insert-ref"].ExecContext(ctx, id, ref.Href, ref.Text) + _, err := tx.Exec(ctx, "cpe/insert-ref", id, ref.Href, ref.Text) if err != nil { return err } @@ -195,51 +192,28 @@ func (me DbStore) CpeSearch( return r, err } - // get query - sql, err := getQuery(searchType.String()) - if err != nil { - return r, err - } - - // exec search query - rows, err := me.db.QueryContext(ctx, sql, db_sql.Named("q", s)) - if err != nil { - return r, err - } - - // walk results - for rows.Next() { + // get/exec search query + err := me.Query(ctx, searchType.String(), []interface{} { + db_sql.Named("q", s), + }, func(rows *db_sql.Rows) error { if sr, err := unmarshalCpeSearchRow(rows); err != nil { // return error - return r, err + return err } else { // append to results r = append(r, sr) + return nil } - } - - // close rows - // FIXME: is this correct? i am following the example from the - // database/sql documentation, but it is messy and it seems - // counterintuitive to close the row set and then do an additional - // test for iteration errors... - if err = rows.Close(); err != nil { - return r, err - } - - // check for iteration errors - if err = rows.Err(); err != nil { - return r, err - } + }) - // return success - return r, nil + // return results + return r, err } // query IDs used by AddCpeMatches() var addCpeMatchesQueryIds = []string { "cpe-match/insert", - "cpe-match/insert-vulnerability", + "cpe-match/insert-vulnerable", "cpe-match/insert-version-min", "cpe-match/insert-version-max", "cpe-match/insert-name", @@ -252,33 +226,16 @@ func (me DbStore) AddCpeMatches(ctx context.Context, matches cpematch.Matches) e return err } - // build query map - queries, err := getQueries(addCpeMatchesQueryIds) + // begin transaction + tx, err := me.Tx(ctx, addCpeMatchesQueryIds) if err != nil { return err } - // begin context - tx, err := me.db.BeginTx(ctx, nil) - if err != nil { - return err - } - - // build statements - sts := make(map[string]*db_sql.Stmt) - for id, sql := range(queries) { - if st, err := tx.PrepareContext(ctx, sql); err != nil { - return err - } else { - sts[id] = st - defer sts[id].Close() - } - } - // add matches for _, m := range(matches.Matches) { // add cpe - rs, err := sts["cpe-match/insert"].ExecContext(ctx, m.Cpe23Uri, m.Cpe22Uri) + rs, err := tx.Exec(ctx, "cpe-match/insert", m.Cpe23Uri, m.Cpe22Uri) if err != nil { return err } @@ -291,33 +248,37 @@ func (me DbStore) AddCpeMatches(ctx context.Context, matches cpematch.Matches) e // add vulnerable if m.Vulnerable != nil { - _, err := sts["cpe-match/insert-vulnerable"].ExecContext(ctx, id, *m.Vulnerable) + _, err := tx.Exec(ctx, "cpe-match/insert-vulnerable", id, *m.Vulnerable) if err != nil { return err } } // add version minimum - if m.VersionStartIncluding != "" { - _, err := sts["cpe-match/insert-versiom-min"].ExecContext(ctx, id, true, m.VersionStartIncluding) + if m.VersionStartIncluding != "" && m.VersionStartExcluding != "" { + return fmt.Errorf("cannot specify both VersionStartIncluding = \"%s\", VersionEndIncluding \"%s\"", m.VersionStartIncluding, m.VersionStartExcluding) + } else if m.VersionStartIncluding != "" { + _, err := tx.Exec(ctx, "cpe-match/insert-version-min", id, true, m.VersionStartIncluding) if err != nil { return err } } else if m.VersionStartExcluding != "" { - _, err := sts["cpe-match/insert-versiom-min"].ExecContext(ctx, id, false, m.VersionStartExcluding) + _, err := tx.Exec(ctx, "cpe-match/insert-version-min", id, false, m.VersionStartExcluding) if err != nil { return err } } // add version maximum - if m.VersionEndIncluding != "" { - _, err := sts["cpe-match/insert-versiom-max"].ExecContext(ctx, id, true, m.VersionEndIncluding) + if m.VersionEndIncluding != "" && m.VersionEndExcluding != "" { + return fmt.Errorf("cannot specify both VersionEndIncluding = \"%s\", VersionEndIncluding \"%s\"", m.VersionEndIncluding, m.VersionEndExcluding) + } else if m.VersionEndIncluding != "" { + _, err := tx.Exec(ctx, "cpe-match/insert-version-max", id, true, m.VersionEndIncluding) if err != nil { return err } } else if m.VersionEndExcluding != "" { - _, err := sts["cpe-match/insert-versiom-max"].ExecContext(ctx, id, false, m.VersionEndExcluding) + _, err := tx.Exec(ctx, "cpe-match/insert-version-max", id, false, m.VersionEndExcluding) if err != nil { return err } @@ -325,7 +286,7 @@ func (me DbStore) AddCpeMatches(ctx context.Context, matches cpematch.Matches) e // add names for _, name := range(m.Names) { - _, err := sts["cpe-match/insert-name"].ExecContext(ctx, id, name.Cpe23Uri, name.Cpe22Uri) + _, err := tx.Exec(ctx, "cpe-match/insert-name", id, name.Cpe23Uri, name.Cpe22Uri) if err != nil { return err } @@ -348,45 +309,21 @@ func (me DbStore) CpeMatchSearch( return r, err } - // get query - // FIXME: cache this? - sql, err := getQuery("cpe-match/search.sql") - if err != nil { - return r, err - } - // exec search query - rows, err := me.db.QueryContext(ctx, sql, match) - if err != nil { - return r, err - } - - // walk results - for rows.Next() { + err := me.Query(ctx, "cpe-match/search", []interface{} { + match, + }, func(rows *db_sql.Rows) error { var s string if err := rows.Scan(&s); err != nil { // return error - return r, err + return err } else { // append to results r = append(r, s) + return nil } - } - - // close rows - // FIXME: is this correct? i am following the example from the - // database/sql documentation, but it is messy and it seems - // counterintuitive to close the row set and then do an additional - // test for iteration errors... - if err = rows.Close(); err != nil { - return r, err - } + }) - // check for iteration errors - if err = rows.Err(); err != nil { - return r, err - } - - // return success - return r, nil + // return result + return r, err } diff --git a/dbstore/dbstore_test.go b/dbstore/dbstore_test.go index 1e03d72..628addb 100644 --- a/dbstore/dbstore_test.go +++ b/dbstore/dbstore_test.go @@ -10,8 +10,10 @@ import ( "fmt" _ "github.com/mattn/go-sqlite3" "github.com/pablotron/cvez/cpedict" + "github.com/pablotron/cvez/cpematch" io_fs "io/fs" "os" + "path/filepath" "reflect" "testing" "time" @@ -526,3 +528,421 @@ func TestCpeSearch(t *testing.T) { }) } } + +func TestAddCpeMatches(t *testing.T) { + // cache context, create temp dir + ctx := context.Background() + dir, err := os.MkdirTemp("", "") + if err != nil { + t.Error(err) + return + } + defer os.RemoveAll(dir) + + vuln := true + passTests := []struct { + name string // test name and path + seed string // db seed query + matches []cpematch.Match // matches + } {{ + name: "pass-basic", + + seed: ` + INSERT INTO cpes(cpe_uri, cpe23) VALUES ( + 'cpe:/1', + 'cpe:2.3:a:101_project:101:1.0.0:*:*:*:*:node.js:*:*' + ), ( + 'cpe:/2', + 'cpe:2.3:a:101_project:101:1.1.0:*:*:*:*:node.js:*:*' + ), ( + 'cpe:/3', + 'cpe:2.3:a:101_project:101:1.1.1:*:*:*:*:node.js:*:*' + ); + `, + + matches: []cpematch.Match { + cpematch.Match { + Cpe23Uri: "cpe:2.3:a:101_project:101:*:*:*:*:*:node.js:*:*", + + Vulnerable: &vuln, + + VersionStartIncluding: "1.0.0", + VersionEndIncluding: "1.6.3", + + Names: []cpematch.Name { + cpematch.Name { + Cpe23Uri: "cpe:2.3:a:101_project:101:1.0.0:*:*:*:*:node.js:*:*", + }, + + cpematch.Name { + Cpe23Uri: "cpe:2.3:a:101_project:101:1.1.0:*:*:*:*:node.js:*:*", + }, + + cpematch.Name { + Cpe23Uri: "cpe:2.3:a:101_project:101:1.1.1:*:*:*:*:node.js:*:*", + }, + }, + }, + }, + }, { + name: "pass-excluding", + + seed: ` + INSERT INTO cpes(cpe_uri, cpe23) VALUES ( + 'cpe:/1', + 'cpe:2.3:a:101_project:101:1.0.0:*:*:*:*:node.js:*:*' + ), ( + 'cpe:/2', + 'cpe:2.3:a:101_project:101:1.1.0:*:*:*:*:node.js:*:*' + ), ( + 'cpe:/3', + 'cpe:2.3:a:101_project:101:1.1.1:*:*:*:*:node.js:*:*' + ); + `, + + matches: []cpematch.Match { + cpematch.Match { + Cpe23Uri: "cpe:2.3:a:101_project:101:*:*:*:*:*:node.js:*:*", + + Vulnerable: &vuln, + + VersionStartExcluding: "1.0.0", + VersionEndExcluding: "1.6.3", + + Names: []cpematch.Name { + cpematch.Name { + Cpe23Uri: "cpe:2.3:a:101_project:101:1.0.0:*:*:*:*:node.js:*:*", + }, + + cpematch.Name { + Cpe23Uri: "cpe:2.3:a:101_project:101:1.1.0:*:*:*:*:node.js:*:*", + }, + + cpematch.Name { + Cpe23Uri: "cpe:2.3:a:101_project:101:1.1.1:*:*:*:*:node.js:*:*", + }, + }, + }, + }, + }} + + for _, test := range(passTests) { + t.Run(test.name, func(t *testing.T) { + // build test matches + matches := cpematch.Matches { Matches: test.matches } + // open db + db, err := Open(filepath.Join(dir, fmt.Sprintf("%s.db", test.name))) + if err != nil { + t.Error(err) + return + } + + // init db + if err := db.Init(ctx); err != nil { + t.Error(err) + return + } + + // seed db + if _, err = db.db.ExecContext(ctx, test.seed); err != nil { + t.Error(err) + return + } + + // add matches + if err = db.AddCpeMatches(ctx, matches); err != nil { + t.Error(err) + } + }) + } + + failTests := []struct { + name string // test name (path inferred) + seed string // db seed query + matches []cpematch.Match // matches + } {{ + name: "bad-match-cpe23", + + seed: ` + INSERT INTO cpes(cpe_uri, cpe23) VALUES ( + 'cpe:/1', + 'cpe:2.3:a:101_project:101:1.0.0:*:*:*:*:node.js:*:*' + ), ( + 'cpe:/2', + 'cpe:2.3:a:101_project:101:1.1.0:*:*:*:*:node.js:*:*' + ), ( + 'cpe:/3', + 'cpe:2.3:a:101_project:101:1.1.1:*:*:*:*:node.js:*:*' + ); + `, + + matches: []cpematch.Match { + cpematch.Match { + Cpe23Uri: "cpe:", + + VersionStartIncluding: "1.0.0", + VersionEndIncluding: "1.6.3", + + Names: []cpematch.Name { + cpematch.Name { + Cpe23Uri: "cpe:2.3:a:101_project:101:1.0.0:*:*:*:*:node.js:*:*", + }, + + cpematch.Name { + Cpe23Uri: "cpe:2.3:a:101_project:101:1.1.0:*:*:*:*:node.js:*:*", + }, + + cpematch.Name { + Cpe23Uri: "cpe:2.3:a:101_project:101:1.1.1:*:*:*:*:node.js:*:*", + }, + }, + }, + }, + }, { + name: "bad-cpe", + + seed: ` + INSERT INTO cpes(cpe_uri, cpe23) VALUES ( + 'cpe:/1', + 'cpe:2.3:a:101_project:101:1.0.0:*:*:*:*:node.js:*:*' + ), ( + 'cpe:/2', + 'cpe:2.3:a:101_project:101:1.1.0:*:*:*:*:node.js:*:*' + ), ( + 'cpe:/3', + 'cpe:2.3:a:101_project:101:1.1.1:*:*:*:*:node.js:*:*' + ); + `, + + matches: []cpematch.Match { + cpematch.Match { + Cpe23Uri: "cpe:2.3:a:101_project:101:*:*:*:*:*:node.js:*:*", + + VersionStartIncluding: "1.0.0", + VersionEndIncluding: "1.6.3", + + Names: []cpematch.Name { + cpematch.Name { + Cpe23Uri: "cpe:2.3", + }, + + cpematch.Name { + Cpe23Uri: "cpe:2.3:a:101_project:101:1.1.0:*:*:*:*:node.js:*:*", + }, + + cpematch.Name { + Cpe23Uri: "cpe:2.3:a:101_project:101:1.1.1:*:*:*:*:node.js:*:*", + }, + }, + }, + }, + }, { + name: "dup-versionstart", + + seed: ` + INSERT INTO cpes(cpe_uri, cpe23) VALUES ( + 'cpe:/1', + 'cpe:2.3:a:101_project:101:1.0.0:*:*:*:*:node.js:*:*' + ), ( + 'cpe:/2', + 'cpe:2.3:a:101_project:101:1.1.0:*:*:*:*:node.js:*:*' + ), ( + 'cpe:/3', + 'cpe:2.3:a:101_project:101:1.1.1:*:*:*:*:node.js:*:*' + ); + `, + + matches: []cpematch.Match { + cpematch.Match { + Cpe23Uri: "cpe:2.3:a:101_project:101:*:*:*:*:*:node.js:*:*", + + VersionStartIncluding: "1.0.0", + VersionStartExcluding: "1.1.0", + + Names: []cpematch.Name { + cpematch.Name { + Cpe23Uri: "cpe:2.3:a:101_project:101:1.0.0:*:*:*:*:node.js:*:*", + }, + + cpematch.Name { + Cpe23Uri: "cpe:2.3:a:101_project:101:1.1.0:*:*:*:*:node.js:*:*", + }, + + cpematch.Name { + Cpe23Uri: "cpe:2.3:a:101_project:101:1.1.1:*:*:*:*:node.js:*:*", + }, + }, + }, + }, + }, { + name: "dup-versionend", + + seed: ` + INSERT INTO cpes(cpe_uri, cpe23) VALUES ( + 'cpe:/1', + 'cpe:2.3:a:101_project:101:1.0.0:*:*:*:*:node.js:*:*' + ), ( + 'cpe:/2', + 'cpe:2.3:a:101_project:101:1.1.0:*:*:*:*:node.js:*:*' + ), ( + 'cpe:/3', + 'cpe:2.3:a:101_project:101:1.1.1:*:*:*:*:node.js:*:*' + ); + `, + + matches: []cpematch.Match { + cpematch.Match { + Cpe23Uri: "cpe:2.3:a:101_project:101:*:*:*:*:*:node.js:*:*", + + VersionEndIncluding: "1.0.0", + VersionEndExcluding: "1.1.0", + + Names: []cpematch.Name { + cpematch.Name { + Cpe23Uri: "cpe:2.3:a:101_project:101:1.0.0:*:*:*:*:node.js:*:*", + }, + + cpematch.Name { + Cpe23Uri: "cpe:2.3:a:101_project:101:1.1.0:*:*:*:*:node.js:*:*", + }, + + cpematch.Name { + Cpe23Uri: "cpe:2.3:a:101_project:101:1.1.1:*:*:*:*:node.js:*:*", + }, + }, + }, + }, + }} + + for _, test := range(failTests) { + t.Run(test.name, func(t *testing.T) { + // build test matches + matches := cpematch.Matches { Matches: test.matches } + // open db + db, err := Open(filepath.Join(dir, fmt.Sprintf("%s.db", test.name))) + if err != nil { + t.Error(err) + return + } + + // init db + if err := db.Init(ctx); err != nil { + t.Error(err) + return + } + + // seed db + if _, err = db.db.ExecContext(ctx, test.seed); err != nil { + t.Error(err) + return + } + + // add matches + if err = db.AddCpeMatches(ctx, matches); err == nil { + t.Error("got success, exp error") + } + }) + } +} + +func TestCpeMatchSearch(t *testing.T) { + // cache context, create temp dir + ctx := context.Background() + dir, err := os.MkdirTemp("", "") + if err != nil { + t.Error(err) + return + } + defer os.RemoveAll(dir) + + // db cpe seed query + seed := ` + INSERT INTO cpes(cpe_uri, cpe23) VALUES ( + 'cpe:/1', + 'cpe:2.3:a:101_project:101:1.0.0:*:*:*:*:node.js:*:*' + ), ( + 'cpe:/2', + 'cpe:2.3:a:101_project:101:1.1.0:*:*:*:*:node.js:*:*' + ), ( + 'cpe:/3', + 'cpe:2.3:a:101_project:101:1.1.1:*:*:*:*:node.js:*:*' + ); + ` + + matches := cpematch.Matches { + Matches: []cpematch.Match { + cpematch.Match { + Cpe23Uri: "cpe:2.3:a:101_project:101:*:*:*:*:*:node.js:*:*", + + VersionStartIncluding: "1.0.0", + VersionEndIncluding: "1.6.3", + + Names: []cpematch.Name { + cpematch.Name { + Cpe23Uri: "cpe:2.3:a:101_project:101:1.0.0:*:*:*:*:node.js:*:*", + }, + + cpematch.Name { + Cpe23Uri: "cpe:2.3:a:101_project:101:1.1.0:*:*:*:*:node.js:*:*", + }, + + cpematch.Name { + Cpe23Uri: "cpe:2.3:a:101_project:101:1.1.1:*:*:*:*:node.js:*:*", + }, + }, + }, + }, + } + + // build test matches + // open db + db, err := Open(filepath.Join(dir, "TestCpeMatchSearch.db")) + if err != nil { + t.Error(err) + return + } + + // init db + if err := db.Init(ctx); err != nil { + t.Error(err) + return + } + + // seed db + if _, err = db.db.ExecContext(ctx, seed); err != nil { + t.Error(err) + return + } + + // add matches + if err = db.AddCpeMatches(ctx, matches); err != nil { + t.Error(err) + return + } + + tests := []struct { + val string // search val + exp []string // expected results + } {{ + val: "nothing", + exp: []string {}, + }, { + val: "cpe:2.3:a:101_project:101:*:*:*:*:*:node.js:*:*", + exp: []string { + "cpe:2.3:a:101_project:101:1.0.0:*:*:*:*:node.js:*:*", + "cpe:2.3:a:101_project:101:1.1.0:*:*:*:*:node.js:*:*", + "cpe:2.3:a:101_project:101:1.1.1:*:*:*:*:node.js:*:*", + }, + }} + + for _, test := range(tests) { + t.Run(test.val, func(t *testing.T) { + if got, err := db.CpeMatchSearch(ctx, test.val); err != nil { + t.Error(err) + } else if ((len(got) > 0) || (len(test.exp) > 0)) && + !reflect.DeepEqual(got, test.exp) { + t.Errorf("got %v, exp %v", got, test.exp) + } + }) + } +} diff --git a/dbstore/sql/test/junk.sql b/dbstore/sql/test/junk.sql new file mode 100644 index 0000000..535ef88 --- /dev/null +++ b/dbstore/sql/test/junk.sql @@ -0,0 +1,3 @@ +-- test junk query to trigger prepare error in tests +alks dlk asdlfk alfk leake f fa + diff --git a/dbstore/sqlfs.go b/dbstore/sqlfs.go new file mode 100644 index 0000000..99626a2 --- /dev/null +++ b/dbstore/sqlfs.go @@ -0,0 +1,38 @@ +package dbstore + +import ( + "embed" + "fmt" +) + +//go:embed sql +var sqlFs embed.FS + +// Get single query from embedded sqlFs. +func getQuery(id string) (string, error) { + // read query + if data, err := sqlFs.ReadFile(fmt.Sprintf("sql/%s.sql", id)); err != nil { + return "", err + } else { + // return query + return string(data), nil + } +} + +// Return queries from embedded sqlFs and return map of query ID to SQL. +func getQueries(ids []string) (map[string]string, error) { + r := make(map[string]string) + + for _, id := range(ids) { + // read query + if sql, err := getQuery(id); err != nil { + return r, fmt.Errorf("%s: %s", id, err.Error()) + } else { + // save query + r[id] = sql + } + } + + // return success + return r, nil +} diff --git a/dbstore/tx.go b/dbstore/tx.go new file mode 100644 index 0000000..460bbe1 --- /dev/null +++ b/dbstore/tx.go @@ -0,0 +1,105 @@ +// database storage +package dbstore + +import ( + "context" + db_sql "database/sql" + "fmt" + _ "github.com/mattn/go-sqlite3" +) + +type Tx struct { + tx *db_sql.Tx // underlying transaction + sts map[string]*db_sql.Stmt // prepared statements + closed bool // are the statements closed? + done bool // is the transaction done? + err error // last error +} + +func newTx(ctx context.Context, db *db_sql.DB, queryIds []string) (Tx, error) { + var r Tx + + // begin context + if tx, err := db.BeginTx(ctx, nil); err != nil { + return r, err + } else { + r.tx = tx + } + + // build query map + queries, err := getQueries(queryIds) + if err != nil { + return r, err + } + + // build statements + sts := make(map[string]*db_sql.Stmt) + for id, sql := range(queries) { + if st, err := r.tx.PrepareContext(ctx, sql); err != nil { + return r, err + } else { + sts[id] = st + } + } + r.sts = sts + + // return success + return r, nil +} + +// Finalize statements. +func (tx Tx) Close() { + if !tx.closed { + for id, st := range(tx.sts) { + // close statement + st.Close() + + // delete key + delete(tx.sts, id) + } + + // mark transaction as closed + tx.closed = true + } +} + +// Finalize statements, commit transaction. +func (tx Tx) Commit() error { + // close statements + // FIXME: this isn't really necessary, rollback and commit will take + // care of it according to the database/sql docs + tx.Close() + + if !tx.done { + tx.err = tx.tx.Commit() + tx.done = true + } + + // return last error + return tx.err +} + +// Finalize statements, rollback transaction. +func (tx Tx) Rollback() error { + // close statements + // FIXME: this isn't really necessary, rollback and commit will take + // care of it according to the database/sql docs + tx.Close() + + if !tx.done { + tx.err = tx.tx.Rollback() + tx.done = true + } + + // return last error + return tx.err +} + +// execute given prepared statement +func (tx Tx) Exec(ctx context.Context, id string, args... interface{}) (db_sql.Result, error) { + if st, ok := tx.sts[id]; !ok { + return nil, fmt.Errorf("unknown statement: %s", id) + } else { + return st.ExecContext(ctx, args...) + } +} diff --git a/dbstore/tx_test.go b/dbstore/tx_test.go new file mode 100644 index 0000000..df15f9e --- /dev/null +++ b/dbstore/tx_test.go @@ -0,0 +1,140 @@ +package dbstore + +import ( + "context" + "os" + "path/filepath" + // "reflect" + "testing" + // "time" +) + +func TestNewTxDupTx(t *testing.T) { + // cache context, create temp dir + ctx := context.Background() + dir, err := os.MkdirTemp("", "") + if err != nil { + t.Error(err) + return + } + defer os.RemoveAll(dir) + + // open db + db, err := Open(filepath.Join(dir, "newtxduptx.db")) + if err != nil { + t.Error(err) + return + } + + // init db + if err := db.Init(ctx); err != nil { + t.Error(err) + return + } + + // create first transaction + if _, err = db.Tx(ctx, []string{}); err != nil { + t.Error(err) + return + } + + // create second transaction + got, err := db.Tx(ctx, []string{}) + if err != nil { // FIXME + t.Errorf("got %v, exp error", got) + return + } +} + +func TestNewTxBadQueries(t *testing.T) { + badQueryIds := []string { "query/does/not/exist" } + + // cache context, create temp dir + ctx := context.Background() + dir, err := os.MkdirTemp("", "") + if err != nil { + t.Error(err) + return + } + defer os.RemoveAll(dir) + + // open db + db, err := Open(filepath.Join(dir, "newtx-bad-queries.db")) + if err != nil { + t.Error(err) + return + } + + // init db + if err := db.Init(ctx); err != nil { + t.Error(err) + return + } + + // create first transaction + if got, err := db.Tx(ctx, badQueryIds); err == nil { + t.Errorf("got %v, exp error", got) + return + } +} + +func TestNewTxPrepareFail(t *testing.T) { + // cache context, create temp dir + ctx := context.Background() + dir, err := os.MkdirTemp("", "") + if err != nil { + t.Error(err) + return + } + defer os.RemoveAll(dir) + + // open db + db, err := Open(filepath.Join(dir, "newtx-prepare-fail.db")) + if err != nil { + t.Error(err) + return + } + + // create transaction + if got, err := db.Tx(ctx, []string { "test/junk" }); err == nil { + t.Errorf("got %v, exp error", got) + } +} + +func TestTxExecFail(t *testing.T) { + // cache context, create temp dir + ctx := context.Background() + dir, err := os.MkdirTemp("", "") + if err != nil { + t.Error(err) + return + } + defer os.RemoveAll(dir) + + // open db + db, err := Open(filepath.Join(dir, "txexecfail.db")) + if err != nil { + t.Error(err) + return + } + + // init db + if err := db.Init(ctx); err != nil { + t.Error(err) + return + } + + // create transaction + tx, err := db.Tx(ctx, []string { "init" }) + if err != nil { + t.Error(err) + return + } + + // close transaction (this will clear statements) + tx.Close() + + if got, err := tx.Exec(ctx, "init"); err == nil { + t.Errorf("got %v, exp error", got) + } +} |