目的
Python numpyを使って、N次元空間内のK個の位置とL個の位置の間の距離を一括で計算します。 2重ループは使いません。計算方法
最初にnumpyをimportします。
import numpy as np
# Make 3 points in 2D (embedded into 1x3 matrix)
p1 = np.array([[[1,2],[3,4],[5,6]]])
p1.shape = (1, 3, 2)となります。数式っぽく書くと、 \[ p_1 = \begin{pmatrix} (1, 2) && (3, 4) && (5, 6) \end{pmatrix} \] です。 同様に、N次元の位置を要素としてL個含む列ベクトルを作ります。ここでは、N=2, L=3としました。
# Make 5 points in 2D (embedded into 5x1 matrix)
p2 = np.array([[[0,1],[1,2],[2,3],[1,1],[2,2]]]).reshape(5, 1, -1)
p2.shape = (5, 1, 2)となります。数式っぽく書くと、 \[ p_1 = \begin{pmatrix} (0,1) \\ (1,2) \\ (2,3) \\ (1,1) \\ (2,2) \end{pmatrix} \] です。 行ベクトルp1を列方向に5個コピーして、5x3の行列m1を作ります。 同様に、列ベクトルp2を行方向に3個コピーして、5x3の行列m2を作ります。
# Make 5x3 matrix by copying elements
m1 = np.tile(p1, (5, 1, 1))
m2 = np.tile(p2, (1, 3, 1))
# Make vectors from p2 to p1
mv = m1 - m2
# Calculate square distances
ms = np.sum(np.square(mv), axis=2)
ms = [[ 2 18 50] [ 0 8 32] [ 2 2 18] [ 1 13 41] [ 1 5 25]]数式で書くと、 \[ M_s = \begin{pmatrix} 2&&18&&50\\ 0&&8&&32\\ 2&&2&&18\\ 1&&13&&41\\ 1&&5&&25\end{pmatrix} \] となります。 あとは各要素の平方根を求めれば、各位置間の距離が得られます。
# Calculate distances
md = np.sqrt(ms)
md = [[ 1.41421356 4.24264069 7.07106781] [ 0. 2.82842712 5.65685425] [ 1.41421356 1.41421356 4.24264069] [ 1. 3.60555128 6.40312424] [ 1. 2.23606798 5. ]]となります。数式で書くと、 \[ M_d = \begin{pmatrix} 1.41421356 && 4.24264069 && 7.07106781 \\ 0 && 2.82842712 && 5.65685425 \\ 1.41421356 && 1.41421356 && 4.24264069 \\ 1 && 3.60555128 && 6.40312424 \\ 1 && 2.23606798 && 5 \end{pmatrix} \] です。
コードまとめ
上記のコードをまとめておきます。
import numpy as np
# Make 3 points in 2D (embedded into 1x3 matrix)
p1 = np.array([[[1,2],[3,4],[5,6]]])
# Make 5 points in 2D (embedded into 5x1 matrix)
p2 = np.array([[[0,1],[1,2],[2,3],[1,1],[2,2]]]).reshape(5, 1, -1)
# Make 5x3 matrix by copying elements
m1 = np.tile(p1, (5, 1, 1))
m2 = np.tile(p2, (1, 3, 1))
# Make vectors from p2 to p1
mv = m1 - m2
# Calculate square distances
ms = np.sum(np.square(mv), axis=2)
# Calculate distances
md = np.sqrt(ms)
print("p1.shape =", p1.shape)
print("p2.shape =", p2.shape)
print("ms =", ms)
print("md =", md)
0 件のコメント :
コメントを投稿