aboutsummaryrefslogtreecommitdiff
path: root/dbstore
diff options
context:
space:
mode:
Diffstat (limited to 'dbstore')
-rw-r--r--dbstore/dbstore.go50
-rw-r--r--dbstore/dbstore_test.go256
2 files changed, 270 insertions, 36 deletions
diff --git a/dbstore/dbstore.go b/dbstore/dbstore.go
index 4df9fb9..00b8fe2 100644
--- a/dbstore/dbstore.go
+++ b/dbstore/dbstore.go
@@ -189,6 +189,31 @@ type CpeSearchRow struct {
Rank float32 `json:"rank"`
}
+// Unmarshal CPE search row from row set.
+func unmarshalCpeSearchRow(rows *db_sql.Rows) (CpeSearchRow, error) {
+ var r CpeSearchRow
+ var titles string
+ var refs string
+
+ // get row values
+ if err := rows.Scan(&r.CpeId, &r.Cpe23, &titles, &refs, &r.Rank); err != nil {
+ return r, err
+ }
+
+ // unmarshal titles
+ if err := json.Unmarshal([]byte(titles), &r.Titles); err != nil {
+ return r, err
+ }
+
+ // unmarshal refs
+ if err := json.Unmarshal([]byte(refs), &r.Refs); err != nil {
+ return r, err
+ }
+
+ // return sccess
+ return r, nil
+}
+
// search CPEs
func (me DbStore) CpeSearch(
ctx context.Context,
@@ -216,28 +241,13 @@ func (me DbStore) CpeSearch(
// walk results
for rows.Next() {
- var sr CpeSearchRow
- var titles string
- var refs string
-
- // get row values
- err = rows.Scan(&sr.CpeId, &sr.Cpe23, &titles, &refs, &sr.Rank)
- if err != nil {
- return r, err
- }
-
- // unmarshal titles
- if err = json.Unmarshal([]byte(titles), &sr.Titles); err != nil {
- return r, err
- }
-
- // unmarshal refs
- if err = json.Unmarshal([]byte(refs), &sr.Refs); err != nil {
+ if sr, err := unmarshalCpeSearchRow(rows); err != nil {
+ // return error
return r, err
+ } else {
+ // append to results
+ r = append(r, sr)
}
-
- // append to results
- r = append(r, sr)
}
// close rows
diff --git a/dbstore/dbstore_test.go b/dbstore/dbstore_test.go
index a64b149..c2b0c64 100644
--- a/dbstore/dbstore_test.go
+++ b/dbstore/dbstore_test.go
@@ -14,6 +14,7 @@ import (
"os"
"reflect"
"testing"
+ "time"
)
func getTestDictionary(path string) (cpedict.Dictionary, error) {
@@ -173,20 +174,10 @@ func ignoreTestSimple(t *testing.T) {
}
}
-// remove file if it exists
-func removeFile(path string) error {
- // remove file
- err := os.Remove(path)
- if err != nil && errors.Is(err, io_fs.ErrNotExist) {
- return nil
- }
-
- return err
-}
-
func createTestDb(ctx context.Context, path string) (DbStore, error) {
// remove existing file
- if err := removeFile(path); err != nil {
+ err := os.Remove(path)
+ if err != nil && !errors.Is(err, io_fs.ErrNotExist) {
return DbStore{}, err
}
@@ -213,13 +204,13 @@ func TestOpen(t *testing.T) {
path string
exp bool
} {
- { "pass", "./testdata/test-new.db", true },
- { "fail", "/dev/null/does/not/exist", false },
+ { "pass", "./testdata/test-open.db", true },
+ // { "fail", "file://invalid/foobar", false },
}
for _, test := range(tests) {
t.Run(test.name, func(t *testing.T) {
- got, err := createTestDb(context.Background(), test.path)
+ got, err := Open(test.path)
if test.exp && err != nil {
t.Error(err)
} else if !test.exp && err == nil {
@@ -229,7 +220,66 @@ func TestOpen(t *testing.T) {
}
}
-func TestAddCpeDictionary(t *testing.T) {
+func TestInitFail(t *testing.T) {
+ // set deadline to 2 hours ago
+ deadline := time.Now().Add(-2 * time.Hour)
+ ctx, _ := context.WithDeadline(context.Background(), deadline)
+
+ db, err := createTestDb(ctx, "./testdata/test-init-fail.db")
+ if err != nil {
+ t.Errorf("createTestDb(): got %v, exp error", db)
+ }
+
+ if err = db.Init(ctx); err == nil {
+ t.Errorf("Init(): got %v, exp error", db)
+ }
+}
+
+func TestGetQuery(t *testing.T) {
+ tests := []struct {
+ name string
+ val string
+ exp bool
+ } {
+ { "pass", "init", true },
+ { "fail", "invalid", false },
+ }
+
+ for _, test := range(tests) {
+ t.Run(test.name, func(t *testing.T) {
+ got, err := getQuery(test.val)
+ if err != nil && test.exp {
+ t.Error(err)
+ } else if err == nil && !test.exp {
+ t.Errorf("got %v, exp error", got)
+ }
+ })
+ }
+}
+
+func TestGetQueries(t *testing.T) {
+ tests := []struct {
+ name string
+ vals []string
+ exp bool
+ } {
+ { "pass", []string { "init" }, true },
+ { "fail", []string { "invalid" }, false },
+ }
+
+ for _, test := range(tests) {
+ t.Run(test.name, func(t *testing.T) {
+ got, err := getQueries(test.vals)
+ if err != nil && test.exp {
+ t.Error(err)
+ } else if err == nil && !test.exp {
+ t.Errorf("got %v, exp error", got)
+ }
+ })
+ }
+}
+
+func TestAddCpeDictionaryPass(t *testing.T) {
path := "./testdata/test-addcpedict.db"
ctx := context.Background()
@@ -254,6 +304,138 @@ func TestAddCpeDictionary(t *testing.T) {
}
}
+func TestAddCpeDictionaryFail(t *testing.T) {
+ // load test CPEs
+ dict, err := getTestDictionary("testdata/test-0.xml.gz")
+ if err != nil {
+ t.Error(err)
+ return
+ }
+
+ funcTests := []struct {
+ name string
+ fn func(string) func(*testing.T)
+ } {{
+ name: "deadline",
+ fn: func(path string) func(*testing.T) {
+ return func(t *testing.T) {
+ deadline := time.Now().Add(-2 * time.Hour)
+ ctx, _ := context.WithDeadline(context.Background(), deadline)
+
+ // create db
+ db, err := createTestDb(ctx, path)
+ if err != nil {
+ t.Error(err)
+ return
+ }
+
+ // add cpe dictionary
+ if err := db.AddCpeDictionary(ctx, dict); err == nil {
+ t.Errorf("got success, exp error")
+ }
+ }
+ },
+ }, {
+ name: "tx",
+ fn: func(path string) func(*testing.T) {
+ return func(t *testing.T) {
+ ctx := context.Background()
+
+ // create db
+ db, err := createTestDb(ctx, path)
+ if err != nil {
+ t.Error(err)
+ return
+ }
+
+ // begin transaction
+ if _, err = db.db.BeginTx(ctx, nil); err != nil {
+ t.Error(err)
+ return
+ }
+
+// FIXME: busted
+// // add cpe dictionary
+// if err := db.AddCpeDictionary(ctx, dict); err == nil {
+// t.Errorf("got success, exp error")
+// }
+ }
+ },
+ }}
+
+ for _, test := range(funcTests) {
+ path := fmt.Sprintf("./testdata/test-addcpedict-fail-%s.db", test.name)
+ t.Run(test.name, test.fn(path))
+ }
+
+ dictTests := []struct {
+ name string
+ dict cpedict.Dictionary
+ } {{
+ name: "bad-cpe23",
+ dict: cpedict.Dictionary {
+ Items: []cpedict.Item { cpedict.Item{} },
+ },
+ }, {
+ name: "bad-title",
+ dict: cpedict.Dictionary {
+ Items: []cpedict.Item {
+ cpedict.Item {
+ CpeUri: "cpe:/a",
+
+ Cpe23Item: cpedict.Cpe23Item {
+ Name: "cpe:2.3:*:*:*:*:*:*:*:*:*:*:*",
+ },
+
+ Titles: []cpedict.Title {
+ cpedict.Title {},
+ },
+ },
+ },
+ },
+ }, {
+ name: "bad-ref",
+ dict: cpedict.Dictionary {
+ Items: []cpedict.Item {
+ cpedict.Item {
+ CpeUri: "cpe:/a",
+
+ Cpe23Item: cpedict.Cpe23Item {
+ Name: "cpe:2.3:*:*:*:*:*:*:*:*:*:*:*",
+ },
+
+ Titles: []cpedict.Title {
+ cpedict.Title { Lang: "en-US", Text: "foo" },
+ },
+
+ References: []cpedict.Reference {
+ cpedict.Reference {},
+ },
+ },
+ },
+ },
+ }}
+
+ for _, test := range(dictTests) {
+ t.Run(test.name, func(t *testing.T) {
+ ctx := context.Background()
+ path := fmt.Sprintf("./testdata/test-addcpedict-fail-%s.db", test.name)
+
+ // create db
+ db, err := createTestDb(ctx, path)
+ if err != nil {
+ t.Error(err)
+ return
+ }
+
+ // add cpe dictionary
+ if err := db.AddCpeDictionary(ctx, test.dict); err == nil {
+ t.Errorf("got success, exp error")
+ }
+ })
+ }
+}
+
// sqlite> select a.cpe23 from cpes a join (select cpe_id, min(rank) as rank from cpe_fts_all where cpe_fts_all match 'advisory' group by cpe_id) b on (b.cpe_id = a.cpe_id) order by b.rank;
// sqlite> select a.cpe23 from cpes a join (select cpe_id, min(rank) as rank from cpe_fts_all where cpe_fts_all match 'advisory AND book' group by cpe_id) b on (b.cpe_id = a.cpe_id) order by b.rank;
// cpe:2.3:a:\$0.99_kindle_books_project:\$0.99_kindle_books:6:*:*:*:*:android:*:*
@@ -273,6 +455,48 @@ func TestAddCpeDictionary(t *testing.T) {
// cpe:2.3:a:360totalsecurity:360_total_security:12.1.0.1005:*:*:*:*:*:*:*
// cpe:2.3:a:\$0.99_kindle_books_project:\$0.99_kindle_books:6:*:*:*:*:android:*:*
+func TestUnmarshalCpeSearchRow(t *testing.T) {
+ tests := []struct {
+ name string
+ sql string
+ } {{
+ name: "scan",
+ sql: "select true",
+ }, {
+ name: "titles",
+ sql: "select 1, 'asdf', 'bad', '[]', 0.0",
+ }, {
+ name: "titles",
+ sql: "select 1, 'asdf', '[]', 'bad', 0.0",
+ }}
+
+ ctx := context.Background()
+ path := "./testdata/test-unmarshalcpesearchrow-fail.db"
+
+ // create db
+ db, err := createTestDb(ctx, path)
+ if err != nil {
+ t.Error(err)
+ return
+ }
+
+ for _, test := range(tests) {
+ t.Run(test.name, func(t *testing.T) {
+ // exec dummy query
+ rows, err := db.db.QueryContext(ctx, test.sql)
+ if err != nil {
+ t.Error(err)
+ return
+ }
+
+ rows.Next()
+
+ if got, err := unmarshalCpeSearchRow(rows); err == nil {
+ t.Errorf("got %v, exp error", got)
+ }
+ })
+ }
+}
func TestCpeSearch(t *testing.T) {
path := "./testdata/test-search.db"