in lib/Graph/Graph.cpp [4912:5284]
void Function::createOnnxLSTM(llvm::StringRef namePrefix, NodeValue X,
NodeValue W, NodeValue R, NodeValue B,
NodeValue initial_h, NodeValue initial_c,
NodeValue P, NodeValue &Y, NodeValue &Y_h,
NodeValue &Y_c, unsigned hiddenSize,
RnnDirection direction,
std::vector<RnnActivation> &activations,
bool inputForget) {
#define LSTM_X_SLICE_RANGE(idx) \
{idx + 0, 0, 0}, { idx + 1, batchSize, inputSize }
#define LSTM_H_SLICE_RANGE(idx) \
{idx + 0, 0, 0}, { idx + 1, batchSize, hiddenSize }
#define LSTM_C_SLICE_RANGE(idx) \
{idx + 0, 0, 0}, { idx + 1, batchSize, hiddenSize }
#define LSTM_W_SLICE_RANGE(idx0, idx1) \
{idx0, idx1 * hiddenSize, 0}, { idx0 + 1, (idx1 + 1) * hiddenSize, inputSize }
#define LSTM_R_SLICE_RANGE(idx0, idx1) \
{idx0, idx1 * hiddenSize, 0}, { \
idx0 + 1, (idx1 + 1) * hiddenSize, hiddenSize \
}
#define LSTM_B_SLICE_RANGE(idx0, idx1) \
{idx0, idx1 * hiddenSize}, { idx0 + 1, (idx1 + 1) * hiddenSize }
#define LSTM_P_SLICE_RANGE(idx0, idx1) \
{idx0, idx1 * hiddenSize}, { idx0 + 1, (idx1 + 1) * hiddenSize }
#define LSTM_CREATE_FC(name, LHS, RHS, BIAS) \
BIAS ? (Node *)createFullyConnected(name, LHS, RHS, BIAS) \
: (Node *)createMatMul(name, LHS, RHS)
// Operator name.
const std::string &opName = namePrefix.str();
// Get all size parameters.
dim_t numDirections = (direction == RnnDirection::Bidirectional) ? 2 : 1;
assert(X.dims().size() == 3 &&
"ONNX LSTM input 'X' should have 3 dimensions!");
dim_t seqLength = X.dims()[0];
dim_t batchSize = X.dims()[1];
dim_t inputSize = X.dims()[2];
// Validate W size.
assert(W.dims().size() == 3 &&
"ONNX LSTM input 'W' should have 3 dimensions!");
assert(W.dims()[0] == numDirections && W.dims()[1] == 4 * hiddenSize &&
W.dims()[2] == inputSize && "ONNX LSTM 'W' tensor size invalid!");
// Validate R size.
assert(R.dims().size() == 3 &&
"ONNX LSTM input 'R' should have 3 dimensions!");
assert(R.dims()[0] == numDirections && R.dims()[1] == 4 * hiddenSize &&
R.dims()[2] == hiddenSize && "ONNX LSTM 'R' tensor size invalid!");
// Validate B size.
if (B.getNode()) {
assert(B.dims().size() == 2 &&
"ONNX LSTM input 'B' should have 2 dimensions!");
assert(B.dims()[0] == numDirections && B.dims()[1] == 8 * hiddenSize &&
"ONNX LSTM 'B' tensor size invalid!");
}
// Validate initial_h size if given else create Splat with 0.
if (initial_h.getNode()) {
assert(initial_h.dims().size() == 3 &&
"ONNX LSTM input 'initial_h' should have 2 dimensions!");
assert(initial_h.dims()[0] == numDirections &&
initial_h.dims()[1] == batchSize &&
initial_h.dims()[2] == hiddenSize &&
"ONNX LSTM 'initial_h' tensor size invalid!");
} else {
auto splatTy = getParent()->uniqueType(
ElemKind::FloatTy, {numDirections, batchSize, hiddenSize});
initial_h = createSplat(opName + ".initial_h", splatTy, 0.0);
}
// Validate initial_c size if given else create Splat with 0.
if (initial_c.getNode()) {
assert(initial_c.dims().size() == 3 &&
"ONNX LSTM input 'initial_c' should have 2 dimensions!");
assert(initial_c.dims()[0] == numDirections &&
initial_c.dims()[1] == batchSize &&
initial_c.dims()[2] == hiddenSize &&
"ONNX LSTM 'initial_c' tensor size invalid!");
} else {
auto splatTy = getParent()->uniqueType(
ElemKind::FloatTy, {numDirections, batchSize, hiddenSize});
initial_c = createSplat(opName + ".initial_c", splatTy, 0.0);
}
// Validate P size.
if (P.getNode()) {
assert(P.dims().size() == 2 &&
"ONNX LSTM input 'P' should have 2 dimensions!");
assert(P.dims()[0] == numDirections && P.dims()[1] == 3 * hiddenSize &&
"ONNX LSTM 'P' tensor size invalid!");
}
// Validate number of activations.
assert(activations.size() == numDirections * 3 &&
"ONNX LSTM activations vector invalid!");
// Create X slices.
std::vector<Node *> Xslices;
for (dim_t t = 0; t < seqLength; t++) {
auto XsliceName = opName + ".X" + std::to_string(t) + ".slice";
Node *Xt = createSlice(XsliceName, X, LSTM_X_SLICE_RANGE(t));
auto XreshapeName = opName + ".X" + std::to_string(t) + ".reshape";
Xt = createReshape(XreshapeName, Xt, {batchSize, inputSize});
Xslices.push_back(Xt);
}
// Lambda to load forward/backward LSTM cell.
auto loadLSTMCell = [&](bool forward, std::vector<NodeValue> &Yslices,
NodeValue &Hslice, NodeValue &Cslice) {
// Name prefix.
std::string dirLabel = forward ? ".fw" : ".bw";
std::string prefix = opName + ((numDirections > 1) ? dirLabel : "");
// Slice index used for creating weights slices.
dim_t sliceIdx0 = 0;
if (direction == RnnDirection::Bidirectional) {
sliceIdx0 = forward ? 0 : 1;
}
// Activations.
size_t activationOffset = sliceIdx0 * 3;
auto activationF = activations[activationOffset + 0];
auto activationG = activations[activationOffset + 1];
auto activationH = activations[activationOffset + 2];
// Create W slices (Required).
NodeValue Wi =
createSlice(prefix + ".Wi.", W, LSTM_W_SLICE_RANGE(sliceIdx0, 0));
NodeValue Wo =
createSlice(prefix + ".Wo.", W, LSTM_W_SLICE_RANGE(sliceIdx0, 1));
NodeValue Wf =
createSlice(prefix + ".Wf.", W, LSTM_W_SLICE_RANGE(sliceIdx0, 2));
NodeValue Wc =
createSlice(prefix + ".Wc.", W, LSTM_W_SLICE_RANGE(sliceIdx0, 3));
Wi = createReshape(prefix + ".Wi.reshape", Wi, {hiddenSize, inputSize});
Wo = createReshape(prefix + ".Wo.reshape", Wo, {hiddenSize, inputSize});
Wf = createReshape(prefix + ".Wf.reshape", Wf, {hiddenSize, inputSize});
Wc = createReshape(prefix + ".Wc.reshape", Wc, {hiddenSize, inputSize});
Wi = createTranspose(prefix + ".Wi.transp", Wi, {1, 0});
Wo = createTranspose(prefix + ".Wo.transp", Wo, {1, 0});
Wf = createTranspose(prefix + ".Wf.transp", Wf, {1, 0});
Wc = createTranspose(prefix + ".Wc.transp", Wc, {1, 0});
// Create R slices (Required).
NodeValue Ri =
createSlice(prefix + ".Ri.", R, LSTM_R_SLICE_RANGE(sliceIdx0, 0));
NodeValue Ro =
createSlice(prefix + ".Ro.", R, LSTM_R_SLICE_RANGE(sliceIdx0, 1));
NodeValue Rf =
createSlice(prefix + ".Rf.", R, LSTM_R_SLICE_RANGE(sliceIdx0, 2));
NodeValue Rc =
createSlice(prefix + ".Rc.", R, LSTM_R_SLICE_RANGE(sliceIdx0, 3));
Ri = createReshape(prefix + ".Ri.reshape", Ri, {hiddenSize, hiddenSize});
Ro = createReshape(prefix + ".Ro.reshape", Ro, {hiddenSize, hiddenSize});
Rf = createReshape(prefix + ".Rf.reshape", Rf, {hiddenSize, hiddenSize});
Rc = createReshape(prefix + ".Rc.reshape", Rc, {hiddenSize, hiddenSize});
Ri = createTranspose(prefix + ".Ri.transp", Ri, {1, 0});
Ro = createTranspose(prefix + ".Ro.transp", Ro, {1, 0});
Rf = createTranspose(prefix + ".Rf.transp", Rf, {1, 0});
Rc = createTranspose(prefix + ".Rc.transp", Rc, {1, 0});
// Create B slices (optional).
NodeValue bWi = nullptr;
NodeValue bWo = nullptr;
NodeValue bWf = nullptr;
NodeValue bWc = nullptr;
NodeValue bRi = nullptr;
NodeValue bRo = nullptr;
NodeValue bRf = nullptr;
NodeValue bRc = nullptr;
if (B) {
bWi = createSlice(prefix + ".bWi.", B, LSTM_B_SLICE_RANGE(sliceIdx0, 0));
bWo = createSlice(prefix + ".bWo.", B, LSTM_B_SLICE_RANGE(sliceIdx0, 1));
bWf = createSlice(prefix + ".bWf.", B, LSTM_B_SLICE_RANGE(sliceIdx0, 2));
bWc = createSlice(prefix + ".bWc.", B, LSTM_B_SLICE_RANGE(sliceIdx0, 3));
bRi = createSlice(prefix + ".bRi.", B, LSTM_B_SLICE_RANGE(sliceIdx0, 4));
bRo = createSlice(prefix + ".bRo.", B, LSTM_B_SLICE_RANGE(sliceIdx0, 5));
bRf = createSlice(prefix + ".bRf.", B, LSTM_B_SLICE_RANGE(sliceIdx0, 6));
bRc = createSlice(prefix + ".bRc.", B, LSTM_B_SLICE_RANGE(sliceIdx0, 7));
bWi = createReshape(prefix + ".bWi.reshape", bWi, {hiddenSize});
bWo = createReshape(prefix + ".bWo.reshape", bWo, {hiddenSize});
bWf = createReshape(prefix + ".bWf.reshape", bWf, {hiddenSize});
bWc = createReshape(prefix + ".bWc.reshape", bWc, {hiddenSize});
bRi = createReshape(prefix + ".bRi.reshape", bRi, {hiddenSize});
bRo = createReshape(prefix + ".bRo.reshape", bRo, {hiddenSize});
bRf = createReshape(prefix + ".bRf.reshape", bRf, {hiddenSize});
bRc = createReshape(prefix + ".bRc.reshape", bRc, {hiddenSize});
}
// Create P slices (optional).
NodeValue Pi = nullptr;
NodeValue Po = nullptr;
NodeValue Pf = nullptr;
if (P) {
Pi = createSlice(prefix + ".Pi.", P, LSTM_P_SLICE_RANGE(sliceIdx0, 0));
Po = createSlice(prefix + ".Po.", P, LSTM_P_SLICE_RANGE(sliceIdx0, 1));
Pf = createSlice(prefix + ".Pf.", P, LSTM_P_SLICE_RANGE(sliceIdx0, 2));
// Repeat P slices to match [batchSize, hiddenSize].
Pi = createTile(prefix + ".Pi.repeat", Pi, batchSize, 0);
Po = createTile(prefix + ".Po.repeat", Po, batchSize, 0);
Pf = createTile(prefix + ".Pf.repeat", Pf, batchSize, 0);
}
// Create H slice for this direction.
Node *Hinit = createSlice(prefix + ".H.slice", initial_h,
LSTM_H_SLICE_RANGE(sliceIdx0));
Hinit =
createReshape(prefix + ".H.reshape", Hinit, {batchSize, hiddenSize});
// Create C slice for this direction.
Node *Cinit = createSlice(prefix + ".C.slice", initial_c,
LSTM_C_SLICE_RANGE(sliceIdx0));
Cinit =
createReshape(prefix + ".C.reshape", Cinit, {batchSize, hiddenSize});
// Initialize.
Node *Ht = Hinit;
Node *Ct = Cinit;
// Unroll LSTM cell for all time steps.
for (size_t t = 0; t < seqLength; t++) {
// Input for current time step.
// For the reverse LSTM cell the inputs are provided in reverse order.
Node *Xt = forward ? Xslices[t] : Xslices[seqLength - 1 - t];
// Forget gate: ft = f(Xt * Wf + bWf + Ht-1 * Rf + bRf + Pf . Ct-1).
Node *ft = createAdd(prefix + ".F.add1",
LSTM_CREATE_FC(prefix + ".F.fc1", Xt, Wf, bWf),
LSTM_CREATE_FC(prefix + ".F.fc2", Ht, Rf, bRf));
if (Pf) {
ft = createAdd(prefix + ".F.add2", ft,
createMul(prefix + ".F.mult", Pf, Ct));
}
ft = activationF(prefix + ".F.act", ft);
// Cell state candidate: ctild = g(Xt * Wc + bWc + Ht-1 * Rc + bRc).
Node *ctild =
createAdd(prefix + ".Ctild.add",
LSTM_CREATE_FC(prefix + ".Ctild.fc1", Xt, Wc, bWc),
LSTM_CREATE_FC(prefix + ".Ctild.fc2", Ht, Rc, bRc));
ctild = activationG(prefix + ".Ctild.act", ctild);
// Input gate:
// For inputForget == true:
// it = 1 - ft.
// For inputForget == false:
// it = f(Xt * Wi + bWi + Ht-1 * Ri + bRi + Pi . Ct-1).
Node *it;
if (inputForget) {
auto splatTy = ft->getNthResult(0).getType();
it = createSub(prefix + ".I.sub",
createSplat(prefix + ".I.splat", splatTy, 1.0), ft);
} else {
it = createAdd(prefix + ".I.add1",
LSTM_CREATE_FC(prefix + ".I.fc1", Xt, Wi, bWi),
LSTM_CREATE_FC(prefix + ".I.fc2", Ht, Ri, bRi));
if (Pi) {
it = createAdd(prefix + ".I.add2", it,
createMul(prefix + ".I.mult", Pi, Ct));
}
it = activationF(prefix + ".I.act", it);
}
// Cell state update: Ct = ft . Ct-1 + it . ctild.
Ct = createAdd(prefix + ".C.add", createMul(prefix + ".C.mult1", ft, Ct),
createMul(prefix + ".C.mult2", it, ctild));
// Output gate: ot = f(Xt * Wo + bWo + Ht-1 * Ro + bRo + Po . Ct).
Node *ot = createAdd(prefix + ".O.add1",
LSTM_CREATE_FC(prefix + ".O.fc1", Xt, Wo, bWo),
LSTM_CREATE_FC(prefix + ".O.fc2", Ht, Ro, bRo));
if (Po) {
ot = createAdd(prefix + ".O.add2", ot,
createMul(prefix + ".O.mult", Po, Ct));
}
ot = activationF(prefix + ".O.act", ot);
// Hidden state update: Ht = ot . h(Ct).
Ht =
createMul(prefix + ".H.mult", ot, activationH(prefix + ".H.act", Ct));
// Output.
Yslices.push_back(Ht);
}
// Updated states nodes.
Hslice = Ht;
Cslice = Ct;
}; // End of local lambda "loadLSTMCell".
bool forwardEnabled = ((direction == RnnDirection::Forward) ||
(direction == RnnDirection::Bidirectional));
bool backwardEnabled = ((direction == RnnDirection::Reverse) ||
(direction == RnnDirection::Bidirectional));
std::vector<NodeValue> YSlices;
std::vector<NodeValue> Hslices;
std::vector<NodeValue> Cslices;
// Load forward LSTM.
std::vector<NodeValue> forwardYslices;
if (forwardEnabled) {
NodeValue forwardHslice;
NodeValue forwardCslice;
loadLSTMCell(/* forward */ true, forwardYslices, forwardHslice,
forwardCslice);
Hslices.push_back(forwardHslice);
Cslices.push_back(forwardCslice);
}
// Load backward LSTM.
std::vector<NodeValue> backwardYslices;
if (backwardEnabled) {
NodeValue backwardHslice;
NodeValue backwardCslice;
loadLSTMCell(/* forward */ false, backwardYslices, backwardHslice,
backwardCslice);
Hslices.push_back(backwardHslice);
Cslices.push_back(backwardCslice);
}
// Gather Y slices.
for (size_t t = 0; t < seqLength; t++) {
if (forwardEnabled) {
YSlices.push_back(forwardYslices[t]);
}
if (backwardEnabled) {
YSlices.push_back(backwardYslices[seqLength - 1 - t]);
}
}
// Concatenate Y slices.
// Y size is [seqLength, numDirections, batchSize, hiddenSize].
Y = createReshape(opName + ".Y.reshape",
createConcat(opName + ".Y.concat", YSlices, 0),
{seqLength, numDirections, batchSize, hiddenSize});
// Concatenate Y_h slices.
// Y_h size is [numDirections, batchSize, hiddenSize].
Y_h = createReshape(opName + ".Y_h.reshape",
createConcat(opName + ".Y_h.concat", Hslices, 0),
{numDirections, batchSize, hiddenSize});
// Concatenate Y_c slices.
// Y_c size is [numDirections, batchSize, hiddenSize].
Y_c = createReshape(opName + ".Y_c.reshape",
createConcat(opName + ".Y_c.concat", Cslices, 0),
{numDirections, batchSize, hiddenSize});
#undef LSTM_X_SLICE_RANGE
#undef LSTM_H_SLICE_RANGE
#undef LSTM_C_SLICE_RANGE
#undef LSTM_W_SLICE_RANGE
#undef LSTM_R_SLICE_RANGE
#undef LSTM_B_SLICE_RANGE
#undef LSTM_P_SLICE_RANGE
#undef LSTM_CREATE_FC
}