如何使用深度学习模型进行多标签文本分类

释放双眼,带上耳机,听听看~!
本文介绍了如何使用深度学习模型进行多标签文本分类。通过对摘要内容进行判断,可以预测文章所属的多个主题领域标签。文章详细介绍了数据准备过程,包括对摘要内容和标签的处理。同时,还提供了标签转换的方法,将标签转换为 multi-hot 的形式。通过本文的指导,读者可以学习到如何使用深度学习模型进行多标签文本分类,并且可以应用于实际项目中。

前文介绍

本文构建了一个常见的深度学习模型,实现多标签文本分类,可以根据论文摘要的文本内容预测其所属的多个主题领域标签。

数据准备

原始数据中分为三列 titles 、summaries 、termstitles 是文章标题,summaries 是摘要内容,terms 是所属的标签列表,我们主要任务是通过判断 summaries 中的内容来预测所属的 terms 。所以数据处理的主要工作在 summariesterms 两列。将每行的 summaries 和 terms 作为样本的(输入,标签)样本对。本次任务是多标签预测,每个摘要会有多个所属标签,所以我们要将标签转换为 multi-hot 的形式。

例如随机展示一个样本:

Abstract: 'Graph convolutional networks produce good predictions of unlabeled samplesndue to its transductive label propagation. Since samples have differentnpredicted confidences, we take high-confidence predictions as pseudo labels tonexpand the label set so that more samples are selected for updating models. Wenpropose a new training method named as mutual teaching, i.e., we train dualnmodels and let them teach each other during each batch. First, each networknfeeds forward all samples and selects samples with high-confidence predictions.nSecond, each model is updated by samples selected by its peer network. We viewnthe high-confidence predictions as useful knowledge, and the useful knowledgenof one network teaches the peer network with model updating in each batch. Innmutual teaching, the pseudo-label set of a network is from its peer network.nSince we use the new strategy of network training, performance improvesnsignificantly. Extensive experimental results demonstrate that our methodnachieves superior performance over state-of-the-art methods under very lownlabel rates.'
Label: ['cs.CV' 'cs.LG' 'stat.ML']

经过处理,标签列表会变成一个数据集中所有标签集合大小的数组,将该样本出现的标签对应的索引位置变成 1 ,其余位置变成 0 ,具体处理过程见代码:

Abstract: 'Visual saliency is a fundamental problem in both cognitive and computationalnsciences, including computer vision. In this CVPR 2015 paper, we discover thatna high-quality visual saliency model can be trained with multiscale featuresnextracted using a popular deep learning architecture, convolutional neuralnnetworks (CNNs), which have had many successes in visual recognition tasks. Fornlearning such saliency models, we introduce a neural network architecture,nwhich has fully connected layers on top of CNNs responsible for extractingnfeatures at three different scales. We then propose a refinement method tonenhance the spatial coherence of our saliency results. Finally, aggregatingnmultiple saliency maps computed for different levels of image segmentation cannfurther boost the performance, yielding saliency maps better than thosengenerated from a single segmentation. To promote further research andnevaluation of visual saliency models, we also construct a new large database ofn4447 challenging images and their pixelwise saliency annotation. Experimentalnresults demonstrate that our proposed method is capable of achievingnstate-of-the-art performance on all public benchmarks, improving the F-Measurenby 5.0% and 13.2% respectively on the MSRA-B dataset and our new datasetn(HKU-IS), and lowering the mean absolute error by 5.7% and 35.1% respectivelynon these two datasets.'
Label: [0. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]

模型训练

模型结构主要由以下三部分组成:

  1. layers.Dense(512, activation="relu") 将输入映射为 512 维的向量,并且使用 relu 激活函数进行非线性运算
  2. layers.Dense(256, activation="relu") 将输入映射为 256 维的向量,并且使用 relu 激活函数进行非线性运算
  3. layers.Dense(lookup.vocabulary_size(), activation="sigmoid") 输出词典大小维度的向量,并且使用 sigmoid 激活函数推断所属标签概率
  4. 编译模型时候,使用 binary_crossentropy 作为损失函数,使用 adam 作为优化器,使用 binary_accuracy 作为观测指标

训练过程日志打印如下:

    Epoch 1/20
    258/258 [==============================] - 8s 25ms/step - loss: 0.0334 - binary_accuracy: 0.9890 - val_loss: 0.0190 - val_binary_accuracy: 0.9941
    Epoch 2/20
    258/258 [==============================] - 6s 25ms/step - loss: 0.0031 - binary_accuracy: 0.9991 - val_loss: 0.0262 - val_binary_accuracy: 0.9938
    ...
    Epoch 20/20
    258/258 [==============================] - 6s 24ms/step - loss: 7.4884e-04 - binary_accuracy: 0.9998 - val_loss: 0.0550 - val_binary_accuracy: 0.9931
    15/15 [==============================] - 1s 28ms/step - loss: 0.0552 - binary_accuracy: 0.9932

将训练过程产生的损失值和准确率进行了绘制,如下所示:

如何使用深度学习模型进行多标签文本分类

如何使用深度学习模型进行多标签文本分类

测试效果

随机选取两个样本,使用训练好的模型进行标签预测,为每个样本最多预测 3 个概率最高的标签,并和原始标签进行对比,可以发现基本上所属的标签都会出现在预测结果的前几个。

    Abstract: b'Graph representation learning is a fundamental problem for modelingnrelational data and benefits a number of downstream applications. ..., The source code is available atnhttps://github.com/upperr/DLSM.'
    Label: ['cs.LG' 'stat.ML']
    Predicted Label(s): (cs.LG, stat.ML, cs.AI) 
    Abstract: b'In recent years, there has been a rapid progress in solving the binarynproblems in computer vision, ..., The SEE algorithm is split into 2 parts, SEE-Pre fornpreprocessing and SEE-Post pour postprocessing.'
    Label: ['cs.CV']
    Predicted Label(s): (cs.CV, I.4.9, cs.LG) 

参考

github.com/wangdayaya/…

本网站的内容主要来自互联网上的各种资源,仅供参考和信息分享之用,不代表本网站拥有相关版权或知识产权。如您认为内容侵犯您的权益,请联系我们,我们将尽快采取行动,包括删除或更正。
AI教程

GPT-3与chatGPT的区别,Stable Diffusion和Midjourney的比较

2023-11-29 15:49:14

AI教程

AI爆文变现脚本:0基础小白的保姆级操作教程-更新迭代

2023-11-29 16:02:14

个人中心
购物车
优惠劵
今日签到
有新私信 私信列表
搜索