Keras学习笔记II-函数式模型

Keras函数式模型接口是用户定义多输出模型、非循环有向模型或具有共享层的模型等复杂模型的途径。这种模型的使用方式和函数式编程很像,下面看一个全连接网络的例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
from keras.layers import Input, Dense
from keras.models import Model

inputs = Input(shape=(784, ))

x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
predictions = Dense(10, activation='softmax')(x)

model = Model(inputs=inputs, outputs=preditions)

model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])

model.fit(data, labels)

所有的模型都是可调用的,像层一样

1
2
x = Input(shape=(784, ))
y = model(x)
  • 例子

    判断两条微博是否出自同一用户,会使用到共享层

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    import keras
    from keras.layers import Input, LSTM, Dense
    from keras.models import Model

    tweet_a = Input(shape=(140,256))
    tweet_b = Input(shape=(140, 256))

    shared_lstm = LSTM(64)

    encoded_a = shared_lstm(tweet_a)
    encoded_b = shared_lstm(tweet_b)

    merged_vector = keras.layers.concatenate([encoded_a, encoded_b], axis=1)
    prediction = Dense(1, activation='sigmoid')(merged_vector)

    model = Model(inputs=[tweet_a, tweet_b], outputs=prediction)
    model.compile(optimizer='rmsprop', loss='binary_crossentropy', mertics=['accuracy'])
    model.fit([data_a, data_b], labels, epochs=10)