in ImageSegmentation/app/src/main/java/org/pytorch/imagesegmentation/MainActivity.java [118:170]
public void run() {
final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(mBitmap,
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB);
final float[] inputs = inputTensor.getDataAsFloatArray();
final long startTime = SystemClock.elapsedRealtime();
Map<String, IValue> outTensors = mModule.forward(IValue.from(inputTensor)).toDictStringKey();
final long inferenceTime = SystemClock.elapsedRealtime() - startTime;
Log.d("ImageSegmentation", "inference time (ms): " + inferenceTime);
final Tensor outputTensor = outTensors.get("out").toTensor();
final float[] scores = outputTensor.getDataAsFloatArray();
int width = mBitmap.getWidth();
int height = mBitmap.getHeight();
int[] intValues = new int[width * height];
for (int j = 0; j < height; j++) {
for (int k = 0; k < width; k++) {
int maxi = 0, maxj = 0, maxk = 0;
double maxnum = -Double.MAX_VALUE;
for (int i = 0; i < CLASSNUM; i++) {
float score = scores[i * (width * height) + j * width + k];
if (score > maxnum) {
maxnum = score;
maxi = i; maxj = j; maxk = k;
}
}
if (maxi == PERSON)
intValues[maxj * width + maxk] = 0xFFFF0000;
else if (maxi == DOG)
intValues[maxj * width + maxk] = 0xFF00FF00;
else if (maxi == SHEEP)
intValues[maxj * width + maxk] = 0xFF0000FF;
else
intValues[maxj * width + maxk] = 0xFF000000;
}
}
Bitmap bmpSegmentation = Bitmap.createScaledBitmap(mBitmap, width, height, true);
Bitmap outputBitmap = bmpSegmentation.copy(bmpSegmentation.getConfig(), true);
outputBitmap.setPixels(intValues, 0, outputBitmap.getWidth(), 0, 0, outputBitmap.getWidth(), outputBitmap.getHeight());
final Bitmap transferredBitmap = Bitmap.createScaledBitmap(outputBitmap, mBitmap.getWidth(), mBitmap.getHeight(), true);
runOnUiThread(new Runnable() {
@Override
public void run() {
mImageView.setImageBitmap(transferredBitmap);
mButtonSegment.setEnabled(true);
mButtonSegment.setText(getString(R.string.segment));
mProgressBar.setVisibility(ProgressBar.INVISIBLE);
}
});
}