Numpy split()函数是用于将一个numpy数组沿着指定轴(axis)分割成多个子数组。其语法如下:
numpy.split(ary, indices_or_sections, axis=0)
其中,参数ary是要分割的numpy数组,indices_or_sections用于指定分割点或分割段的数量或者位置,axis是指定分割轴的方向。返回值是分割后的子数组列表。
下面我们通过两个实例来说明split()函数的使用方法。
我们首先创建一个3行4列的数组a:
import numpy as np
a = np.array([[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]])
我们可以通过split()函数将这个数组按行分割成三个子数组:
b = np.split(a, 3, axis=0)
print(b)
输出结果如下:
[array([[1, 2, 3, 4]]), array([[5, 6, 7, 8]]), array([[ 9, 10, 11, 12]])]
这里,我们指定分割点为3,即按照行数将数组a分割成三个子数组,并且指定了分割轴axis=0。输出结果中每一个子数组都是一个numpy数组。
我们可以通过split()函数将数组跨列分割成多个子数组,例如:
c = np.array([1, 2, 3, 4, 5, 6, 7, 8])
d = np.split(c, [3, 5])
print(d)
输出结果如下:
[array([1, 2, 3]), array([4, 5]), array([6, 7, 8])]
这里,我们通过指定分割点[3, 5],将数组c分割成三个子数组。第一个子数组包含原数组的前三个元素,第二个子数组包含第四个和第五个元素,最后一个子数组包含后三个元素。
总而言之,split()函数是一种非常实用的数组处理函数,可以满足快速对numpy数组进行分割的需求。
本文链接:http://task.lmcjl.com/news/17057.html