void calc_cell_state_with_cifg()

in tensorflow/lite/micro/kernels/xtensa/lstm_eval_hifi.cc [203:375]


void calc_cell_state_with_cifg(int16_t* cell_state, const int16_t* forget_gate,
                               const int16_t* cell_gate, int shift1, int shift2,
                               int clip, int num_elms) {
  const ae_int16x8 *p16x8_cs_r, *p16x8_fg_r;
  const ae_int16x8* p16x8_cg_r;

  ae_int16x8* p16x8_cs_w;

  ae_valignx2 align_cs_r, align_fg_r;
  ae_valignx2 align_cg_r;
  ae_valignx2 align_cs_w;

  ae_int16x4 d_cs_r_0, d_cs_r_1;
  ae_int16x4 d_fg_0, d_fg_1;
  ae_int16x4 d_cg_0, d_cg_1;
  ae_int16x4 d_1mfg_0, d_1mfg_1;
  ae_int16x4 d_cs_w_0, d_cs_w_1;
  ae_int32x2 d_mul_0, d_mul_1, d_mul_2, d_mul_3;
  ae_int32x2 d_mul_4, d_mul_5, d_mul_6, d_mul_7;

  ae_int16x4 d_min, d_max, d_one;

  int i = 0;
  p16x8_cs_r = (const ae_int16x8*)cell_state;
  p16x8_fg_r = (const ae_int16x8*)forget_gate;
  p16x8_cg_r = (const ae_int16x8*)cell_gate;

  p16x8_cs_w = (ae_int16x8*)cell_state;

  align_cs_r = AE_LA128_PP(p16x8_cs_r);
  align_fg_r = AE_LA128_PP(p16x8_fg_r);
  align_cg_r = AE_LA128_PP(p16x8_cg_r);

  align_cs_w = AE_ZALIGN128();

  if (clip > 0) {
    d_min = AE_MOVDA16(-clip);
    d_max = AE_MOVDA16(clip);
  } else {
    d_min = AE_MOVDA16(-32768);
    d_max = AE_MOVDA16(32767);
  }
  d_one = AE_MOVDA16(32767);

#pragma concurrent
  if (shift1 == 15) {
    for (i = 0; i < (num_elms >> 3); i++) {
      AE_LA16X4X2_IP(d_cs_r_0, d_cs_r_1, align_cs_r, p16x8_cs_r);
      AE_LA16X4X2_IP(d_fg_0, d_fg_1, align_fg_r, p16x8_fg_r);
      AE_LA16X4X2_IP(d_cg_0, d_cg_1, align_cg_r, p16x8_cg_r);

      d_cs_w_0 = AE_MULFP16X4RS(d_cs_r_0, d_fg_0);
      d_cs_w_1 = AE_MULFP16X4RS(d_cs_r_1, d_fg_1);

      d_1mfg_0 = AE_SUB16S(d_one, d_fg_0);
      d_1mfg_1 = AE_SUB16S(d_one, d_fg_1);
      AE_MUL16X4(d_mul_4, d_mul_5, d_cg_0, d_1mfg_0);
      AE_MUL16X4(d_mul_6, d_mul_7, d_cg_1, d_1mfg_1);
      d_mul_4 = AE_SRAA32SYMS(d_mul_4, shift2);
      d_mul_5 = AE_SRAA32SYMS(d_mul_5, shift2);
      d_mul_6 = AE_SRAA32SYMS(d_mul_6, shift2);
      d_mul_7 = AE_SRAA32SYMS(d_mul_7, shift2);
      d_cg_0 = AE_SAT16X4(d_mul_4, d_mul_5);
      d_cg_1 = AE_SAT16X4(d_mul_6, d_mul_7);

      d_cs_w_0 = AE_ADD16S(d_cs_w_0, d_cg_0);
      d_cs_w_1 = AE_ADD16S(d_cs_w_1, d_cg_1);

      AE_MINMAX16(d_cs_w_0, d_min, d_max);
      AE_MINMAX16(d_cs_w_1, d_min, d_max);

      AE_SA16X4X2_IP(d_cs_w_0, d_cs_w_1, align_cs_w, p16x8_cs_w);
    }
    AE_SA128POS_FP(align_cs_w, p16x8_cs_w);  // finalize the stream

    const ae_int16 *p16_cs_r, *p16_fg_r;
    const ae_int16* p16_cg_r;

    ae_int16* p16_cs_w;

    p16_cs_r = (const ae_int16*)p16x8_cs_r;
    p16_fg_r = (const ae_int16*)p16x8_fg_r;
    p16_cg_r = (const ae_int16*)p16x8_cg_r;

    p16_cs_w = (ae_int16*)p16x8_cs_w;
// residue iterations
#pragma concurrent
#pragma loop_count max = 7
    for (i = 0; i < ((num_elms)&7); i++) {
      d_cs_r_0 = p16_cs_r[i];
      d_fg_0 = p16_fg_r[i];
      d_cg_0 = p16_cg_r[i];

      d_cs_w_0 = AE_MULFP16X4RS(d_cs_r_0, d_fg_0);

      d_1mfg_0 = AE_SUB16S(d_one, d_fg_0);
      AE_MUL16X4(d_mul_0, d_mul_1, d_cg_0, d_1mfg_0);
      d_mul_0 = AE_SRAA32SYMS(d_mul_0, shift2);
      d_cg_0 = AE_SAT16X4(d_mul_0, d_mul_1);

      d_cs_w_0 = AE_ADD16S(d_cs_w_0, d_cg_0);
      AE_MINMAX16(d_cs_w_0, d_min, d_max);
      p16_cs_w[i] = d_cs_w_0;
    }
  } else {
    for (i = 0; i < (num_elms >> 3); i++) {
      AE_LA16X4X2_IP(d_cs_r_0, d_cs_r_1, align_cs_r, p16x8_cs_r);
      AE_LA16X4X2_IP(d_fg_0, d_fg_1, align_fg_r, p16x8_fg_r);
      AE_LA16X4X2_IP(d_cg_0, d_cg_1, align_cg_r, p16x8_cg_r);

      AE_MUL16X4(d_mul_0, d_mul_1, d_cs_r_0, d_fg_0);
      AE_MUL16X4(d_mul_2, d_mul_3, d_cs_r_1, d_fg_1);
      d_mul_0 = AE_SRAA32SYMS(d_mul_0, shift1);
      d_mul_1 = AE_SRAA32SYMS(d_mul_1, shift1);
      d_mul_2 = AE_SRAA32SYMS(d_mul_2, shift1);
      d_mul_3 = AE_SRAA32SYMS(d_mul_3, shift1);
      d_cs_w_0 = AE_SAT16X4(d_mul_0, d_mul_1);
      d_cs_w_1 = AE_SAT16X4(d_mul_2, d_mul_3);

      d_1mfg_0 = AE_SUB16S(d_one, d_fg_0);
      d_1mfg_1 = AE_SUB16S(d_one, d_fg_1);
      AE_MUL16X4(d_mul_4, d_mul_5, d_cg_0, d_1mfg_0);
      AE_MUL16X4(d_mul_6, d_mul_7, d_cg_1, d_1mfg_1);
      d_mul_4 = AE_SRAA32SYMS(d_mul_4, shift2);
      d_mul_5 = AE_SRAA32SYMS(d_mul_5, shift2);
      d_mul_6 = AE_SRAA32SYMS(d_mul_6, shift2);
      d_mul_7 = AE_SRAA32SYMS(d_mul_7, shift2);
      d_cg_0 = AE_SAT16X4(d_mul_4, d_mul_5);
      d_cg_1 = AE_SAT16X4(d_mul_6, d_mul_7);

      d_cs_w_0 = AE_ADD16S(d_cs_w_0, d_cg_0);
      d_cs_w_1 = AE_ADD16S(d_cs_w_1, d_cg_1);

      AE_MINMAX16(d_cs_w_0, d_min, d_max);
      AE_MINMAX16(d_cs_w_1, d_min, d_max);

      AE_SA16X4X2_IP(d_cs_w_0, d_cs_w_1, align_cs_w, p16x8_cs_w);
    }
    AE_SA128POS_FP(align_cs_w, p16x8_cs_w);  // finalize the stream

    const ae_int16 *p16_cs_r, *p16_fg_r;
    const ae_int16* p16_cg_r;

    ae_int16* p16_cs_w;

    p16_cs_r = (const ae_int16*)p16x8_cs_r;
    p16_fg_r = (const ae_int16*)p16x8_fg_r;
    p16_cg_r = (const ae_int16*)p16x8_cg_r;

    p16_cs_w = (ae_int16*)p16x8_cs_w;
// residue iterations
#pragma concurrent
#pragma loop_count max = 7
    for (i = 0; i < ((num_elms)&7); i++) {
      d_cs_r_0 = p16_cs_r[i];
      d_fg_0 = p16_fg_r[i];
      d_cg_0 = p16_cg_r[i];

      AE_MUL16X4(d_mul_0, d_mul_1, d_cs_r_0, d_fg_0);
      d_mul_0 = AE_SRAA32SYMS(d_mul_0, shift1);
      d_cs_w_0 = AE_SAT16X4(d_mul_0, d_mul_1);

      d_1mfg_0 = AE_SUB16S(d_one, d_fg_0);
      AE_MUL16X4(d_mul_0, d_mul_1, d_cg_0, d_1mfg_0);
      d_mul_0 = AE_SRAA32SYMS(d_mul_0, shift2);
      d_cg_0 = AE_SAT16X4(d_mul_0, d_mul_1);

      d_cs_w_0 = AE_ADD16S(d_cs_w_0, d_cg_0);
      AE_MINMAX16(d_cs_w_0, d_min, d_max);
      p16_cs_w[i] = d_cs_w_0;
    }
  }
}