diff --git a/middleclass.lua b/middleclass.lua index dae6496..1822bc5 100644 --- a/middleclass.lua +++ b/middleclass.lua @@ -6,43 +6,67 @@ local _classes = setmetatable({}, {__mode = "k"}) -local function _initializeClass(klass, super) - +local function _setClassDictionariesMetatables(klass) local dict = klass.__instanceDict + local super = klass.superclass + + dict.__index = dict if super then - setmetatable(dict, { __index = super.__instanceDict }) + setmetatable(dict, super.__instanceDict) setmetatable(klass.static, { __index = function(_,k) return dict[k] or super[k] end }) else setmetatable(klass.static, { __index = function(_,k) return dict[k] end }) end +end +local function _setClassMetatable(klass) setmetatable(klass, { __tostring = function() return "class " .. klass.name end, __index = klass.static, __newindex = klass.__instanceDict, __call = function(_, ...) return klass:new(...) end }) - - _classes[klass] = true end -Object = { - name = "Object", - static = {}, - __mixins = {}, - __instanceDict = {}, - __metamethods = { '__add', '__call', '__concat', '__div', '__le', '__lt', - '__mod', '__mul', '__pow', '__sub', '__tostring', '__unm' } -} +local function _createClass(name, super) + local klass = { name = name, superclass = super, static = {}, __mixins = {}, __instanceDict={} } -_initializeClass(Object) + _setClassDictionariesMetatables(klass) + _setClassMetatable(klass) + _classes[klass] = true -Object.initialize = function() end + return klass +end + +local function _createLookupMetamethod(klass, methodName) + return function(...) + local method = klass.superclass[methodName] + assert( type(method)=='function', tostring(klass) .. " doesn't implement metamethod '" .. methodName .. "'" ) + return method(...) + end +end + +local function _setClassMetamethods(klass) + for _,m in ipairs(klass.__metamethods) do + klass[m]= _createLookupMetamethod(klass, m) + end +end + +local function _setDefaultInitializeMethod(klass) + klass.initialize = function(instance, ...) + return klass.superclass.initialize(instance, ...) + end +end + +Object = _createClass("Object", nil) + +Object.static.__metamethods = { '__add', '__call', '__concat', '__div', '__le', '__lt', + '__mod', '__mul', '__pow', '__sub', '__tostring', '__unm' } function Object.static:allocate() assert(_classes[self], "Make sure that you are using 'Class:allocate' instead of 'Class.allocate'") - return setmetatable({ class = self }, {__index = self.__instanceDict }) + return setmetatable({ class = self }, self.__instanceDict) end function Object.static:new(...) @@ -55,38 +79,22 @@ function Object.static:subclass(name) assert(_classes[self], "Make sure that you are using 'Class:subclass' instead of 'Class.subclass'") assert(type(name) == "string", "You must provide a name(string) for your class") - local subclass = { name = name, superclass = self, static = {}, __mixins = {}, __instanceDict={} } - - _initializeClass(subclass, self) + local subclass = _createClass(name, self) + _setClassMetamethods(subclass) + _setDefaultInitializeMethod(subclass) return subclass end +function Object:initialize() end + +function Object:__tostring() return "instance of " .. tostring(self.class) end + --[[ -- creates a subclass function Object.subclass(klass, name) - setmetatable(thesubclass, { - __index = dict, -- look for stuff on the dict - __newindex = function(_, methodName, method) -- ensure that __index isn't modified by mistake - assert(methodName ~= '__index', "Can't modify __index. Include middleclass-extras.Indexable and use 'index' instead") - rawset(dict, methodName , method) - end, - __tostring = function() return ("class ".. name) end, -- allows tostring(MyClass) - __call = function(_, ...) return thesubclass:new(...) end -- allows MyClass(...) instead of MyClass:new(...) - }) - - for _,mmName in ipairs(klass.__metamethods) do -- Creates the initial metamethods - dict[mmName]= function(...) -- by default, they just 'look up' for an implememtation - local method = superDict[mmName] -- and if none found, they throw an error - assert( type(method)=='function', tostring(thesubclass) .. " doesn't implement metamethod '" .. mmName .. "'" ) - return method(...) - end - end - - thesubclass.initialize = function(instance,...) klass.initialize(instance, ...) end - _classes[thesubclass]= true -- registers the new class on the list of _classes klass:subclassed(thesubclass) -- hook method. By default it does nothing return thesubclass diff --git a/spec/Object_spec.lua b/spec/Object_spec.lua index 2bc660b..8581be6 100644 --- a/spec/Object_spec.lua +++ b/spec/Object_spec.lua @@ -21,64 +21,6 @@ context('Object', function() end) end) - context('instance creation', function() - - local MyClass - - before(function() - MyClass = Object:subclass('MyClass') - function MyClass:initialize() self.mark=true end - end) - - context('allocate', function() - - test('allocates instances properly', function() - local instance = MyClass:allocate() - assert_equal(instance.class, MyClass) - end) - - test('throws an error when used without the :', function() - assert_error(Object.allocate) - end) - - test('does not call the initializer', function() - local allocated = MyClass:allocate() - assert_nil(allocated.mark) - end) - - test('can be overriden', function() - function MyClass.static:allocate() - local instance = Object:allocate() - instance.mark = true - return instance - end - - local allocated = MyClass:allocate() - assert_true(allocated.mark) - end) - - end) - - context('new', function() - - test('initializes instances properly', function() - local instance = MyClass:new() - assert_equal(instance.class, MyClass) - end) - - test('throws an error when used without the :', function() - assert_error(MyClass.new) - end) - - test('calls the initializer', function() - local allocated = MyClass:new() - assert_true(allocated.mark) - end) - - end) - - end) - context('subclass', function() test('throws an error when used without the :', function() @@ -87,14 +29,14 @@ context('Object', function() context('when given a class name', function() - local MyClass = Object:subclass('MyClass') + local SubClass = Object:subclass('SubClass') test('it returns a class with the correct name', function() - assert_equal(MyClass.name, 'MyClass') + assert_equal(SubClass.name, 'SubClass') end) test('it returns a class with the correct superclass', function() - assert_equal(MyClass.superclass, Object) + assert_equal(SubClass.superclass, Object) end) end) @@ -106,189 +48,76 @@ context('Object', function() end) -end) + context('instance creation', function() + local SubClass + local classes = { Object, SubClass } - - - ---[[ - - context('Metamethods', function() - - test('__index should throw an error', function() - local NonIndexable = class('NonIndexable') - - assert_error(function() function NonIndexable:__index(name) end end) + before(function() + SubClass = Object:subclass('SubClass') + function SubClass:initialize() self.mark=true end end) - context('Custom Metamethods', function() - -- Tests all metamethods. Note that __len is missing (lua makes table length unoverridable) - -- I'll use a() to note the length of vector "a" (I would have preferred to use #a, but it's not possible) - -- I'll be using 'a' instead of 'self' on this example since it is shorter - local Vector= class('Vector') - function Vector.initialize(a,x,y,z) a.x, a.y, a.z = x,y,z end - function Vector.__tostring(a) return a.class.name .. '[' .. a.x .. ',' .. a.y .. ',' .. a.z .. ']' end - function Vector.__eq(a,b) return a.x==b.x and a.y==b.y and a.z==b.z end - function Vector.__lt(a,b) return a() < b() end - function Vector.__le(a,b) return a() <= b() end - function Vector.__add(a,b) return Vector:new(a.x+b.x, a.y+b.y ,a.z+b.z) end - function Vector.__sub(a,b) return Vector:new(a.x-b.x, a.y-b.y, a.z-b.z) end - function Vector.__div(a,s) return Vector:new(a.x/s, a.y/s, a.z/s) end - function Vector.__unm(a) return Vector:new(-a.x, -a.y, -a.z) end - function Vector.__concat(a,b) return a.x*b.x+a.y*b.y+a.z*b.z end - function Vector.__call(a) return math.sqrt(a.x*a.x+a.y*a.y+a.z*a.z) end - function Vector.__pow(a,b) - return Vector:new(a.y*b.z-a.z*b.y,a.z*b.x-a.x*b.z,a.x*b.y-a.y*b.x) - end - function Vector.__mul(a,b) - if type(b)=="number" then return Vector:new(a.x*b, a.y*b, a.z*b) end - if type(a)=="number" then return Vector:new(a*b.x, a*b.y, a*b.z) end - end + for _,theClass in ipairs(classes) do + context(theClass.name, function() - local a = Vector:new(1,2,3) - local b = Vector:new(2,4,6) + context('allocate', function() - for metamethod,values in pairs({ - __tostring = { tostring(a), "Vector[1,2,3]" }, - __eq = { a, a}, - __lt = { a