aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--dbstore/dbstore.go190
-rw-r--r--dbstore/tx_test.go10
2 files changed, 109 insertions, 91 deletions
diff --git a/dbstore/dbstore.go b/dbstore/dbstore.go
index 21a80a3..ccddd72 100644
--- a/dbstore/dbstore.go
+++ b/dbstore/dbstore.go
@@ -83,11 +83,35 @@ var addCpeDictionaryQueryIds = []string {
"cpe/insert-ref",
}
-// create new transaction
-func (me DbStore) Tx(ctx context.Context, queryIds []string) (Tx, error) {
+// Begin new transaction and create prepared statements.
+func (me DbStore) Begin(ctx context.Context, queryIds []string) (Tx, error) {
return newTx(ctx, me.db, queryIds)
}
+// Create a transaction, pass it to callback, then commit the transaction
+// if the callback returns success and rollback the transaction if the
+// callback returns an error.
+func (me DbStore) Tx(ctx context.Context, queryIds []string, fn func(Tx) error) error {
+ // create transaction
+ tx, err := me.Begin(ctx, queryIds)
+ if err != nil {
+ return err
+ }
+
+ if err := fn(tx); err != nil {
+ // rollback
+ if rb_err := tx.Rollback(); rb_err != nil {
+ return rb_err
+ }
+
+ // return error
+ return err
+ } else {
+ // commit transaction
+ return tx.Commit()
+ }
+}
+
// Execute query and invoke callback with each row of result.
func (me DbStore) Query(
ctx context.Context,
@@ -139,44 +163,41 @@ func (me DbStore) AddCpeDictionary(ctx context.Context, dict cpedict.Dictionary)
return err
}
- tx, err := me.Tx(ctx, addCpeDictionaryQueryIds)
- if err != nil {
- return err
- }
-
- // add items
- for _, item := range(dict.Items) {
- // add cpe
- rs, err := tx.Exec(ctx, "cpe/insert", item.CpeUri, item.Cpe23Item.Name)
- if err != nil {
- return err
- }
-
- // get last row ID
- id, err := rs.LastInsertId()
- if err != nil {
- return err
- }
-
- // add titles
- for _, title := range(item.Titles) {
- _, err := tx.Exec(ctx, "cpe/insert-title", id, title.Lang, title.Text)
+ return me.Tx(ctx, addCpeDictionaryQueryIds, func(tx Tx) error {
+ // add items
+ for _, item := range(dict.Items) {
+ // add cpe
+ rs, err := tx.Exec(ctx, "cpe/insert", item.CpeUri, item.Cpe23Item.Name)
if err != nil {
return err
}
- }
- // add refs
- for _, ref := range(item.References) {
- _, err := tx.Exec(ctx, "cpe/insert-ref", id, ref.Href, ref.Text)
+ // get last row ID
+ id, err := rs.LastInsertId()
if err != nil {
return err
}
+
+ // add titles
+ for _, title := range(item.Titles) {
+ _, err := tx.Exec(ctx, "cpe/insert-title", id, title.Lang, title.Text)
+ if err != nil {
+ return err
+ }
+ }
+
+ // add refs
+ for _, ref := range(item.References) {
+ _, err := tx.Exec(ctx, "cpe/insert-ref", id, ref.Href, ref.Text)
+ if err != nil {
+ return err
+ }
+ }
}
- }
- // commit changes, return result
- return tx.Commit()
+ // return success
+ return nil
+ })
}
// search CPEs
@@ -227,74 +248,71 @@ func (me DbStore) AddCpeMatches(ctx context.Context, matches cpematch.Matches) e
}
// begin transaction
- tx, err := me.Tx(ctx, addCpeMatchesQueryIds)
- if err != nil {
- return err
- }
-
- // add matches
- for _, m := range(matches.Matches) {
- // add cpe
- rs, err := tx.Exec(ctx, "cpe-match/insert", m.Cpe23Uri, m.Cpe22Uri)
- if err != nil {
- return err
- }
-
- // get last row ID
- id, err := rs.LastInsertId()
- if err != nil {
- return err
- }
-
- // add vulnerable
- if m.Vulnerable != nil {
- _, err := tx.Exec(ctx, "cpe-match/insert-vulnerable", id, *m.Vulnerable)
- if err != nil {
+ return me.Tx(ctx, addCpeMatchesQueryIds, func(tx Tx) error {
+ // add matches
+ for _, m := range(matches.Matches) {
+ // add cpe
+ rs, err := tx.Exec(ctx, "cpe-match/insert", m.Cpe23Uri, m.Cpe22Uri)
+ if err != nil {
return err
}
- }
- // add version minimum
- if m.VersionStartIncluding != "" && m.VersionStartExcluding != "" {
- return fmt.Errorf("cannot specify both VersionStartIncluding = \"%s\", VersionEndIncluding \"%s\"", m.VersionStartIncluding, m.VersionStartExcluding)
- } else if m.VersionStartIncluding != "" {
- _, err := tx.Exec(ctx, "cpe-match/insert-version-min", id, true, m.VersionStartIncluding)
- if err != nil {
+ // get last row ID
+ id, err := rs.LastInsertId()
+ if err != nil {
return err
}
- } else if m.VersionStartExcluding != "" {
- _, err := tx.Exec(ctx, "cpe-match/insert-version-min", id, false, m.VersionStartExcluding)
- if err != nil {
- return err
+
+ // add vulnerable
+ if m.Vulnerable != nil {
+ _, err := tx.Exec(ctx, "cpe-match/insert-vulnerable", id, *m.Vulnerable)
+ if err != nil {
+ return err
+ }
}
- }
- // add version maximum
- if m.VersionEndIncluding != "" && m.VersionEndExcluding != "" {
- return fmt.Errorf("cannot specify both VersionEndIncluding = \"%s\", VersionEndIncluding \"%s\"", m.VersionEndIncluding, m.VersionEndExcluding)
- } else if m.VersionEndIncluding != "" {
- _, err := tx.Exec(ctx, "cpe-match/insert-version-max", id, true, m.VersionEndIncluding)
- if err != nil {
- return err
+ // add version minimum
+ if m.VersionStartIncluding != "" && m.VersionStartExcluding != "" {
+ return fmt.Errorf("cannot specify both VersionStartIncluding = \"%s\", VersionEndIncluding \"%s\"", m.VersionStartIncluding, m.VersionStartExcluding)
+ } else if m.VersionStartIncluding != "" {
+ _, err := tx.Exec(ctx, "cpe-match/insert-version-min", id, true, m.VersionStartIncluding)
+ if err != nil {
+ return err
+ }
+ } else if m.VersionStartExcluding != "" {
+ _, err := tx.Exec(ctx, "cpe-match/insert-version-min", id, false, m.VersionStartExcluding)
+ if err != nil {
+ return err
+ }
}
- } else if m.VersionEndExcluding != "" {
- _, err := tx.Exec(ctx, "cpe-match/insert-version-max", id, false, m.VersionEndExcluding)
- if err != nil {
- return err
+
+ // add version maximum
+ if m.VersionEndIncluding != "" && m.VersionEndExcluding != "" {
+ return fmt.Errorf("cannot specify both VersionEndIncluding = \"%s\", VersionEndIncluding \"%s\"", m.VersionEndIncluding, m.VersionEndExcluding)
+ } else if m.VersionEndIncluding != "" {
+ _, err := tx.Exec(ctx, "cpe-match/insert-version-max", id, true, m.VersionEndIncluding)
+ if err != nil {
+ return err
+ }
+ } else if m.VersionEndExcluding != "" {
+ _, err := tx.Exec(ctx, "cpe-match/insert-version-max", id, false, m.VersionEndExcluding)
+ if err != nil {
+ return err
+ }
}
- }
- // add names
- for _, name := range(m.Names) {
- _, err := tx.Exec(ctx, "cpe-match/insert-name", id, name.Cpe23Uri, name.Cpe22Uri)
- if err != nil {
- return err
+ // add names
+ for _, name := range(m.Names) {
+ _, err := tx.Exec(ctx, "cpe-match/insert-name", id, name.Cpe23Uri, name.Cpe22Uri)
+ if err != nil {
+ return err
+ }
}
}
- }
- // commit changes, return result
- return tx.Commit()
+ // return success
+ return nil
+ })
}
// search CPE matches
diff --git a/dbstore/tx_test.go b/dbstore/tx_test.go
index df15f9e..0e3ed23 100644
--- a/dbstore/tx_test.go
+++ b/dbstore/tx_test.go
@@ -33,13 +33,13 @@ func TestNewTxDupTx(t *testing.T) {
}
// create first transaction
- if _, err = db.Tx(ctx, []string{}); err != nil {
+ if _, err = db.Begin(ctx, []string{}); err != nil {
t.Error(err)
return
}
// create second transaction
- got, err := db.Tx(ctx, []string{})
+ got, err := db.Begin(ctx, []string{})
if err != nil { // FIXME
t.Errorf("got %v, exp error", got)
return
@@ -72,7 +72,7 @@ func TestNewTxBadQueries(t *testing.T) {
}
// create first transaction
- if got, err := db.Tx(ctx, badQueryIds); err == nil {
+ if got, err := db.Begin(ctx, badQueryIds); err == nil {
t.Errorf("got %v, exp error", got)
return
}
@@ -96,7 +96,7 @@ func TestNewTxPrepareFail(t *testing.T) {
}
// create transaction
- if got, err := db.Tx(ctx, []string { "test/junk" }); err == nil {
+ if got, err := db.Begin(ctx, []string { "test/junk" }); err == nil {
t.Errorf("got %v, exp error", got)
}
}
@@ -125,7 +125,7 @@ func TestTxExecFail(t *testing.T) {
}
// create transaction
- tx, err := db.Tx(ctx, []string { "init" })
+ tx, err := db.Begin(ctx, []string { "init" })
if err != nil {
t.Error(err)
return