使用本地Jupyter Notebook搭载TensorFlow相关库进行操作

释放双眼,带上耳机,听听看~!
本文介绍如何在本地Jupyter Notebook中使用TensorFlow相关库进行操作,包括读取TFRecords文件、Example解析、图像调整等操作。

注意:本文使用tensorflow1.x版本进行演示

使用本地Jupyter Notebook搭载TensorFlow相关库进行操作

1. 读取TFRecords文件

其实读取TFRecords文件大体思路与常规文件读取思路(构造队列、读取、解码、批处理队列)比较一致。但是,还是有一点不一样,在解码操作之前,需要解析Example操作(因为TFRecords比其他文件多了个Example结构),TFRecords文件读取步骤如下所示:

  • 构造文件名队列
  • 读取
  • 解析Example
    • tf.parse_single_example()
      • tf.FixedLenFeature(shape, dtype)
  • 解码
  • 构造批处理队列

接下来,我们将对TFRecords文件读取中用到的函数进行详细说明:

  • tf.parse_single_example(serialized, features=None, name=None)

    • 用来解析一个单一的Example原型
    • serialized:标量字符串Tensor,一个序列化的Example
    • features:dict字典数据,键为读取的名字,值为FixedLenFeature
    • return:返回一个键值对组成的字典,键为读取的名字。想拿到解析后的example数据,需要通过字典形式访问。
  • tf.FixedLenFeature(shape, dtype)

    • 这个函数和上一个函数其实是嵌套使用的,上一个函数中的features参数中的一部分(字典中值的部分)需要用本函数填充
    • shape:输入数据的形状,一般不指定即为空列表
    • dtype:输入数据类型,与存储进文件的类型要一样
    • 类型只能是float32,int64,string

2. 代码演示

导入所需模块,因为本地下载的是Tensorflow2.x版本,想运行Tensorflow1的语法,需要开启兼容模型,以支持Tensorflow1语法正常运行。

import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import os

从读取TFRecords文件的角度,进行函数定义,对已保存到本地的TFRecords文件进行读取。具体代码如下所示:

  • 首先在函数中需要构造文件名队列,通过变量file_queue接收

  • 然后,使用tf.TFRecordReader()读取器,使用该读取器下面的read方法进行文件读取,使用变量key与value接受元组

  • 接下来以上述介绍的API进行Example解析,可以将中间结果image与label打印出来看看

  • 别忘记开启会话tf.Session()才能看到具体的值

    • 会话中tf.train.Coordinator()开启线程
    • sess.run()运行一下用以查看具体的值
    • 回收资源,回收线程
  • 然后是解码操作,我们可以将其解码成uint8

  • 打印出的是一维数组,我们需要进行图像调整将其调整成32323

  • 最终,将其放入批处理队列。

class Cifar():

    def __init__(self):

        # 设置图像大小
        self.height = 32
        self.width = 32
        self.channel = 3

        # 设置图像字节数
        self.image = self.height * self.width * self.channel
        self.label = 1
        self.sample = self.image + self.label
    def read_tfrecords(self):
        """
        读取tfrecords文件
        """
        # 1. 构造文件名队列
        file_queue = tf.train.string_input_producer(["cifar10.tfrecords"])

        # 2. 读取与解码
        # 2.1 读取
        reader = tf.TFRecordReader()
        key, value = reader.read(file_queue)

        # 2.2 解析example
        feature = tf.parse_single_example(value, features={
            "image": tf.FixedLenFeature([], tf.string),
            "label": tf.FixedLenFeature([], tf.int64)
        })
        image = feature["image"]
        label = feature["label"]
        print("read_tf_image:n", image)
        print("read_tf_label:n", label)

        # 2.3 解码
        image_decoded = tf.decode_raw(image, tf.uint8)
        print("image_decoded:n", image_decoded)
        # 图像形状调整
        image_reshaped = tf.reshape(image_decoded, [self.height, self.width, self.channel])

        # 3. 构造批处理队列
        image_batch, label_batch = tf.train.batch([image_reshaped, label], batch_size=100, num_threads=5, capacity=100)
        print("image_batch:n", image_batch)
        print("label_batch:n", label_batch)
        
        # 开启会话
        with tf.Session() as sess:
            # 开启线程
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            image_value, label_value, image_decoded, image_batch, label_batch = sess.run([image, label, image_decoded, image_batch, label_batch])
            print("image_value:n", image_value)
            print("label_value:n", label_value)

            # 回收资源
            coord.request_stop()
            coord.join(threads)
        return None
cifar = Cifar()
cifar.read_tfrecords()

部分读取结果如下图所示:

使用本地Jupyter Notebook搭载TensorFlow相关库进行操作
本文正在参加「金石计划 . 瓜分6万现金大奖」

本网站的内容主要来自互联网上的各种资源,仅供参考和信息分享之用,不代表本网站拥有相关版权或知识产权。如您认为内容侵犯您的权益,请联系我们,我们将尽快采取行动,包括删除或更正。
AI教程

深度聚类方法详解与应用

2023-12-15 13:51:14

AI教程

AIGC的历史发展及早期萌芽阶段

2023-12-15 14:03:14

个人中心
购物车
优惠劵
今日签到
有新私信 私信列表
搜索