From 7413c56a6ac3d7a14ebde6ebb504efc4f3d5d575 Mon Sep 17 00:00:00 2001 From: Paul Duncan Date: Sat, 26 Feb 2022 09:59:49 -0500 Subject: dbstore: add dbstore/util.go w/ tests --- dbstore/dbstore.go | 29 ++--------------------------- dbstore/util.go | 27 +++++++++++++++++++++++++++ dbstore/util_test.go | 44 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 73 insertions(+), 27 deletions(-) create mode 100644 dbstore/util.go create mode 100644 dbstore/util_test.go diff --git a/dbstore/dbstore.go b/dbstore/dbstore.go index 840f268..6ff0ced 100644 --- a/dbstore/dbstore.go +++ b/dbstore/dbstore.go @@ -365,31 +365,6 @@ func (me DbStore) CpeMatchSearch( return r, err } -// Check data type, data format, and data version. -func (me DbStore) checkData( - dataType nvd_feed.DataType, - dataFormat nvd_feed.DataFormat, - dataVersion nvd_feed.DataVersion, -) error { - // check data type - if dataType != nvd_feed.CveType { - return fmt.Errorf("unknown data type: %s", dataType) - } - - // check data format - if dataFormat != nvd_feed.MitreFormat { - return fmt.Errorf("unknown data format: %s", dataFormat) - } - - // check data version - if dataVersion != nvd_feed.V40 { - return fmt.Errorf("unknown data version: %s", dataVersion) - } - - // return success - return nil -} - // Add description. func (me DbStore) addDescriptions(ctx context.Context, tx Tx, ds []nvd_feed.Description) ([]int64, error) { r := make([]int64, len(ds)) @@ -427,7 +402,7 @@ func (me DbStore) addCve(ctx context.Context, tx Tx, itemId int64, cve nvd_feed. var cveId int64 // check data type, data format, and data version - if err := me.checkData(cve.DataType, cve.DataFormat, cve.DataVersion); err != nil { + if err := checkNvdData(cve.DataType, cve.DataFormat, cve.DataVersion); err != nil { return err } @@ -650,7 +625,7 @@ func (me DbStore) addFeed(ctx context.Context, tx Tx, feed nvd_feed.Feed) (int64 var feedId int64 // check feed data type, data format, and data version - if err := me.checkData(feed.DataType, feed.DataFormat, feed.DataVersion); err != nil { + if err := checkNvdData(feed.DataType, feed.DataFormat, feed.DataVersion); err != nil { return feedId, err } diff --git a/dbstore/util.go b/dbstore/util.go new file mode 100644 index 0000000..2b8d444 --- /dev/null +++ b/dbstore/util.go @@ -0,0 +1,27 @@ +package dbstore + +import ( + "fmt" + "github.com/pablotron/cvez/feed" +) + +// Check data type, data format, and data version in NVD feed entry. +func checkNvdData(dt feed.DataType, df feed.DataFormat, dv feed.DataVersion) error { + // check data type + if dt != feed.CveType { + return fmt.Errorf("unknown data type: %s", dt) + } + + // check data format + if df != feed.MitreFormat { + return fmt.Errorf("unknown data format: %s", df) + } + + // check data version + if dv != feed.V40 { + return fmt.Errorf("unknown data version: %s", dv) + } + + // return success + return nil +} diff --git a/dbstore/util_test.go b/dbstore/util_test.go new file mode 100644 index 0000000..a949878 --- /dev/null +++ b/dbstore/util_test.go @@ -0,0 +1,44 @@ +package dbstore + +import ( + "github.com/pablotron/cvez/feed" + "testing" +) + +func TestCheckNvdData(t *testing.T) { + passTests := []struct { + name string + dt feed.DataType + df feed.DataFormat + dv feed.DataVersion + } { + { "valid", feed.CveType, feed.MitreFormat, feed.V40 }, + } + + for _, test := range(passTests) { + t.Run(test.name, func(t *testing.T) { + if err := checkNvdData(test.dt, test.df, test.dv); err != nil { + t.Error(err) + } + }) + } + + failTests := []struct { + name string + dt feed.DataType + df feed.DataFormat + dv feed.DataVersion + } { + { "bad type", feed.DataType(255), feed.MitreFormat, feed.V40 }, + { "bad format", feed.CveType, feed.DataFormat(255), feed.V40 }, + { "bad version", feed.CveType, feed.MitreFormat, feed.DataVersion(255) }, + } + + for _, test := range(failTests) { + t.Run(test.name, func(t *testing.T) { + if checkNvdData(test.dt, test.df, test.dv) == nil { + t.Errorf("got success, exp error") + } + }) + } +} -- cgit v1.2.3