Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
tensorflow
GitHub Repository: tensorflow/docs-l10n
Path: blob/master/site/zh-cn/guide/tensor_slicing.ipynb
25115 views
Kernel: Python 3
#@title Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License.

张量切片简介

在处理目标检测和 NLP 等机器学习应用时,有时需要使用张量的子部分(切片)。例如,如果您的模型架构包含路由,则其中一层可能控制将哪个训练样本路由到下一层。在这种情况下,可以使用张量切片运算将张量拆分并以正确的顺序将它们重新组合在一起。

在 NLP 应用中,可以在训练时使用张量切片来执行单词遮盖。例如,可以通过在每个句子中选择要遮盖的单词索引,将单词作为标签,然后使用遮盖词例替换选中的单词,以从句子列表中生成训练数据。

在本指南中,您将学习如何使用 TensorFlow API 执行以下操作:

  • 从张量中提取切片

  • 在张量中的特定索引处插入数据

本指南假定您熟悉张量索引。在开始学习本指南之前,请阅读张量TensorFlow NumPy 指南的索引部分。

安装

import tensorflow as tf import numpy as np

提取张量切片

使用 tf.slice 执行类似 NumPy 的张量切片。

t1 = tf.constant([0, 1, 2, 3, 4, 5, 6, 7]) print(tf.slice(t1, begin=[1], size=[3]))

或者,您也可以使用更具 Python 风格的语法。请注意,张量切片在开始-停止范围内均匀分布。

print(t1[1:4])

print(t1[-3:])

对于二维张量,可以使用以下代码:

t2 = tf.constant([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9], [10, 11, 12, 13, 14], [15, 16, 17, 18, 19]]) print(t2[:-1, 1:3])

您也可以在更高维度的张量上使用 tf.slice

t3 = tf.constant([[[1, 3, 5, 7], [9, 11, 13, 15]], [[17, 19, 21, 23], [25, 27, 29, 31]] ]) print(tf.slice(t3, begin=[1, 1, 0], size=[1, 1, 2]))

还可以使用 tf.strided_slice 通过在张量维度上“跨步”来提取张量切片。

使用 tf.gather 从张量的单个轴中提取特定索引。

print(tf.gather(t1, indices=[0, 3, 6])) # This is similar to doing t1[::3]

tf.gather 不要求索引均匀分布。

alphabet = tf.constant(list('abcdefghijklmnopqrstuvwxyz')) print(tf.gather(alphabet, indices=[2, 0, 19, 18]))

要从张量的多个轴中提取切片,请使用 tf.gather_nd。当您想要收集矩阵的元素而不仅仅是它的行或列时,这非常有用。

t4 = tf.constant([[0, 5], [1, 6], [2, 7], [3, 8], [4, 9]]) print(tf.gather_nd(t4, indices=[[2], [3], [0]]))

t5 = np.reshape(np.arange(18), [2, 3, 3]) print(tf.gather_nd(t5, indices=[[0, 0, 0], [1, 2, 1]]))
# Return a list of two matrices print(tf.gather_nd(t5, indices=[[[0, 0], [0, 2]], [[1, 0], [1, 2]]]))
# Return one matrix print(tf.gather_nd(t5, indices=[[0, 0], [0, 2], [1, 0], [1, 2]]))

将数据插入张量

使用 tf.scatter_nd 在张量的特定切片/索引处插入数据。请注意,您将值插入的张量是用零初始化的。

t6 = tf.constant([10]) indices = tf.constant([[1], [3], [5], [7], [9]]) data = tf.constant([2, 4, 6, 8, 10]) print(tf.scatter_nd(indices=indices, updates=data, shape=t6))

tf.scatter_nd 这样需要零初始化张量的方法类似于稀疏张量初始值设定项。可以使用 tf.gather_ndtf.scatter_nd 来模拟稀疏张量运算的行为。

考虑一个将这两种方法结合使用来构造稀疏张量的示例。

# Gather values from one tensor by specifying indices new_indices = tf.constant([[0, 2], [2, 1], [3, 3]]) t7 = tf.gather_nd(t2, indices=new_indices)

# Add these values into a new tensor t8 = tf.scatter_nd(indices=new_indices, updates=t7, shape=tf.constant([4, 5])) print(t8)

这类似于:

t9 = tf.SparseTensor(indices=[[0, 2], [2, 1], [3, 3]], values=[2, 11, 18], dense_shape=[4, 5]) print(t9)
# Convert the sparse tensor into a dense tensor t10 = tf.sparse.to_dense(t9) print(t10)

要将数据插入到具有既有值的张量中,请使用 tf.tensor_scatter_nd_add

t11 = tf.constant([[2, 7, 0], [9, 0, 1], [0, 3, 8]]) # Convert the tensor into a magic square by inserting numbers at appropriate indices t12 = tf.tensor_scatter_nd_add(t11, indices=[[0, 2], [1, 1], [2, 0]], updates=[6, 5, 4]) print(t12)

类似地,可以使用 tf.tensor_scatter_nd_sub 从具有既有值的张量中减去值。

# Convert the tensor into an identity matrix t13 = tf.tensor_scatter_nd_sub(t11, indices=[[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [2, 1], [2, 2]], updates=[1, 7, 9, -1, 1, 3, 7]) print(t13)

使用 tf.tensor_scatter_nd_min 将逐元素最小值从一个张量复制到另一个。

t14 = tf.constant([[-2, -7, 0], [-9, 0, 1], [0, -3, -8]]) t15 = tf.tensor_scatter_nd_min(t14, indices=[[0, 2], [1, 1], [2, 0]], updates=[-6, -5, -4]) print(t15)

类似地,使用 tf.tensor_scatter_nd_max 将逐元素最大值从一个张量复制到另一个。

t16 = tf.tensor_scatter_nd_max(t14, indices=[[0, 2], [1, 1], [2, 0]], updates=[6, 5, 4]) print(t16)

补充阅读和资源

在本指南中,您学习了如何使用 TensorFlow 提供的张量切片运算来更好地控制张量中的元素。

  • 查看 TensorFlow NumPy 提供的切片运算,例如 tf.experimental.numpy.take_along_axistf.experimental.numpy.take

  • 另请查看张量指南变量指南