中秋节图片搜索模型构建及应用

释放双眼,带上耳机,听听看~!
本文介绍了一个基于双塔深度学习模型的中秋节图片搜索模型构建,并应用于 MS-COCO 数据集,详细介绍了视觉编码器和文本编码器的建立过程。

前言

临近佳节,也要努力学习哦。

大家都知道中秋有很多节日元素,像玉兔,月亮,蟾蜍,吃螃蟹,月饼等,为了能够准确地找出语义相关的图片,本文构建了一个简易的双塔深度学习模型,使用自然语言来搜索相关的图像。该模型地主要架构思路是联合训练视觉编码器和文本编码器,将图像及其文本内容的表示投影到相同的嵌入空间中,使得文本嵌入在所描述的图像的嵌入附近,最后通过计算向量相似度返回 topk 个图片即可。

基础

  • tensorflow
  • python
  • nlp
  • cv

数据

我们本次使用的是 MS-COCO 数据集,因为我手头只有这些国外的数据集,就凑活用吧,意思到位就行了,国内真正包含中秋元素的图文对数据集估计还没有。该数据集包含超过 82000 张图像,每张图像至少有 5 个不同的文本描述,经过整理一一对应,也就是说图文对达到了 40 多万,然后分好训练集和测试集即可。

视觉编码器

这里介绍视觉编码器模型 vision_encoder ,该模型主要用于将输入的图像数据编码为特征表示。我们这里选用了预训练大模型 Xception 作为基础图像特征提取器,冻结它的参数,不进行训练,然后通过投影函数的一系列操作进行转换,包括:全连接层转换、Dropout、LayerNormalization、残差连接等步骤,来得到最终的图像特征编码,方便后续的双塔模型可以使用这些特征进行训练。

def create_vision_encoder(num_projection_layers, projection_dims, dropout_rate, trainable=False):
    xception = keras.applications.Xception(include_top=False, weights='imagenet', pooling='avg')
    for layer in xception.layers:
        layer.trainable = trainable
    inputs = layers.Input(shape=(299,299,3), name='image_input')
    xception_input = preprocess_input(inputs)
    embeddings = xception(xception_input)
    outputs = project_embeddings(embeddings, num_projection_layers, projection_dims, dropout_rate)  # [B, 128]
    return keras.Model(inputs, outputs, name='vision_encoder')

文本编码器

这里介绍文本编码器模型 text_encoder ,该模型用于将文本数据编码为特征表示。我们首先选用 bert_en_uncased_preprocess 处理器来对我们的输入文本进行处理,bert_en_uncased_preprocess 是一个 TensorFlow Hub 内置的模块,用于对文本数据进行预处理,以便满足输入到 BERT 模型中的格式。这省去了我们做分词、截断、填充、掩码、控制长度、添加特殊标记等工作。当然我们还要选用预训练大模型 bert_en_uncased 来进行文本的编码特征提取,并且也冻结了 bert 的权重参数。这里后续和上面一样要进行相同投影函数的一系列操作进行转换,结果我们可以知道每个图像和每个文本的特征维度最终会相同,这是后续进行双塔模型训练的基础。

def create_text_encoder(num_projection_layers, projection_dims, dropout_rate, trainable=False):
    bert = hub.KerasLayer( "https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/2", trainable=trainable )
    preprocess = hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3")
    inputs = layers.Input(shape=(), dtype=tf.string, name='text_input')
    bert_inputs = preprocess(inputs)
    embeddings = bert(bert_inputs)['pooled_output']
    outputs = project_embeddings(embeddings, num_projection_layers, projection_dims, dropout_rate)  # [B, 128]
    return keras.Model(inputs, outputs, name='text_encoder')

双塔编码器

这里介绍双塔编码器模型,该模型用于实现文本和图像编码器的训练和测试。结构很简单,就是将两个模型都融合到继承了 keras.Model 的 DualEncoder 类中,然后在前向过程中调用 call 函数,分别求出文本对应的 caption_embeddings 和图像对应的 image_embeddings ,然后使用这两个 embedding 来进行目标 loss 的计算和梯度的更新。此模型的主要目的在于训练过程中通过最小化损失函数来学习文本嵌入和图像嵌入,使它们在嵌入空间中彼此相似。

class DualEncoder(keras.Model):
   def __init__(self, text_endocer, image_encoder, temperature=1.0, **kwargs):
       ...

   def call(self, features, training=False):
       caption_embeddings = text_encoder(features['caption'], training=training)
       image_embeddings = vision_encoder(features['image'], training=training)
       return caption_embeddings, image_embeddings

   def compute_loss(self, caption_embeddings, image_embeddings):
       logits = ( tf.matmul(caption_embeddings, image_embeddings, transpose_b=True) / self.temperature)   # [B, B]
       images_similarity = tf.matmul(image_embeddings, image_embeddings, transpose_b=True)   # [B, B]
       caption_similarity = tf.matmul(caption_embeddings, caption_embeddings, transpose_b=True)  # [B, B]
       targets = keras.activations.softmax( (caption_similarity + images_similarity) / (2 * self.temperature))  # [B, B]
       caption_loss = keras.losses.categorical_crossentropy(y_true=targets, y_pred=logits, from_logits=True)  # [B,]
       images_loss = keras.losses.categorical_crossentropy(y_true=tf.transpose(targets), y_pred=tf.transpose(logits), from_logits=True)  # [B,]
       return (caption_loss + images_loss) / 2

   def train_step(self, features):
       ...

   def test_step(self, features):
       ...

编译和训练

整个模型的优化器使用的是 AdamW ,并且定义了回调函数 early_stopping ,保证模型在 3 个 epoch 后 val_loss 都没有下降便提前停止训练。整个模型的训练相当耗时,平均一个 epoch 一个小时左右。总共训练了 4 个 epoch ,日志如下:

# Epoch 1/100
# 2760/2760 [==============================] - 3521s 1s/step - loss: 10.6863 - val_loss: 3.3784 - lr: 3.0000e-04
# Epoch 2/100
# 2760/2760 [==============================] - 3503s 1s/step - loss: 7.5042 - val_loss: 15.8172 - lr: 3.0000e-04
# Epoch 3/100
# 2760/2760 [==============================] - 3502s 1s/step - loss: 12.3186 - val_loss: 4.3896 - lr: 3.0000e-04
# Epoch 4/100
# 2760/2760 [==============================] - 3505s 1s/step - loss: 30.6067 - val_loss: 10.6790 - lr: 3.0000e-04

效果展示

经过漫长的训练,我们就检查一下模型的效果吧。我们已经有训练好的文本编码器和图像编码器了,想要通过文字搜索图片,就需要先把所有的图像通过图像编码器都转换成编码 image_embeddings 保存好,然后我们输入文本,通过文本编码器将输入的文本转换为编码 query_embedding ,然后通过两个编码的矩阵相乘计算出相似度,image_embeddings 中和 query_embedding 最相近的若干张图像,下面大家展示一下效果。

我输入一些和我们中秋元素有关的文本描述,因为这个模型本身是个演示模型,结构很简单,而且洋人数据集中包含中秋元素的内容很少,所以训练出来的效果不是很好,很多返回的图片甚至风马牛不相及,不一定是语义最相近的图片,我努力在 topN 中挑选了一些有意思的图片,大家一起开心开心得嘞。

这贼眉鼠眼的是兔子?我感觉像乌龟。

中秋节图片搜索模型构建及应用

看来我家的兔子喜欢玩手机,把手机捧在脑壳里。

中秋节图片搜索模型构建及应用

你这兔爷站在厕所想干啥?偷看吗?

中秋节图片搜索模型构建及应用

这么小的月亮,WTF…

中秋节图片搜索模型构建及应用

这个月亮更小,我差点没发现,行不行啊细狗…

中秋节图片搜索模型构建及应用

好歹有一个正经月亮了,你看这个月亮,它又大又圆(哎?旋律有点熟悉的感觉)…

中秋节图片搜索模型构建及应用

别说,这个螃蟹大面包还挺可爱的,不过我更喜欢旁边的熊。

中秋节图片搜索模型构建及应用

果然物以类聚,人、狗、螃蟹,正不正经我还真不好说…

中秋节图片搜索模型构建及应用

你两偷偷摸摸在马桶旁边密谋什么呢,能不能跟我说说。

中秋节图片搜索模型构建及应用

蛤蟆也要敲代码是吧!跟我抢饭碗是吧!你瞅啥!

中秋节图片搜索模型构建及应用

这个大蛤蟆有一点蛤蟆仙人的感觉了,一看就懂修炼。

中秋节图片搜索模型构建及应用

蛤蟆在哪呢?不会是那个边上趴着的玩意吧……

中秋节图片搜索模型构建及应用

参考

本网站的内容主要来自互联网上的各种资源,仅供参考和信息分享之用,不代表本网站拥有相关版权或知识产权。如您认为内容侵犯您的权益,请联系我们,我们将尽快采取行动,包括删除或更正。
AI教程

如何在AWS上使用Walrus部署Llama-2大语言模型

2023-11-25 2:50:14

AI教程

Stable Diffusion安装及模型使用教程

2023-11-25 5:40:14

个人中心
购物车
优惠劵
今日签到
有新私信 私信列表
搜索