aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--nvdmirror/sync.go246
-rw-r--r--nvdmirror/synccontext.go249
2 files changed, 249 insertions, 246 deletions
diff --git a/nvdmirror/sync.go b/nvdmirror/sync.go
index 546a0ed..6f0b4d4 100644
--- a/nvdmirror/sync.go
+++ b/nvdmirror/sync.go
@@ -1,258 +1,12 @@
-// mirror files from upstream NVD source
package nvdmirror
import (
- "bytes"
- "crypto/sha256"
- "errors"
- "fmt"
- "github.com/pablotron/cvez/atomictemp"
- "github.com/pablotron/cvez/feed"
"github.com/rs/zerolog/log"
- "io"
- "io/fs"
- "net/http"
- "net/url"
- "os"
"path/filepath"
)
-// Fetch result.
-type fetchResult struct {
- src string // source URL
- err error // fetch result
- modified bool // Was the result modified?
- path string // Destination file.
- headers http.Header // response headers
-}
-
-// Check result.
-type checkResult struct {
- metaUrl string // meta full url
- metaPath string // meta file path
- fullPath string // full file path
- err error // error
- match bool // true if size and hash match
-}
-
-type syncMessage struct {
- fetch fetchResult // fetch result
- check checkResult // check result
-}
-
-// sync context
-type syncContext struct {
- config SyncConfig // sync config
- client *http.Client // shared HTTP client
- cache Cache // cache
- dstDir string // destination directory
- ch chan syncMessage // sync message channel
-}
-
-// Create sync context.
-func newSyncContext(config SyncConfig, cache Cache, dstDir string) syncContext {
- // create shared transport and client
- tr := &http.Transport {
- MaxIdleConns: config.MaxIdleConns,
- IdleConnTimeout: config.IdleConnTimeout,
- }
-
- return syncContext {
- config: config,
- client: &http.Client{Transport: tr},
- cache: cache,
- dstDir: dstDir,
- ch: make(chan syncMessage),
- }
-}
-
-// Build request
-func (me syncContext) getRequest(srcUrl string) (*http.Request, error) {
- // create HTTP request
- req, err := http.NewRequest("GET", srcUrl, nil)
- if err != nil {
- return nil, err
- }
-
- // Add user-agent, if-none-match, and if-modified-since headers.
- req.Header.Add("user-agent", me.config.GetUserAgent())
- if headers, ok := me.cache.Get(srcUrl); ok {
- for k, v := range(headers) {
- req.Header.Add(k, v)
- }
- }
-
- // return success
- return req, nil
-}
-
-// Fetch URL and write result to destination directory.
-//
-// Note: This method is called from a goroutine and writes the results
-// back via the member channel.
-func (me syncContext) fetch(srcUrl string) {
- // parse source url
- src, err := url.Parse(srcUrl)
- if err != nil {
- me.ch <- syncMessage {
- fetch: fetchResult { src: srcUrl, err: err },
- }
- return
- }
-
- // build destination path
- path := filepath.Join(me.dstDir, filepath.Base(src.Path))
- log.Debug().Str("url", srcUrl).Str("path", path).Send()
-
- // create request
- req, err := me.getRequest(srcUrl)
- if err != nil {
- me.ch <- syncMessage {
- fetch: fetchResult { src: srcUrl, err: err },
- }
- return
- }
-
- // send request
- resp, err := me.client.Do(req)
- if err != nil {
- me.ch <- syncMessage {
- fetch: fetchResult { src: srcUrl, err: err },
- }
- return
- }
- defer resp.Body.Close()
-
- switch resp.StatusCode {
- case 200: // success
- // write to output file
- err := atomictemp.Create(path, func(f io.Writer) error {
- _, err := io.Copy(f, resp.Body)
- return err
- })
-
- if err != nil {
- // write failed
- me.ch <- syncMessage {
- fetch: fetchResult { src: srcUrl, err: err },
- }
- } else {
- me.ch <- syncMessage {
- fetch: fetchResult {
- src: srcUrl,
- modified: true,
- path: path,
- headers: resp.Header,
- },
- }
- }
- case 304: // not modified
- me.ch <- syncMessage {
- fetch: fetchResult { src: srcUrl },
- }
- default: // error
- code := resp.StatusCode
- err := fmt.Errorf("%d: %s", code, http.StatusText(code))
- me.ch <- syncMessage {
- fetch: fetchResult { src: srcUrl, err: err },
- }
- }
-}
-
-// read hash from given meta file.
-func (me syncContext) getMeta(path string) (*feed.Meta, error) {
- // open meta file
- f, err := os.Open(path)
- if err != nil {
- return nil, err
- }
- defer f.Close()
-
- // parse meta
- return feed.NewMeta(f)
-}
-
-// get hash of file in destination directory.
-func (me syncContext) getFileHash(path string) ([32]byte, error) {
- var r [32]byte
-
- // open file
- f, err := os.Open(path)
- if err != nil {
- return r, err
- }
- defer f.Close()
-
- // hash file
- hash := sha256.New()
- if _, err := io.Copy(hash, f); err != nil {
- return r, err
- }
-
- // copy sum to result, return success
- hash.Sum(r[:])
- return r, nil
-}
-
-// Check the size and hash in the metadata file against the full file.
-//
-// Note: This method is called from a goroutine and returns it's value
-// via the internal channel.
-func (me syncContext) check(metaUrl, fullUrl string) {
- // build result
- r := syncMessage {
- check: checkResult {
- metaUrl: metaUrl,
- // build paths
- metaPath: filepath.Join(me.dstDir, filepath.Base(metaUrl)),
- fullPath: filepath.Join(me.dstDir, filepath.Base(fullUrl)),
- },
- }
-
- // get size of full file
- size, err := getFileSize(r.check.fullPath)
- if errors.Is(err, fs.ErrNotExist) {
- r.check.match = false
- me.ch <- r
- return
- } else if err != nil {
- r.check.err = err
- me.ch <- r
- return
- }
-
- // get meta hash
- m, err := me.getMeta(r.check.metaPath)
- if err != nil {
- r.check.err = err
- me.ch <- r
- return
- }
-
- // check for file size match
- if size != m.GzSize {
- r.check.match = false
- me.ch <- r
- return
- }
-
- // get full hash
- fh, err := me.getFileHash(r.check.fullPath)
- if err != nil {
- r.check.err = err
- me.ch <- r
- return
- }
-
- // return result
- r.check.match = (bytes.Compare(m.Sha256[:], fh[:]) == 0)
- me.ch <- r
-}
-
// Sync to destination directory and return an array of updated files.
func Sync(config SyncConfig, cache Cache, dstDir string) []string {
- // log.Debug().Str("dstDir", dstDir).Msg("Sync")
-
// build sync context
ctx := newSyncContext(config, cache, dstDir)
diff --git a/nvdmirror/synccontext.go b/nvdmirror/synccontext.go
new file mode 100644
index 0000000..efc2f28
--- /dev/null
+++ b/nvdmirror/synccontext.go
@@ -0,0 +1,249 @@
+package nvdmirror
+
+import (
+ "bytes"
+ "crypto/sha256"
+ "errors"
+ "fmt"
+ "github.com/pablotron/cvez/atomictemp"
+ "github.com/pablotron/cvez/feed"
+ "github.com/rs/zerolog/log"
+ "io"
+ "io/fs"
+ "net/http"
+ "net/url"
+ "os"
+ "path/filepath"
+)
+
+// Fetch result.
+type fetchResult struct {
+ src string // source URL
+ err error // fetch result
+ modified bool // Was the result modified?
+ path string // Destination file.
+ headers http.Header // response headers
+}
+
+// Check result.
+type checkResult struct {
+ metaUrl string // meta full url
+ metaPath string // meta file path
+ fullPath string // full file path
+ err error // error
+ match bool // true if size and hash match
+}
+
+type syncMessage struct {
+ fetch fetchResult // fetch result
+ check checkResult // check result
+}
+
+// sync context
+type syncContext struct {
+ config SyncConfig // sync config
+ client *http.Client // shared HTTP client
+ cache Cache // cache
+ dstDir string // destination directory
+ ch chan syncMessage // sync message channel
+}
+
+// Create sync context.
+func newSyncContext(config SyncConfig, cache Cache, dstDir string) syncContext {
+ // create shared transport and client
+ tr := &http.Transport {
+ MaxIdleConns: config.MaxIdleConns,
+ IdleConnTimeout: config.IdleConnTimeout,
+ }
+
+ return syncContext {
+ config: config,
+ client: &http.Client{Transport: tr},
+ cache: cache,
+ dstDir: dstDir,
+ ch: make(chan syncMessage),
+ }
+}
+
+// Build request
+func (me syncContext) getRequest(srcUrl string) (*http.Request, error) {
+ // create HTTP request
+ req, err := http.NewRequest("GET", srcUrl, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ // Add user-agent, if-none-match, and if-modified-since headers.
+ req.Header.Add("user-agent", me.config.GetUserAgent())
+ if headers, ok := me.cache.Get(srcUrl); ok {
+ for k, v := range(headers) {
+ req.Header.Add(k, v)
+ }
+ }
+
+ // return success
+ return req, nil
+}
+
+// Fetch URL and write result to destination directory.
+//
+// Note: This method is called from a goroutine and writes the results
+// back via the member channel.
+func (me syncContext) fetch(srcUrl string) {
+ // parse source url
+ src, err := url.Parse(srcUrl)
+ if err != nil {
+ me.ch <- syncMessage {
+ fetch: fetchResult { src: srcUrl, err: err },
+ }
+ return
+ }
+
+ // build destination path
+ path := filepath.Join(me.dstDir, filepath.Base(src.Path))
+ log.Debug().Str("url", srcUrl).Str("path", path).Send()
+
+ // create request
+ req, err := me.getRequest(srcUrl)
+ if err != nil {
+ me.ch <- syncMessage {
+ fetch: fetchResult { src: srcUrl, err: err },
+ }
+ return
+ }
+
+ // send request
+ resp, err := me.client.Do(req)
+ if err != nil {
+ me.ch <- syncMessage {
+ fetch: fetchResult { src: srcUrl, err: err },
+ }
+ return
+ }
+ defer resp.Body.Close()
+
+ switch resp.StatusCode {
+ case 200: // success
+ // write to output file
+ err := atomictemp.Create(path, func(f io.Writer) error {
+ _, err := io.Copy(f, resp.Body)
+ return err
+ })
+
+ if err != nil {
+ // write failed
+ me.ch <- syncMessage {
+ fetch: fetchResult { src: srcUrl, err: err },
+ }
+ } else {
+ me.ch <- syncMessage {
+ fetch: fetchResult {
+ src: srcUrl,
+ modified: true,
+ path: path,
+ headers: resp.Header,
+ },
+ }
+ }
+ case 304: // not modified
+ me.ch <- syncMessage {
+ fetch: fetchResult { src: srcUrl },
+ }
+ default: // error
+ code := resp.StatusCode
+ err := fmt.Errorf("%d: %s", code, http.StatusText(code))
+ me.ch <- syncMessage {
+ fetch: fetchResult { src: srcUrl, err: err },
+ }
+ }
+}
+
+// read hash from given meta file.
+func (me syncContext) getMeta(path string) (*feed.Meta, error) {
+ // open meta file
+ f, err := os.Open(path)
+ if err != nil {
+ return nil, err
+ }
+ defer f.Close()
+
+ // parse meta
+ return feed.NewMeta(f)
+}
+
+// get hash of file in destination directory.
+func (me syncContext) getFileHash(path string) ([32]byte, error) {
+ var r [32]byte
+
+ // open file
+ f, err := os.Open(path)
+ if err != nil {
+ return r, err
+ }
+ defer f.Close()
+
+ // hash file
+ hash := sha256.New()
+ if _, err := io.Copy(hash, f); err != nil {
+ return r, err
+ }
+
+ // copy sum to result, return success
+ hash.Sum(r[:])
+ return r, nil
+}
+
+// Check the size and hash in the metadata file against the full file.
+//
+// Note: This method is called from a goroutine and returns it's value
+// via the internal channel.
+func (me syncContext) check(metaUrl, fullUrl string) {
+ // build result
+ r := syncMessage {
+ check: checkResult {
+ metaUrl: metaUrl,
+ // build paths
+ metaPath: filepath.Join(me.dstDir, filepath.Base(metaUrl)),
+ fullPath: filepath.Join(me.dstDir, filepath.Base(fullUrl)),
+ },
+ }
+
+ // get size of full file
+ size, err := getFileSize(r.check.fullPath)
+ if errors.Is(err, fs.ErrNotExist) {
+ r.check.match = false
+ me.ch <- r
+ return
+ } else if err != nil {
+ r.check.err = err
+ me.ch <- r
+ return
+ }
+
+ // get meta hash
+ m, err := me.getMeta(r.check.metaPath)
+ if err != nil {
+ r.check.err = err
+ me.ch <- r
+ return
+ }
+
+ // check for file size match
+ if size != m.GzSize {
+ r.check.match = false
+ me.ch <- r
+ return
+ }
+
+ // get full hash
+ fh, err := me.getFileHash(r.check.fullPath)
+ if err != nil {
+ r.check.err = err
+ me.ch <- r
+ return
+ }
+
+ // return result
+ r.check.match = (bytes.Compare(m.Sha256[:], fh[:]) == 0)
+ me.ch <- r
+}