slam/lua-rpc/rpc.lua
2011-12-02 14:07:55 +01:00

222 lines
5.5 KiB
Lua

local socket = require 'socket'
assert(common and common.class, "A Class Commons implementation is required")
-- TODO: more serializers
local function is_serializable(t,v)
return t == "number" or
t == "boolean" or
t == "string" or
t == "nil" or
(getmetatable(v) or {}).__tostring ~= nil
end
local function serialize(...)
local args = {n = select('#', ...), ...}
local serialized = {}
for i = 1,args.n do
local t, v = type(args[i]), args[i]
if not is_serializable(t,v) then
error(("Cannot serialize values of type `%s'."):format(t))
end
serialized[i] = ("%s<%s>"):format(t,tostring(v))
end
return table.concat(serialized, ",")
end
local converter = {
['nil'] = function() return nil end,
string = function(v) return v end,
number = tonumber,
boolean = function(v) return v == 'true' end,
}
local function deserialize_helper(iter)
local token = iter()
if not token then return end
local t,v = token:match('(%w+)(%b<>)')
return (converter[t] or error)(v:sub(2,-2)), deserialize_helper(iter)
end
local function deserialize(str)
return deserialize_helper(str:gmatch('%w+%b<>'))
end
--
-- RPC SERVER
--
local server = {}
local function capabilities(self, pattern)
pattern = pattern or ".*"
local ret = {}
for name, _ in pairs(self.registry) do
if name:match(pattern) then
ret[#ret+1] = name
end
end
return table.concat(ret, "\n")
end
function server:init(port, address)
port = port or 0
address = address or '*'
self.socket = assert(socket.bind(address, port))
self.socket:settimeout(0)
self.address, self.port = self.socket:getsockname()
self.registry = {}
function self.registry.capabilities(...) return capabilities(self, ...) end
end
function server:register(name, func)
assert(name, "Missing argument: `name'")
assert(func, "Missing argument: `func'")
self.registry[name] = func
end
function server:remove(name)
assert(name, "Missing argument: `name'")
self.registry[name] = nil
end
function server:defined(name)
assert(name, "Missing argument: `name'")
return self.registry[name] ~= nil
end
local function execute(self, func, args)
if not self.registry[func] then
return false, ("Tried to execute unknown function `%s'"):format(func)
end
return (function(pcall_ok, ...)
if pcall_ok then
return true, serialize(...)
end
return false, ...
end)(pcall(self.registry[func], deserialize(args)))
end
function server:serve()
assert(self.socket, "Server socket not initialized")
local client,err = self.socket:accept()
if client then
local line = client:receive('*l')
local bytes = tonumber(line:match('^RPC:(%d+)$'))
if bytes then
line = client:receive(bytes)
local token, func, args = line:match('^([^:]+):([^:]+):(.*)%s*')
if token and func and args then
local ok, ret = execute(self, func, args)
local str = ("RPC:%s:%s:%s\r\n"):format(token, tostring(ok), ret)
client:send(str)
end
end
client:close()
elseif not client and err ~= 'timeout' then
error(err)
end
end
--
-- RPC CLIENT
--
local client = {}
function client:init(address, port)
assert(address and port, "Need server address and port")
self.address, self.port = address, port
self.workers = {}
self.on_success = print
self.on_failure = error
self.rpc = setmetatable({}, {__index = function(_,func)
return function(...) return self:call(func, self.on_success, self.on_failure,...) end
end})
end
local function query(self, func, args)
local client = socket.tcp()
client:settimeout(10)
local _, err = client:connect(self.address, self.port)
if err then
return false, ("Cannot connect to %s[%s]: %s"):format(self.address, self.port, err)
end
client:settimeout(0)
local token = ("%d-%s"):format(os.time(), math.random())
local str = ("%s:%s:%s\r\n"):format(token, func, args)
_, err = client:send(("RPC:%d\r\n"):format(str:len()))
if err then
client:close()
return false, ("Cannot send query header to %s[%s]: %s"):format(self.address, self.port, err)
end
_, err = client:send(str)
if err then
client:close()
return false, ("Cannot send query message to %s[%s]: %s"):format(self.address, self.port, err)
end
local lines = {}
while true do
local line, err = client:receive('*l')
if line then
lines[#lines+1] = line
elseif err == "closed" then
local ret = table.concat(lines,'\n')
local ret_token, success, values = ret:match("^RPC:([^:]+):([^:]+):(.*)%s*$")
if not (ret_token and success and values) then
return false, ("Malformated answer: `%s'"):format(ret)
end
if ret_token == token then
return (success == 'true'), values
else
return false, ("Token mismatch: expected `%s', got `%s'"):format(token, ret_token)
end
end
-- err == 'timeout'
coroutine.yield()
end
end
function client:call(func, on_success, on_failure, ...)
local args = serialize(...)
local q = coroutine.create(function() return query(self, func, args) end)
local worker = function()
local coroutine_ok, call_ok, returns = coroutine.resume(q)
if coroutine_ok and call_ok ~= nil and returns ~= nil then
if call_ok then
on_success(deserialize(returns))
else
on_failure(returns)
end
return false
end
return coroutine_ok
end
self.workers[worker] = worker
end
function client:dispatch()
local to_remove = {}
for _, worker in pairs(self.workers) do
if not worker() then
to_remove[worker] = worker
end
end
for _, worker in pairs(to_remove) do
self.workers[worker] = nil
end
end
--
-- THE MODULE
--
return {
server = common.class("RPC.server", server),
client = common.class("RPC.client", client),
}