aboutsummaryrefslogtreecommitdiff
path: root/dbstore
diff options
context:
space:
mode:
authorPaul Duncan <pabs@pablotron.org>2022-02-19 16:21:51 -0500
committerPaul Duncan <pabs@pablotron.org>2022-02-19 16:21:51 -0500
commit4e478d0b49e777bdb61269ec06f36d0c3ee9c68c (patch)
tree8b4ae7c0c45784f63102d6068d9f6d6454baca47 /dbstore
parent6ea658ce4193bf4c0e9c928473184fe708c2c6d1 (diff)
downloadcvez-4e478d0b49e777bdb61269ec06f36d0c3ee9c68c.tar.bz2
cvez-4e478d0b49e777bdb61269ec06f36d0c3ee9c68c.zip
dbstore/dbstore_test.go: add TestAddCveFeed()
Diffstat (limited to 'dbstore')
-rw-r--r--dbstore/dbstore_test.go59
1 files changed, 59 insertions, 0 deletions
diff --git a/dbstore/dbstore_test.go b/dbstore/dbstore_test.go
index 628addb..ea5d284 100644
--- a/dbstore/dbstore_test.go
+++ b/dbstore/dbstore_test.go
@@ -4,6 +4,7 @@ import (
"compress/gzip"
"context"
db_sql "database/sql"
+ "encoding/json"
"encoding/xml"
"embed"
"errors"
@@ -11,6 +12,7 @@ import (
_ "github.com/mattn/go-sqlite3"
"github.com/pablotron/cvez/cpedict"
"github.com/pablotron/cvez/cpematch"
+ nvd_feed "github.com/pablotron/cvez/feed"
io_fs "io/fs"
"os"
"path/filepath"
@@ -19,6 +21,27 @@ import (
"time"
)
+// Load gzip-compressed test feed.
+func getFeed(path string) (nvd_feed.Feed, error) {
+ var feed nvd_feed.Feed
+
+ // open file for reading
+ file, err := os.Open(path)
+ if err != nil {
+ return feed, err
+ }
+
+ // wrap in reader, return success
+ src, err := gzip.NewReader(file)
+ if err != nil {
+ return feed, err
+ }
+
+ // create decoder, decode feed, return result
+ d := json.NewDecoder(src)
+ return feed, d.Decode(&feed)
+}
+
func getTestDictionary(path string) (cpedict.Dictionary, error) {
var dict cpedict.Dictionary
@@ -946,3 +969,39 @@ func TestCpeMatchSearch(t *testing.T) {
})
}
}
+
+func TestAddCveFeed(t *testing.T) {
+ ctx := context.Background()
+
+ tests := []string {
+ "nvdcve-1.1-2002",
+ "nvdcve-1.1-2003",
+ "nvdcve-1.1-2021",
+ }
+
+ // create test db
+ db, err := createTestDb(ctx, "testdata/test-addcvefeed.db")
+ if err != nil {
+ t.Error(err)
+ return
+ }
+
+ // run tests
+ for _, test := range(tests) {
+ t.Run(test, func(t *testing.T) {
+ // add feed
+ feed, err := getFeed(fmt.Sprintf("testdata/%s.json.gz", test))
+ if err != nil {
+ t.Error(err)
+ return
+ }
+
+ // add feed
+ _, err = db.AddCveFeed(ctx, feed)
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ })
+ }
+}