探索学习型排序算法:机器学习与排序的完美结合

释放双眼,带上耳机,听听看~!
本文探讨了学习型排序算法的核心思想以及机器学习模型在排序中的应用,通过使用递归模型索引构建高准确度的排序模型,实现了对大规模数据集的高效排序。

导语

这篇论文是Ani Kristo等几位大佬在2020年发表的,其尝试利用机器学习,设计一种能够直接为无序数据排序的方法——学习型排序算法,也就是说:传统的排序算法,如冒泡,快排等,都需要对于待排序的数据进行比较交换等操作,并且需要被比较/交换的数据定然包含了全部的待排序数据,这也使得目前的排序算法,其时间复杂度一定是大于O(N)O(N)的。而本文则尝试设计一种机器学习模型来直接预测无序数据集中每一个元素排序后的位置,从而直接完成排序。idea比较有意思,还是很值得一读。

论文总览

近几年机器学习发展迅速,并且被极为广泛的应用在各行各业中。如今,无论是什么类型的应用软件,其核心模块都或多或少的包括了机器学习的部分。但在系统软件领域,即类似于操作系统等为应用软件提供平台的软件,机器学习的应用还在起步阶段,但像FydeOS、SageDB 等项目也正在这条道路上进行探索。

本次分享的论文在 SageDB 的基础上进一步探讨了计算机科学中的经典难题:排序。依据文中的实验,其提出的学习型排序算法在处理高达十亿条数据的数据集时,性能比作为SOTA的基数排序算法高出了1.49倍。并且其使用自己的总时间(模型训练时间+排序时间)与基数排序的时间进行比较

核心思想

面临的问题

假设有一个无序的数据列表,如果有一个模型能够预测列表中的某个数据项在排序后的位置。且模型准确率为100%,那么通过遍历一次数据集,就可以把每个数据项完成排序。也就是指,可以在O(N)O(N)的时间复杂度来完成一个排序~

但该方法的有如下两个问题

  1. 机器学习模型需要数据集才能够进行训练,对于一个有监督训练的排序模型,其训练集必然要包含数据集中的所有数据,并使用每一条数据排序后的位置作为标签。但是这样的话我们何必还要对其进行训练呢?毕竟待排序数据已经被排好序了。
  2. 目前很难说有哪个机器学习模型在其对应的领域上能达到100%的准确度,因此,如果模型最后得到的序列不是一个完全有序的序列,我们该如何进行后续处理?

解决思路

我们面临的问题在于需要使用全量数据和其排序后的位置进行训练。因此,如果我们能够只使用部分数据就能训练出排序模型,就可以大大降低计算代价。

理论上的技术方案

机器学习的潜在假设是所有数据都是相互独立同分布的,即所有数据都来自于相同的分布,且每条数据的产生都相互之间不影响,这样我们可以通过抽样部分数据,利用部分数据并学习CDF(累积分布函数)的近似值,即构建所有数据所属的概率分布,来训练出一个排序模型。从而解决问题1

在构建一个高准确度的模型后(方法稍后讨论),我们就可以快速进行排序:

  1. 先扫描列表,根据模型预测为每个项找到大致的位置
  2. 再利用一个擅长处理几乎已排序数组的排序算法(如插入排序)来将这个几乎已排序的列表变为一个完全排序的列表。从而解决问题2。

模型构建

为此,作者使用了 《The case for learned index structures》 中提出的递归模型索引(RMI)架构。简单来说,RMI 是通过一系列模型层级组织起来的,可以类比成像多个专家的混合体。如下图所示:

探索学习型排序算法:机器学习与排序的完美结合

递归模型索引(RMI)

上图中,每一个Model都可以是任意一个机器学习模型,从最简单的线性回归(LR)到深度神经网络(DNN)都可以。实践中,越简单的模型越好(减小模型推理时间,并防止过拟合)。当进行查找时,最上层的模型(只有一个模型),将选择一个第二层的模型来处理这个key。然后第二层的模型,会接着选择一个下一层的模型来处理这个key,直到最底层的模型,才会给出这个key对应的预测位置

RMI的训练方式

整个Staged Model分层训练,先训练最顶层,然后顶层模型进行数据分发,使得下层模型能够进行训练。

数据分发为:上层模型将key分配给下层Model,下层Model只能将上层模型分发给它的训练数据作为训练集(比如第二层的三个模型,他们的训练集是不相同的,具体是哪些数据取决于第一层的模型分给它哪些)。所以随着层数的加深,以及每一层模型数量的提升,每个越底层模型拥有的训练数据是越少的。这样的优点是,底层模型可以非常容易的拟合这一部分数据的分布(缺点是较少的数据量带来了模型的选择限制,如果某个模型结构复杂,可能没法收敛)。

在推理阶段,模型的每一层都接收一个key作为输入,并进行线性转换来得到一个值,这个值为下一层模型的索引,即选择出一个下一层的模型进行继续的预测。

本文中的方案

本文中每一层的模型都使用简单的线性模型,但是作者没有使用最常用的基于均方误差损失(MSE Loss)的线性回归模型,而是采用线性样条拟合(linear spline fitting),这正是作者在这里引入的主要创新点

优点

相对于基于均方误差损失(MSE Loss)的线性回归模型,样条拟合计算成本更低,还能更好地保证单调性(从而缩短插入排序阶段的时间)。尽管每个独立的样条模型的拟合度不如传统的线性回归模型,但是模型层次的组合弥补了这一点。线性样条拟合的训练速度比同样的线性回归快了2.7倍,而且在插入排序过程中减少了多达35%的键值交换。(这可能也体现了集成学习的一种优势,虽然这个RMI并不是常见的集成学习架构,但这种多模型集成的思想是有效果的)

冲突问题

这里假设我们已经利用局部数据训练了一个排序模型,暂时称为Learned Sort ver1.0。
本文所提出的排序模型(Learned Sort ver1.0)并不会在原地进行排序,也就是说其需要额外的空间:其会创建一个和待排序数组等长的数组。每预测出一个待排序数据的位置,就会把这个数据插入到新数组对应的位置上。因此,不可避免的会出现“冲突”现象:如果模型预测元素i排在第287个,但这个位置在新数组里已经被占用了。

冲突问题的解决方案

针对这种情况,作者借鉴了hash表的冲突解决方案:

  1. 线性探测:扫描数组,找到最近的空闲位置,并把元素放在那里。
  2. 链式存储:每个位置使用一个列表或链来存储多个元素。
  3. 溢出桶:即构建一个缓冲区数组,称为“溢出桶(Spill bucket)”,如果目标位置已经被占用,就把本应放在该位置的新的数据放入这个缓冲区中。当模型处理完全部数据后,对溢出桶中的数据进行排序,并将其与目标数组合并。

探索学习型排序算法:机器学习与排序的完美结合

如上图所示,经过实验比较,作者发现“溢出桶”方案最适合他们的需求。而我们把使用了溢出桶方案的排序模型称为Learned Sort ver2.0。

0冲突模型的低效问题

模型预测的准确性直接影响了排序的性能,一个高质量的模型会减少碰撞现象的出现,从而最小化最终插入排序阶段需要调整的数据数量。但是作者发现:即使假设冲突问题不存在,即假设使用一个准确率100%的模型来进行排序——及模型预测的位置一定是该数据排序后所在的最终位置。这样的完美模型排序速度是慢于基数排序的!

论文对于这个低效问题进行了实验:在一个小规模实验中,尽管使用了完美模型,但将所有数据分配到它们最终排序的位置竟然耗时38.7秒,而基数排序只用了37.5秒。

低效问题原因:低下的缓存命中率&大量随机内存访问

原因何在?一个高性能的计算机算法必须要充分利用硬件特性。目前的基数排序算法充分利用了计算机的L2缓存,并且其内存访问都为连续的内存访问,避免了大量额外的IO,而目前的排序模型(Learned Sort ver2.0)则是在整个目标数组中进行随机访问。缓存命中率低下,需要大量无用的IO操作。

解决方案

怎样改进 “Learned Sort ver2.0” 使其更高效地利用缓存呢?作者给出的解决方案如下:

把 “Learned Sort ver2.0” 的预测过程改为分级的桶排序。模型不是预测每个数据在最终数组中的具体位置,而是利用预测每个元素应该放入哪一个桶(或区间)。从而得到了Learned Sort ver3.0。详细过程如下:

探索学习型排序算法:机器学习与排序的完美结合

  1. 假设桶的数量为MM。”Learned Sort ver3.0″先进行分级桶排序。
    • 首先,根据模型的预测,把输入的元素分配到 MM 个顺序桶中(放入MM个不同区间)。
    • 然后,每个桶再细分为 MM 个小桶,这个过程一直递归进行,直到桶的大小(区间中数据数量)达到预设阈值 tt
    • 如果某个元素根据模型预测被分配到了一个已经满了的桶,就把它移到一个备用桶里。
  2. 当桶的大小缩减到 t 时,就利用模型预测每个桶内的数据在桶内所在的位置,完成大致的排序。
  3. 把各个已排序的桶按顺序连接起来,然后用插入排序法对全部数据进行排序。
  4. 将溢出桶(模型预测时出现冲突的数据所在的缓冲数组)进行排序,并合并到已排序数组中。

为什么这种排序方法能加速呢?

“Learned Sort ver3.0” 加速的关键在于合理选择超参数 MMtt。从而减小随机内存访问的数量并提升缓存命中率。

关于 MM 的取值:较大的 MM 值可以更充分利用每一步的模型预测能力(毕竟分类问题中,类别越多,模型预测的可能越差),而较小的的 MM 则能减小将数据精排时产生缓存缺失的可能性。为了最佳性能,MM 需要与计算机的L2缓存大小相匹配。作者在实验中得到的 MM 的经验值大约是1000。

参数 tt 则影响着可能会进入溢出桶的数据数量(即冲突次数)。作者通过实践发现,当溢出桶中的数据量少于总数据量的5%时,系统的性能达到最优。作者认为对于大型数据集来说 tt 的经验最佳值大约是100。

在进行了这些调整后,如果待排序的元素数量接近键值域的大小(比如,用32位键排序2^32个元素),”Learned Sort ver3.0″ 的性能几乎能与基数排序相媲美。如果元素数量远小于键值域的大小,”Learned Sort ver3.0″ 则有可能大幅超越基数排序。

这里进行一些补充:作者对于数据集中的元素,使用了一个key

实验

作者在合成数据集和一些真实数据集上进行了测试,合成数据集为一个正态分布中抽样出的大量double型的浮点数,并将 “Learned Sort” 与多种进行了缓存优化和高度调校的排序算法进行了比较,作为SOTA的算法采用了C++来实现,结果如下图所示:图中横坐标为数据数量,纵坐标为排序速度。

探索学习型排序算法:机器学习与排序的完美结合

根据实验可以得到如下结论(来自原论文):

  1. “Learned Sort” 在各种规模的数据集上都表现出色,尤其在数据量超出L3缓存容量时,平均吞吐率比其他算法高出30%,优势非常明显。

  2. 研究结果显示,相较于C++ STL中的快排算法(std::sort()),我们的方法在性能上平均提升了3.38倍,相较于顺序 基数排序提升了1.49倍,而与 “TimSort” 的C++版本相比,提升了5.54倍,”TimSort” 是Java和Python中默认的排序函数。

探索学习型排序算法:机器学习与排序的完美结合

  1. “Learned Sort” 不仅在合成数据集上表现优异,在真实数据集(详见第6.1节的测试数据集)和不同类型的元素上也同样具有优势:

  2. “Learned Sort” 相比于目前的SOTA排序算法,表现出了显著的性能提升,这标志着机器学习增强算法和数据结构研究的重要进展。

本网站的内容主要来自互联网上的各种资源,仅供参考和信息分享之用,不代表本网站拥有相关版权或知识产权。如您认为内容侵犯您的权益,请联系我们,我们将尽快采取行动,包括删除或更正。
AI教程

AI时代对普通人的影响及应对策略

2023-11-24 9:33:14

AI教程

ChatGPT最佳实践系列第5篇:使用外部工具

2023-11-24 9:46:14

个人中心
购物车
优惠劵
今日签到
有新私信 私信列表
搜索