aboutsummaryrefslogtreecommitdiff
path: root/km-load.c
blob: 3615aa92d1591131067673a55602c26a51602b99 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#include <stdbool.h> // bool
#include <stdio.h> // fscanf()
#include <string.h> // strerror()
#include <errno.h> // errno
#include "util.h"
#include "km.h"

#define FAIL(...) do { \
  if (cbs && cbs->on_error) { \
    char buf[1024]; \
    snprintf(buf, sizeof(buf), __VA_ARGS__); \
    cbs->on_error(buf, cb_data); \
  } \
  return false; \
} while (0)

_Bool
km_load(
  FILE * const fh,
  const km_load_cbs_t * const cbs,
  void * const cb_data
) {
  // read shape
  km_shape_t shape = { 0, 0 };
  if (fscanf(fh, "%zu %zu", &(shape.num_floats), &(shape.num_ints)) != 2) {
    FAIL("shape fscanf() failed: %s", strerror(errno));
  }

  if (cbs && cbs->on_shape) {
    // emit shape
    if (!cbs->on_shape(&shape, cb_data)) {
      // return failure
      return false;
    }
  }

  // alloc floats buffer
  float *floats = NULL;
  if (shape.num_floats > 0) {
    floats = malloc(sizeof(float) * shape.num_floats);
    if (!floats) {
      FAIL("floats malloc() failed: %s", strerror(errno));
    }
  }

  // alloc ints buffer
  int *ints = NULL;
  if (shape.num_ints > 0) {
    ints = malloc(sizeof(int) * shape.num_ints);
    if (!ints) {
      FAIL("ints malloc() failed: %s", strerror(errno));
    }
  }

  for (size_t row = 0; !feof(fh); row++) {
    for (size_t i = 0; i < shape.num_floats; i++) {
      if (fscanf(fh, " %f ", floats + i) != 1) {
        FAIL("[%zu, %zu] float fscanf() failed: %s", row, i, strerror(errno));
      }
    }

    // read ints
    for (size_t i = 0; i < shape.num_ints; i++) {
      if (fscanf(fh, " %d ", ints + i) != 1) {
        FAIL("[%zu, %zu] int fscanf() failed: %s", row, i, strerror(errno));
      }
    }

    if (cbs && cbs->on_row) {
      // emit row
      if (!cbs->on_row(floats, ints, cb_data)) {
        // return failure
        return false;
      }
    }
  }

  if (shape.num_floats > 0) {
    // free float buffer
    free(floats);
  }

  if (shape.num_ints > 0) {
    // free int buffer
    free(ints);
  }

  // return success
  return true;
}

_Bool
km_load_path(
  const char * const path,
  const km_load_cbs_t * const cbs,
  void * const cb_data
) {
  // open file
  FILE *fh = fopen(path, "rb");
  if (!fh) {
    FAIL("fopen(\"%s\") failed: %s", path, strerror(errno));
  }

  // load file, get result
  const bool r = km_load(fh, cbs, cb_data);

  // close file
  fclose(fh);

  // return result
  return r;
}