package dbstore import ( "compress/gzip" "context" db_sql "database/sql" "encoding/xml" "embed" "errors" "fmt" _ "github.com/mattn/go-sqlite3" "github.com/pablotron/cvez/cpedict" io_fs "io/fs" "os" "reflect" "testing" "time" ) func getTestDictionary(path string) (cpedict.Dictionary, error) { var dict cpedict.Dictionary // open test data f, err := os.Open(path) if err != nil { return dict, err } defer f.Close() // create zip reader gz, err := gzip.NewReader(f) if err != nil { return dict, err } defer gz.Close() // create xml decoder d := xml.NewDecoder(gz) // decode xml if err = d.Decode(&dict); err != nil { return dict, err } // return success return dict, nil } //go:embed testdata/sql/*.sql var testSqlFs embed.FS var testSqlIds = map[string]bool { "init": false, "insert-cpe": true, "insert-title": true, "insert-ref": true, } func getTestQueries() (map[string]string, error) { r := make(map[string]string) for id, _ := range(testSqlIds) { path := fmt.Sprintf("testdata/sql/%s.sql", id) if data, err := testSqlFs.ReadFile(path); err != nil { return r, err } else { r[id] = string(data) } } return r, nil } func ignoreTestSimple(t *testing.T) { testDbPath := "./testdata/foo.db" // get queries queries, err := getTestQueries() if err != nil { t.Error(err) return } // load test CPEs dict, err := getTestDictionary("testdata/test-0.xml.gz") if err != nil { t.Error(err) return } // does test db exist? if _, err = os.Stat(testDbPath); err != nil { if !errors.Is(err, io_fs.ErrNotExist) { t.Error(err) return } } else if err == nil { // remove test db if err = os.Remove(testDbPath); err != nil { t.Error(err) return } } // init db db, err := db_sql.Open("sqlite3", testDbPath) if err != nil { t.Error(err) return } defer db.Close() // init tables if _, err := db.Exec(queries["init"]); err != nil { t.Error(err) return } tx, err := db.Begin() if err != nil { t.Error(err) return } // build statements sts := make(map[string]*db_sql.Stmt) for id, use := range(testSqlIds) { if use { if st, err := tx.Prepare(queries[id]); err != nil { t.Error(err) return } else { sts[id] = st defer sts[id].Close() } } } // add items for _, item := range(dict.Items) { // add cpe rs, err := sts["insert-cpe"].Exec(item.CpeUri, item.Cpe23Item.Name); if err != nil { t.Error(err) return } // get last row ID id, err := rs.LastInsertId() if err != nil { t.Error(err) return } // add titles for _, title := range(item.Titles) { if _, err := sts["insert-title"].Exec(id, title.Lang, title.Text); err != nil { t.Error(err) return } } // add refs for _, ref := range(item.References) { if _, err := sts["insert-ref"].Exec(id, ref.Href, ref.Text); err != nil { t.Error(err) return } } } // commit changes if err = tx.Commit(); err != nil { t.Error(err) return } } func createTestDb(ctx context.Context, path string) (DbStore, error) { // remove existing file err := os.Remove(path) if err != nil && !errors.Is(err, io_fs.ErrNotExist) { return DbStore{}, err } // open db return Open(path) } func seedTestDb(ctx context.Context, db DbStore) error { // load test CPEs dict, err := getTestDictionary("testdata/test-0.xml.gz") if err != nil { return err } // add cpe dictionary return db.AddCpeDictionary(ctx, dict) // TODO: seed with other data } func TestOpen(t *testing.T) { tests := []struct { name string path string exp bool } { { "pass", "./testdata/test-open.db", true }, // { "fail", "file://invalid/foobar", false }, } for _, test := range(tests) { t.Run(test.name, func(t *testing.T) { got, err := Open(test.path) if test.exp && err != nil { t.Error(err) } else if !test.exp && err == nil { t.Errorf("got %v, exp error", got) } }) } } func TestInitFail(t *testing.T) { // set deadline to 2 hours ago deadline := time.Now().Add(-2 * time.Hour) ctx, _ := context.WithDeadline(context.Background(), deadline) db, err := createTestDb(ctx, "./testdata/test-init-fail.db") if err != nil { t.Errorf("createTestDb(): got %v, exp error", db) } if err = db.Init(ctx); err == nil { t.Errorf("Init(): got %v, exp error", db) } } func TestGetQuery(t *testing.T) { tests := []struct { name string val string exp bool } { { "pass", "init", true }, { "fail", "invalid", false }, } for _, test := range(tests) { t.Run(test.name, func(t *testing.T) { got, err := getQuery(test.val) if err != nil && test.exp { t.Error(err) } else if err == nil && !test.exp { t.Errorf("got %v, exp error", got) } }) } } func TestGetQueries(t *testing.T) { tests := []struct { name string vals []string exp bool } { { "pass", []string { "init" }, true }, { "fail", []string { "invalid" }, false }, } for _, test := range(tests) { t.Run(test.name, func(t *testing.T) { got, err := getQueries(test.vals) if err != nil && test.exp { t.Error(err) } else if err == nil && !test.exp { t.Errorf("got %v, exp error", got) } }) } } func TestAddCpeDictionaryPass(t *testing.T) { path := "./testdata/test-addcpedict.db" ctx := context.Background() // create db db, err := createTestDb(ctx, path) if err != nil { t.Error(err) return } // load test CPEs dict, err := getTestDictionary("testdata/test-0.xml.gz") if err != nil { t.Error(err) return } // add cpe dictionary if err := db.AddCpeDictionary(ctx, dict); err != nil { t.Error(err) return } } func TestAddCpeDictionaryFail(t *testing.T) { // load test CPEs dict, err := getTestDictionary("testdata/test-0.xml.gz") if err != nil { t.Error(err) return } funcTests := []struct { name string fn func(string) func(*testing.T) } {{ name: "deadline", fn: func(path string) func(*testing.T) { return func(t *testing.T) { deadline := time.Now().Add(-2 * time.Hour) ctx, _ := context.WithDeadline(context.Background(), deadline) // create db db, err := createTestDb(ctx, path) if err != nil { t.Error(err) return } // add cpe dictionary if err := db.AddCpeDictionary(ctx, dict); err == nil { t.Errorf("got success, exp error") } } }, }, { name: "tx", fn: func(path string) func(*testing.T) { return func(t *testing.T) { ctx := context.Background() // create db db, err := createTestDb(ctx, path) if err != nil { t.Error(err) return } // begin transaction if _, err = db.db.BeginTx(ctx, nil); err != nil { t.Error(err) return } // FIXME: busted // // add cpe dictionary // if err := db.AddCpeDictionary(ctx, dict); err == nil { // t.Errorf("got success, exp error") // } } }, }} for _, test := range(funcTests) { path := fmt.Sprintf("./testdata/test-addcpedict-fail-%s.db", test.name) t.Run(test.name, test.fn(path)) } dictTests := []struct { name string dict cpedict.Dictionary } {{ name: "bad-cpe23", dict: cpedict.Dictionary { Items: []cpedict.Item { cpedict.Item{} }, }, }, { name: "bad-title", dict: cpedict.Dictionary { Items: []cpedict.Item { cpedict.Item { CpeUri: "cpe:/a", Cpe23Item: cpedict.Cpe23Item { Name: "cpe:2.3:*:*:*:*:*:*:*:*:*:*:*", }, Titles: []cpedict.Title { cpedict.Title {}, }, }, }, }, }, { name: "bad-ref", dict: cpedict.Dictionary { Items: []cpedict.Item { cpedict.Item { CpeUri: "cpe:/a", Cpe23Item: cpedict.Cpe23Item { Name: "cpe:2.3:*:*:*:*:*:*:*:*:*:*:*", }, Titles: []cpedict.Title { cpedict.Title { Lang: "en-US", Text: "foo" }, }, References: []cpedict.Reference { cpedict.Reference {}, }, }, }, }, }} for _, test := range(dictTests) { t.Run(test.name, func(t *testing.T) { ctx := context.Background() path := fmt.Sprintf("./testdata/test-addcpedict-fail-%s.db", test.name) // create db db, err := createTestDb(ctx, path) if err != nil { t.Error(err) return } // add cpe dictionary if err := db.AddCpeDictionary(ctx, test.dict); err == nil { t.Errorf("got success, exp error") } }) } } func TestCpeSearch(t *testing.T) { path := "./testdata/test-search.db" ctx := context.Background() // tests that are expected to pass passTests := []struct { t CpeSearchType // search type q string // query string exp []string // expected search results (cpe23s) } {{ t: CpeSearchAll, q: "advisory AND book", exp: []string { "cpe:2.3:a:\\$0.99_kindle_books_project:\\$0.99_kindle_books:6:*:*:*:*:android:*:*", }, }, { t: CpeSearchTitle, q: "project", exp: []string { "cpe:2.3:a:\\@thi.ng\\/egf_project:\\@thi.ng\\/egf:-:*:*:*:*:node.js:*:*", "cpe:2.3:a:\\@thi.ng\\/egf_project:\\@thi.ng\\/egf:0.1.0:*:*:*:*:node.js:*:*", "cpe:2.3:a:\\@thi.ng\\/egf_project:\\@thi.ng\\/egf:0.2.0:*:*:*:*:node.js:*:*", "cpe:2.3:a:\\@thi.ng\\/egf_project:\\@thi.ng\\/egf:0.2.1:*:*:*:*:node.js:*:*", "cpe:2.3:a:\\$0.99_kindle_books_project:\\$0.99_kindle_books:6:*:*:*:*:android:*:*", }, }, { t: CpeSearchRef, q: "advisory", exp: []string { "cpe:2.3:a:\\@thi.ng\\/egf_project:\\@thi.ng\\/egf:-:*:*:*:*:node.js:*:*", "cpe:2.3:a:\\@thi.ng\\/egf_project:\\@thi.ng\\/egf:0.1.0:*:*:*:*:node.js:*:*", "cpe:2.3:a:\\@thi.ng\\/egf_project:\\@thi.ng\\/egf:0.2.0:*:*:*:*:node.js:*:*", "cpe:2.3:a:\\@thi.ng\\/egf_project:\\@thi.ng\\/egf:0.2.1:*:*:*:*:node.js:*:*", "cpe:2.3:a:360totalsecurity:360_total_security:12.1.0.1005:*:*:*:*:*:*:*", "cpe:2.3:a:\\$0.99_kindle_books_project:\\$0.99_kindle_books:6:*:*:*:*:android:*:*", }, }} // create db db, err := createTestDb(ctx, path) if err != nil { t.Error(err) return } // seed test database if err = seedTestDb(ctx, db); err != nil { t.Error(err) return } for _, test := range(passTests) { t.Run(test.t.String(), func(t *testing.T) { rows, err := db.CpeSearch(ctx, test.t, test.q) if err != nil { t.Error(err) return } // build ids got := make([]string, len(rows)) for i, row := range(rows) { got[i] = row.Cpe23 } if !reflect.DeepEqual(got, test.exp) { t.Errorf("got \"%v\", exp \"%v\"", got, test.exp) return } }) } // tests that are expected to fail failTests := []struct { name string // test name t CpeSearchType // search type q string // query string } { { "bad-search-type", CpeSearchType(255), "" }, } for _, test := range(failTests) { t.Run(test.name, func(t *testing.T) { if got, err := db.CpeSearch(ctx, test.t, test.q); err == nil { t.Errorf("got \"%v\", exp error", got) } }) } }