未验证 提交 66e9406d 编写于 作者: Y yaoxuefeng 提交者: GitHub

fix AutoInt (#133)

* fix AutoInt

* bug fix
上级 5f6ab95e
......@@ -36,8 +36,9 @@ hyper_parameters:
class: SGD
learning_rate: 0.0001
sparse_feature_number: 1086460
sparse_feature_dim: 9
sparse_feature_dim: 96
num_field: 39
d_model: 96
d_key: 16
d_value: 16
n_head: 6
......
......@@ -31,6 +31,7 @@ class Model(ModelBase):
"hyper_parameters.sparse_feature_dim", None)
self.num_field = envs.get_global_env("hyper_parameters.num_field",
None)
self.d_model = envs.get_global_env("hyper_parameters.d_model", None)
self.d_key = envs.get_global_env("hyper_parameters.d_key", None)
self.d_value = envs.get_global_env("hyper_parameters.d_value", None)
self.n_head = envs.get_global_env("hyper_parameters.n_head", None)
......@@ -40,7 +41,7 @@ class Model(ModelBase):
"hyper_parameters.n_interacting_layers", 1)
def multi_head_attention(self, queries, keys, values, d_key, d_value,
n_head, dropout_rate):
d_model, n_head, dropout_rate):
keys = queries if keys is None else keys
values = keys if values is None else values
if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3
......@@ -126,9 +127,8 @@ class Model(ModelBase):
q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value)
q, k, v = __split_heads_qkv(q, k, v, n_head, d_key, d_value)
d_model = d_key * n_head
ctx_multiheads = scaled_dot_product_attention(q, k, v, d_model,
ctx_multiheads = scaled_dot_product_attention(q, k, v, self.d_model,
dropout_rate)
out = __combine_heads(ctx_multiheads)
......@@ -136,16 +136,14 @@ class Model(ModelBase):
return out
def interacting_layer(self, x):
attention_out = self.multi_head_attention(x, None, None, self.d_key,
self.d_value, self.n_head,
self.dropout_rate)
attention_out = self.multi_head_attention(
x, None, None, self.d_key, self.d_value, self.d_model, self.n_head,
self.dropout_rate)
W_0_x = fluid.layers.fc(input=x,
size=self.d_key * self.n_head,
size=self.d_model,
bias_attr=False,
num_flatten_dims=2)
res_out = fluid.layers.relu(attention_out + W_0_x)
self.d_key = self.d_key * self.n_head
self.d_value = self.d_value * self.n_head
return res_out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册