Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- -- Cases:
- -- 1:
- -- 2: row
- -- 3: row,ffi
- -- 4: row,ffi,gc
- -- 5: row,copy
- -- +: cases 1-5 + unroll
- local ffi = require 'ffi'
- local type, setmetatable = type, setmetatable
- local opt = arg[1] or ''
- local row_mt
- row_mt = {
- __index = function(self, k)
- if type(k) == 'number' then
- return self._o[self._r][k]
- else
- return row_mt[k]
- end
- end,
- len = function(self) -- No __len for tables.
- return self._m
- end,
- }
- local row_ffi_mt
- row_ffi_mt = {
- __index = function(self, k)
- if type(k) == 'number' then
- return self._p[k-1]
- else
- return row_ffi_mt[k]
- end
- end,
- len = function(self)
- return self._m
- end,
- __gc = (opt:find('gc') and not opt:find('copy')) and function()
- -- Not important.
- end or nil,
- }
- local row_ffi = ffi.metatype('struct { int32_t _m; double* _p; }', row_ffi_mt)
- local row_copy = ffi.metatype('struct { int32_t _m; double _p[?]; }', row_ffi_mt)
- local mat_mt
- mat_mt = {
- __new = function(ct, nrow, ncol)
- local o = ffi.new(ct, nrow*ncol)
- o._n, o._m = nrow, ncol
- return o
- end,
- __index = function(self, k)
- if type(k) == 'number' then
- return self._p-1 + (k-1)*self._m
- else
- return mat_mt[k]
- end
- end,
- nrow = function(self)
- return self._n
- end,
- ncol = function(self)
- return self._m
- end,
- }
- if opt:find('ffi') then
- mat_mt.row = function(self, r)
- return row_ffi(self._m, self._p + (r-1)*self._m)
- end
- elseif opt:find('copy') then
- mat_mt.row = function(self, r)
- local o = row_copy(self._m)
- o._m = self._m
- ffi.copy(o._p, self._p + (r-1)*self._m, 8*self._m)
- return o
- end
- else
- mat_mt.row = function(self, r)
- return setmetatable({ _o = self, _r = r, _m = self._m }, row_mt)
- end
- end
- local mat = ffi.metatype('struct { int32_t _n, _m; double _p[?]; }', mat_mt)
- local NROW, NCOL = 5, 5
- local NSIM = 1e6
- local A = mat(NROW, NCOL)
- for r=1,NROW do for c=1,NCOL do A[r][c] = r*10 + c end end
- local min, max = math.min, math.max
- local function range(x)
- local l, u = -1/0, 1/0
- for i=1,x:len() do
- l, u = max(l, x[i]), min(u, x[i])
- end
- return u - l
- end
- local function range_unroll(x)
- local l, u = -1/0, 1/0
- l, u = max(l, x[1]), min(u, x[1])
- l, u = max(l, x[2]), min(u, x[2])
- l, u = max(l, x[3]), min(u, x[3])
- l, u = max(l, x[4]), min(u, x[4])
- l, u = max(l, x[5]), min(u, x[5])
- return u - l
- end
- if opt:find('unroll') then
- assert(NROW == 5, 'being lazy!')
- end
- local algo = opt:find('unroll') and range_unroll or range
- local clock_start = os.clock()
- local range_sum = 0
- if opt:find('row') then
- -- Via rows access:
- for _=1,NSIM do
- for r=1,A:nrow() do
- range_sum = range_sum + algo(A:row(r))
- end
- end
- else
- -- Direct access:
- if opt:find('unroll') then
- for _=1,NSIM do
- for r=1,A:nrow() do
- local l, u = -1/0, 1/0
- l, u = max(l, A[r][1]), min(u, A[r][1])
- l, u = max(l, A[r][2]), min(u, A[r][2])
- l, u = max(l, A[r][3]), min(u, A[r][3])
- l, u = max(l, A[r][4]), min(u, A[r][4])
- l, u = max(l, A[r][5]), min(u, A[r][5])
- range_sum = range_sum + (u - l)
- end
- end
- else
- for _=1,NSIM do
- for r=1,A:nrow() do
- local l, u = -1/0, 1/0
- for c = 1,A:ncol() do
- l, u = max(l, A[r][c]), min(u, A[r][c])
- end
- range_sum = range_sum + (u - l)
- end
- end
- end
- end
- print(range_sum)
- print(os.clock() - clock_start)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement