keras: CNNを使う時のshapeの注意点

例として、1次元のCNNを使う時を取り上げます。

1次元のデータに対して、CNNを使う時にfilter分の次元が増えます。 これの次元の存在をinputを渡す時点で作る必要があります。

  • input: 100次元のデータ
  • output: 2次元のデータ

Function APIで書きます。 (パラメーターは適当です。) 通常の多層のネットワークを作るときは、

from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model

inputs = Input(shape=(100,))
hidden = Dense(40, activation="sigmoid")(inputs)
outputs = Dense(2, activation="linear")(hidden)

model = Model(inputs=inputs, outputs=outputs)
model.compile(optimizer="adam", loss="mean_squared_error")

でokです。

CNNで行うときは注意が必要です。Reshapeを行なって、3つ目の次元を作る必要があります。 また、最後にDenseに渡す前に、Flattenし忘れないように。

from tensorflow.keras.layers import Input, Dense, Conv1D, Flatten, Reshape
from tensorflow.keras.models import Model

inputs = Input(shape=(100,))
hidden = Reshape((100, 1))(inputs)
hidden = Conv1D(10, 3, strides=3, activation="relu")(hidden)
hidden = Flatten()(hidden)
outputs = Dense(2, activation="linear")(hidden)

model = Model(inputs=inputs, outputs=outputs)
model.compile(optimizer="adam", loss="mean_squared_error")

2次元CNN

2次元データに対してCNNを当てるときも同様の考え方で、filter方向に次元を増やして渡す必要があります。