From d1483e75727afccc251f70f5a58e756efaff8b9e Mon Sep 17 00:00:00 2001 From: Paul Duncan Date: Fri, 18 Feb 2022 20:57:40 -0500 Subject: dbstore: refactor sqlfs and tx, add tx tests --- dbstore/dbstore.go | 237 ++++++++++++++++++++--------------------------------- 1 file changed, 87 insertions(+), 150 deletions(-) (limited to 'dbstore/dbstore.go') diff --git a/dbstore/dbstore.go b/dbstore/dbstore.go index 25c3564..21a80a3 100644 --- a/dbstore/dbstore.go +++ b/dbstore/dbstore.go @@ -4,16 +4,12 @@ package dbstore import ( "context" db_sql "database/sql" - "embed" "fmt" _ "github.com/mattn/go-sqlite3" "github.com/pablotron/cvez/cpedict" "github.com/pablotron/cvez/cpematch" ) -//go:embed sql -var sqlFs embed.FS - // sqlite3 backing store type DbStore struct { db *db_sql.DB @@ -72,49 +68,68 @@ func (me DbStore) Init(ctx context.Context) error { } // read init query - sql, err := sqlFs.ReadFile("sql/init.sql") - if err != nil { + if sql, err := getQuery("init"); err != nil { + return err + } else { + // exec init query, return result + _, err = me.db.ExecContext(ctx, sql) return err } +} - // exec init query, return result - _, err = me.db.ExecContext(ctx, string(sql)) - return err +var addCpeDictionaryQueryIds = []string { + "cpe/insert", + "cpe/insert-title", + "cpe/insert-ref", } -// get single query from embedded filesystem -func getQuery(id string) (string, error) { - // read query - if data, err := sqlFs.ReadFile(fmt.Sprintf("sql/%s.sql", id)); err != nil { - return "", err - } else { - // return query - return string(data), nil - } +// create new transaction +func (me DbStore) Tx(ctx context.Context, queryIds []string) (Tx, error) { + return newTx(ctx, me.db, queryIds) } -// return query map -func getQueries(ids []string) (map[string]string, error) { - r := make(map[string]string) +// Execute query and invoke callback with each row of result. +func (me DbStore) Query( + ctx context.Context, + queryId string, + args []interface{}, + fn func(*db_sql.Rows) error, +) error { + // get query + sql, err := getQuery(queryId) + if err != nil { + return err + } - for _, id := range(ids) { - // read query - if sql, err := getQuery(id); err != nil { - return r, fmt.Errorf("%s: %s", id, err.Error()) - } else { - // save query - r[id] = sql + // exec query + rows, err := me.db.QueryContext(ctx, sql, args...) + if err != nil { + return err + } + + // walk results + for rows.Next() { + if err = fn(rows); err != nil { + return err } } - // return success - return r, nil -} + // close rows + // FIXME: is this correct? i am following the example from the + // database/sql documentation, but it is messy and it seems + // counterintuitive to close the row set and then do an additional + // test for iteration errors... + if err = rows.Close(); err != nil { + return err + } -var addCpeDictionaryQueryIds = []string { - "cpe/insert", - "cpe/insert-title", - "cpe/insert-ref", + // check for iteration errors + if err = rows.Err(); err != nil { + return err + } + + // return success + return nil } // import CPE dictionary @@ -124,33 +139,15 @@ func (me DbStore) AddCpeDictionary(ctx context.Context, dict cpedict.Dictionary) return err } - // build query map - queries, err := getQueries(addCpeDictionaryQueryIds) - if err != nil { - return err - } - - // begin context - tx, err := me.db.BeginTx(ctx, nil) + tx, err := me.Tx(ctx, addCpeDictionaryQueryIds) if err != nil { return err } - // build statements - sts := make(map[string]*db_sql.Stmt) - for id, sql := range(queries) { - if st, err := tx.PrepareContext(ctx, sql); err != nil { - return err - } else { - sts[id] = st - defer sts[id].Close() - } - } - // add items for _, item := range(dict.Items) { // add cpe - rs, err := sts["cpe/insert"].ExecContext(ctx, item.CpeUri, item.Cpe23Item.Name) + rs, err := tx.Exec(ctx, "cpe/insert", item.CpeUri, item.Cpe23Item.Name) if err != nil { return err } @@ -163,7 +160,7 @@ func (me DbStore) AddCpeDictionary(ctx context.Context, dict cpedict.Dictionary) // add titles for _, title := range(item.Titles) { - _, err := sts["cpe/insert-title"].ExecContext(ctx, id, title.Lang, title.Text) + _, err := tx.Exec(ctx, "cpe/insert-title", id, title.Lang, title.Text) if err != nil { return err } @@ -171,7 +168,7 @@ func (me DbStore) AddCpeDictionary(ctx context.Context, dict cpedict.Dictionary) // add refs for _, ref := range(item.References) { - _, err := sts["cpe/insert-ref"].ExecContext(ctx, id, ref.Href, ref.Text) + _, err := tx.Exec(ctx, "cpe/insert-ref", id, ref.Href, ref.Text) if err != nil { return err } @@ -195,51 +192,28 @@ func (me DbStore) CpeSearch( return r, err } - // get query - sql, err := getQuery(searchType.String()) - if err != nil { - return r, err - } - - // exec search query - rows, err := me.db.QueryContext(ctx, sql, db_sql.Named("q", s)) - if err != nil { - return r, err - } - - // walk results - for rows.Next() { + // get/exec search query + err := me.Query(ctx, searchType.String(), []interface{} { + db_sql.Named("q", s), + }, func(rows *db_sql.Rows) error { if sr, err := unmarshalCpeSearchRow(rows); err != nil { // return error - return r, err + return err } else { // append to results r = append(r, sr) + return nil } - } - - // close rows - // FIXME: is this correct? i am following the example from the - // database/sql documentation, but it is messy and it seems - // counterintuitive to close the row set and then do an additional - // test for iteration errors... - if err = rows.Close(); err != nil { - return r, err - } - - // check for iteration errors - if err = rows.Err(); err != nil { - return r, err - } + }) - // return success - return r, nil + // return results + return r, err } // query IDs used by AddCpeMatches() var addCpeMatchesQueryIds = []string { "cpe-match/insert", - "cpe-match/insert-vulnerability", + "cpe-match/insert-vulnerable", "cpe-match/insert-version-min", "cpe-match/insert-version-max", "cpe-match/insert-name", @@ -252,33 +226,16 @@ func (me DbStore) AddCpeMatches(ctx context.Context, matches cpematch.Matches) e return err } - // build query map - queries, err := getQueries(addCpeMatchesQueryIds) + // begin transaction + tx, err := me.Tx(ctx, addCpeMatchesQueryIds) if err != nil { return err } - // begin context - tx, err := me.db.BeginTx(ctx, nil) - if err != nil { - return err - } - - // build statements - sts := make(map[string]*db_sql.Stmt) - for id, sql := range(queries) { - if st, err := tx.PrepareContext(ctx, sql); err != nil { - return err - } else { - sts[id] = st - defer sts[id].Close() - } - } - // add matches for _, m := range(matches.Matches) { // add cpe - rs, err := sts["cpe-match/insert"].ExecContext(ctx, m.Cpe23Uri, m.Cpe22Uri) + rs, err := tx.Exec(ctx, "cpe-match/insert", m.Cpe23Uri, m.Cpe22Uri) if err != nil { return err } @@ -291,33 +248,37 @@ func (me DbStore) AddCpeMatches(ctx context.Context, matches cpematch.Matches) e // add vulnerable if m.Vulnerable != nil { - _, err := sts["cpe-match/insert-vulnerable"].ExecContext(ctx, id, *m.Vulnerable) + _, err := tx.Exec(ctx, "cpe-match/insert-vulnerable", id, *m.Vulnerable) if err != nil { return err } } // add version minimum - if m.VersionStartIncluding != "" { - _, err := sts["cpe-match/insert-versiom-min"].ExecContext(ctx, id, true, m.VersionStartIncluding) + if m.VersionStartIncluding != "" && m.VersionStartExcluding != "" { + return fmt.Errorf("cannot specify both VersionStartIncluding = \"%s\", VersionEndIncluding \"%s\"", m.VersionStartIncluding, m.VersionStartExcluding) + } else if m.VersionStartIncluding != "" { + _, err := tx.Exec(ctx, "cpe-match/insert-version-min", id, true, m.VersionStartIncluding) if err != nil { return err } } else if m.VersionStartExcluding != "" { - _, err := sts["cpe-match/insert-versiom-min"].ExecContext(ctx, id, false, m.VersionStartExcluding) + _, err := tx.Exec(ctx, "cpe-match/insert-version-min", id, false, m.VersionStartExcluding) if err != nil { return err } } // add version maximum - if m.VersionEndIncluding != "" { - _, err := sts["cpe-match/insert-versiom-max"].ExecContext(ctx, id, true, m.VersionEndIncluding) + if m.VersionEndIncluding != "" && m.VersionEndExcluding != "" { + return fmt.Errorf("cannot specify both VersionEndIncluding = \"%s\", VersionEndIncluding \"%s\"", m.VersionEndIncluding, m.VersionEndExcluding) + } else if m.VersionEndIncluding != "" { + _, err := tx.Exec(ctx, "cpe-match/insert-version-max", id, true, m.VersionEndIncluding) if err != nil { return err } } else if m.VersionEndExcluding != "" { - _, err := sts["cpe-match/insert-versiom-max"].ExecContext(ctx, id, false, m.VersionEndExcluding) + _, err := tx.Exec(ctx, "cpe-match/insert-version-max", id, false, m.VersionEndExcluding) if err != nil { return err } @@ -325,7 +286,7 @@ func (me DbStore) AddCpeMatches(ctx context.Context, matches cpematch.Matches) e // add names for _, name := range(m.Names) { - _, err := sts["cpe-match/insert-name"].ExecContext(ctx, id, name.Cpe23Uri, name.Cpe22Uri) + _, err := tx.Exec(ctx, "cpe-match/insert-name", id, name.Cpe23Uri, name.Cpe22Uri) if err != nil { return err } @@ -348,45 +309,21 @@ func (me DbStore) CpeMatchSearch( return r, err } - // get query - // FIXME: cache this? - sql, err := getQuery("cpe-match/search.sql") - if err != nil { - return r, err - } - // exec search query - rows, err := me.db.QueryContext(ctx, sql, match) - if err != nil { - return r, err - } - - // walk results - for rows.Next() { + err := me.Query(ctx, "cpe-match/search", []interface{} { + match, + }, func(rows *db_sql.Rows) error { var s string if err := rows.Scan(&s); err != nil { // return error - return r, err + return err } else { // append to results r = append(r, s) + return nil } - } - - // close rows - // FIXME: is this correct? i am following the example from the - // database/sql documentation, but it is messy and it seems - // counterintuitive to close the row set and then do an additional - // test for iteration errors... - if err = rows.Close(); err != nil { - return r, err - } + }) - // check for iteration errors - if err = rows.Err(); err != nil { - return r, err - } - - // return success - return r, nil + // return result + return r, err } -- cgit v1.2.3