Jump to content

Module:stable sort

From Wiktionary, the free dictionary

This module provides an alternative to table.sort that is stable, i.e. guarantees that elements considered equal by the given order will keep their relative positions.

It supports the same parameters as table.sort; the first parameter is the array to sort, and the second (optional) parameter is a comparison function, which returns true if and only if the first parameter is less than the second. If the comparison function is not specified, the default less-than operator < is used.

This function ranges anywhere from slightly to significantly slower than table.sort and should only be used in cases where sorting stability is a requirement.

Warning: if the comparison function throws an error and that error is caught, the array may be left in an indeterminate state.


local function table_len(...)
	table_len = require("Module:table").length
	return table_len(...)
end

-- insertion sort: sort part of array tbl[i0] ... tbl[i1] with comparison function less
local function sort_insertion(tbl, less, i0, i1)
	local i = i0 + 1, j
	while i <= i1 do
		j = i
		local tmp = tbl[j]
		while j > i0 and ((less and less(tmp, tbl[j - 1])) or (not less and tmp < tbl[j - 1])) do
			tbl[j] = tbl[j - 1]
			j = j - 1
		end
		tbl[j] = tmp
		i = i + 1
	end
end

-- merge: merge runs src[i0] ... src[i2 - 1] and src[i2] ... src[i3] with comparison function less
-- and output merged run into dst[i0] ... dst[i3]
-- note: assumes i0 < i2 and i2 <= i3 for performance reasons.
local function sort_merge(dst, src, less, i0, i2, i3)
	local i1 = i2 - 1
	-- left and right run pointers
	local a, b = i0, i2
	local i, j
	
	for j = i0, i3 do
		if (less and less(src[b], src[a])) or (not less and src[b] < src[a]) then
			-- src[a] > src[b]: item from right run
			dst[j] = src[b]
			if b >= i3 then
				-- remaining items from the left run
				for i = a, i1 do
					j = j + 1
					dst[j] = src[i]
				end
				return
			end
			b = b + 1
		else
			-- src[a] <= src[b]: item from left run
			dst[j] = src[a]
			if a >= i1 then
				-- remaining items from the right run
				for i = b, i3 do
					j = j + 1
					dst[j] = src[i]
				end
				return
			end
			a = a + 1
		end
	end
end

local function stable_sort(tbl, comp)
	local b, n = 1, table_len(tbl) -- start and end of table
	local i, k -- index, merge increment
	
	local k = 8
	local sort_insertion = sort_insertion
	for i = b, n, k do
		-- insertion sort for small blocks
		local e = i + k - 1
		if e > n then e = n end
		sort_insertion(tbl, comp, i, e)
	end
	
	-- no need to merge, array is small enough
	if n <= k then return end

	-- merge sort the rest from the bottom up; now we have
	-- runs of K all sorted
	local buf = {}
	-- to avoid copies, swap between two buffers on every iteration
	local src, dst = tbl, buf
	
	local sort_merge = sort_merge
	repeat
		local k2 = k + k
		for i = b, n, k2 do
			-- e.g. k = 8: we take two 8-item blocks and merge into a 16-item one
			-- start of right run
			local s = i + k
			if s > n then
				-- copy remaining from src to dst
				for j = i, n do
					dst[j] = src[j]
				end
				break
			end
			-- end of right run
			local e = s + k - 1
			if e > n then e = n end
			sort_merge(dst, src, comp, i, s, e)
		end
		k = k2 -- double k
		-- swap buffers
		src, dst = dst, src
	until k >= n
	
	if src ~= tbl then
		-- final sorted array in buf; copy back to tbl
		for i = b, n do
			tbl[i] = src[i]
		end
	end
end

return stable_sort