TensorFlow 内置重要函数解析

概要

本部分介绍一些在 TensorFlow 中内置的重要函数,了解这些函数有时候更加方便我们进行数据的处理或者构建神经网络。

这些函数如下:

    tf.one_hot()

    tf.random_shuffle()


主要内容

tf.one_hot()

这是一个用来生成符合 one_hot 编码的张量的函数。完整参数形式是:

tf.one_hot(indices, depth, on_value=None, off_value=None, axis=None, dtype=None, name=None)

下面我们一一通过实例来了解各个参数表示什么意思。

为了容易理解,我们举个例子,比如我们熟悉的 mnist 数据集中标签的 one_hot 编码中,数字 4 是用向量 \([0,0,0,0,1,0,0,0,0,0]\) 来表示的。

  • on_value ,float 类型,表示在 one_hot 编码中标签标记值,在上述编码中 on_value 的值就是 1
  • off_value, float 类型,就是标记点除外的其它值,即 off_value 为 0
  • indices ,一个列表,表示要生成的 one_hot 张量中标记值所在索引,即 indices = [4]
  • depth,int 类型,表示要生成的 one_hot 张量的长度,即 depth = 10
  • Axis,取值为 -1,0 或 1,Axis 取 -1 时造成的张量 shape=[indices 长度, depth],默认值虽是 None,但是和取 -1 效果一样。为 0 时 shape=[depth, indices 长度],取 1 时,比较复杂,是指在三维以上情况下,比方考虑批量输入中,有个批 batch 大小, shape=[batch, indices 长度, depth],具体的可以做下实验验证就好,不需要刻意去记。

下面用代码验证一下:

# -*- coding: utf-8 -*-
"""
Created on Mon Jun  4 08:56:57 2018

@author: zhoukui
"""

import tensorflow as tf

tf.reset_default_graph()

indices = [0, 2, -1, 1, 2]
depth = 4
on_value = 3.0
off_value = 0.0
axis = -1

t = tf.one_hot(indices, depth, on_value, off_value, axis)

with tf.Session() as sess:
    print(sess.run(t))  #输出 [[ 3.  0.  0.  0.]
                        #     [ 0.  0.  3.  0.]
                        #     [ 0.  0.  0.  0.]
                        #     [ 0.  3.  0.  0.]
                        #     [ 0.  0.  3.  0.]]

tf.random_shuffle()

这个函数相对简单,它就一个参数 input,表示沿着 input 的第一维度进行随机重新排列,在进行数据分批的时候特别实用。实例如下:

# -*- coding: utf-8 -*-
"""
Created on Mon Jun  4 08:56:57 2018

@author: zhoukui
"""

import tensorflow as tf

tf.reset_default_graph()

input = tf.reshape(tf.linspace(1.0, 10.0, 10), (-1,2))

tf.set_random_seed(666)  # 可以选择固定种子
t = tf.random_shuffle(input)

with tf.Session() as sess:
    
    print(sess.run(input)) # 输出 [[  1.   2.]
                           #       [  3.   4.]
                           #       [  5.   6.]
                           #       [  7.   8.]
                           #       [  9.  10.]]
    
    print(sess.run(t))  #输出 [[  7.   8.]
                        #      [  5.  6.]
                        #      [  1.   2.]
                        #      [  3.   4.]
                        #      [  9.   10.]]