From d34c68f709ba482dff8f79801971c16f19a6c9f7 Mon Sep 17 00:00:00 2001 From: Paul Duncan Date: Sat, 26 Feb 2022 13:56:25 -0500 Subject: dbstore/dbstore_test.go: add TestTx(), TestQuery(), and TestQueryRow() --- dbstore/dbstore_test.go | 254 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 254 insertions(+) diff --git a/dbstore/dbstore_test.go b/dbstore/dbstore_test.go index 09093c7..1b5d498 100644 --- a/dbstore/dbstore_test.go +++ b/dbstore/dbstore_test.go @@ -261,6 +261,260 @@ func TestInitFail(t *testing.T) { } } +func TestTx(t *testing.T) { + ctx := context.Background() + expCtx, _ := context.WithTimeout(ctx, 0) + nopeErr := fmt.Errorf("nope") + + // create temp dir + dir, err := os.MkdirTemp("", "") + if err != nil { + t.Error(err) + return + } + defer os.RemoveAll(dir) + + // create dbstore + db, err := Open(filepath.Join(dir, "test-tx.db")) + if err != nil { + t.Error(err) + return + } + + passTests := []struct { + id string + exp error + fn func(Tx) error + } { + { "commit", nil, func(_ Tx) error { return nil } }, + { "rollback", nopeErr, func(_ Tx) error { return nopeErr } }, + } + + for _, test := range(passTests) { + t.Run(test.id, func(t *testing.T) { + got := db.Tx(ctx, []string {}, test.fn) + if got != test.exp { + t.Errorf("got %v, exp %v", got, test.exp) + } + }) + } + + // null transaction function + nullTxFn := func(_ Tx) error { return nil } + expSoonCtx, _ := context.WithTimeout(ctx, 100 * time.Millisecond) + + failTests := []struct { + id string // test ID + ctx context.Context // context + ids []string // query IDs + fn func(Tx) error // transaction callback + } {{ + id: "ctx", + ctx: expCtx, + ids: []string {}, + fn: nullTxFn, + }, { + id: "ids", + ctx: ctx, + ids: []string {"bad query"}, + fn: nullTxFn, + }, { + id: "rollback", // FIXME: doesn't work + ctx: expSoonCtx, + ids: []string {}, + fn: func(_ Tx) error { + time.Sleep(200 * time.Millisecond) + return nil + }, + }} + + for _, test := range(failTests) { + t.Run(test.id, func(t *testing.T) { + if db.Tx(test.ctx, test.ids, test.fn) == nil { + t.Errorf("got success, exp err") + } + }) + } +} + +func TestQuery(t *testing.T) { + ctx := context.Background() + + // create temp dir + dir, err := os.MkdirTemp("", "") + if err != nil { + t.Error(err) + return + } + defer os.RemoveAll(dir) + + // create dbstore + db, err := Open(filepath.Join(dir, "test-query.db")) + if err != nil { + t.Error(err) + return + } + + passTests := []struct { + id string // query ID + args []interface{} // query args + exp []string // expected results + } {{ + id: "test/query/ids", + args: []interface{} {}, + exp: []string { "foo", "bar", "baz" }, + }, { + id: "test/query/ids-arg-named", + args: []interface{} { db_sql.Named("q", "bar") }, + exp: []string { "bar" }, + }, { + id: "test/query/ids-arg-pos", + args: []interface{} { "baz" }, + exp: []string { "baz" }, + }} + + for _, test := range(passTests) { + t.Run(fmt.Sprintf("pass/%s", test.id), func(t *testing.T) { + var got []string + + // exec query, build result + err := db.Query(ctx, test.id, test.args, func(rows *db_sql.Rows) error { + var s string + if err := rows.Scan(&s); err != nil { + return err + } else { + got = append(got, s) + return nil + } + }) + + // check for error + if err != nil { + t.Error(err) + return + } + + // check for expected value + if !reflect.DeepEqual(got, test.exp) { + t.Errorf("got %v, exp %v", got, test.exp) + } + }) + } + + failTests := []struct { + id string // query ID + args []interface{} // query args + } {{ + id: "test/query/ids-arg-named", + // args: []interface{} { "foo", true }, + }, { + id: "test/query/ids-arg-pos", + // args: []interface{} { "foo", true }, + }} + + var s string + someErr := errors.New("an error") + scanFn := func(rows *db_sql.Rows) error { return rows.Scan(&s) } + failFn := func(_ *db_sql.Rows) error { return someErr } + + // run fail tests + for _, test := range(failTests) { + t.Run(fmt.Sprintf("fail/%s", test.id), func(t *testing.T) { + if db.Query(ctx, test.id, test.args, scanFn) == nil { + t.Errorf("got success, exp error") + } + }) + } + + // test callback func error + t.Run("fail-func", func(t *testing.T) { + if db.Query(ctx, "test/query/ids", nil, failFn) == nil { + t.Errorf("got success, exp error") + } + }) +} + +func TestQueryRow(t *testing.T) { + ctx := context.Background() + + // create temp dir + dir, err := os.MkdirTemp("", "") + if err != nil { + t.Error(err) + return + } + defer os.RemoveAll(dir) + + // create dbstore + db, err := Open(filepath.Join(dir, "test-queryrow.db")) + if err != nil { + t.Error(err) + return + } + + passTests := []struct { + id string // query ID + args []interface{} // query args + exp string // expected results + } {{ + id: "test/queryrow/ids", + exp: "foo", + }, { + id: "test/queryrow/ids-arg-named", + args: []interface{} { db_sql.Named("q", "bar") }, + exp: "bar", + }, { + id: "test/queryrow/ids-arg-pos", + args: []interface{} { "baz" }, + exp: "baz", + }} + + for _, test := range(passTests) { + t.Run(fmt.Sprintf("pass/%s", test.id), func(t *testing.T) { + // exec query, get result + var got string + err := db.QueryRow(ctx, test.id, test.args, func(row *db_sql.Row) error { + return row.Scan(&got) + }) + + // check for error, check result + if err != nil { + t.Error(err) + return + } else if got != test.exp { + t.Errorf("got %v, exp %v", got, test.exp) + } + }) + } + + failTests := []string { + "invalid query", + "test/queryrow/ids-arg-named", + "test/queryrow/ids-arg-pos", + } + + var got string + someErr := errors.New("an error") + scanFn := func(row *db_sql.Row) error { return row.Scan(&got) } + failFn := func(_ *db_sql.Row) error { return someErr } + + // run fail tests + for _, test := range(failTests) { + t.Run(fmt.Sprintf("fail/%s", test), func(t *testing.T) { + if db.QueryRow(ctx, test, nil, scanFn) == nil { + t.Errorf("got %s, exp error", got) + } + }) + } + + // test callback func error + t.Run("fail-func", func(t *testing.T) { + if db.QueryRow(ctx, "test/queryrow/ids", nil, failFn) == nil { + t.Errorf("got success, exp error") + } + }) +} + func TestAddCpeDictionaryPass(t *testing.T) { if testing.Short() { t.Skip("skipping TestAddCveFeeds() in short mode") -- cgit v1.2.3