aboutsummaryrefslogtreecommitdiffstats
path: root/profile.lua
blob: 4609d3f8f5cc6a04563ad5d2f0e28eb263d13a68 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
local clock = os.clock

local profile = {}

-- function labels
local _labeled = {}
-- function definitions
local _defined = {}
-- time of last call
local _tcalled = {}
-- total execution time
local _telapsed = {}
-- number of calls
local _ncalls = {}
-- list of internal profiler functions
local _internal = {}

function profile.hooker(event, line, info)
  info = info or debug.getinfo(2, 'fnS')
  local f = info.func
  -- ignore the profiler itself
  if _internal[f] or info.what ~= "Lua" then
    return
  end
  -- get the function name if available
  if info.name then
    _labeled[f] = info.name
  end
  -- find the line definition
  if not _defined[f] then
    _defined[f] = info.short_src..":"..info.linedefined
    _ncalls[f] = 0
    _telapsed[f] = 0
  end
  if _tcalled[f] then
    local dt = clock() - _tcalled[f]
    _telapsed[f] = _telapsed[f] + dt
    _tcalled[f] = nil
  end
  if event == "tail call" then
    local prev = debug.getinfo(3, 'fnS')
    profile.hooker("return", line, prev)
    profile.hooker("call", line, info)
  elseif event == 'call' then
    _tcalled[f] = clock()
  else
    _ncalls[f] = _ncalls[f] + 1
  end
end

--- Sets a clock function to be used by the profiler.
-- @param f Clock function that returns a number
function profile.setclock(f)
  assert(type(f) == "function", "clock must be a function")
  clock = f
end

--- Starts collecting data.
function profile.start()
  if rawget(_G, 'jit') then
    jit.off()
    jit.flush()
  end
  debug.sethook(profile.hooker, "cr")
end

--- Stops collecting data.
function profile.stop()
  debug.sethook()
  for f in pairs(_tcalled) do
    local dt = clock() - _tcalled[f]
    _telapsed[f] = _telapsed[f] + dt
    _tcalled[f] = nil
  end
  -- merge closures
  local lookup = {}
  for f, d in pairs(_defined) do
    local id = (_labeled[f] or '?')..d
    local f2 = lookup[id]
    if f2 then
      _ncalls[f2] = _ncalls[f2] + (_ncalls[f] or 0)
      _telapsed[f2] = _telapsed[f2] + (_telapsed[f] or 0)
      _defined[f], _labeled[f] = nil, nil
      _ncalls[f], _telapsed[f] = nil, nil
    else
      lookup[id] = f
    end
  end
  collectgarbage('collect')
end

--- Resets all collected data.
function profile.reset()
  for f in pairs(_ncalls) do
    _ncalls[f] = 0
  end
  for f in pairs(_telapsed) do
    _telapsed[f] = 0
  end
  for f in pairs(_tcalled) do
    _tcalled[f] = nil
  end
  collectgarbage('collect')
end

function profile.comp(a, b)
  local dt = _telapsed[b] - _telapsed[a]
  if dt == 0 then
    return _ncalls[b] < _ncalls[a]
  end
  return dt < 0
end

--- Iterates all functions that have been called since the profile was started.
-- @param n Number of results (optional)
function profile.query(limit)
  local t = {}
  for f, n in pairs(_ncalls) do
    if n > 0 then
      t[#t + 1] = f
    end
  end
  table.sort(t, profile.comp)
  if limit then
    while #t > limit do
      table.remove(t)
    end
  end
  for i, f in ipairs(t) do
    local dt = 0
    if _tcalled[f] then
      dt = clock() - _tcalled[f]
    end
    t[i] = { i, _labeled[f] or '?', _ncalls[f], _telapsed[f] + dt, _defined[f] }
  end
  return t
end

local cols = { 3, 29, 11, 24, 32 }
function profile.report(n)
  local out = {}
  local report = profile.query(n)
  for i, row in ipairs(report) do
    for j = 1, 5 do
      local s = row[j]
      local l2 = cols[j]
      s = tostring(s)
      local l1 = s:len()
      if l1 < l2 then
        s = s..(' '):rep(l2-l1)
      elseif l1 > l2 then
        s = s:sub(l1 - l2 + 1, l1)
      end
      row[j] = s
    end
    out[i] = table.concat(row, ' | ')
  end

  local row = " +-----+-------------------------------+-------------+--------------------------+----------------------------------+ \n"
  local col = " | #   | Function                      | Calls       | Time                     | Code                             | \n"
  local sz = row..col..row
  if #out > 0 then
    sz = sz..' | '..table.concat(out, ' | \n | ')..' | \n'
  end
  return '\n'..sz..row
end

-- store all internal profiler functions
for k, v in pairs(profile) do
  if type(v) == "function" then
    _internal[v] = true
  end
end

return profile