aboutsummaryrefslogtreecommitdiff
path: root/internal/dbstore/dbstore_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/dbstore/dbstore_test.go')
-rw-r--r--internal/dbstore/dbstore_test.go172
1 files changed, 172 insertions, 0 deletions
diff --git a/internal/dbstore/dbstore_test.go b/internal/dbstore/dbstore_test.go
new file mode 100644
index 0000000..974573d
--- /dev/null
+++ b/internal/dbstore/dbstore_test.go
@@ -0,0 +1,172 @@
+package dbstore
+
+import (
+ "compress/gzip"
+ db_sql "database/sql"
+ "encoding/xml"
+ "embed"
+ "errors"
+ "fmt"
+ _ "github.com/mattn/go-sqlite3"
+ "github.com/pablotron/cvez/cpedict"
+ io_fs "io/fs"
+ "os"
+ "testing"
+)
+
+func getTestDictionary(path string) (cpedict.Dictionary, error) {
+ var dict cpedict.Dictionary
+
+ // open test data
+ f, err := os.Open(path)
+ if err != nil {
+ return dict, err
+ }
+ defer f.Close()
+
+ // create zip reader
+ gz, err := gzip.NewReader(f)
+ if err != nil {
+ return dict, err
+ }
+ defer gz.Close()
+
+ // create xml decoder
+ d := xml.NewDecoder(gz)
+
+ // decode xml
+ if err = d.Decode(&dict); err != nil {
+ return dict, err
+ }
+
+ // return success
+ return dict, nil
+}
+//go:embed testdata/sql/*.sql
+var testSqlFs embed.FS
+
+var testSqlIds = map[string]bool {
+ "init": false,
+ "insert-cpe": true,
+ "insert-title": true,
+ "insert-ref": true,
+}
+
+func getTestQueries() (map[string]string, error) {
+ r := make(map[string]string)
+
+ for id, _ := range(testSqlIds) {
+ path := fmt.Sprintf("testdata/sql/%s.sql", id)
+ if data, err := testSqlFs.ReadFile(path); err != nil {
+ return r, err
+ } else {
+ r[id] = string(data)
+ }
+ }
+
+ return r, nil
+}
+
+func TestSimple(t *testing.T) {
+ testDbPath := "./testdata/foo.db"
+ // get queries
+ queries, err := getTestQueries()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+
+ // load test CPEs
+ dict, err := getTestDictionary("testdata/test-0.xml.gz")
+ if err != nil {
+ t.Error(err)
+ return
+ }
+
+ // does test db exist?
+ if _, err = os.Stat(testDbPath); err != nil {
+ if !errors.Is(err, io_fs.ErrNotExist) {
+ t.Error(err)
+ return
+ }
+ } else if err == nil {
+ // remove test db
+ if err = os.Remove(testDbPath); err != nil {
+ t.Error(err)
+ return
+ }
+ }
+
+ // init db
+ db, err := db_sql.Open("sqlite3", testDbPath)
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ defer db.Close()
+
+ // init tables
+ if _, err := db.Exec(queries["init"]); err != nil {
+ t.Error(err)
+ return
+ }
+
+ tx, err := db.Begin()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+
+ // build statements
+ sts := make(map[string]*db_sql.Stmt)
+ for id, use := range(testSqlIds) {
+ if use {
+ if st, err := tx.Prepare(queries[id]); err != nil {
+ t.Error(err)
+ return
+ } else {
+ sts[id] = st
+ defer sts[id].Close()
+ }
+ }
+ }
+
+ // add items
+ for _, item := range(dict.Items) {
+ // add cpe
+ rs, err := sts["insert-cpe"].Exec(item.CpeUri, item.Cpe23Item.Name);
+ if err != nil {
+ t.Error(err)
+ return
+ }
+
+ // get last row ID
+ id, err := rs.LastInsertId()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+
+ // add titles
+ for _, title := range(item.Titles) {
+ if _, err := sts["insert-title"].Exec(id, title.Lang, title.Text); err != nil {
+ t.Error(err)
+ return
+ }
+ }
+
+ // add refs
+ for _, ref := range(item.References) {
+ if _, err := sts["insert-ref"].Exec(id, ref.Href, ref.Text); err != nil {
+ t.Error(err)
+ return
+ }
+ }
+ }
+
+ // commit changes
+ if err = tx.Commit(); err != nil {
+ t.Error(err)
+ return
+ }
+}