tzrec/protos/seq_encoder.proto (78 lines of code) (raw):

syntax = "proto2"; package tzrec.protos; import "tzrec/protos/module.proto"; message DINEncoder { // seq encoder name optional string name = 1; // sequence feature name required string input = 2; // mlp config for target attention score required MLP attn_mlp = 3; } message SimpleAttention { // seq encoder name optional string name = 1; // sequence feature name required string input = 2; } message PoolingEncoder { // seq encoder name optional string name = 1; // sequence feature name required string input = 2; // pooling type, sum or mean optional string pooling_type = 3 [default = 'mean']; } message MultiWindowDINEncoder { // seq encoder name optional string name = 1; // sequence feature name required string input = 2; // time windows len required MLP attn_mlp = 3; // mlp config for target attention score repeated uint32 windows_len = 4; } message HSTUEncoder { // seq encoder name optional string name = 1; // sequence feature name required string input = 2; // sequence dimension optional int32 sequence_dim = 3; // attention dimension optional int32 attn_dim = 4 [default = 64]; // linear dimension optional int32 linear_dim = 5 [default = 64]; // maximum sequence length optional int32 max_seq_length = 6 [default = 0]; // dropout rate for positional embeddings optional float pos_dropout_rate = 7 [default = 0.2]; // dropout rate for linear layers optional float linear_dropout_rate = 8 [default = 0.2]; // dropout rate for attention optional float attn_dropout_rate = 9 [default = 0.0]; // normalization type, currently only support rel_bias optional string normalization = 10 [default = "rel_bias"]; // activation function for linear layers, currently only support silu optional string linear_activation = 11 [default = "silu"]; // linear configuration type, currently only support uvqk optional string linear_config = 12 [default = "uvqk"]; // number of attention heads optional int32 num_heads = 13 [default = 4]; // number of transformer blocks optional int32 num_blocks = 14 [default = 4]; // maximum output sequence length optional int32 max_output_len = 15 [default = 2]; // size of time buckets for relative attention optional int32 time_bucket_size = 16 [default = 128]; } message SeqEncoderConfig { oneof seq_module { DINEncoder din_encoder = 1; SimpleAttention simple_attention = 2; PoolingEncoder pooling_encoder = 3; MultiWindowDINEncoder multi_window_din_encoder = 4; HSTUEncoder hstu_encoder = 5; } }