package dbstore import ( "compress/gzip" db_sql "database/sql" "encoding/xml" "embed" "errors" "fmt" _ "github.com/mattn/go-sqlite3" "github.com/pablotron/cvez/cpedict" io_fs "io/fs" "os" "testing" ) func getTestDictionary(path string) (cpedict.Dictionary, error) { var dict cpedict.Dictionary // open test data f, err := os.Open(path) if err != nil { return dict, err } defer f.Close() // create zip reader gz, err := gzip.NewReader(f) if err != nil { return dict, err } defer gz.Close() // create xml decoder d := xml.NewDecoder(gz) // decode xml if err = d.Decode(&dict); err != nil { return dict, err } // return success return dict, nil } //go:embed testdata/sql/*.sql var testSqlFs embed.FS var testSqlIds = map[string]bool { "init": false, "insert-cpe": true, "insert-title": true, "insert-ref": true, } func getTestQueries() (map[string]string, error) { r := make(map[string]string) for id, _ := range(testSqlIds) { path := fmt.Sprintf("testdata/sql/%s.sql", id) if data, err := testSqlFs.ReadFile(path); err != nil { return r, err } else { r[id] = string(data) } } return r, nil } func TestSimple(t *testing.T) { testDbPath := "./testdata/foo.db" // get queries queries, err := getTestQueries() if err != nil { t.Error(err) return } // load test CPEs dict, err := getTestDictionary("testdata/test-0.xml.gz") if err != nil { t.Error(err) return } // does test db exist? if _, err = os.Stat(testDbPath); err != nil { if !errors.Is(err, io_fs.ErrNotExist) { t.Error(err) return } } else if err == nil { // remove test db if err = os.Remove(testDbPath); err != nil { t.Error(err) return } } // init db db, err := db_sql.Open("sqlite3", testDbPath) if err != nil { t.Error(err) return } defer db.Close() // init tables if _, err := db.Exec(queries["init"]); err != nil { t.Error(err) return } tx, err := db.Begin() if err != nil { t.Error(err) return } // build statements sts := make(map[string]*db_sql.Stmt) for id, use := range(testSqlIds) { if use { if st, err := tx.Prepare(queries[id]); err != nil { t.Error(err) return } else { sts[id] = st defer sts[id].Close() } } } // add items for _, item := range(dict.Items) { // add cpe rs, err := sts["insert-cpe"].Exec(item.CpeUri, item.Cpe23Item.Name); if err != nil { t.Error(err) return } // get last row ID id, err := rs.LastInsertId() if err != nil { t.Error(err) return } // add titles for _, title := range(item.Titles) { if _, err := sts["insert-title"].Exec(id, title.Lang, title.Text); err != nil { t.Error(err) return } } // add refs for _, ref := range(item.References) { if _, err := sts["insert-ref"].Exec(id, ref.Href, ref.Text); err != nil { t.Error(err) return } } } // commit changes if err = tx.Commit(); err != nil { t.Error(err) return } }