diff --git a/encoder.lua b/encoder.lua new file mode 100644 index 0000000..187813a --- /dev/null +++ b/encoder.lua @@ -0,0 +1,175 @@ +local ftcsv = { + _VERSION = 'ftcsv 1.2.0', + _DESCRIPTION = 'CSV library for Lua', + _URL = 'https://github.com/FourierTransformer/ftcsv', + _LICENSE = [[ + The MIT License (MIT) + + Copyright (c) 2016-2020 Shakil Thakur + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + ]] +} + +-- luajit/lua compatability layer +global jit: table +global _ENV: table +global loadstring: function(string) +local luaCompatibility: {string: function(string)} = {} +if type(jit) == 'table' or _ENV then + -- luajit and lua 5.2+ + luaCompatibility.load = _G.load +else + -- lua 5.1 + luaCompatibility.load = loadstring +end + +-- The ENCODER code is below here +-- This could be broken out, but is kept here for portability + +local type EncoderOptions = record + fieldsToKeep: {string} +end +local type GeneratorArgs = record + t: {CSVRow} + delimitField: function(string): string +end +local type CSVRow = {string: any} + + +local function delimitField(field: string): string + field = tostring(field) + if field:find('"') then + return field:gsub('"', '""') + else + return field + end +end + +local function escapeHeadersForLuaGenerator(headers: {string}): {string} + local escapedHeaders = {} + for i = 1, #headers do + if headers[i]:find('"') then + escapedHeaders[i] = headers[i]:gsub('"', '\\"') + else + escapedHeaders[i] = headers[i] + end + end + return escapedHeaders +end + +-- a function that compiles some lua code to quickly print out the csv +local function csvLineGenerator(inputTable: {CSVRow}, delimiter: string, headers: {string}): (function(string): (number, string), GeneratorArgs, number) + local escapedHeaders = escapeHeadersForLuaGenerator(headers) + + local outputFunc = [[ + local args, i = ... + i = i + 1; + if i > ]] .. #inputTable .. [[ then return nil end; + return i, '"' .. args.delimitField(args.t[i]["]] .. + table.concat(escapedHeaders, [["]) .. '"]] .. + delimiter .. [["' .. args.delimitField(args.t[i]["]]) .. + [["]) .. '"\r\n']] + + local arguments: GeneratorArgs = {} + arguments.t = inputTable + -- we want to use the same delimitField throughout, + -- so we're just going to pass it in + arguments.delimitField = delimitField + + return luaCompatibility.load(outputFunc), arguments, 0 + +end + +local function validateHeaders(headers: {string}, inputTable: {CSVRow}) + for i = 1, #headers do + if inputTable[1][headers[i]] == nil then + error("ftcsv: the field '" .. headers[i] .. "' doesn't exist in the inputTable") + end + end +end + +local function initializeOutputWithEscapedHeaders(escapedHeaders: {string}, delimiter: string): {string} + local output = {} + output[1] = '"' .. table.concat(escapedHeaders, '"' .. delimiter .. '"') .. '"\r\n' + return output +end + +local function escapeHeadersForOutput(headers: {string}): {string} + local escapedHeaders = {} + for i = 1, #headers do + escapedHeaders[i] = delimitField(headers[i]) + end + return escapedHeaders +end + +local function extractHeadersFromTable(inputTable: {CSVRow}): {string} + local headers = {} + for key, _ in pairs(inputTable[1]) do + headers[#headers+1] = key + end + + -- lets make the headers alphabetical + table.sort(headers) + + return headers +end + +local function getHeadersFromOptions(options: EncoderOptions): {string} + local headers: {string} = nil + if options then + if options.fieldsToKeep ~= nil then + assert( + type(options.fieldsToKeep) == "table", "ftcsv only takes in a list (as a table) for the optional parameter 'fieldsToKeep'. You passed in '" .. tostring(options.fieldsToKeep) .. "' of type '" .. type(options.fieldsToKeep) .. "'.") + headers = options.fieldsToKeep + + end + end + return headers +end + +local function initializeGenerator(inputTable: {CSVRow}, delimiter: string, options: EncoderOptions): ({string}, {string}) + -- delimiter MUST be one character + assert(#delimiter == 1 and type(delimiter) == "string", "the delimiter must be of string type and exactly one character") + + local headers = getHeadersFromOptions(options) + if headers == nil then + headers = extractHeadersFromTable(inputTable) + end + validateHeaders(headers, inputTable) + + local escapedHeaders = escapeHeadersForOutput(headers) + local output = initializeOutputWithEscapedHeaders(escapedHeaders, delimiter) + return output, headers +end + +-- works really quickly with luajit-2.1, because table.concat life +function ftcsv.encode(inputTable: {CSVRow}, delimiter: string, options: EncoderOptions): string + local output, headers = initializeGenerator(inputTable, delimiter, options) + + for i, line in csvLineGenerator(inputTable, delimiter, headers) do + output[i+1] = line + end + + -- combine and return final string + return table.concat(output) +end + +return ftcsv + diff --git a/ftcsv.lua b/ftcsv.lua index 74b455c..b8ae0f6 100644 --- a/ftcsv.lua +++ b/ftcsv.lua @@ -32,7 +32,16 @@ local sbyte = string.byte local ssub = string.sub -- luajit/lua compatability layer -local luaCompatibility = {} +global jit: table +global _ENV: table +global loadstring: function(string) +local record LuaCompatibility + load: function(string) + LuaJIT: boolean + findClosingQuote: function(i: number, inputLength: number, inputString: string, quote: number, doubleQuoteEscape: boolean) +end + +local luaCompatibility: LuaCompatibility = {} if type(jit) == 'table' or _ENV then -- luajit and lua 5.2+ luaCompatibility.load = _G.load @@ -41,14 +50,34 @@ else luaCompatibility.load = loadstring end +local type CSVRow = {string: any} + +local record ParserOptions + loadFromString: boolean + rename: {string: string} + fieldsToKeep: {string} + ignoreQuotes: boolean + headerFunc: function(string): string + headers: boolean + inputLength: number + delimiter: string + totalColumnCount: number + bufferSize: number + buffered: boolean + endOfFile: boolean + rowOffset: number + headerField: {string} + headersMetamethod: function(table, string, string) +end + -- luajit specific speedups -- luajit performs faster with iterating over string.byte, -- whereas vanilla lua performs faster with string.find if type(jit) == 'table' then luaCompatibility.LuaJIT = true -- finds the end of an escape sequence - function luaCompatibility.findClosingQuote(i, inputLength, inputString, quote, doubleQuoteEscape) - local currentChar, nextChar = sbyte(inputString, i), nil + function luaCompatibility.findClosingQuote(i: number, inputLength: number, inputString: string, quote: number, doubleQuoteEscape: boolean) + local currentChar, nextChar: (number, number) = sbyte(inputString, i), nil while i <= inputLength do nextChar = sbyte(inputString, i+1) @@ -73,8 +102,8 @@ else luaCompatibility.LuaJIT = false -- vanilla lua closing quote finder - function luaCompatibility.findClosingQuote(i, inputLength, inputString, quote, doubleQuoteEscape) - local j, difference + function luaCompatibility.findClosingQuote(i: number, inputLength: number, inputString: string, quote: number, doubleQuoteEscape: boolean) + local j, difference: number, number i, j = inputString:find('"+', i) if j == nil then return nil @@ -90,9 +119,9 @@ end -- determine the real headers as opposed to the header mapping -local function determineRealHeaders(headerField, fieldsToKeep) +local function determineRealHeaders(headerField: {string}, fieldsToKeep: {string: boolean}): {string} local realHeaders = {} - local headerSet = {} + local headerSet: {string: boolean} = {} for i = 1, #headerField do if not headerSet[headerField[i]] then if fieldsToKeep ~= nil and fieldsToKeep[headerField[i]] then @@ -108,7 +137,7 @@ local function determineRealHeaders(headerField, fieldsToKeep) end -local function determineTotalColumnCount(headerField, fieldsToKeep) +local function determineTotalColumnCount(headerField: {string}, fieldsToKeep: {string: boolean}): number local totalColumnCount = 0 local headerFieldSet = {} for _, header in pairs(headerField) do @@ -123,7 +152,7 @@ local function determineTotalColumnCount(headerField, fieldsToKeep) return totalColumnCount end -local function generateHeadersMetamethod(finalHeaders) +local function generateHeadersMetamethod(finalHeaders: {string}): function(table, string, string) -- if a header field tries to escape, we will simply return nil -- the parser will still parse, but wont get the performance benefit of -- having headers predefined @@ -139,20 +168,20 @@ local function generateHeadersMetamethod(finalHeaders) end -- main function used to parse -local function parseString(inputString, i, options) +local function parseString(inputString: string, i: number, options: ParserOptions): ({string: string}, number, number) -- keep track of my chars! local inputLength = options.inputLength or #inputString - local currentChar, nextChar = sbyte(inputString, i), nil + local currentChar, nextChar: number, number = sbyte(inputString, i), nil local skipChar = 0 - local field + local field: string local fieldStart = i local fieldNum = 1 local lineNum = 1 local lineStart = i local doubleQuoteEscape, emptyIdentified = false, false - local skipIndex + local skipIndex: number local charPatternToSkip = "[" .. options.delimiter .. "\r\n]" --bytes @@ -175,7 +204,7 @@ local function parseString(inputString, i, options) if headerField == nil then headerField = {} -- setup a metatable to simply return the key that's passed in - local headerMeta = {__index = function(_, key) return key end} + local headerMeta = {__index = function(_: table, key: string) return key end} setmetatable(headerField, headerMeta) end @@ -208,7 +237,7 @@ local function parseString(inputString, i, options) if headerField[fieldNum] ~= nil then outResults[lineNum][headerField[fieldNum]] = field else - error('ftcsv: too many columns in row ' .. options.rowOffset + lineNum) + error('ftcsv: too many columns in row ' .. (options.rowOffset + lineNum)) end end end @@ -265,7 +294,7 @@ local function parseString(inputString, i, options) fieldStart = i + 1 + skipChar lineStart = fieldStart else - error('ftcsv: too few columns in row ' .. options.rowOffset + lineNum) + error('ftcsv: too few columns in row ' .. (options.rowOffset + lineNum)) end else lineNum = lineNum + 1 @@ -290,7 +319,7 @@ local function parseString(inputString, i, options) outResults[lineNum] = nil return outResults, lineStart else - error("ftcsv: can't find closing quote in row " .. options.rowOffset + lineNum .. + error("ftcsv: can't find closing quote in row " .. (options.rowOffset + lineNum) .. ". Try running with the option ignoreQuotes=true if the source incorrectly uses quotes.") end end @@ -321,14 +350,14 @@ local function parseString(inputString, i, options) if fieldNum == 1 and field == "" then outResults[lineNum] = nil else - error('ftcsv: too few columns in row ' .. options.rowOffset + lineNum) + error('ftcsv: too few columns in row ' .. (options.rowOffset + lineNum)) end end return outResults, i, totalColumnCount end -local function handleHeaders(headerField, options) +local function handleHeaders(headerField: {string}, options: ParserOptions): {string | number} -- make sure a header isn't empty for _, headerName in ipairs(headerField) do if #headerName == 0 then @@ -370,19 +399,24 @@ local function handleHeaders(headerField, options) end -- load an entire file into memory -local function loadFile(textFile, amount) +local function loadFile(textFile: string, amount: string | number) local file = io.open(textFile, "r") if not file then error("ftcsv: File not found at " .. textFile) end - local lines = file:read(amount) + local lines: string + if amount is string then + file:read(amount) + else + file:read(amount) + end if amount == "*all" then file:close() end return lines, file end -local function initializeInputFromStringOrFile(inputFile, options, amount) +local function initializeInputFromStringOrFile(inputFile: string, options: ParserOptions, amount: string | number): (string, FILE) -- handle input via string or file! - local inputString, file + local inputString, file: (string, FILE) if options.loadFromString then inputString = inputFile else @@ -396,11 +430,11 @@ local function initializeInputFromStringOrFile(inputFile, options, amount) return inputString, file end -local function parseOptions(delimiter, options, fromParseLine) +local function parseOptions(delimiter: string, options: ParserOptions, fromParseLine: boolean): (ParserOptions, {string: boolean}) -- delimiter MUST be one character assert(#delimiter == 1 and type(delimiter) == "string", "the delimiter must be of string type and exactly one character") - local fieldsToKeep = nil + local fieldsToKeep: {string: boolean} = nil if options then if options.headers ~= nil then @@ -452,7 +486,7 @@ local function parseOptions(delimiter, options, fromParseLine) end -local function findEndOfHeaders(str, entireFile) +local function findEndOfHeaders(str: string, entireFile: boolean): number local i = 1 local quote = sbyte('"') local newlines = { @@ -482,7 +516,8 @@ local function findEndOfHeaders(str, entireFile) return i end -local function determineBOMOffset(inputString) + +local function determineBOMOffset(inputString: string): number -- BOM files start with bytes 239, 187, 191 if sbyte(inputString, 1) == 239 and sbyte(inputString, 2) == 187 @@ -493,12 +528,12 @@ local function determineBOMOffset(inputString) end end -local function parseHeadersAndSetupArgs(inputString, delimiter, options, fieldsToKeep, entireFile) +local function parseHeadersAndSetupArgs(inputString: string, delimiter: string, options: ParserOptions, fieldsToKeep: {string: boolean}, entireFile: boolean): (number, ParserOptions, {string}) local startLine = determineBOMOffset(inputString) local endOfHeaderRow = findEndOfHeaders(inputString, entireFile) - local parserArgs = { + local parserArgs: ParserOptions = { delimiter = delimiter, headerField = nil, fieldsToKeep = nil, @@ -528,7 +563,7 @@ local function parseHeadersAndSetupArgs(inputString, delimiter, options, fieldsT end -- runs the show! -function ftcsv.parse(inputFile, delimiter, options) +function ftcsv.parse(inputFile: string, delimiter: string, options: ParserOptions) local options, fieldsToKeep = parseOptions(delimiter, options, false) local inputString = initializeInputFromStringOrFile(inputFile, options, "*all") @@ -540,14 +575,14 @@ function ftcsv.parse(inputFile, delimiter, options) return output, finalHeaders end -local function getFileSize (file) +local function getFileSize (file: FILE): number local current = file:seek() local size = file:seek("end") file:seek("set", current) return size end -local function determineAtEndOfFile(file, fileSize) +local function determineAtEndOfFile(file: FILE, fileSize: number): boolean if file:seek() >= fileSize then return true else @@ -555,14 +590,14 @@ local function determineAtEndOfFile(file, fileSize) end end -local function initializeInputFile(inputString, options) +local function initializeInputFile(inputString: string, options: ParserOptions): (string, FILE) if options.loadFromString == true then error("ftcsv: parseLine currently doesn't support loading from string") end return initializeInputFromStringOrFile(inputString, options, options.bufferSize) end -function ftcsv.parseLine(inputFile, delimiter, userOptions) +function ftcsv.parseLine(inputFile: string, delimiter: string, userOptions: ParserOptions) local options, fieldsToKeep = parseOptions(delimiter, userOptions, true) local inputString, file = initializeInputFile(inputFile, options) @@ -580,7 +615,7 @@ function ftcsv.parseLine(inputFile, delimiter, userOptions) inputString = ssub(inputString, endOfParsedInput) local bufferIndex, returnedRowsCount = 0, 0 - local currentRow, buffer + local currentRow, buffer: string, string return function() -- check parsed buffer for value @@ -625,8 +660,16 @@ end -- The ENCODER code is below here -- This could be broken out, but is kept here for portability +local type EncoderOptions = record + fieldsToKeep: {string} +end +local type GeneratorArgs = record + t: {CSVRow} + delimitField: function(string): string +end -local function delimitField(field) + +local function delimitField(field: string): string field = tostring(field) if field:find('"') then return field:gsub('"', '""') @@ -635,7 +678,7 @@ local function delimitField(field) end end -local function escapeHeadersForLuaGenerator(headers) +local function escapeHeadersForLuaGenerator(headers: {string}): {string} local escapedHeaders = {} for i = 1, #headers do if headers[i]:find('"') then @@ -648,7 +691,7 @@ local function escapeHeadersForLuaGenerator(headers) end -- a function that compiles some lua code to quickly print out the csv -local function csvLineGenerator(inputTable, delimiter, headers) +local function csvLineGenerator(inputTable: {CSVRow}, delimiter: string, headers: {string}): (function(string): (number, string), GeneratorArgs, number) local escapedHeaders = escapeHeadersForLuaGenerator(headers) local outputFunc = [[ @@ -660,7 +703,7 @@ local function csvLineGenerator(inputTable, delimiter, headers) delimiter .. [["' .. args.delimitField(args.t[i]["]]) .. [["]) .. '"\r\n']] - local arguments = {} + local arguments: GeneratorArgs = {} arguments.t = inputTable -- we want to use the same delimitField throughout, -- so we're just going to pass it in @@ -670,7 +713,7 @@ local function csvLineGenerator(inputTable, delimiter, headers) end -local function validateHeaders(headers, inputTable) +local function validateHeaders(headers: {string}, inputTable: {CSVRow}) for i = 1, #headers do if inputTable[1][headers[i]] == nil then error("ftcsv: the field '" .. headers[i] .. "' doesn't exist in the inputTable") @@ -678,13 +721,13 @@ local function validateHeaders(headers, inputTable) end end -local function initializeOutputWithEscapedHeaders(escapedHeaders, delimiter) +local function initializeOutputWithEscapedHeaders(escapedHeaders: {string}, delimiter: string): {string} local output = {} output[1] = '"' .. table.concat(escapedHeaders, '"' .. delimiter .. '"') .. '"\r\n' return output end -local function escapeHeadersForOutput(headers) +local function escapeHeadersForOutput(headers: {string}): {string} local escapedHeaders = {} for i = 1, #headers do escapedHeaders[i] = delimitField(headers[i]) @@ -692,7 +735,7 @@ local function escapeHeadersForOutput(headers) return escapedHeaders end -local function extractHeadersFromTable(inputTable) +local function extractHeadersFromTable(inputTable: {CSVRow}): {string} local headers = {} for key, _ in pairs(inputTable[1]) do headers[#headers+1] = key @@ -704,19 +747,20 @@ local function extractHeadersFromTable(inputTable) return headers end -local function getHeadersFromOptions(options) - local headers = nil +local function getHeadersFromOptions(options: EncoderOptions): {string} + local headers: {string} = nil if options then if options.fieldsToKeep ~= nil then assert( - type(options.fieldsToKeep) == "table", "ftcsv only takes in a list (as a table) for the optional parameter 'fieldsToKeep'. You passed in '" .. tostring(options.headers) .. "' of type '" .. type(options.headers) .. "'.") + type(options.fieldsToKeep) == "table", "ftcsv only takes in a list (as a table) for the optional parameter 'fieldsToKeep'. You passed in '" .. tostring(options.fieldsToKeep) .. "' of type '" .. type(options.fieldsToKeep) .. "'.") headers = options.fieldsToKeep + end end return headers end -local function initializeGenerator(inputTable, delimiter, options) +local function initializeGenerator(inputTable: {CSVRow}, delimiter: string, options: EncoderOptions): ({string}, {string}) -- delimiter MUST be one character assert(#delimiter == 1 and type(delimiter) == "string", "the delimiter must be of string type and exactly one character") @@ -732,7 +776,7 @@ local function initializeGenerator(inputTable, delimiter, options) end -- works really quickly with luajit-2.1, because table.concat life -function ftcsv.encode(inputTable, delimiter, options) +function ftcsv.encode(inputTable: {CSVRow}, delimiter: string, options: EncoderOptions): string local output, headers = initializeGenerator(inputTable, delimiter, options) for i, line in csvLineGenerator(inputTable, delimiter, headers) do