I am following the text classification demo example given by tensorflow to run on Android Studio. However when running the app, after hitting the predict button, the app crashes with the following error.
E/AndroidRuntime: FATAL EXCEPTION: main
Process: com.example.mltest, PID: 6318
java.lang.NullPointerException: Attempt to invoke virtual method 'void org.tensorflow.lite.Interpreter.run(java.lang.Object, java.lang.Object)' on a null object reference
at com.example.mltest.TextClassificationClient.classify(TextClassificationClient.java:154)
at com.example.mltest.MainActivity.lambda$classify$3$MainActivity(MainActivity.java:73)
at com.example.mltest.-$$Lambda$MainActivity$iZpagZiqjnywt769FNidzy-9BHU.run(Unknown Source:4)
at android.os.Handler.handleCallback(Handler.java:873)
at android.os.Handler.dispatchMessage(Handler.java:99)
at android.os.Looper.loop(Looper.java:193)
at android.app.ActivityThread.main(ActivityThread.java:6669)
at java.lang.reflect.Method.invoke(Native Method)
at com.android.internal.os.RuntimeInit$MethodAndArgsCaller.run(RuntimeInit.java:493)
at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:858)
Here is the TextClassificationClient java file.
package com.example.mltest;
public class TextClassificationClient {
private static final String TAG = "TextClassificationDemo";
private static final String MODEL_PATH = "text_classification.tflite";
private static final String DIC_PATH = "text_classification_vocab.txt";
private static final String LABEL_PATH = "text_classification_labels.txt";
private static final int SENTENCE_LEN = 256;
private static final String SIMPLE_SPACE_OR_PUNCTUATION = " |\\\\,|\\\\.|\\\\!|\\\\?|\\n";
private static final String START = "<START>";
private static final String PAD = "<PAD>";
private static final String UNKNOWN = "<UNKNOWN>";
private static final int MAX_RESULTS = 3;
private final Context context;
private final Map<String, Integer> dic = new HashMap<>();
private final List<String> labels = new ArrayList<>();
private Interpreter tflite;
public static class Result {
private final String id;
private final String title;
private final Float confidence;
public Result(String id, String title, Float confidence) {
this.id = id;
this.title = title;
this.confidence = confidence;
}
public String getId() {
return id;
}
public String getTitle() {
return title;
}
public Float getConfidence() {
return confidence;
}
@SuppressLint("DefaultLocale")
@Override
public String toString() {
String resultString = "";
if (id != null) {
resultString += "[" + id + "] ";
}
if (title != null) {
resultString += title + " ";
}
if (confidence != null) {
resultString += String.format("(%.1f%%) ", confidence * 100.0f);
}
return resultString.trim();
}
};
public TextClassificationClient(Context context) {
this.context = context;
}
@WorkerThread
public void load() {
loadModel();
loadDictionary();
loadLabels();
}
@WorkerThread
private synchronized void loadModel() {
try {
ByteBuffer buffer = loadModelFile(this.context.getAssets());
tflite = new Interpreter(buffer);
Log.v(TAG, "TFLite Model Loaded");
} catch (IOException ex) {
Log.v(TAG, ex.getMessage());
}
}
@WorkerThread
private synchronized void loadDictionary() {
try {
loadDictionaryFile(this.context.getAssets());
Log.v(TAG, "Dictionary Loaded");
} catch (IOException ex) {
Log.v(TAG, ex.getMessage());
}
}
@WorkerThread
private synchronized void loadLabels() {
try {
loadLabelFile(this.context.getAssets());
Log.v(TAG, "Labels Loaded");
} catch (IOException ex) {
Log.v(TAG, ex.getMessage());
}
}
@WorkerThread
private synchronized void unload(){
tflite.close();
dic.clear();
labels.clear();
}
@WorkerThread
public synchronized List<Result> classify(String text) {
float[][] input = tokenizeInputText(text);
Log.v(TAG, "Classifying with TFLite");
float[][] output = new float[1][labels.size()];
System.out.println("input inside classify in textclient" + Arrays.deepToString(input) + " and labels size is " + labels.size());
System.out.println("Out put is " + Arrays.deepToString(output));
tflite.run(input, output);
PriorityQueue<Result> pq = new PriorityQueue<>(
MAX_RESULTS, (lhs, rhs) -> Float.compare(rhs.getConfidence(), lhs.getConfidence()));
for(int i = 0; i < labels.size(); i++) {
pq.add(new Result("" + i, labels.get(i), output[0][i]));
}
final ArrayList<Result> results = new ArrayList<>();
while (!pq.isEmpty()){
results.add(pq.poll());
}
return results;
}
private static MappedByteBuffer loadModelFile(AssetManager assetManager) throws IOException {
try(AssetFileDescriptor fileDescriptor = assetManager.openFd(MODEL_PATH);
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor())) {
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}
}
private void loadLabelFile(AssetManager assetManager) throws IOException{
try (InputStream ins = assetManager.open(LABEL_PATH);
BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(ins))){
while (bufferedReader.ready()) {
labels.add(bufferedReader.readLine());
}
}
}
private void loadDictionaryFile(AssetManager assetManager) throws IOException{
try (InputStream ins = assetManager.open(DIC_PATH);
BufferedReader reader = new BufferedReader(new InputStreamReader(ins))){
while (reader.ready()){
List<String> line = Arrays.asList(reader.readLine().split(" "));
if (line.size() < 2){
continue;
}
dic.put(line.get(0), Integer.parseInt(line.get(1)));
}
}
}
float[][] tokenizeInputText(String text) {
float[] tmp = new float[SENTENCE_LEN];
List<String> array = Arrays.asList(text.split(SIMPLE_SPACE_OR_PUNCTUATION));
int index = 0;
// Prepend <START> if it is in vocabulary file.
if (dic.containsKey(START)) {
tmp[index++] = dic.get(START);
}
for (String word : array) {
if (index >= SENTENCE_LEN) {
break;
}
tmp[index++] = dic.containsKey(word) ? dic.get(word) : (int) dic.get(UNKNOWN);
}
// Padding and wrapping.
Arrays.fill(tmp, index, SENTENCE_LEN - 1, (int) dic.get(PAD));
float[][] ans = {tmp};
return ans;
}
Map<String, Integer> getDic() {
return this.dic;
}
Interpreter getTflite() {
return this.tflite;
}
List<String> getLabels(){
return this.labels;
}
}
And the MainActivity java file.
public class MainActivity extends AppCompatActivity {
private static final String TAG = "TextClassificationDemo";
private TextClassificationClient client;
private TextView resultTextView;
private EditText inputEditText;
private Handler handler;
private ScrollView scrollView;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
Log.v(TAG, "On Create");
client = new TextClassificationClient(getApplicationContext());
handler = new Handler();
Button classifyButton = findViewById(R.id.button);
classifyButton.setOnClickListener( (View v) -> {
classify(inputEditText.getText().toString());
});
resultTextView = findViewById(R.id.result_text_view);
inputEditText = findViewById(R.id.input_text);
scrollView = findViewById(R.id.scroll_view);
}
@Override
protected void onStart(){
super.onStart();
Log.v(TAG, "OnStart");
handler.post(
() -> {
client.load();
}
);
}
@Override
protected void onStop(){
super.onStop();
Log.v(TAG, "OnStop");
handler.post(
() -> {
client.load();
}
);
}
private void classify(final String text) {
System.out.println("Text inside classify of Main Activity " + text);
handler.post(
() -> {
List<TextClassificationClient.Result> results = client.classify(text);
showResult(text, results);
}
);
}
private void showResult(final String inputText, final List<TextClassificationClient.Result> results){
runOnUiThread(
() -> {
String textToShow = "Input : " + inputText + "\nOutput : \n";
for (int i = 0; i < results.size(); i++) {
TextClassificationClient.Result result = results.get(i);
textToShow += String.format(" %s: %s\\n", result.getTitle(), result.getConfidence());
}
textToShow += "---------\\n";
resultTextView.append(textToShow);
inputEditText.getText().clear();
scrollView.post(() -> scrollView.fullScroll(View.FOCUS_DOWN));
}
);
}
}
Here is my gradle file.
apply plugin: 'com.android.application'
android {
compileSdkVersion 28
buildToolsVersion "30.0.2"
defaultConfig {
applicationId "com.example.mltest"
minSdkVersion 28
targetSdkVersion 28
versionCode 1
versionName "1.0"
testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
}
buildTypes {
release {
minifyEnabled false
proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
}
}
compileOptions {
sourceCompatibility JavaVersion.VERSION_1_8
targetCompatibility JavaVersion.VERSION_1_8
}
aaptOptions {
noCompress "tflite"
noCompress "lite"
}
}
dependencies {
implementation fileTree(dir: "libs", include: ["*.jar"])
implementation 'androidx.appcompat:appcompat:1.2.0'
implementation 'androidx.constraintlayout:constraintlayout:2.0.1'
implementation 'org.tensorflow:tensorflow-lite-task-vision:0.0.0-nightly'
implementation 'org.tensorflow:tensorflow-lite-task-text:0.0.0-nightly'
implementation 'org.tensorflow:tensorflow-lite-support:0.0.0-nightly'
testImplementation 'junit:junit:4.12'
androidTestImplementation 'androidx.test.ext:junit:1.1.2'
androidTestImplementation 'androidx.test.espresso:espresso-core:3.3.0'
}
I have followed other links at SO where the same issue was raised, but they haven't been of any help. Please help me fix this problem. Thank you in advance!