aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--dbstore/dbstore.go237
-rw-r--r--dbstore/dbstore_test.go420
-rw-r--r--dbstore/sql/test/junk.sql3
-rw-r--r--dbstore/sqlfs.go38
-rw-r--r--dbstore/tx.go105
-rw-r--r--dbstore/tx_test.go140
6 files changed, 793 insertions, 150 deletions
diff --git a/dbstore/dbstore.go b/dbstore/dbstore.go
index 25c3564..21a80a3 100644
--- a/dbstore/dbstore.go
+++ b/dbstore/dbstore.go
@@ -4,16 +4,12 @@ package dbstore
import (
"context"
db_sql "database/sql"
- "embed"
"fmt"
_ "github.com/mattn/go-sqlite3"
"github.com/pablotron/cvez/cpedict"
"github.com/pablotron/cvez/cpematch"
)
-//go:embed sql
-var sqlFs embed.FS
-
// sqlite3 backing store
type DbStore struct {
db *db_sql.DB
@@ -72,49 +68,68 @@ func (me DbStore) Init(ctx context.Context) error {
}
// read init query
- sql, err := sqlFs.ReadFile("sql/init.sql")
- if err != nil {
+ if sql, err := getQuery("init"); err != nil {
+ return err
+ } else {
+ // exec init query, return result
+ _, err = me.db.ExecContext(ctx, sql)
return err
}
+}
- // exec init query, return result
- _, err = me.db.ExecContext(ctx, string(sql))
- return err
+var addCpeDictionaryQueryIds = []string {
+ "cpe/insert",
+ "cpe/insert-title",
+ "cpe/insert-ref",
}
-// get single query from embedded filesystem
-func getQuery(id string) (string, error) {
- // read query
- if data, err := sqlFs.ReadFile(fmt.Sprintf("sql/%s.sql", id)); err != nil {
- return "", err
- } else {
- // return query
- return string(data), nil
- }
+// create new transaction
+func (me DbStore) Tx(ctx context.Context, queryIds []string) (Tx, error) {
+ return newTx(ctx, me.db, queryIds)
}
-// return query map
-func getQueries(ids []string) (map[string]string, error) {
- r := make(map[string]string)
+// Execute query and invoke callback with each row of result.
+func (me DbStore) Query(
+ ctx context.Context,
+ queryId string,
+ args []interface{},
+ fn func(*db_sql.Rows) error,
+) error {
+ // get query
+ sql, err := getQuery(queryId)
+ if err != nil {
+ return err
+ }
- for _, id := range(ids) {
- // read query
- if sql, err := getQuery(id); err != nil {
- return r, fmt.Errorf("%s: %s", id, err.Error())
- } else {
- // save query
- r[id] = sql
+ // exec query
+ rows, err := me.db.QueryContext(ctx, sql, args...)
+ if err != nil {
+ return err
+ }
+
+ // walk results
+ for rows.Next() {
+ if err = fn(rows); err != nil {
+ return err
}
}
- // return success
- return r, nil
-}
+ // close rows
+ // FIXME: is this correct? i am following the example from the
+ // database/sql documentation, but it is messy and it seems
+ // counterintuitive to close the row set and then do an additional
+ // test for iteration errors...
+ if err = rows.Close(); err != nil {
+ return err
+ }
-var addCpeDictionaryQueryIds = []string {
- "cpe/insert",
- "cpe/insert-title",
- "cpe/insert-ref",
+ // check for iteration errors
+ if err = rows.Err(); err != nil {
+ return err
+ }
+
+ // return success
+ return nil
}
// import CPE dictionary
@@ -124,33 +139,15 @@ func (me DbStore) AddCpeDictionary(ctx context.Context, dict cpedict.Dictionary)
return err
}
- // build query map
- queries, err := getQueries(addCpeDictionaryQueryIds)
- if err != nil {
- return err
- }
-
- // begin context
- tx, err := me.db.BeginTx(ctx, nil)
+ tx, err := me.Tx(ctx, addCpeDictionaryQueryIds)
if err != nil {
return err
}
- // build statements
- sts := make(map[string]*db_sql.Stmt)
- for id, sql := range(queries) {
- if st, err := tx.PrepareContext(ctx, sql); err != nil {
- return err
- } else {
- sts[id] = st
- defer sts[id].Close()
- }
- }
-
// add items
for _, item := range(dict.Items) {
// add cpe
- rs, err := sts["cpe/insert"].ExecContext(ctx, item.CpeUri, item.Cpe23Item.Name)
+ rs, err := tx.Exec(ctx, "cpe/insert", item.CpeUri, item.Cpe23Item.Name)
if err != nil {
return err
}
@@ -163,7 +160,7 @@ func (me DbStore) AddCpeDictionary(ctx context.Context, dict cpedict.Dictionary)
// add titles
for _, title := range(item.Titles) {
- _, err := sts["cpe/insert-title"].ExecContext(ctx, id, title.Lang, title.Text)
+ _, err := tx.Exec(ctx, "cpe/insert-title", id, title.Lang, title.Text)
if err != nil {
return err
}
@@ -171,7 +168,7 @@ func (me DbStore) AddCpeDictionary(ctx context.Context, dict cpedict.Dictionary)
// add refs
for _, ref := range(item.References) {
- _, err := sts["cpe/insert-ref"].ExecContext(ctx, id, ref.Href, ref.Text)
+ _, err := tx.Exec(ctx, "cpe/insert-ref", id, ref.Href, ref.Text)
if err != nil {
return err
}
@@ -195,51 +192,28 @@ func (me DbStore) CpeSearch(
return r, err
}
- // get query
- sql, err := getQuery(searchType.String())
- if err != nil {
- return r, err
- }
-
- // exec search query
- rows, err := me.db.QueryContext(ctx, sql, db_sql.Named("q", s))
- if err != nil {
- return r, err
- }
-
- // walk results
- for rows.Next() {
+ // get/exec search query
+ err := me.Query(ctx, searchType.String(), []interface{} {
+ db_sql.Named("q", s),
+ }, func(rows *db_sql.Rows) error {
if sr, err := unmarshalCpeSearchRow(rows); err != nil {
// return error
- return r, err
+ return err
} else {
// append to results
r = append(r, sr)
+ return nil
}
- }
-
- // close rows
- // FIXME: is this correct? i am following the example from the
- // database/sql documentation, but it is messy and it seems
- // counterintuitive to close the row set and then do an additional
- // test for iteration errors...
- if err = rows.Close(); err != nil {
- return r, err
- }
-
- // check for iteration errors
- if err = rows.Err(); err != nil {
- return r, err
- }
+ })
- // return success
- return r, nil
+ // return results
+ return r, err
}
// query IDs used by AddCpeMatches()
var addCpeMatchesQueryIds = []string {
"cpe-match/insert",
- "cpe-match/insert-vulnerability",
+ "cpe-match/insert-vulnerable",
"cpe-match/insert-version-min",
"cpe-match/insert-version-max",
"cpe-match/insert-name",
@@ -252,33 +226,16 @@ func (me DbStore) AddCpeMatches(ctx context.Context, matches cpematch.Matches) e
return err
}
- // build query map
- queries, err := getQueries(addCpeMatchesQueryIds)
+ // begin transaction
+ tx, err := me.Tx(ctx, addCpeMatchesQueryIds)
if err != nil {
return err
}
- // begin context
- tx, err := me.db.BeginTx(ctx, nil)
- if err != nil {
- return err
- }
-
- // build statements
- sts := make(map[string]*db_sql.Stmt)
- for id, sql := range(queries) {
- if st, err := tx.PrepareContext(ctx, sql); err != nil {
- return err
- } else {
- sts[id] = st
- defer sts[id].Close()
- }
- }
-
// add matches
for _, m := range(matches.Matches) {
// add cpe
- rs, err := sts["cpe-match/insert"].ExecContext(ctx, m.Cpe23Uri, m.Cpe22Uri)
+ rs, err := tx.Exec(ctx, "cpe-match/insert", m.Cpe23Uri, m.Cpe22Uri)
if err != nil {
return err
}
@@ -291,33 +248,37 @@ func (me DbStore) AddCpeMatches(ctx context.Context, matches cpematch.Matches) e
// add vulnerable
if m.Vulnerable != nil {
- _, err := sts["cpe-match/insert-vulnerable"].ExecContext(ctx, id, *m.Vulnerable)
+ _, err := tx.Exec(ctx, "cpe-match/insert-vulnerable", id, *m.Vulnerable)
if err != nil {
return err
}
}
// add version minimum
- if m.VersionStartIncluding != "" {
- _, err := sts["cpe-match/insert-versiom-min"].ExecContext(ctx, id, true, m.VersionStartIncluding)
+ if m.VersionStartIncluding != "" && m.VersionStartExcluding != "" {
+ return fmt.Errorf("cannot specify both VersionStartIncluding = \"%s\", VersionEndIncluding \"%s\"", m.VersionStartIncluding, m.VersionStartExcluding)
+ } else if m.VersionStartIncluding != "" {
+ _, err := tx.Exec(ctx, "cpe-match/insert-version-min", id, true, m.VersionStartIncluding)
if err != nil {
return err
}
} else if m.VersionStartExcluding != "" {
- _, err := sts["cpe-match/insert-versiom-min"].ExecContext(ctx, id, false, m.VersionStartExcluding)
+ _, err := tx.Exec(ctx, "cpe-match/insert-version-min", id, false, m.VersionStartExcluding)
if err != nil {
return err
}
}
// add version maximum
- if m.VersionEndIncluding != "" {
- _, err := sts["cpe-match/insert-versiom-max"].ExecContext(ctx, id, true, m.VersionEndIncluding)
+ if m.VersionEndIncluding != "" && m.VersionEndExcluding != "" {
+ return fmt.Errorf("cannot specify both VersionEndIncluding = \"%s\", VersionEndIncluding \"%s\"", m.VersionEndIncluding, m.VersionEndExcluding)
+ } else if m.VersionEndIncluding != "" {
+ _, err := tx.Exec(ctx, "cpe-match/insert-version-max", id, true, m.VersionEndIncluding)
if err != nil {
return err
}
} else if m.VersionEndExcluding != "" {
- _, err := sts["cpe-match/insert-versiom-max"].ExecContext(ctx, id, false, m.VersionEndExcluding)
+ _, err := tx.Exec(ctx, "cpe-match/insert-version-max", id, false, m.VersionEndExcluding)
if err != nil {
return err
}
@@ -325,7 +286,7 @@ func (me DbStore) AddCpeMatches(ctx context.Context, matches cpematch.Matches) e
// add names
for _, name := range(m.Names) {
- _, err := sts["cpe-match/insert-name"].ExecContext(ctx, id, name.Cpe23Uri, name.Cpe22Uri)
+ _, err := tx.Exec(ctx, "cpe-match/insert-name", id, name.Cpe23Uri, name.Cpe22Uri)
if err != nil {
return err
}
@@ -348,45 +309,21 @@ func (me DbStore) CpeMatchSearch(
return r, err
}
- // get query
- // FIXME: cache this?
- sql, err := getQuery("cpe-match/search.sql")
- if err != nil {
- return r, err
- }
-
// exec search query
- rows, err := me.db.QueryContext(ctx, sql, match)
- if err != nil {
- return r, err
- }
-
- // walk results
- for rows.Next() {
+ err := me.Query(ctx, "cpe-match/search", []interface{} {
+ match,
+ }, func(rows *db_sql.Rows) error {
var s string
if err := rows.Scan(&s); err != nil {
// return error
- return r, err
+ return err
} else {
// append to results
r = append(r, s)
+ return nil
}
- }
-
- // close rows
- // FIXME: is this correct? i am following the example from the
- // database/sql documentation, but it is messy and it seems
- // counterintuitive to close the row set and then do an additional
- // test for iteration errors...
- if err = rows.Close(); err != nil {
- return r, err
- }
+ })
- // check for iteration errors
- if err = rows.Err(); err != nil {
- return r, err
- }
-
- // return success
- return r, nil
+ // return result
+ return r, err
}
diff --git a/dbstore/dbstore_test.go b/dbstore/dbstore_test.go
index 1e03d72..628addb 100644
--- a/dbstore/dbstore_test.go
+++ b/dbstore/dbstore_test.go
@@ -10,8 +10,10 @@ import (
"fmt"
_ "github.com/mattn/go-sqlite3"
"github.com/pablotron/cvez/cpedict"
+ "github.com/pablotron/cvez/cpematch"
io_fs "io/fs"
"os"
+ "path/filepath"
"reflect"
"testing"
"time"
@@ -526,3 +528,421 @@ func TestCpeSearch(t *testing.T) {
})
}
}
+
+func TestAddCpeMatches(t *testing.T) {
+ // cache context, create temp dir
+ ctx := context.Background()
+ dir, err := os.MkdirTemp("", "")
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ defer os.RemoveAll(dir)
+
+ vuln := true
+ passTests := []struct {
+ name string // test name and path
+ seed string // db seed query
+ matches []cpematch.Match // matches
+ } {{
+ name: "pass-basic",
+
+ seed: `
+ INSERT INTO cpes(cpe_uri, cpe23) VALUES (
+ 'cpe:/1',
+ 'cpe:2.3:a:101_project:101:1.0.0:*:*:*:*:node.js:*:*'
+ ), (
+ 'cpe:/2',
+ 'cpe:2.3:a:101_project:101:1.1.0:*:*:*:*:node.js:*:*'
+ ), (
+ 'cpe:/3',
+ 'cpe:2.3:a:101_project:101:1.1.1:*:*:*:*:node.js:*:*'
+ );
+ `,
+
+ matches: []cpematch.Match {
+ cpematch.Match {
+ Cpe23Uri: "cpe:2.3:a:101_project:101:*:*:*:*:*:node.js:*:*",
+
+ Vulnerable: &vuln,
+
+ VersionStartIncluding: "1.0.0",
+ VersionEndIncluding: "1.6.3",
+
+ Names: []cpematch.Name {
+ cpematch.Name {
+ Cpe23Uri: "cpe:2.3:a:101_project:101:1.0.0:*:*:*:*:node.js:*:*",
+ },
+
+ cpematch.Name {
+ Cpe23Uri: "cpe:2.3:a:101_project:101:1.1.0:*:*:*:*:node.js:*:*",
+ },
+
+ cpematch.Name {
+ Cpe23Uri: "cpe:2.3:a:101_project:101:1.1.1:*:*:*:*:node.js:*:*",
+ },
+ },
+ },
+ },
+ }, {
+ name: "pass-excluding",
+
+ seed: `
+ INSERT INTO cpes(cpe_uri, cpe23) VALUES (
+ 'cpe:/1',
+ 'cpe:2.3:a:101_project:101:1.0.0:*:*:*:*:node.js:*:*'
+ ), (
+ 'cpe:/2',
+ 'cpe:2.3:a:101_project:101:1.1.0:*:*:*:*:node.js:*:*'
+ ), (
+ 'cpe:/3',
+ 'cpe:2.3:a:101_project:101:1.1.1:*:*:*:*:node.js:*:*'
+ );
+ `,
+
+ matches: []cpematch.Match {
+ cpematch.Match {
+ Cpe23Uri: "cpe:2.3:a:101_project:101:*:*:*:*:*:node.js:*:*",
+
+ Vulnerable: &vuln,
+
+ VersionStartExcluding: "1.0.0",
+ VersionEndExcluding: "1.6.3",
+
+ Names: []cpematch.Name {
+ cpematch.Name {
+ Cpe23Uri: "cpe:2.3:a:101_project:101:1.0.0:*:*:*:*:node.js:*:*",
+ },
+
+ cpematch.Name {
+ Cpe23Uri: "cpe:2.3:a:101_project:101:1.1.0:*:*:*:*:node.js:*:*",
+ },
+
+ cpematch.Name {
+ Cpe23Uri: "cpe:2.3:a:101_project:101:1.1.1:*:*:*:*:node.js:*:*",
+ },
+ },
+ },
+ },
+ }}
+
+ for _, test := range(passTests) {
+ t.Run(test.name, func(t *testing.T) {
+ // build test matches
+ matches := cpematch.Matches { Matches: test.matches }
+ // open db
+ db, err := Open(filepath.Join(dir, fmt.Sprintf("%s.db", test.name)))
+ if err != nil {
+ t.Error(err)
+ return
+ }
+
+ // init db
+ if err := db.Init(ctx); err != nil {
+ t.Error(err)
+ return
+ }
+
+ // seed db
+ if _, err = db.db.ExecContext(ctx, test.seed); err != nil {
+ t.Error(err)
+ return
+ }
+
+ // add matches
+ if err = db.AddCpeMatches(ctx, matches); err != nil {
+ t.Error(err)
+ }
+ })
+ }
+
+ failTests := []struct {
+ name string // test name (path inferred)
+ seed string // db seed query
+ matches []cpematch.Match // matches
+ } {{
+ name: "bad-match-cpe23",
+
+ seed: `
+ INSERT INTO cpes(cpe_uri, cpe23) VALUES (
+ 'cpe:/1',
+ 'cpe:2.3:a:101_project:101:1.0.0:*:*:*:*:node.js:*:*'
+ ), (
+ 'cpe:/2',
+ 'cpe:2.3:a:101_project:101:1.1.0:*:*:*:*:node.js:*:*'
+ ), (
+ 'cpe:/3',
+ 'cpe:2.3:a:101_project:101:1.1.1:*:*:*:*:node.js:*:*'
+ );
+ `,
+
+ matches: []cpematch.Match {
+ cpematch.Match {
+ Cpe23Uri: "cpe:",
+
+ VersionStartIncluding: "1.0.0",
+ VersionEndIncluding: "1.6.3",
+
+ Names: []cpematch.Name {
+ cpematch.Name {
+ Cpe23Uri: "cpe:2.3:a:101_project:101:1.0.0:*:*:*:*:node.js:*:*",
+ },
+
+ cpematch.Name {
+ Cpe23Uri: "cpe:2.3:a:101_project:101:1.1.0:*:*:*:*:node.js:*:*",
+ },
+
+ cpematch.Name {
+ Cpe23Uri: "cpe:2.3:a:101_project:101:1.1.1:*:*:*:*:node.js:*:*",
+ },
+ },
+ },
+ },
+ }, {
+ name: "bad-cpe",
+
+ seed: `
+ INSERT INTO cpes(cpe_uri, cpe23) VALUES (
+ 'cpe:/1',
+ 'cpe:2.3:a:101_project:101:1.0.0:*:*:*:*:node.js:*:*'
+ ), (
+ 'cpe:/2',
+ 'cpe:2.3:a:101_project:101:1.1.0:*:*:*:*:node.js:*:*'
+ ), (
+ 'cpe:/3',
+ 'cpe:2.3:a:101_project:101:1.1.1:*:*:*:*:node.js:*:*'
+ );
+ `,
+
+ matches: []cpematch.Match {
+ cpematch.Match {
+ Cpe23Uri: "cpe:2.3:a:101_project:101:*:*:*:*:*:node.js:*:*",
+
+ VersionStartIncluding: "1.0.0",
+ VersionEndIncluding: "1.6.3",
+
+ Names: []cpematch.Name {
+ cpematch.Name {
+ Cpe23Uri: "cpe:2.3",
+ },
+
+ cpematch.Name {
+ Cpe23Uri: "cpe:2.3:a:101_project:101:1.1.0:*:*:*:*:node.js:*:*",
+ },
+
+ cpematch.Name {
+ Cpe23Uri: "cpe:2.3:a:101_project:101:1.1.1:*:*:*:*:node.js:*:*",
+ },
+ },
+ },
+ },
+ }, {
+ name: "dup-versionstart",
+
+ seed: `
+ INSERT INTO cpes(cpe_uri, cpe23) VALUES (
+ 'cpe:/1',
+ 'cpe:2.3:a:101_project:101:1.0.0:*:*:*:*:node.js:*:*'
+ ), (
+ 'cpe:/2',
+ 'cpe:2.3:a:101_project:101:1.1.0:*:*:*:*:node.js:*:*'
+ ), (
+ 'cpe:/3',
+ 'cpe:2.3:a:101_project:101:1.1.1:*:*:*:*:node.js:*:*'
+ );
+ `,
+
+ matches: []cpematch.Match {
+ cpematch.Match {
+ Cpe23Uri: "cpe:2.3:a:101_project:101:*:*:*:*:*:node.js:*:*",
+
+ VersionStartIncluding: "1.0.0",
+ VersionStartExcluding: "1.1.0",
+
+ Names: []cpematch.Name {
+ cpematch.Name {
+ Cpe23Uri: "cpe:2.3:a:101_project:101:1.0.0:*:*:*:*:node.js:*:*",
+ },
+
+ cpematch.Name {
+ Cpe23Uri: "cpe:2.3:a:101_project:101:1.1.0:*:*:*:*:node.js:*:*",
+ },
+
+ cpematch.Name {
+ Cpe23Uri: "cpe:2.3:a:101_project:101:1.1.1:*:*:*:*:node.js:*:*",
+ },
+ },
+ },
+ },
+ }, {
+ name: "dup-versionend",
+
+ seed: `
+ INSERT INTO cpes(cpe_uri, cpe23) VALUES (
+ 'cpe:/1',
+ 'cpe:2.3:a:101_project:101:1.0.0:*:*:*:*:node.js:*:*'
+ ), (
+ 'cpe:/2',
+ 'cpe:2.3:a:101_project:101:1.1.0:*:*:*:*:node.js:*:*'
+ ), (
+ 'cpe:/3',
+ 'cpe:2.3:a:101_project:101:1.1.1:*:*:*:*:node.js:*:*'
+ );
+ `,
+
+ matches: []cpematch.Match {
+ cpematch.Match {
+ Cpe23Uri: "cpe:2.3:a:101_project:101:*:*:*:*:*:node.js:*:*",
+
+ VersionEndIncluding: "1.0.0",
+ VersionEndExcluding: "1.1.0",
+
+ Names: []cpematch.Name {
+ cpematch.Name {
+ Cpe23Uri: "cpe:2.3:a:101_project:101:1.0.0:*:*:*:*:node.js:*:*",
+ },
+
+ cpematch.Name {
+ Cpe23Uri: "cpe:2.3:a:101_project:101:1.1.0:*:*:*:*:node.js:*:*",
+ },
+
+ cpematch.Name {
+ Cpe23Uri: "cpe:2.3:a:101_project:101:1.1.1:*:*:*:*:node.js:*:*",
+ },
+ },
+ },
+ },
+ }}
+
+ for _, test := range(failTests) {
+ t.Run(test.name, func(t *testing.T) {
+ // build test matches
+ matches := cpematch.Matches { Matches: test.matches }
+ // open db
+ db, err := Open(filepath.Join(dir, fmt.Sprintf("%s.db", test.name)))
+ if err != nil {
+ t.Error(err)
+ return
+ }
+
+ // init db
+ if err := db.Init(ctx); err != nil {
+ t.Error(err)
+ return
+ }
+
+ // seed db
+ if _, err = db.db.ExecContext(ctx, test.seed); err != nil {
+ t.Error(err)
+ return
+ }
+
+ // add matches
+ if err = db.AddCpeMatches(ctx, matches); err == nil {
+ t.Error("got success, exp error")
+ }
+ })
+ }
+}
+
+func TestCpeMatchSearch(t *testing.T) {
+ // cache context, create temp dir
+ ctx := context.Background()
+ dir, err := os.MkdirTemp("", "")
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ defer os.RemoveAll(dir)
+
+ // db cpe seed query
+ seed := `
+ INSERT INTO cpes(cpe_uri, cpe23) VALUES (
+ 'cpe:/1',
+ 'cpe:2.3:a:101_project:101:1.0.0:*:*:*:*:node.js:*:*'
+ ), (
+ 'cpe:/2',
+ 'cpe:2.3:a:101_project:101:1.1.0:*:*:*:*:node.js:*:*'
+ ), (
+ 'cpe:/3',
+ 'cpe:2.3:a:101_project:101:1.1.1:*:*:*:*:node.js:*:*'
+ );
+ `
+
+ matches := cpematch.Matches {
+ Matches: []cpematch.Match {
+ cpematch.Match {
+ Cpe23Uri: "cpe:2.3:a:101_project:101:*:*:*:*:*:node.js:*:*",
+
+ VersionStartIncluding: "1.0.0",
+ VersionEndIncluding: "1.6.3",
+
+ Names: []cpematch.Name {
+ cpematch.Name {
+ Cpe23Uri: "cpe:2.3:a:101_project:101:1.0.0:*:*:*:*:node.js:*:*",
+ },
+
+ cpematch.Name {
+ Cpe23Uri: "cpe:2.3:a:101_project:101:1.1.0:*:*:*:*:node.js:*:*",
+ },
+
+ cpematch.Name {
+ Cpe23Uri: "cpe:2.3:a:101_project:101:1.1.1:*:*:*:*:node.js:*:*",
+ },
+ },
+ },
+ },
+ }
+
+ // build test matches
+ // open db
+ db, err := Open(filepath.Join(dir, "TestCpeMatchSearch.db"))
+ if err != nil {
+ t.Error(err)
+ return
+ }
+
+ // init db
+ if err := db.Init(ctx); err != nil {
+ t.Error(err)
+ return
+ }
+
+ // seed db
+ if _, err = db.db.ExecContext(ctx, seed); err != nil {
+ t.Error(err)
+ return
+ }
+
+ // add matches
+ if err = db.AddCpeMatches(ctx, matches); err != nil {
+ t.Error(err)
+ return
+ }
+
+ tests := []struct {
+ val string // search val
+ exp []string // expected results
+ } {{
+ val: "nothing",
+ exp: []string {},
+ }, {
+ val: "cpe:2.3:a:101_project:101:*:*:*:*:*:node.js:*:*",
+ exp: []string {
+ "cpe:2.3:a:101_project:101:1.0.0:*:*:*:*:node.js:*:*",
+ "cpe:2.3:a:101_project:101:1.1.0:*:*:*:*:node.js:*:*",
+ "cpe:2.3:a:101_project:101:1.1.1:*:*:*:*:node.js:*:*",
+ },
+ }}
+
+ for _, test := range(tests) {
+ t.Run(test.val, func(t *testing.T) {
+ if got, err := db.CpeMatchSearch(ctx, test.val); err != nil {
+ t.Error(err)
+ } else if ((len(got) > 0) || (len(test.exp) > 0)) &&
+ !reflect.DeepEqual(got, test.exp) {
+ t.Errorf("got %v, exp %v", got, test.exp)
+ }
+ })
+ }
+}
diff --git a/dbstore/sql/test/junk.sql b/dbstore/sql/test/junk.sql
new file mode 100644
index 0000000..535ef88
--- /dev/null
+++ b/dbstore/sql/test/junk.sql
@@ -0,0 +1,3 @@
+-- test junk query to trigger prepare error in tests
+alks dlk asdlfk alfk leake f fa
+
diff --git a/dbstore/sqlfs.go b/dbstore/sqlfs.go
new file mode 100644
index 0000000..99626a2
--- /dev/null
+++ b/dbstore/sqlfs.go
@@ -0,0 +1,38 @@
+package dbstore
+
+import (
+ "embed"
+ "fmt"
+)
+
+//go:embed sql
+var sqlFs embed.FS
+
+// Get single query from embedded sqlFs.
+func getQuery(id string) (string, error) {
+ // read query
+ if data, err := sqlFs.ReadFile(fmt.Sprintf("sql/%s.sql", id)); err != nil {
+ return "", err
+ } else {
+ // return query
+ return string(data), nil
+ }
+}
+
+// Return queries from embedded sqlFs and return map of query ID to SQL.
+func getQueries(ids []string) (map[string]string, error) {
+ r := make(map[string]string)
+
+ for _, id := range(ids) {
+ // read query
+ if sql, err := getQuery(id); err != nil {
+ return r, fmt.Errorf("%s: %s", id, err.Error())
+ } else {
+ // save query
+ r[id] = sql
+ }
+ }
+
+ // return success
+ return r, nil
+}
diff --git a/dbstore/tx.go b/dbstore/tx.go
new file mode 100644
index 0000000..460bbe1
--- /dev/null
+++ b/dbstore/tx.go
@@ -0,0 +1,105 @@
+// database storage
+package dbstore
+
+import (
+ "context"
+ db_sql "database/sql"
+ "fmt"
+ _ "github.com/mattn/go-sqlite3"
+)
+
+type Tx struct {
+ tx *db_sql.Tx // underlying transaction
+ sts map[string]*db_sql.Stmt // prepared statements
+ closed bool // are the statements closed?
+ done bool // is the transaction done?
+ err error // last error
+}
+
+func newTx(ctx context.Context, db *db_sql.DB, queryIds []string) (Tx, error) {
+ var r Tx
+
+ // begin context
+ if tx, err := db.BeginTx(ctx, nil); err != nil {
+ return r, err
+ } else {
+ r.tx = tx
+ }
+
+ // build query map
+ queries, err := getQueries(queryIds)
+ if err != nil {
+ return r, err
+ }
+
+ // build statements
+ sts := make(map[string]*db_sql.Stmt)
+ for id, sql := range(queries) {
+ if st, err := r.tx.PrepareContext(ctx, sql); err != nil {
+ return r, err
+ } else {
+ sts[id] = st
+ }
+ }
+ r.sts = sts
+
+ // return success
+ return r, nil
+}
+
+// Finalize statements.
+func (tx Tx) Close() {
+ if !tx.closed {
+ for id, st := range(tx.sts) {
+ // close statement
+ st.Close()
+
+ // delete key
+ delete(tx.sts, id)
+ }
+
+ // mark transaction as closed
+ tx.closed = true
+ }
+}
+
+// Finalize statements, commit transaction.
+func (tx Tx) Commit() error {
+ // close statements
+ // FIXME: this isn't really necessary, rollback and commit will take
+ // care of it according to the database/sql docs
+ tx.Close()
+
+ if !tx.done {
+ tx.err = tx.tx.Commit()
+ tx.done = true
+ }
+
+ // return last error
+ return tx.err
+}
+
+// Finalize statements, rollback transaction.
+func (tx Tx) Rollback() error {
+ // close statements
+ // FIXME: this isn't really necessary, rollback and commit will take
+ // care of it according to the database/sql docs
+ tx.Close()
+
+ if !tx.done {
+ tx.err = tx.tx.Rollback()
+ tx.done = true
+ }
+
+ // return last error
+ return tx.err
+}
+
+// execute given prepared statement
+func (tx Tx) Exec(ctx context.Context, id string, args... interface{}) (db_sql.Result, error) {
+ if st, ok := tx.sts[id]; !ok {
+ return nil, fmt.Errorf("unknown statement: %s", id)
+ } else {
+ return st.ExecContext(ctx, args...)
+ }
+}
diff --git a/dbstore/tx_test.go b/dbstore/tx_test.go
new file mode 100644
index 0000000..df15f9e
--- /dev/null
+++ b/dbstore/tx_test.go
@@ -0,0 +1,140 @@
+package dbstore
+
+import (
+ "context"
+ "os"
+ "path/filepath"
+ // "reflect"
+ "testing"
+ // "time"
+)
+
+func TestNewTxDupTx(t *testing.T) {
+ // cache context, create temp dir
+ ctx := context.Background()
+ dir, err := os.MkdirTemp("", "")
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ defer os.RemoveAll(dir)
+
+ // open db
+ db, err := Open(filepath.Join(dir, "newtxduptx.db"))
+ if err != nil {
+ t.Error(err)
+ return
+ }
+
+ // init db
+ if err := db.Init(ctx); err != nil {
+ t.Error(err)
+ return
+ }
+
+ // create first transaction
+ if _, err = db.Tx(ctx, []string{}); err != nil {
+ t.Error(err)
+ return
+ }
+
+ // create second transaction
+ got, err := db.Tx(ctx, []string{})
+ if err != nil { // FIXME
+ t.Errorf("got %v, exp error", got)
+ return
+ }
+}
+
+func TestNewTxBadQueries(t *testing.T) {
+ badQueryIds := []string { "query/does/not/exist" }
+
+ // cache context, create temp dir
+ ctx := context.Background()
+ dir, err := os.MkdirTemp("", "")
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ defer os.RemoveAll(dir)
+
+ // open db
+ db, err := Open(filepath.Join(dir, "newtx-bad-queries.db"))
+ if err != nil {
+ t.Error(err)
+ return
+ }
+
+ // init db
+ if err := db.Init(ctx); err != nil {
+ t.Error(err)
+ return
+ }
+
+ // create first transaction
+ if got, err := db.Tx(ctx, badQueryIds); err == nil {
+ t.Errorf("got %v, exp error", got)
+ return
+ }
+}
+
+func TestNewTxPrepareFail(t *testing.T) {
+ // cache context, create temp dir
+ ctx := context.Background()
+ dir, err := os.MkdirTemp("", "")
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ defer os.RemoveAll(dir)
+
+ // open db
+ db, err := Open(filepath.Join(dir, "newtx-prepare-fail.db"))
+ if err != nil {
+ t.Error(err)
+ return
+ }
+
+ // create transaction
+ if got, err := db.Tx(ctx, []string { "test/junk" }); err == nil {
+ t.Errorf("got %v, exp error", got)
+ }
+}
+
+func TestTxExecFail(t *testing.T) {
+ // cache context, create temp dir
+ ctx := context.Background()
+ dir, err := os.MkdirTemp("", "")
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ defer os.RemoveAll(dir)
+
+ // open db
+ db, err := Open(filepath.Join(dir, "txexecfail.db"))
+ if err != nil {
+ t.Error(err)
+ return
+ }
+
+ // init db
+ if err := db.Init(ctx); err != nil {
+ t.Error(err)
+ return
+ }
+
+ // create transaction
+ tx, err := db.Tx(ctx, []string { "init" })
+ if err != nil {
+ t.Error(err)
+ return
+ }
+
+ // close transaction (this will clear statements)
+ tx.Close()
+
+ if got, err := tx.Exec(ctx, "init"); err == nil {
+ t.Errorf("got %v, exp error", got)
+ }
+}