diff --git a/lua-rpc/README.md b/lua-rpc/README.md new file mode 100644 index 0000000..2e9a4d4 --- /dev/null +++ b/lua-rpc/README.md @@ -0,0 +1,58 @@ +**Lua-RPC** is a simple premote procedure call protocol implemented in Lua. This +is meant to be a proof of concept. Please don't use it in productive environments. + + +Dependencies +============ + +- luasocket: http://w3.impa.br/~diego/software/luasocket/ +- any class commons enabled class library: https://github.com/bartbes/Class-Commons + --> hump.class is included + +Usage example +============= + +server.lua: + + require 'class' -- any class commons enabled library + local RPC = require 'rpc' + + -- open server at port 12345 on localhost + server = RPC.server(12345, '127.0.0.1') + + -- register 'print' function as remotely callable + server:register('print', print) + + -- register a function that returns a value + server:register('twice', function(x) return 2 * x end) + + -- yet another way to define callable functions + function server.registry.thrice(x) return 3 * x end + + -- run the server + while true do + server:serve() + end + + +client.lua: + + require 'class' + local RPC = require 'rpc' + + -- create new client to call functions from server at localhost:12345 + client = RPC.client('127.0.0.1', 12345) + + -- set actions what to do on success/failure of the remote call. + -- these are the defaults + client.on_success = print + client.on_failure = error + + -- queue some functions + client.rpc.print("Hello world!\nHello remote server!") + client.rpc.twice(2) + + -- you can also define function-specific callbacks. + -- prototype is client:call(function_name, on_success, on_failure, ...) + client:call('thrice', function(result) print('3 * 2 = ', result) end, function(err) print("RPC error:", err), 3) + diff --git a/lua-rpc/class.lua b/lua-rpc/class.lua new file mode 100644 index 0000000..7ce7252 --- /dev/null +++ b/lua-rpc/class.lua @@ -0,0 +1,116 @@ +--[[ +Copyright (c) 2010-2011 Matthias Richter + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +Except as contained in this notice, the name(s) of the above copyright holders +shall not be used in advertising or otherwise to promote the sale, use or +other dealings in this Software without prior written authorization. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +]]-- + +local function __NULL__() end + +-- class "inheritance" by copying functions +local function inherit(class, interface, ...) + if not interface then return end + assert(type(interface) == "table", "Can only inherit from other classes.") + + -- __index and construct are not overwritten as for them class[name] is defined + for name, func in pairs(interface) do + if not class[name] then + class[name] = func + end + end + for super in pairs(interface.__is_a or {}) do + class.__is_a[super] = true + end + + return inherit(class, ...) +end + +-- class builder +local function new(args) + local super = {} + local name = '' + local constructor = args or __NULL__ + if type(args) == "table" then + -- nasty hack to check if args.inherits is a table of classes or a class or nil + super = (args.inherits or {}).__is_a and {args.inherits} or args.inherits or {} + name = args.name or name + constructor = args[1] or __NULL__ + end + assert(type(constructor) == "function", 'constructor has to be nil or a function') + + -- build class + local class = {} + class.__index = class + class.__tostring = function() return (""):format(tostring(class)) end + class.construct = constructor or __NULL__ + class.inherit = inherit + class.__is_a = {[class] = true} + class.is_a = function(self, other) return not not self.__is_a[other] end + + -- intercept assignment in global environment to infer the class name + if not (type(args) == "table" and args.name) then + local env, env_meta, interceptor = getfenv(0), getmetatable(getfenv(0)), {} + function interceptor:__newindex(key, value) + if value == class then + local name = tostring(key) + getmetatable(class).__tostring = function() return name end + end + -- restore old metatable and insert value + setmetatable(env, env_meta) + if env.global then env.global(key) end -- handle 'strict' module + env[key] = value + end + setmetatable(env, interceptor) + end + + -- inherit superclasses (see above) + inherit(class, unpack(super)) + + -- syntactic sugar + local meta = { + __call = function(self, ...) + local obj = {} + setmetatable(obj, self) + self.construct(obj, ...) + return obj + end, + __tostring = function() return name end + } + return setmetatable(class, meta) +end + +-- interface for cross class-system compatibility (see https://github.com/bartbes/Class-Commons). +if class_commons ~= false and not common then + common = {} + function common.class(name, prototype, parent) + local init = prototype.init or (parent or {}).init + return new{name = name, inherits = {prototype, parent}, init} + end + function common.instance(class, ...) + return class(...) + end +end + + +-- the module +return setmetatable({new = new, inherit = inherit}, + {__call = function(_,...) return new(...) end}) diff --git a/lua-rpc/rpc.lua b/lua-rpc/rpc.lua new file mode 100644 index 0000000..165d86e --- /dev/null +++ b/lua-rpc/rpc.lua @@ -0,0 +1,221 @@ +local socket = require 'socket' +assert(common and common.class, "A Class Commons implementation is required") + +-- TODO: more serializers +local function is_serializable(t,v) + return t == "number" or + t == "boolean" or + t == "string" or + t == "nil" or + (getmetatable(v) or {}).__tostring ~= nil +end + +local function serialize(...) + local args = {n = select('#', ...), ...} + local serialized = {} + for i = 1,args.n do + local t, v = type(args[i]), args[i] + if not is_serializable(t,v) then + error(("Cannot serialize values of type `%s'."):format(t)) + end + serialized[i] = ("%s<%s>"):format(t,tostring(v)) + end + return table.concat(serialized, ",") +end + +local converter = { + ['nil'] = function() return nil end, + string = function(v) return v end, + number = tonumber, + boolean = function(v) return v == 'true' end, +} +local function deserialize_helper(iter) + local token = iter() + if not token then return end + + local t,v = token:match('(%w+)(%b<>)') + return (converter[t] or error)(v:sub(2,-2)), deserialize_helper(iter) +end +local function deserialize(str) + return deserialize_helper(str:gmatch('%w+%b<>')) +end + +-- +-- RPC SERVER +-- +local server = {} + +local function capabilities(self, pattern) + pattern = pattern or ".*" + local ret = {} + for name, _ in pairs(self.registry) do + if name:match(pattern) then + ret[#ret+1] = name + end + end + return table.concat(ret, "\n") +end + +function server:init(port, address) + port = port or 0 + address = address or '*' + self.socket = assert(socket.bind(address, port)) + self.socket:settimeout(0) + self.address, self.port = self.socket:getsockname() + self.registry = {} + + function self.registry.capabilities(...) return capabilities(self, ...) end +end + +function server:register(name, func) + assert(name, "Missing argument: `name'") + assert(func, "Missing argument: `func'") + self.registry[name] = func +end + +function server:remove(name) + assert(name, "Missing argument: `name'") + self.registry[name] = nil +end + +function server:defined(name) + assert(name, "Missing argument: `name'") + return self.registry[name] ~= nil +end + +local function execute(self, func, args) + if not self.registry[func] then + return false, ("Tried to execute unknown function `%s'"):format(func) + end + + return (function(pcall_ok, ...) + if pcall_ok then + return true, serialize(...) + end + return false, ... + end)(pcall(self.registry[func], deserialize(args))) +end + +function server:serve() + assert(self.socket, "Server socket not initialized") + + local client,err = self.socket:accept() + if client then + local line = client:receive('*l') + local bytes = tonumber(line:match('^RPC:(%d+)$')) + if bytes then + line = client:receive(bytes) + local token, func, args = line:match('^([^:]+):([^:]+):(.*)%s*') + if token and func and args then + local ok, ret = execute(self, func, args) + local str = ("RPC:%s:%s:%s\r\n"):format(token, tostring(ok), ret) + client:send(str) + end + end + client:close() + elseif not client and err ~= 'timeout' then + error(err) + end +end + +-- +-- RPC CLIENT +-- +local client = {} + +function client:init(address, port) + assert(address and port, "Need server address and port") + self.address, self.port = address, port + self.workers = {} + self.on_success = print + self.on_failure = error + self.rpc = setmetatable({}, {__index = function(_,func) + return function(...) return self:call(func, self.on_success, self.on_failure,...) end + end}) +end + +local function query(self, func, args) + local client = socket.tcp() + client:settimeout(10) + local _, err = client:connect(self.address, self.port) + if err then + return false, ("Cannot connect to %s[%s]: %s"):format(self.address, self.port, err) + end + client:settimeout(0) + + local token = ("%d-%s"):format(os.time(), math.random()) + local str = ("%s:%s:%s\r\n"):format(token, func, args) + + _, err = client:send(("RPC:%d\r\n"):format(str:len())) + if err then + client:close() + return false, ("Cannot send query header to %s[%s]: %s"):format(self.address, self.port, err) + end + + _, err = client:send(str) + if err then + client:close() + return false, ("Cannot send query message to %s[%s]: %s"):format(self.address, self.port, err) + end + + local lines = {} + while true do + local line, err = client:receive('*l') + if line then + lines[#lines+1] = line + elseif err == "closed" then + local ret = table.concat(lines,'\n') + local ret_token, success, values = ret:match("^RPC:([^:]+):([^:]+):(.*)%s*$") + if not (ret_token and success and values) then + return false, ("Malformated answer: `%s'"):format(ret) + end + + if ret_token == token then + return (success == 'true'), values + else + return false, ("Token mismatch: expected `%s', got `%s'"):format(token, ret_token) + end + end + -- err == 'timeout' + coroutine.yield() + end +end + +function client:call(func, on_success, on_failure, ...) + local args = serialize(...) + local q = coroutine.create(function() return query(self, func, args) end) + local worker = function() + local coroutine_ok, call_ok, returns = coroutine.resume(q) + if coroutine_ok and call_ok ~= nil and returns ~= nil then + if call_ok then + on_success(deserialize(returns)) + else + on_failure(returns) + end + return false + end + return coroutine_ok + end + self.workers[worker] = worker +end + +function client:dispatch() + local to_remove = {} + for _, worker in pairs(self.workers) do + if not worker() then + to_remove[worker] = worker + end + end + + for _, worker in pairs(to_remove) do + self.workers[worker] = nil + end +end + +-- +-- THE MODULE +-- +return { + server = common.class("RPC.server", server), + client = common.class("RPC.client", client), +}