From 14ba42b83fc9ea6f811f8ff44b449f9ad6a90f85 Mon Sep 17 00:00:00 2001 From: Paul Duncan Date: Fri, 18 Feb 2022 21:12:33 -0500 Subject: dbstore: rename Tx to Begin, add Tx which accepts callback and handles commit/rollback --- dbstore/dbstore.go | 190 +++++++++++++++++++++++++++++------------------------ dbstore/tx_test.go | 10 +-- 2 files changed, 109 insertions(+), 91 deletions(-) (limited to 'dbstore') diff --git a/dbstore/dbstore.go b/dbstore/dbstore.go index 21a80a3..ccddd72 100644 --- a/dbstore/dbstore.go +++ b/dbstore/dbstore.go @@ -83,11 +83,35 @@ var addCpeDictionaryQueryIds = []string { "cpe/insert-ref", } -// create new transaction -func (me DbStore) Tx(ctx context.Context, queryIds []string) (Tx, error) { +// Begin new transaction and create prepared statements. +func (me DbStore) Begin(ctx context.Context, queryIds []string) (Tx, error) { return newTx(ctx, me.db, queryIds) } +// Create a transaction, pass it to callback, then commit the transaction +// if the callback returns success and rollback the transaction if the +// callback returns an error. +func (me DbStore) Tx(ctx context.Context, queryIds []string, fn func(Tx) error) error { + // create transaction + tx, err := me.Begin(ctx, queryIds) + if err != nil { + return err + } + + if err := fn(tx); err != nil { + // rollback + if rb_err := tx.Rollback(); rb_err != nil { + return rb_err + } + + // return error + return err + } else { + // commit transaction + return tx.Commit() + } +} + // Execute query and invoke callback with each row of result. func (me DbStore) Query( ctx context.Context, @@ -139,44 +163,41 @@ func (me DbStore) AddCpeDictionary(ctx context.Context, dict cpedict.Dictionary) return err } - tx, err := me.Tx(ctx, addCpeDictionaryQueryIds) - if err != nil { - return err - } - - // add items - for _, item := range(dict.Items) { - // add cpe - rs, err := tx.Exec(ctx, "cpe/insert", item.CpeUri, item.Cpe23Item.Name) - if err != nil { - return err - } - - // get last row ID - id, err := rs.LastInsertId() - if err != nil { - return err - } - - // add titles - for _, title := range(item.Titles) { - _, err := tx.Exec(ctx, "cpe/insert-title", id, title.Lang, title.Text) + return me.Tx(ctx, addCpeDictionaryQueryIds, func(tx Tx) error { + // add items + for _, item := range(dict.Items) { + // add cpe + rs, err := tx.Exec(ctx, "cpe/insert", item.CpeUri, item.Cpe23Item.Name) if err != nil { return err } - } - // add refs - for _, ref := range(item.References) { - _, err := tx.Exec(ctx, "cpe/insert-ref", id, ref.Href, ref.Text) + // get last row ID + id, err := rs.LastInsertId() if err != nil { return err } + + // add titles + for _, title := range(item.Titles) { + _, err := tx.Exec(ctx, "cpe/insert-title", id, title.Lang, title.Text) + if err != nil { + return err + } + } + + // add refs + for _, ref := range(item.References) { + _, err := tx.Exec(ctx, "cpe/insert-ref", id, ref.Href, ref.Text) + if err != nil { + return err + } + } } - } - // commit changes, return result - return tx.Commit() + // return success + return nil + }) } // search CPEs @@ -227,74 +248,71 @@ func (me DbStore) AddCpeMatches(ctx context.Context, matches cpematch.Matches) e } // begin transaction - tx, err := me.Tx(ctx, addCpeMatchesQueryIds) - if err != nil { - return err - } - - // add matches - for _, m := range(matches.Matches) { - // add cpe - rs, err := tx.Exec(ctx, "cpe-match/insert", m.Cpe23Uri, m.Cpe22Uri) - if err != nil { - return err - } - - // get last row ID - id, err := rs.LastInsertId() - if err != nil { - return err - } - - // add vulnerable - if m.Vulnerable != nil { - _, err := tx.Exec(ctx, "cpe-match/insert-vulnerable", id, *m.Vulnerable) - if err != nil { + return me.Tx(ctx, addCpeMatchesQueryIds, func(tx Tx) error { + // add matches + for _, m := range(matches.Matches) { + // add cpe + rs, err := tx.Exec(ctx, "cpe-match/insert", m.Cpe23Uri, m.Cpe22Uri) + if err != nil { return err } - } - // add version minimum - if m.VersionStartIncluding != "" && m.VersionStartExcluding != "" { - return fmt.Errorf("cannot specify both VersionStartIncluding = \"%s\", VersionEndIncluding \"%s\"", m.VersionStartIncluding, m.VersionStartExcluding) - } else if m.VersionStartIncluding != "" { - _, err := tx.Exec(ctx, "cpe-match/insert-version-min", id, true, m.VersionStartIncluding) - if err != nil { + // get last row ID + id, err := rs.LastInsertId() + if err != nil { return err } - } else if m.VersionStartExcluding != "" { - _, err := tx.Exec(ctx, "cpe-match/insert-version-min", id, false, m.VersionStartExcluding) - if err != nil { - return err + + // add vulnerable + if m.Vulnerable != nil { + _, err := tx.Exec(ctx, "cpe-match/insert-vulnerable", id, *m.Vulnerable) + if err != nil { + return err + } } - } - // add version maximum - if m.VersionEndIncluding != "" && m.VersionEndExcluding != "" { - return fmt.Errorf("cannot specify both VersionEndIncluding = \"%s\", VersionEndIncluding \"%s\"", m.VersionEndIncluding, m.VersionEndExcluding) - } else if m.VersionEndIncluding != "" { - _, err := tx.Exec(ctx, "cpe-match/insert-version-max", id, true, m.VersionEndIncluding) - if err != nil { - return err + // add version minimum + if m.VersionStartIncluding != "" && m.VersionStartExcluding != "" { + return fmt.Errorf("cannot specify both VersionStartIncluding = \"%s\", VersionEndIncluding \"%s\"", m.VersionStartIncluding, m.VersionStartExcluding) + } else if m.VersionStartIncluding != "" { + _, err := tx.Exec(ctx, "cpe-match/insert-version-min", id, true, m.VersionStartIncluding) + if err != nil { + return err + } + } else if m.VersionStartExcluding != "" { + _, err := tx.Exec(ctx, "cpe-match/insert-version-min", id, false, m.VersionStartExcluding) + if err != nil { + return err + } } - } else if m.VersionEndExcluding != "" { - _, err := tx.Exec(ctx, "cpe-match/insert-version-max", id, false, m.VersionEndExcluding) - if err != nil { - return err + + // add version maximum + if m.VersionEndIncluding != "" && m.VersionEndExcluding != "" { + return fmt.Errorf("cannot specify both VersionEndIncluding = \"%s\", VersionEndIncluding \"%s\"", m.VersionEndIncluding, m.VersionEndExcluding) + } else if m.VersionEndIncluding != "" { + _, err := tx.Exec(ctx, "cpe-match/insert-version-max", id, true, m.VersionEndIncluding) + if err != nil { + return err + } + } else if m.VersionEndExcluding != "" { + _, err := tx.Exec(ctx, "cpe-match/insert-version-max", id, false, m.VersionEndExcluding) + if err != nil { + return err + } } - } - // add names - for _, name := range(m.Names) { - _, err := tx.Exec(ctx, "cpe-match/insert-name", id, name.Cpe23Uri, name.Cpe22Uri) - if err != nil { - return err + // add names + for _, name := range(m.Names) { + _, err := tx.Exec(ctx, "cpe-match/insert-name", id, name.Cpe23Uri, name.Cpe22Uri) + if err != nil { + return err + } } } - } - // commit changes, return result - return tx.Commit() + // return success + return nil + }) } // search CPE matches diff --git a/dbstore/tx_test.go b/dbstore/tx_test.go index df15f9e..0e3ed23 100644 --- a/dbstore/tx_test.go +++ b/dbstore/tx_test.go @@ -33,13 +33,13 @@ func TestNewTxDupTx(t *testing.T) { } // create first transaction - if _, err = db.Tx(ctx, []string{}); err != nil { + if _, err = db.Begin(ctx, []string{}); err != nil { t.Error(err) return } // create second transaction - got, err := db.Tx(ctx, []string{}) + got, err := db.Begin(ctx, []string{}) if err != nil { // FIXME t.Errorf("got %v, exp error", got) return @@ -72,7 +72,7 @@ func TestNewTxBadQueries(t *testing.T) { } // create first transaction - if got, err := db.Tx(ctx, badQueryIds); err == nil { + if got, err := db.Begin(ctx, badQueryIds); err == nil { t.Errorf("got %v, exp error", got) return } @@ -96,7 +96,7 @@ func TestNewTxPrepareFail(t *testing.T) { } // create transaction - if got, err := db.Tx(ctx, []string { "test/junk" }); err == nil { + if got, err := db.Begin(ctx, []string { "test/junk" }); err == nil { t.Errorf("got %v, exp error", got) } } @@ -125,7 +125,7 @@ func TestTxExecFail(t *testing.T) { } // create transaction - tx, err := db.Tx(ctx, []string { "init" }) + tx, err := db.Begin(ctx, []string { "init" }) if err != nil { t.Error(err) return -- cgit v1.2.3