3차원 이상 배열 곱셈
00. Intro
데이터 프레임워크를 사용하다보면 행렬 곱셈의 차원(shape)이 매번 헷갈린다. 2차원은 간단하지만, 3차원부터는 자꾸 찾아보게 된다. 그래서 그냥 내가 보려고 정리를 해두었다.
01. 2차원 곱셈
일단 기본 numpy 행렬곱으로 shape 특성을 이해해본다. 아래 예시에서 확인할 수 있듯이, 2차원의 행렬에서 shape은 (행, 열)을 의미한다. 이 상태에서는 크게 어려울 것이 없다.
import numpy as np
a = np.array([[1, 0],
[0, 1],
[1, 1]])
print(a.shape) # (3, 2)
b = np.array([[4, 1],
[2, 2]])
print(b.shape) # (2, 2)
c = np.matmul(a, b)
print(c)
#array([[4, 1],
# [2, 2],
# [6, 3]])
print(c.shape) # (3, 2)
02. 3차원 이상 곱셈
여기서부터 매번 헷갈리기 시작한다. 하지만, 데이터분석 framework 상에서 계산들은 기본적으로 2차원 연산을 상위 차원에서 반복한다. 다음 코드를 보면 두 개의 3차원 내에서 (3,4)◦(4,3) 곱셈을 진행한다. 그러므로 행렬곱의 결과 차원은 (2, 3, 3)이 된다.
- 즉 1차원과 2차원은 행렬곱이 가능한 shape이어야 하고,
- 3차원 이상의 값은 곱하려는 행렬이 서로 같아야 한다.
import numpy as np
a = np.arange(24).reshape(2,3,4)
a
#array([[[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11]],
#
# [[12, 13, 14, 15],
# [16, 17, 18, 19],
# [20, 21, 22, 23]]])
b = np.arange(24).reshape(2,4,3)
b
#array([[[ 0, 1, 2],
# [ 3, 4, 5],
# [ 6, 7, 8],
# [ 9, 10, 11]],
#
# [[12, 13, 14],
# [15, 16, 17],
# [18, 19, 20],
# [21, 22, 23]]])
np.matmul(a,b)
#array([[[ 42, 48, 54],
# [ 114, 136, 158],
# [ 186, 224, 262]],
#
# [[ 906, 960, 1014],
# [1170, 1240, 1310],
# [1434, 1520, 1606]]])
np.matmul(a,b).shape
# (2, 3, 3)
03. 이론적 3차원 곱셈
데이터분석 프레임워크를 사용할 때는 알 필요 없는, 이론적인 3차원 곱셈은 그 방법론만 소개한다.
height, widht, length차원을 각각 d1, d2 d3로 정의하면, 다음과 같은 결과를 만들어 낸다고 한다.
- outer prod: (d1, d3), (d3, d2), (d2, d1)
- inner prod: (d3, d1), (d2, d3), (d1, d2)
그러므로 shape(1, 1, 2) 행렬과 shape(1, 2, 1) 행렬의 곱셈 결과는 shape(1, 2, 2)가 된다.