in PTMobileWalkthruAndroid/app/src/main/java/com/example/ptmobilewalkthru/MainActivity.java [59:106]
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
try {
bitmap = BitmapFactory.decodeStream(getAssets().open("kitten.jpg"));
// loading serialized torchscript module from packaged into app android asset model.pt,
// app/src/model/assets/model.pt
module = Module.load(assetFilePath(this, "model.pt"));
} catch (IOException e) {
Log.e("PTMobileWalkthru", "Error reading assets", e);
finish();
}
// showing image on UI
ImageView imageView = findViewById(R.id.imageView);
imageView.setImageBitmap(bitmap);
final Button button = findViewById(R.id.inferButton);
button.setOnClickListener(new View.OnClickListener() {
public void onClick(View v) {
final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB);
// running the model
final Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
// getting tensor content as java array of floats
final float[] scores = outputTensor.getDataAsFloatArray();
// searching for the index with maximum score
float maxScore = -Float.MAX_VALUE;
int maxScoreIdx = -1;
for (int i = 0; i < scores.length; i++) {
if (scores[i] > maxScore) {
maxScore = scores[i];
maxScoreIdx = i;
}
}
String className = org.pytorch.helloworld.ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];
// showing className on UI
TextView textView = findViewById(R.id.resultView);
textView.setText(className);
}
});
}