chunk

paddle. chunk ( x, chunks, axis=0, name=None ) [源代码]

将输入 Tensor 分割成多个子 Tensor。

参数

  • x (Tensor) - 输入变量,数据类型为 bool, float16, float32,float64,int32,int64 的多维 Tensor。

  • chunks (int) - chunks 是一个整数,表示将输入 Tensor 划分成多少个相同大小的子 Tensor。

  • axis (int|Tensor,可选) - 整数或者形状为[]的 0-D Tensor,数据类型为 int32 或 int64。表示需要分割的维度。如果 axis < 0,则划分的维度为 rank(x) + axis。默认值为 0。

  • name (str,可选) - 具体用法请参见 Name,一般无需设置,默认值为 None。

返回

分割后的 Tensor 列表。

代码示例

>>> import paddle

>>> x = paddle.rand([3, 9, 5])

>>> out0, out1, out2 = paddle.chunk(x, chunks=3, axis=1)
>>> # out0.shape [3, 3, 5]
>>> # out1.shape [3, 3, 5]
>>> # out2.shape [3, 3, 5]


>>> # axis is negative, the real axis is (rank(x) + axis) which real
>>> # value is 1.
>>> out0, out1, out2 = paddle.chunk(x, chunks=3, axis=-2)
>>> # out0.shape [3, 3, 5]
>>> # out1.shape [3, 3, 5]
>>> # out2.shape [3, 3, 5]