tzrec/protos/seq_encoder.proto (78 lines of code) (raw):
syntax = "proto2";
package tzrec.protos;
import "tzrec/protos/module.proto";
message DINEncoder {
// seq encoder name
optional string name = 1;
// sequence feature name
required string input = 2;
// mlp config for target attention score
required MLP attn_mlp = 3;
}
message SimpleAttention {
// seq encoder name
optional string name = 1;
// sequence feature name
required string input = 2;
}
message PoolingEncoder {
// seq encoder name
optional string name = 1;
// sequence feature name
required string input = 2;
// pooling type, sum or mean
optional string pooling_type = 3 [default = 'mean'];
}
message MultiWindowDINEncoder {
// seq encoder name
optional string name = 1;
// sequence feature name
required string input = 2;
// time windows len
required MLP attn_mlp = 3;
// mlp config for target attention score
repeated uint32 windows_len = 4;
}
message HSTUEncoder {
// seq encoder name
optional string name = 1;
// sequence feature name
required string input = 2;
// sequence dimension
optional int32 sequence_dim = 3;
// attention dimension
optional int32 attn_dim = 4 [default = 64];
// linear dimension
optional int32 linear_dim = 5 [default = 64];
// maximum sequence length
optional int32 max_seq_length = 6 [default = 0];
// dropout rate for positional embeddings
optional float pos_dropout_rate = 7 [default = 0.2];
// dropout rate for linear layers
optional float linear_dropout_rate = 8 [default = 0.2];
// dropout rate for attention
optional float attn_dropout_rate = 9 [default = 0.0];
// normalization type, currently only support rel_bias
optional string normalization = 10 [default = "rel_bias"];
// activation function for linear layers, currently only support silu
optional string linear_activation = 11 [default = "silu"];
// linear configuration type, currently only support uvqk
optional string linear_config = 12 [default = "uvqk"];
// number of attention heads
optional int32 num_heads = 13 [default = 4];
// number of transformer blocks
optional int32 num_blocks = 14 [default = 4];
// maximum output sequence length
optional int32 max_output_len = 15 [default = 2];
// size of time buckets for relative attention
optional int32 time_bucket_size = 16 [default = 128];
}
message SeqEncoderConfig {
oneof seq_module {
DINEncoder din_encoder = 1;
SimpleAttention simple_attention = 2;
PoolingEncoder pooling_encoder = 3;
MultiWindowDINEncoder multi_window_din_encoder = 4;
HSTUEncoder hstu_encoder = 5;
}
}