From a6d72e2f75064a3ab823449c912188d210a77feb Mon Sep 17 00:00:00 2001 From: Paul Duncan Date: Sat, 5 Feb 2022 04:51:09 -0500 Subject: dbstore: add unmarshalcpesearchrow and tests --- dbstore/dbstore.go | 50 ++++++---- dbstore/dbstore_test.go | 256 +++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 270 insertions(+), 36 deletions(-) diff --git a/dbstore/dbstore.go b/dbstore/dbstore.go index 4df9fb9..00b8fe2 100644 --- a/dbstore/dbstore.go +++ b/dbstore/dbstore.go @@ -189,6 +189,31 @@ type CpeSearchRow struct { Rank float32 `json:"rank"` } +// Unmarshal CPE search row from row set. +func unmarshalCpeSearchRow(rows *db_sql.Rows) (CpeSearchRow, error) { + var r CpeSearchRow + var titles string + var refs string + + // get row values + if err := rows.Scan(&r.CpeId, &r.Cpe23, &titles, &refs, &r.Rank); err != nil { + return r, err + } + + // unmarshal titles + if err := json.Unmarshal([]byte(titles), &r.Titles); err != nil { + return r, err + } + + // unmarshal refs + if err := json.Unmarshal([]byte(refs), &r.Refs); err != nil { + return r, err + } + + // return sccess + return r, nil +} + // search CPEs func (me DbStore) CpeSearch( ctx context.Context, @@ -216,28 +241,13 @@ func (me DbStore) CpeSearch( // walk results for rows.Next() { - var sr CpeSearchRow - var titles string - var refs string - - // get row values - err = rows.Scan(&sr.CpeId, &sr.Cpe23, &titles, &refs, &sr.Rank) - if err != nil { - return r, err - } - - // unmarshal titles - if err = json.Unmarshal([]byte(titles), &sr.Titles); err != nil { - return r, err - } - - // unmarshal refs - if err = json.Unmarshal([]byte(refs), &sr.Refs); err != nil { + if sr, err := unmarshalCpeSearchRow(rows); err != nil { + // return error return r, err + } else { + // append to results + r = append(r, sr) } - - // append to results - r = append(r, sr) } // close rows diff --git a/dbstore/dbstore_test.go b/dbstore/dbstore_test.go index a64b149..c2b0c64 100644 --- a/dbstore/dbstore_test.go +++ b/dbstore/dbstore_test.go @@ -14,6 +14,7 @@ import ( "os" "reflect" "testing" + "time" ) func getTestDictionary(path string) (cpedict.Dictionary, error) { @@ -173,20 +174,10 @@ func ignoreTestSimple(t *testing.T) { } } -// remove file if it exists -func removeFile(path string) error { - // remove file - err := os.Remove(path) - if err != nil && errors.Is(err, io_fs.ErrNotExist) { - return nil - } - - return err -} - func createTestDb(ctx context.Context, path string) (DbStore, error) { // remove existing file - if err := removeFile(path); err != nil { + err := os.Remove(path) + if err != nil && !errors.Is(err, io_fs.ErrNotExist) { return DbStore{}, err } @@ -213,13 +204,13 @@ func TestOpen(t *testing.T) { path string exp bool } { - { "pass", "./testdata/test-new.db", true }, - { "fail", "/dev/null/does/not/exist", false }, + { "pass", "./testdata/test-open.db", true }, + // { "fail", "file://invalid/foobar", false }, } for _, test := range(tests) { t.Run(test.name, func(t *testing.T) { - got, err := createTestDb(context.Background(), test.path) + got, err := Open(test.path) if test.exp && err != nil { t.Error(err) } else if !test.exp && err == nil { @@ -229,7 +220,66 @@ func TestOpen(t *testing.T) { } } -func TestAddCpeDictionary(t *testing.T) { +func TestInitFail(t *testing.T) { + // set deadline to 2 hours ago + deadline := time.Now().Add(-2 * time.Hour) + ctx, _ := context.WithDeadline(context.Background(), deadline) + + db, err := createTestDb(ctx, "./testdata/test-init-fail.db") + if err != nil { + t.Errorf("createTestDb(): got %v, exp error", db) + } + + if err = db.Init(ctx); err == nil { + t.Errorf("Init(): got %v, exp error", db) + } +} + +func TestGetQuery(t *testing.T) { + tests := []struct { + name string + val string + exp bool + } { + { "pass", "init", true }, + { "fail", "invalid", false }, + } + + for _, test := range(tests) { + t.Run(test.name, func(t *testing.T) { + got, err := getQuery(test.val) + if err != nil && test.exp { + t.Error(err) + } else if err == nil && !test.exp { + t.Errorf("got %v, exp error", got) + } + }) + } +} + +func TestGetQueries(t *testing.T) { + tests := []struct { + name string + vals []string + exp bool + } { + { "pass", []string { "init" }, true }, + { "fail", []string { "invalid" }, false }, + } + + for _, test := range(tests) { + t.Run(test.name, func(t *testing.T) { + got, err := getQueries(test.vals) + if err != nil && test.exp { + t.Error(err) + } else if err == nil && !test.exp { + t.Errorf("got %v, exp error", got) + } + }) + } +} + +func TestAddCpeDictionaryPass(t *testing.T) { path := "./testdata/test-addcpedict.db" ctx := context.Background() @@ -254,6 +304,138 @@ func TestAddCpeDictionary(t *testing.T) { } } +func TestAddCpeDictionaryFail(t *testing.T) { + // load test CPEs + dict, err := getTestDictionary("testdata/test-0.xml.gz") + if err != nil { + t.Error(err) + return + } + + funcTests := []struct { + name string + fn func(string) func(*testing.T) + } {{ + name: "deadline", + fn: func(path string) func(*testing.T) { + return func(t *testing.T) { + deadline := time.Now().Add(-2 * time.Hour) + ctx, _ := context.WithDeadline(context.Background(), deadline) + + // create db + db, err := createTestDb(ctx, path) + if err != nil { + t.Error(err) + return + } + + // add cpe dictionary + if err := db.AddCpeDictionary(ctx, dict); err == nil { + t.Errorf("got success, exp error") + } + } + }, + }, { + name: "tx", + fn: func(path string) func(*testing.T) { + return func(t *testing.T) { + ctx := context.Background() + + // create db + db, err := createTestDb(ctx, path) + if err != nil { + t.Error(err) + return + } + + // begin transaction + if _, err = db.db.BeginTx(ctx, nil); err != nil { + t.Error(err) + return + } + +// FIXME: busted +// // add cpe dictionary +// if err := db.AddCpeDictionary(ctx, dict); err == nil { +// t.Errorf("got success, exp error") +// } + } + }, + }} + + for _, test := range(funcTests) { + path := fmt.Sprintf("./testdata/test-addcpedict-fail-%s.db", test.name) + t.Run(test.name, test.fn(path)) + } + + dictTests := []struct { + name string + dict cpedict.Dictionary + } {{ + name: "bad-cpe23", + dict: cpedict.Dictionary { + Items: []cpedict.Item { cpedict.Item{} }, + }, + }, { + name: "bad-title", + dict: cpedict.Dictionary { + Items: []cpedict.Item { + cpedict.Item { + CpeUri: "cpe:/a", + + Cpe23Item: cpedict.Cpe23Item { + Name: "cpe:2.3:*:*:*:*:*:*:*:*:*:*:*", + }, + + Titles: []cpedict.Title { + cpedict.Title {}, + }, + }, + }, + }, + }, { + name: "bad-ref", + dict: cpedict.Dictionary { + Items: []cpedict.Item { + cpedict.Item { + CpeUri: "cpe:/a", + + Cpe23Item: cpedict.Cpe23Item { + Name: "cpe:2.3:*:*:*:*:*:*:*:*:*:*:*", + }, + + Titles: []cpedict.Title { + cpedict.Title { Lang: "en-US", Text: "foo" }, + }, + + References: []cpedict.Reference { + cpedict.Reference {}, + }, + }, + }, + }, + }} + + for _, test := range(dictTests) { + t.Run(test.name, func(t *testing.T) { + ctx := context.Background() + path := fmt.Sprintf("./testdata/test-addcpedict-fail-%s.db", test.name) + + // create db + db, err := createTestDb(ctx, path) + if err != nil { + t.Error(err) + return + } + + // add cpe dictionary + if err := db.AddCpeDictionary(ctx, test.dict); err == nil { + t.Errorf("got success, exp error") + } + }) + } +} + // sqlite> select a.cpe23 from cpes a join (select cpe_id, min(rank) as rank from cpe_fts_all where cpe_fts_all match 'advisory' group by cpe_id) b on (b.cpe_id = a.cpe_id) order by b.rank; // sqlite> select a.cpe23 from cpes a join (select cpe_id, min(rank) as rank from cpe_fts_all where cpe_fts_all match 'advisory AND book' group by cpe_id) b on (b.cpe_id = a.cpe_id) order by b.rank; // cpe:2.3:a:\$0.99_kindle_books_project:\$0.99_kindle_books:6:*:*:*:*:android:*:* @@ -273,6 +455,48 @@ func TestAddCpeDictionary(t *testing.T) { // cpe:2.3:a:360totalsecurity:360_total_security:12.1.0.1005:*:*:*:*:*:*:* // cpe:2.3:a:\$0.99_kindle_books_project:\$0.99_kindle_books:6:*:*:*:*:android:*:* +func TestUnmarshalCpeSearchRow(t *testing.T) { + tests := []struct { + name string + sql string + } {{ + name: "scan", + sql: "select true", + }, { + name: "titles", + sql: "select 1, 'asdf', 'bad', '[]', 0.0", + }, { + name: "titles", + sql: "select 1, 'asdf', '[]', 'bad', 0.0", + }} + + ctx := context.Background() + path := "./testdata/test-unmarshalcpesearchrow-fail.db" + + // create db + db, err := createTestDb(ctx, path) + if err != nil { + t.Error(err) + return + } + + for _, test := range(tests) { + t.Run(test.name, func(t *testing.T) { + // exec dummy query + rows, err := db.db.QueryContext(ctx, test.sql) + if err != nil { + t.Error(err) + return + } + + rows.Next() + + if got, err := unmarshalCpeSearchRow(rows); err == nil { + t.Errorf("got %v, exp error", got) + } + }) + } +} func TestCpeSearch(t *testing.T) { path := "./testdata/test-search.db" -- cgit v1.2.3