『TensorFlow』张量拼接_调整维度_切片

1、tf.concat

tf.concat的作用主要是将向量按指定维连起来,其余维度不变;而1.0版本以后,函数的用法变成:

t1 = [[1, 2, 3], [4, 5, 6]]
t2 = [[7, 8, 9], [10, 11, 12]]
#按照第0维连接
tf.concat( [t1, t2],0) ==> [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]
#按照第1维连接
tf.concat([t1, t2],1) ==> [[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]]

作为参考合成神经网络输出的时候在深度方向(inception_v3)是数字3,[batch,heigh,width,depth]。

2、tf.stack

用法:stack(values, axis=0, name=”stack”):

“”“Stacks a list of rank-R tensors into one rank-(R+1) tensor.

x = tf.constant([1, 4])
y = tf.constant([2, 5])
z = tf.constant([3, 6])
tf.stack([x,y,z]) ==> [[1,4],[2,5],[3,6]]
tf.stack([x,y,z],axis=0) ==> [[1,4],[2,5],[3,6]]
tf.stack([x,y,z],axis=1) ==> [[1, 2, 3], [4, 5, 6]]

tf.stack将一组R维张量变为R+1维张量。注意:tf.pack已经变成了tf.stack\3、tf.squeeze

数据降维,只裁剪等于1的维度

不指定维度则裁剪所有长度为1的维度

import tensorflow as tf 
arr = tf.Variable(tf.truncated_normal([3,4,1,6,1], stddev=0.1)) 
sess = tf.Session() 
sess.run(tf.global_variables_initializer()) 
sess.run(arr).shape 
# Out[12]: # (3, 4, 1, 6, 1) 
sess.run(tf.squeeze(arr,[2,])).shape 
# Out[17]: # (3, 4, 6, 1) 
sess.run(tf.squeeze(arr,[2,4])).shape 
# Out[16]: # (3, 4, 6) 
sess.run(tf.squeeze(arr)).shape
 # Out[19]: # (3, 4, 6)

3、tf.split

依照输入参数二的标量/向量有不同的行为:参数二为标量时,意为沿着axis等分为scalar份;向量时意为安装元素作为边界索引切分多份

def split(value, num_or_size_splits, axis=0, num=None, name="split"):

"""Splits a tensor into sub tensors.

If `num_or_size_splits` is an integer type, `num_split`, then splits `value`

along dimension `axis` into `num_split` smaller tensors.

Requires that `num_split` evenly divides `value.shape[axis]`.

If `num_or_size_splits` is not an integer type, it is presumed to be a Tensor

`size_splits`, then splits `value` into `len(size_splits)` pieces. The shape

of the `i`-th piece has the same size as the `value` except along dimension

`axis` where the size is `size_splits[i]`.

For example:

```python

# 'value' is a tensor with shape [5, 30]

# Split 'value' into 3 tensors with sizes [4, 15, 11] along dimension 1

split0, split1, split2 = tf.split(value, [4, 15, 11], 1)

tf.shape(split0) # [5, 4]

tf.shape(split1) # [5, 15]

tf.shape(split2) # [5, 11]

# Split 'value' into 3 tensors along dimension 1

split0, split1, split2 = tf.split(value, num_or_size_splits=3, axis=1)

tf.shape(split0) # [5, 10]

```

4、张量切片

tf.slice

解析:slice(input_, begin, size, name=None):Extracts a slice from a tensor.

假设input为[[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 6, 6]]],如下所示:

(1)tf.slice(input, [1, 0, 0], [1, 1, 3]) ==> [[[3, 3, 3]]]

(2)tf.slice(input, [1, 0, 0], [1, 2, 3]) ==> [[[3, 3, 3], [4, 4, 4]]]

(3)tf.slice(input, [1, 0, 0], [2, 1, 3]) ==> [[[3, 3, 3]], [[5, 5, 5]]]

tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32)

在看cifar10的例子的时候,必然会看到一个函数,官方给的文档注释长而晦涩,基本等于0.网上也有这个函数,但解释差劲或者基本没有解释,函数的原型是酱紫的.

def strided_slice(input_,
                  begin,
                  end,
                  strides=None,
                  begin_mask=0,
                  end_mask=0,
                  ellipsis_mask=0,
                  new_axis_mask=0,
                  shrink_axis_mask=0,
                  var=None,
                  name=None):
  """Extracts a strided slice from a tensor.
'input'= [[[1, 1, 1], [2, 2, 2]],
             [[3, 3, 3], [4, 4, 4]],
             [[5, 5, 5], [6, 6, 6]]]

来把输入变个型,可以看成3维的tensor,从外向为1,2,3维

[[[1, 1, 1], [2, 2, 2]],
 [[3, 3, 3], [4, 4, 4]],
 [[5, 5, 5], [6, 6, 6]]]

以tf.strided_slice(input, [0,0,0], [2,2,2], [1,2,1])调用为例,start = [0,0,0] , end = [2,2,2], stride = [1,2,1],求一个[start, end)的一个片段,注意end为开区间

第1维 start = 0 , end = 2, stride = 1, 所以取 0 , 1行,此时的输出

[[[1, 1, 1], [2, 2, 2]],
 [[3, 3, 3], [4, 4, 4]]]

第2维时, start = 0 , end = 2 , stride = 2, 所以只能取0行,此时的输出

[[[1, 1, 1]],
 [[3, 3, 3]]]

第3维的时候,start = 0, end = 2, stride = 1, 可以取0,1行,此时得到的就是最后的输出

[[[1, 1]],
 [[3, 3]]]

整理之后最终的输出为:

[[[1,1],[3,3]]]

类似代码如下:

import tensorflow as tf   
data = [[[1, 1, 1], [2, 2, 2]],   
     [[3, 3, 3], [4, 4, 4]],   
     [[5, 5, 5], [6, 6, 6]]]   
x = tf.strided_slice(data,[0,0,0],[1,1,1])   
with tf.Session() as sess:   
print(sess.run(x))