mirror of
https://github.com/TangentFoxy/slam.git
synced 2024-11-14 10:24:21 +00:00
222 lines
5.5 KiB
Lua
222 lines
5.5 KiB
Lua
local socket = require 'socket'
|
|
assert(common and common.class, "A Class Commons implementation is required")
|
|
|
|
-- TODO: more serializers
|
|
local function is_serializable(t,v)
|
|
return t == "number" or
|
|
t == "boolean" or
|
|
t == "string" or
|
|
t == "nil" or
|
|
(getmetatable(v) or {}).__tostring ~= nil
|
|
end
|
|
|
|
local function serialize(...)
|
|
local args = {n = select('#', ...), ...}
|
|
local serialized = {}
|
|
for i = 1,args.n do
|
|
local t, v = type(args[i]), args[i]
|
|
if not is_serializable(t,v) then
|
|
error(("Cannot serialize values of type `%s'."):format(t))
|
|
end
|
|
serialized[i] = ("%s<%s>"):format(t,tostring(v))
|
|
end
|
|
return table.concat(serialized, ",")
|
|
end
|
|
|
|
local converter = {
|
|
['nil'] = function() return nil end,
|
|
string = function(v) return v end,
|
|
number = tonumber,
|
|
boolean = function(v) return v == 'true' end,
|
|
}
|
|
local function deserialize_helper(iter)
|
|
local token = iter()
|
|
if not token then return end
|
|
|
|
local t,v = token:match('(%w+)(%b<>)')
|
|
return (converter[t] or error)(v:sub(2,-2)), deserialize_helper(iter)
|
|
end
|
|
local function deserialize(str)
|
|
return deserialize_helper(str:gmatch('%w+%b<>'))
|
|
end
|
|
|
|
--
|
|
-- RPC SERVER
|
|
--
|
|
local server = {}
|
|
|
|
local function capabilities(self, pattern)
|
|
pattern = pattern or ".*"
|
|
local ret = {}
|
|
for name, _ in pairs(self.registry) do
|
|
if name:match(pattern) then
|
|
ret[#ret+1] = name
|
|
end
|
|
end
|
|
return table.concat(ret, "\n")
|
|
end
|
|
|
|
function server:init(port, address)
|
|
port = port or 0
|
|
address = address or '*'
|
|
self.socket = assert(socket.bind(address, port))
|
|
self.socket:settimeout(0)
|
|
self.address, self.port = self.socket:getsockname()
|
|
self.registry = {}
|
|
|
|
function self.registry.capabilities(...) return capabilities(self, ...) end
|
|
end
|
|
|
|
function server:register(name, func)
|
|
assert(name, "Missing argument: `name'")
|
|
assert(func, "Missing argument: `func'")
|
|
self.registry[name] = func
|
|
end
|
|
|
|
function server:remove(name)
|
|
assert(name, "Missing argument: `name'")
|
|
self.registry[name] = nil
|
|
end
|
|
|
|
function server:defined(name)
|
|
assert(name, "Missing argument: `name'")
|
|
return self.registry[name] ~= nil
|
|
end
|
|
|
|
local function execute(self, func, args)
|
|
if not self.registry[func] then
|
|
return false, ("Tried to execute unknown function `%s'"):format(func)
|
|
end
|
|
|
|
return (function(pcall_ok, ...)
|
|
if pcall_ok then
|
|
return true, serialize(...)
|
|
end
|
|
return false, ...
|
|
end)(pcall(self.registry[func], deserialize(args)))
|
|
end
|
|
|
|
function server:serve()
|
|
assert(self.socket, "Server socket not initialized")
|
|
|
|
local client,err = self.socket:accept()
|
|
if client then
|
|
local line = client:receive('*l')
|
|
local bytes = tonumber(line:match('^RPC:(%d+)$'))
|
|
if bytes then
|
|
line = client:receive(bytes)
|
|
local token, func, args = line:match('^([^:]+):([^:]+):(.*)%s*')
|
|
if token and func and args then
|
|
local ok, ret = execute(self, func, args)
|
|
local str = ("RPC:%s:%s:%s\r\n"):format(token, tostring(ok), ret)
|
|
client:send(str)
|
|
end
|
|
end
|
|
client:close()
|
|
elseif not client and err ~= 'timeout' then
|
|
error(err)
|
|
end
|
|
end
|
|
|
|
--
|
|
-- RPC CLIENT
|
|
--
|
|
local client = {}
|
|
|
|
function client:init(address, port)
|
|
assert(address and port, "Need server address and port")
|
|
self.address, self.port = address, port
|
|
self.workers = {}
|
|
self.on_success = print
|
|
self.on_failure = error
|
|
self.rpc = setmetatable({}, {__index = function(_,func)
|
|
return function(...) return self:call(func, self.on_success, self.on_failure,...) end
|
|
end})
|
|
end
|
|
|
|
local function query(self, func, args)
|
|
local client = socket.tcp()
|
|
client:settimeout(10)
|
|
local _, err = client:connect(self.address, self.port)
|
|
if err then
|
|
return false, ("Cannot connect to %s[%s]: %s"):format(self.address, self.port, err)
|
|
end
|
|
client:settimeout(0)
|
|
|
|
local token = ("%d-%s"):format(os.time(), math.random())
|
|
local str = ("%s:%s:%s\r\n"):format(token, func, args)
|
|
|
|
_, err = client:send(("RPC:%d\r\n"):format(str:len()))
|
|
if err then
|
|
client:close()
|
|
return false, ("Cannot send query header to %s[%s]: %s"):format(self.address, self.port, err)
|
|
end
|
|
|
|
_, err = client:send(str)
|
|
if err then
|
|
client:close()
|
|
return false, ("Cannot send query message to %s[%s]: %s"):format(self.address, self.port, err)
|
|
end
|
|
|
|
local lines = {}
|
|
while true do
|
|
local line, err = client:receive('*l')
|
|
if line then
|
|
lines[#lines+1] = line
|
|
elseif err == "closed" then
|
|
local ret = table.concat(lines,'\n')
|
|
local ret_token, success, values = ret:match("^RPC:([^:]+):([^:]+):(.*)%s*$")
|
|
if not (ret_token and success and values) then
|
|
return false, ("Malformated answer: `%s'"):format(ret)
|
|
end
|
|
|
|
if ret_token == token then
|
|
return (success == 'true'), values
|
|
else
|
|
return false, ("Token mismatch: expected `%s', got `%s'"):format(token, ret_token)
|
|
end
|
|
end
|
|
-- err == 'timeout'
|
|
coroutine.yield()
|
|
end
|
|
end
|
|
|
|
function client:call(func, on_success, on_failure, ...)
|
|
local args = serialize(...)
|
|
local q = coroutine.create(function() return query(self, func, args) end)
|
|
local worker = function()
|
|
local coroutine_ok, call_ok, returns = coroutine.resume(q)
|
|
if coroutine_ok and call_ok ~= nil and returns ~= nil then
|
|
if call_ok then
|
|
on_success(deserialize(returns))
|
|
else
|
|
on_failure(returns)
|
|
end
|
|
return false
|
|
end
|
|
return coroutine_ok
|
|
end
|
|
self.workers[worker] = worker
|
|
end
|
|
|
|
function client:dispatch()
|
|
local to_remove = {}
|
|
for _, worker in pairs(self.workers) do
|
|
if not worker() then
|
|
to_remove[worker] = worker
|
|
end
|
|
end
|
|
|
|
for _, worker in pairs(to_remove) do
|
|
self.workers[worker] = nil
|
|
end
|
|
end
|
|
|
|
--
|
|
-- THE MODULE
|
|
--
|
|
return {
|
|
server = common.class("RPC.server", server),
|
|
client = common.class("RPC.client", client),
|
|
}
|