in src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java [171:465]
public void processInstruction(ExecutionContext ec) {
String opcode = getOpcode();
ScalarObject sores = null;
if(opcode.equalsIgnoreCase(Opcodes.CDF.toString())) {
SimpleOperator op = (SimpleOperator) _optr;
double result = op.fn.execute(params);
sores = new DoubleObject(result);
ec.setScalarOutput(output.getName(), sores);
}
else if(opcode.equalsIgnoreCase(Opcodes.INVCDF.toString())) {
SimpleOperator op = (SimpleOperator) _optr;
double result = op.fn.execute(params);
sores = new DoubleObject(result);
ec.setScalarOutput(output.getName(), sores);
}
else if(opcode.equalsIgnoreCase(Opcodes.AUTODIFF.toString()))
{
ArrayList<Data> lineage = (ArrayList<Data>) ec.getListObject(params.get("lineage")).getData();
MatrixObject mo = ec.getMatrixObject(params.get("output"));
ListObject diffs = AutoDiff.getBackward(mo, lineage, ExecutionContextFactory.createContext());
ec.setVariable(output.getName(), diffs);
}
else if(opcode.equalsIgnoreCase(Opcodes.GROUPEDAGG.toString())) {
// acquire locks
MatrixBlock target = ec.getMatrixInput(params.get(Statement.GAGG_TARGET));
MatrixBlock groups = ec.getMatrixInput(params.get(Statement.GAGG_GROUPS));
MatrixBlock weights = null;
if(params.get(Statement.GAGG_WEIGHTS) != null)
weights = ec.getMatrixInput(params.get(Statement.GAGG_WEIGHTS));
int ngroups = -1;
if(params.get(Statement.GAGG_NUM_GROUPS) != null) {
ngroups = (int) Double.parseDouble(params.get(Statement.GAGG_NUM_GROUPS));
}
// compute the result
int k = Integer.parseInt(params.get("k")); // num threads
MatrixBlock soresBlock = groups.groupedAggOperations(target, weights, new MatrixBlock(), ngroups, _optr, k);
ec.setMatrixOutput(output.getName(), soresBlock);
// release locks
target = groups = weights = null;
ec.releaseMatrixInput(params.get(Statement.GAGG_TARGET));
ec.releaseMatrixInput(params.get(Statement.GAGG_GROUPS));
if(params.get(Statement.GAGG_WEIGHTS) != null)
ec.releaseMatrixInput(params.get(Statement.GAGG_WEIGHTS));
}
else if(opcode.equalsIgnoreCase(Opcodes.RMEMPTY.toString())) {
String margin = params.get("margin");
if(!(margin.equals("rows") || margin.equals("cols")))
throw new DMLRuntimeException("Unspupported margin identifier '" + margin + "'.");
if(ec.isFrameObject(params.get("target"))) {
FrameBlock target = ec.getFrameInput(params.get("target"));
MatrixBlock select = params.containsKey("select") ? ec.getMatrixInput(params.get("select")) : null;
boolean emptyReturn = Boolean.parseBoolean(params.get("empty.return").toLowerCase());
FrameBlock soresBlock = target.removeEmptyOperations(margin.equals("rows"), emptyReturn, select);
ec.setFrameOutput(output.getName(), soresBlock);
ec.releaseFrameInput(params.get("target"));
if(params.containsKey("select"))
ec.releaseMatrixInput(params.get("select"));
} else {
// acquire locks
MatrixBlock target = ec.getMatrixInput(params.get("target"));
MatrixBlock select = params.containsKey("select") ? ec.getMatrixInput(params.get("select")) : null;
// compute the result
boolean emptyReturn = Boolean.parseBoolean(params.get("empty.return").toLowerCase());
MatrixBlock ret = target.removeEmptyOperations(new MatrixBlock(), margin.equals("rows"), emptyReturn, select);
// release locks
if( target == ret ) //short-circuit (avoid buffer pool pollution)
ec.setVariable(output.getName(), ec.getVariable(params.get("target")));
else
ec.setMatrixOutput(output.getName(), ret);
ec.releaseMatrixInput(params.get("target"));
if(params.containsKey("select"))
ec.releaseMatrixInput(params.get("select"));
}
}
else if(opcode.equalsIgnoreCase(Opcodes.CONTAINS.toString())) {
String varName = params.get("target");
int k = Integer.parseInt(params.get("k")); //num threads
MatrixBlock target = ec.getMatrixInput(varName);
Data pattern = ec.getVariable(params.get("pattern"));
if( pattern == null ) //literal
pattern = ScalarObjectFactory.createScalarObject(ValueType.FP64, params.get("pattern"));
boolean ret = pattern.getDataType().isScalar() ?
target.containsValue(((ScalarObject)pattern).getDoubleValue(), k) :
(target.containsVector(((MatrixObject)pattern).acquireRead(), true).size()>0);
ec.releaseMatrixInput(varName);
if(!pattern.getDataType().isScalar())
ec.releaseMatrixInput(params.get("pattern"));
ec.setScalarOutput(output.getName(), new BooleanObject(ret));
}
else if(opcode.equalsIgnoreCase(Opcodes.REPLACE.toString())) {
if(ec.isFrameObject(params.get("target"))){
FrameBlock target = ec.getFrameInput(params.get("target"));
String pattern = params.get("pattern");
String replacement = params.get("replacement");
FrameBlock ret = target.replaceOperations(pattern, replacement);
ec.setFrameOutput(output.getName(), ret);
ec.releaseFrameInput(params.get("target"));
} else{
MatrixObject targetObj = ec.getMatrixObject(params.get("target"));
MatrixBlock target = targetObj.acquireRead();
double pattern = Double.parseDouble(params.get("pattern"));
double replacement = Double.parseDouble(params.get("replacement"));
MatrixBlock ret = target.replaceOperations(new MatrixBlock(), pattern, replacement,
InfrastructureAnalyzer.getLocalParallelism());
if( ret == target ) //shallow copy (avoid bufferpool pollution)
ec.setVariable(output.getName(), targetObj);
else
ec.setMatrixOutput(output.getName(), ret);
targetObj.release();
}
}
else if(opcode.equals(Opcodes.LOWERTRI.toString()) || opcode.equals(Opcodes.UPPERTRI.toString())) {
MatrixBlock target = ec.getMatrixInput(params.get("target"));
boolean lower = opcode.equals(Opcodes.LOWERTRI.toString());
boolean diag = Boolean.parseBoolean(params.get("diag"));
boolean values = Boolean.parseBoolean(params.get("values"));
MatrixBlock ret = target.extractTriangular(new MatrixBlock(), lower, diag, values);
ec.setMatrixOutput(output.getName(), ret);
ec.releaseMatrixInput(params.get("target"));
}
else if(opcode.equalsIgnoreCase(Opcodes.REXPAND.toString())) {
// acquire locks
MatrixBlock target = ec.getMatrixInput(params.get("target"));
// compute the result
double maxVal = Double.parseDouble(params.get("max"));
boolean dirVal = params.get("dir").equals("rows");
boolean cast = Boolean.parseBoolean(params.get("cast"));
boolean ignore = Boolean.parseBoolean(params.get("ignore"));
int numThreads = Integer.parseInt(params.get("k"));
MatrixBlock ret = target.rexpandOperations(new MatrixBlock(), maxVal, dirVal, cast, ignore, numThreads);
// release locks
ec.setMatrixOutput(output.getName(), ret);
ec.releaseMatrixInput(params.get("target"));
}
else if(opcode.equalsIgnoreCase(Opcodes.TOKENIZE.toString())) {
// acquire locks
FrameBlock data = ec.getFrameInput(params.get("target"));
// compute tokenizer
Tokenizer tokenizer = TokenizerFactory.createTokenizer(getParameterMap().get("spec"),
Integer.parseInt(getParameterMap().get("max_tokens")));
FrameBlock fbout = tokenizer.tokenize(data, OptimizerUtils.getTokenizeNumThreads());
// release locks
ec.setFrameOutput(output.getName(), fbout);
ec.releaseFrameInput(params.get("target"));
}
else if(opcode.equalsIgnoreCase(Opcodes.TRANSFORMAPPLY.toString())) {
// acquire locks
FrameBlock data = ec.getFrameInput(params.get("target"));
FrameBlock meta = ec.getFrameInput(params.get("meta"));
MatrixBlock embeddings = params.get("embedding") != null ? ec.getMatrixInput(params.get("embedding")) : null;
String[] colNames = data.getColumnNames();
// compute transformapply
MultiColumnEncoder encoder = EncoderFactory
.createEncoder(params.get("spec"), colNames, data.getNumColumns(), meta, embeddings);
MatrixBlock mbout = encoder.apply(data, OptimizerUtils.getTransformNumThreads());
// release locks
ec.setMatrixOutput(output.getName(), mbout);
ec.releaseFrameInput(params.get("target"));
ec.releaseFrameInput(params.get("meta"));
if(params.get("embedding") != null)
ec.releaseMatrixInput(params.get("embedding"));
}
else if(opcode.equalsIgnoreCase(Opcodes.TRANSFORMDECODE.toString())) {
// acquire locks
MatrixBlock data = ec.getMatrixInput(params.get("target"));
FrameBlock meta = ec.getFrameInput(params.get("meta"));
String[] colnames = meta.getColumnNames();
// compute transformdecode
Decoder decoder = DecoderFactory
.createDecoder(getParameterMap().get("spec"), colnames, null, meta, data.getNumColumns());
FrameBlock fbout = decoder.decode(data, new FrameBlock(decoder.getSchema()));
fbout.setColumnNames(Arrays.copyOfRange(colnames, 0, fbout.getNumColumns()));
// release locks
ec.setFrameOutput(output.getName(), fbout);
ec.releaseMatrixInput(params.get("target"));
ec.releaseFrameInput(params.get("meta"));
}
else if(opcode.equalsIgnoreCase(Opcodes.TRANSFORMCOLMAP.toString())) {
// acquire locks
FrameBlock meta = ec.getFrameInput(params.get("target"));
String[] colNames = meta.getColumnNames();
// compute transformapply
MultiColumnEncoder encoder = EncoderFactory
.createEncoder(params.get("spec"), colNames, meta.getNumColumns(), null, null);
MatrixBlock mbout = encoder.getColMapping(meta);
// release locks
ec.setMatrixOutput(output.getName(), mbout);
ec.releaseFrameInput(params.get("target"));
}
else if(opcode.equalsIgnoreCase(Opcodes.TRANSFORMMETA.toString())) {
// get input spec and path
String spec = getParameterMap().get("spec");
String path = getParameterMap().get(ParameterizedBuiltinFunctionExpression.TF_FN_PARAM_MTD);
String delim = getParameterMap().getOrDefault("sep", TfUtils.TXMTD_SEP);
// execute transform meta data read
FrameBlock meta = null;
try {
meta = TfMetaUtils.readTransformMetaDataFromFile(spec, path, delim);
}
catch(Exception ex) {
throw new DMLRuntimeException(ex);
}
// release locks
ec.setFrameOutput(output.getName(), meta);
}
else if(opcode.equalsIgnoreCase(Opcodes.TOSTRING.toString())) {
// handle input parameters
int rows = (getParam("rows") != null) ? Integer.parseInt(getParam("rows")) : TOSTRING_MAXROWS;
int cols = (getParam("cols") != null) ? Integer.parseInt(getParam("cols")) : TOSTRING_MAXCOLS;
int decimal = (getParam("decimal") != null) ? Integer.parseInt(getParam("decimal")) : TOSTRING_DECIMAL;
boolean sparse = (getParam("sparse") != null) ? Boolean.parseBoolean(getParam("sparse")) : TOSTRING_SPARSE;
String separator = (getParam("sep") != null) ? getParam("sep") : TOSTRING_SEPARATOR;
String lineSeparator = (getParam("linesep") != null) ? getParam("linesep") : TOSTRING_LINESEPARATOR;
// get input matrix/frame and convert to string
String out = null;
Data cacheData = ec.getVariable(getParam("target"));
if(cacheData instanceof MatrixObject) {
MatrixBlock matrix = ((MatrixObject) cacheData).acquireRead();
warnOnTrunction(matrix, rows, cols);
out = DataConverter.toString(matrix, sparse, separator, lineSeparator, rows, cols, decimal);
}
else if(cacheData instanceof TensorObject) {
TensorBlock tensor = ((TensorObject) cacheData).acquireRead();
// TODO improve truncation to check all dimensions
warnOnTrunction(tensor, rows, cols);
out = DataConverter.toString(tensor, sparse, separator, lineSeparator, "[", "]", rows, cols, decimal);
}
else if(cacheData instanceof FrameObject) {
FrameBlock frame = ((FrameObject) cacheData).acquireRead();
warnOnTrunction(frame, rows, cols);
out = DataConverter.toString(frame, sparse, separator, lineSeparator, rows, cols, decimal);
}
else if(cacheData instanceof ListObject) {
out = DataConverter.toString((ListObject) cacheData,
rows, cols, sparse, separator, lineSeparator, rows, cols, decimal);
}
else {
throw new DMLRuntimeException("toString only converts "
+ "matrix, tensors, lists or frames to string: "+cacheData.getClass().getSimpleName());
}
if(!(cacheData instanceof ListObject)) {
ec.releaseCacheableData(getParam("target"));
}
ec.setScalarOutput(output.getName(), new StringObject(out));
}
else if(opcode.equals(Opcodes.NVLIST.toString())) {
// obtain all input data objects and names in insertion order
List<Data> data = params.values().stream()
.map(d -> ec.containsVariable(d) ? ec.getVariable(d) :
ScalarObjectFactory.createScalarObject(d))
.collect(Collectors.toList());
List<String> names = new ArrayList<>(params.keySet());
ListObject list = null;
if (DMLScript.LINEAGE) {
CPOperand[] listOperands = names.stream().map(n -> ec.containsVariable(params.get(n))
? new CPOperand(n, ec.getVariable(params.get(n)))
: getStringLiteral(n)).toArray(CPOperand[]::new);
LineageItem[] liList = LineageItemUtils.getLineage(ec, listOperands);
// create list object over all inputs w/ the corresponding lineage items
list = new ListObject(data, names, Arrays.asList(liList));
}
else
// create list object over all inputs
list = new ListObject(data, names);
list.deriveAndSetStatusFromData();
ec.setVariable(output.getName(), list);
}
else {
throw new DMLRuntimeException("Unknown opcode : " + opcode);
}
}