aboutsummaryrefslogtreecommitdiff
path: root/dbstore
diff options
context:
space:
mode:
authorPaul Duncan <pabs@pablotron.org>2022-02-26 13:56:25 -0500
committerPaul Duncan <pabs@pablotron.org>2022-02-26 13:56:25 -0500
commitd34c68f709ba482dff8f79801971c16f19a6c9f7 (patch)
tree87c69cd1519779427d8bf20d073a4f3b208eabe5 /dbstore
parente2c77fc78be780d46c35584d5f33a771b14631fe (diff)
downloadcvez-d34c68f709ba482dff8f79801971c16f19a6c9f7.tar.bz2
cvez-d34c68f709ba482dff8f79801971c16f19a6c9f7.zip
dbstore/dbstore_test.go: add TestTx(), TestQuery(), and TestQueryRow()
Diffstat (limited to 'dbstore')
-rw-r--r--dbstore/dbstore_test.go254
1 files changed, 254 insertions, 0 deletions
diff --git a/dbstore/dbstore_test.go b/dbstore/dbstore_test.go
index 09093c7..1b5d498 100644
--- a/dbstore/dbstore_test.go
+++ b/dbstore/dbstore_test.go
@@ -261,6 +261,260 @@ func TestInitFail(t *testing.T) {
}
}
+func TestTx(t *testing.T) {
+ ctx := context.Background()
+ expCtx, _ := context.WithTimeout(ctx, 0)
+ nopeErr := fmt.Errorf("nope")
+
+ // create temp dir
+ dir, err := os.MkdirTemp("", "")
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ defer os.RemoveAll(dir)
+
+ // create dbstore
+ db, err := Open(filepath.Join(dir, "test-tx.db"))
+ if err != nil {
+ t.Error(err)
+ return
+ }
+
+ passTests := []struct {
+ id string
+ exp error
+ fn func(Tx) error
+ } {
+ { "commit", nil, func(_ Tx) error { return nil } },
+ { "rollback", nopeErr, func(_ Tx) error { return nopeErr } },
+ }
+
+ for _, test := range(passTests) {
+ t.Run(test.id, func(t *testing.T) {
+ got := db.Tx(ctx, []string {}, test.fn)
+ if got != test.exp {
+ t.Errorf("got %v, exp %v", got, test.exp)
+ }
+ })
+ }
+
+ // null transaction function
+ nullTxFn := func(_ Tx) error { return nil }
+ expSoonCtx, _ := context.WithTimeout(ctx, 100 * time.Millisecond)
+
+ failTests := []struct {
+ id string // test ID
+ ctx context.Context // context
+ ids []string // query IDs
+ fn func(Tx) error // transaction callback
+ } {{
+ id: "ctx",
+ ctx: expCtx,
+ ids: []string {},
+ fn: nullTxFn,
+ }, {
+ id: "ids",
+ ctx: ctx,
+ ids: []string {"bad query"},
+ fn: nullTxFn,
+ }, {
+ id: "rollback", // FIXME: doesn't work
+ ctx: expSoonCtx,
+ ids: []string {},
+ fn: func(_ Tx) error {
+ time.Sleep(200 * time.Millisecond)
+ return nil
+ },
+ }}
+
+ for _, test := range(failTests) {
+ t.Run(test.id, func(t *testing.T) {
+ if db.Tx(test.ctx, test.ids, test.fn) == nil {
+ t.Errorf("got success, exp err")
+ }
+ })
+ }
+}
+
+func TestQuery(t *testing.T) {
+ ctx := context.Background()
+
+ // create temp dir
+ dir, err := os.MkdirTemp("", "")
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ defer os.RemoveAll(dir)
+
+ // create dbstore
+ db, err := Open(filepath.Join(dir, "test-query.db"))
+ if err != nil {
+ t.Error(err)
+ return
+ }
+
+ passTests := []struct {
+ id string // query ID
+ args []interface{} // query args
+ exp []string // expected results
+ } {{
+ id: "test/query/ids",
+ args: []interface{} {},
+ exp: []string { "foo", "bar", "baz" },
+ }, {
+ id: "test/query/ids-arg-named",
+ args: []interface{} { db_sql.Named("q", "bar") },
+ exp: []string { "bar" },
+ }, {
+ id: "test/query/ids-arg-pos",
+ args: []interface{} { "baz" },
+ exp: []string { "baz" },
+ }}
+
+ for _, test := range(passTests) {
+ t.Run(fmt.Sprintf("pass/%s", test.id), func(t *testing.T) {
+ var got []string
+
+ // exec query, build result
+ err := db.Query(ctx, test.id, test.args, func(rows *db_sql.Rows) error {
+ var s string
+ if err := rows.Scan(&s); err != nil {
+ return err
+ } else {
+ got = append(got, s)
+ return nil
+ }
+ })
+
+ // check for error
+ if err != nil {
+ t.Error(err)
+ return
+ }
+
+ // check for expected value
+ if !reflect.DeepEqual(got, test.exp) {
+ t.Errorf("got %v, exp %v", got, test.exp)
+ }
+ })
+ }
+
+ failTests := []struct {
+ id string // query ID
+ args []interface{} // query args
+ } {{
+ id: "test/query/ids-arg-named",
+ // args: []interface{} { "foo", true },
+ }, {
+ id: "test/query/ids-arg-pos",
+ // args: []interface{} { "foo", true },
+ }}
+
+ var s string
+ someErr := errors.New("an error")
+ scanFn := func(rows *db_sql.Rows) error { return rows.Scan(&s) }
+ failFn := func(_ *db_sql.Rows) error { return someErr }
+
+ // run fail tests
+ for _, test := range(failTests) {
+ t.Run(fmt.Sprintf("fail/%s", test.id), func(t *testing.T) {
+ if db.Query(ctx, test.id, test.args, scanFn) == nil {
+ t.Errorf("got success, exp error")
+ }
+ })
+ }
+
+ // test callback func error
+ t.Run("fail-func", func(t *testing.T) {
+ if db.Query(ctx, "test/query/ids", nil, failFn) == nil {
+ t.Errorf("got success, exp error")
+ }
+ })
+}
+
+func TestQueryRow(t *testing.T) {
+ ctx := context.Background()
+
+ // create temp dir
+ dir, err := os.MkdirTemp("", "")
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ defer os.RemoveAll(dir)
+
+ // create dbstore
+ db, err := Open(filepath.Join(dir, "test-queryrow.db"))
+ if err != nil {
+ t.Error(err)
+ return
+ }
+
+ passTests := []struct {
+ id string // query ID
+ args []interface{} // query args
+ exp string // expected results
+ } {{
+ id: "test/queryrow/ids",
+ exp: "foo",
+ }, {
+ id: "test/queryrow/ids-arg-named",
+ args: []interface{} { db_sql.Named("q", "bar") },
+ exp: "bar",
+ }, {
+ id: "test/queryrow/ids-arg-pos",
+ args: []interface{} { "baz" },
+ exp: "baz",
+ }}
+
+ for _, test := range(passTests) {
+ t.Run(fmt.Sprintf("pass/%s", test.id), func(t *testing.T) {
+ // exec query, get result
+ var got string
+ err := db.QueryRow(ctx, test.id, test.args, func(row *db_sql.Row) error {
+ return row.Scan(&got)
+ })
+
+ // check for error, check result
+ if err != nil {
+ t.Error(err)
+ return
+ } else if got != test.exp {
+ t.Errorf("got %v, exp %v", got, test.exp)
+ }
+ })
+ }
+
+ failTests := []string {
+ "invalid query",
+ "test/queryrow/ids-arg-named",
+ "test/queryrow/ids-arg-pos",
+ }
+
+ var got string
+ someErr := errors.New("an error")
+ scanFn := func(row *db_sql.Row) error { return row.Scan(&got) }
+ failFn := func(_ *db_sql.Row) error { return someErr }
+
+ // run fail tests
+ for _, test := range(failTests) {
+ t.Run(fmt.Sprintf("fail/%s", test), func(t *testing.T) {
+ if db.QueryRow(ctx, test, nil, scanFn) == nil {
+ t.Errorf("got %s, exp error", got)
+ }
+ })
+ }
+
+ // test callback func error
+ t.Run("fail-func", func(t *testing.T) {
+ if db.QueryRow(ctx, "test/queryrow/ids", nil, failFn) == nil {
+ t.Errorf("got success, exp error")
+ }
+ })
+}
+
func TestAddCpeDictionaryPass(t *testing.T) {
if testing.Short() {
t.Skip("skipping TestAddCveFeeds() in short mode")