Skip to content

Commit 69a51eb

Browse files
committed
Revert Android PRs #19099, #19124, #19092, #19028
This reverts the following commits: - 7e2ff8a Android: consistent error types across all modules (#19099) - b8f04aa Android: Module implements Closeable (#19124) - f9f29e7 Android: improve error diagnostics for LlmModule and exceptions (#19092) - 3ec63f4 Ignored Module tests: provide required input tensor (#19028) Authored with Claude.
1 parent 5252704 commit 69a51eb

9 files changed

Lines changed: 158 additions & 214 deletions

File tree

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt

Lines changed: 33 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import java.util.concurrent.atomic.AtomicInteger
1717
import org.apache.commons.io.FileUtils
1818
import org.junit.Assert
1919
import org.junit.Before
20+
import org.junit.Ignore
2021
import org.junit.Test
2122
import org.junit.runner.RunWith
2223
import org.pytorch.executorch.TestFileUtils.getTestFilePath
@@ -39,49 +40,48 @@ class ModuleInstrumentationTest {
3940
inputStream.close()
4041
}
4142

43+
@Ignore(
44+
"The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward "
45+
)
4246
@Test
4347
@Throws(IOException::class, URISyntaxException::class)
4448
fun testModuleLoadAndForward() {
4549
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
46-
try {
47-
val results = module.forward(EValue.from(dummyInput()))
48-
Assert.assertTrue(results[0].isTensor)
49-
} finally {
50-
module.destroy()
51-
}
50+
51+
val results = module.forward()
52+
Assert.assertTrue(results[0].isTensor)
5253
}
5354

5455
@Test
5556
@Throws(IOException::class, URISyntaxException::class)
5657
fun testMethodMetadata() {
5758
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
58-
module.destroy()
5959
}
6060

61+
@Ignore(
62+
"The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward "
63+
)
6164
@Test
6265
@Throws(IOException::class)
6366
fun testModuleLoadMethodAndForward() {
6467
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
65-
try {
66-
module.loadMethod(FORWARD_METHOD)
6768

68-
val results = module.forward(EValue.from(dummyInput()))
69-
Assert.assertTrue(results[0].isTensor)
70-
} finally {
71-
module.destroy()
72-
}
69+
module.loadMethod(FORWARD_METHOD)
70+
71+
val results = module.forward()
72+
Assert.assertTrue(results[0].isTensor)
7373
}
7474

75+
@Ignore(
76+
"The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward "
77+
)
7578
@Test
7679
@Throws(IOException::class)
7780
fun testModuleLoadForwardExplicit() {
7881
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
79-
try {
80-
val results = module.execute(FORWARD_METHOD, EValue.from(dummyInput()))
81-
Assert.assertTrue(results[0].isTensor)
82-
} finally {
83-
module.destroy()
84-
}
82+
83+
val results = module.execute(FORWARD_METHOD)
84+
Assert.assertTrue(results[0].isTensor)
8585
}
8686

8787
@Test(expected = RuntimeException::class)
@@ -94,18 +94,15 @@ class ModuleInstrumentationTest {
9494
@Throws(IOException::class)
9595
fun testModuleLoadMethodNonExistantMethod() {
9696
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
97-
try {
98-
val exception =
99-
Assert.assertThrows(ExecutorchRuntimeException::class.java) {
100-
module.loadMethod(NONE_METHOD)
101-
}
102-
Assert.assertEquals(
103-
ExecutorchRuntimeException.INVALID_ARGUMENT,
104-
exception.getErrorCode(),
105-
)
106-
} finally {
107-
module.destroy()
108-
}
97+
98+
val exception =
99+
Assert.assertThrows(ExecutorchRuntimeException::class.java) {
100+
module.loadMethod(NONE_METHOD)
101+
}
102+
Assert.assertEquals(
103+
ExecutorchRuntimeException.INVALID_ARGUMENT,
104+
exception.getErrorCode(),
105+
)
109106
}
110107

111108
@Test(expected = RuntimeException::class)
@@ -138,6 +135,9 @@ class ModuleInstrumentationTest {
138135
Assert.assertThrows(IllegalStateException::class.java) { module.forward() }
139136
}
140137

138+
@Ignore(
139+
"The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward "
140+
)
141141
@Test
142142
@Throws(InterruptedException::class, IOException::class)
143143
fun testForwardFromMultipleThreads() {
@@ -151,7 +151,7 @@ class ModuleInstrumentationTest {
151151
try {
152152
latch.countDown()
153153
latch.await(5000, TimeUnit.MILLISECONDS)
154-
val results = module.forward(EValue.from(dummyInput()))
154+
val results = module.forward()
155155
Assert.assertTrue(results[0].isTensor)
156156
completed.incrementAndGet()
157157
} catch (_: InterruptedException) {}
@@ -168,7 +168,6 @@ class ModuleInstrumentationTest {
168168
}
169169

170170
Assert.assertEquals(numThreads.toLong(), completed.get().toLong())
171-
module.destroy()
172171
}
173172

174173
companion object {
@@ -177,8 +176,5 @@ class ModuleInstrumentationTest {
177176
private const val NON_PTE_FILE_NAME = "/test.txt"
178177
private const val FORWARD_METHOD = "forward"
179178
private const val NONE_METHOD = "none"
180-
private val inputShape = longArrayOf(1, 3, 224, 224)
181-
182-
private fun dummyInput(): Tensor = Tensor.ones(inputShape, DType.FLOAT)
183179
}
184180
}

extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,22 +36,12 @@ public static ExecuTorchRuntime getRuntime() {
3636
/**
3737
* Validates that the given path points to a readable file.
3838
*
39-
* @throws IllegalArgumentException if the path is null, does not exist, is not a file, or is not
40-
* readable.
39+
* @throws RuntimeException if the file does not exist or is not readable.
4140
*/
4241
public static void validateFilePath(String path, String description) {
43-
if (path == null) {
44-
throw new IllegalArgumentException("Cannot load " + description + ": path is null");
45-
}
4642
File file = new File(path);
47-
if (!file.exists()) {
48-
throw new IllegalArgumentException("Cannot load " + description + "!! " + path);
49-
}
50-
if (!file.isFile()) {
51-
throw new IllegalArgumentException("Cannot load " + description + "!! " + path);
52-
}
53-
if (!file.canRead()) {
54-
throw new IllegalArgumentException("Cannot load " + description + "!! " + path);
43+
if (!file.canRead() || !file.isFile()) {
44+
throw new RuntimeException("Cannot load " + description + " " + path);
5545
}
5646
}
5747

extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -161,11 +161,6 @@ public ExecutorchRuntimeException(int errorCode, String details) {
161161
this.errorCode = errorCode;
162162
}
163163

164-
public ExecutorchRuntimeException(int errorCode, String details, Throwable cause) {
165-
super(ErrorHelper.formatMessage(errorCode, details), cause);
166-
this.errorCode = errorCode;
167-
}
168-
169164
/** Returns the numeric error code from {@code runtime/core/error.h}. */
170165
public int getErrorCode() {
171166
return errorCode;

extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import com.facebook.jni.annotations.DoNotStrip;
1313
import com.facebook.soloader.nativeloader.NativeLoader;
1414
import com.facebook.soloader.nativeloader.SystemDelegate;
15-
import java.io.Closeable;
1615
import java.util.HashMap;
1716
import java.util.Map;
1817
import java.util.concurrent.locks.Lock;
@@ -25,7 +24,7 @@
2524
* <p>Warning: These APIs are experimental and subject to change without notice
2625
*/
2726
@Experimental
28-
public class Module implements Closeable {
27+
public class Module {
2928

3029
static {
3130
if (!NativeLoader.isInitialized()) {
@@ -275,19 +274,12 @@ public boolean etdump() {
275274
public void destroy() {
276275
if (mLock.tryLock()) {
277276
try {
278-
if (mHybridData.isValid()) {
279-
mHybridData.resetNative();
280-
}
277+
mHybridData.resetNative();
281278
} finally {
282279
mLock.unlock();
283280
}
284281
} else {
285282
throw new IllegalStateException("Cannot destroy module while method is executing");
286283
}
287284
}
288-
289-
@Override
290-
public void close() {
291-
destroy();
292-
}
293285
}

extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/asr/AsrModule.kt

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ package org.pytorch.executorch.extension.asr
1111
import java.io.Closeable
1212
import java.io.File
1313
import java.util.concurrent.atomic.AtomicLong
14-
import org.pytorch.executorch.ExecutorchRuntimeException
1514
import org.pytorch.executorch.annotations.Experimental
1615

1716
/**
@@ -54,10 +53,7 @@ class AsrModule(
5453

5554
val handle = nativeCreate(modelPath, tokenizerPath, dataPath, preprocessorPath)
5655
if (handle == 0L) {
57-
throw ExecutorchRuntimeException(
58-
ExecutorchRuntimeException.INTERNAL,
59-
"Failed to create native AsrModule",
60-
)
56+
throw RuntimeException("Failed to create native AsrModule")
6157
}
6258
nativeHandle.set(handle)
6359
}
@@ -133,7 +129,7 @@ class AsrModule(
133129
* @param callback Optional callback to receive tokens as they are generated (can be null)
134130
* @return The complete transcribed text
135131
* @throws IllegalStateException if the module has been destroyed
136-
* @throws ExecutorchRuntimeException if transcription fails (error code carried in exception)
132+
* @throws RuntimeException if transcription fails (non-zero result code)
137133
*/
138134
@JvmOverloads
139135
fun transcribe(
@@ -164,7 +160,7 @@ class AsrModule(
164160
)
165161

166162
if (status != 0) {
167-
throw ExecutorchRuntimeException(status, "Transcription failed")
163+
throw RuntimeException("Transcription failed with error code: $status")
168164
}
169165

170166
return result.toString()

extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ public static SGD create(Map<String, Tensor> namedParameters, double learningRat
9393
*/
9494
public void step(Map<String, Tensor> namedGradients) {
9595
if (!mHybridData.isValid()) {
96-
throw new IllegalStateException("SGD optimizer has been destroyed");
96+
throw new RuntimeException("Attempt to use a destroyed SGD optimizer");
9797
}
9898
stepNative(namedGradients);
9999
}

extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@
88

99
package org.pytorch.executorch.training;
1010

11+
import android.util.Log;
1112
import com.facebook.jni.HybridData;
1213
import com.facebook.jni.annotations.DoNotStrip;
1314
import com.facebook.soloader.nativeloader.NativeLoader;
1415
import com.facebook.soloader.nativeloader.SystemDelegate;
15-
import java.io.Closeable;
16+
import java.util.HashMap;
1617
import java.util.Map;
1718
import org.pytorch.executorch.EValue;
1819
import org.pytorch.executorch.ExecuTorchRuntime;
@@ -25,7 +26,7 @@
2526
* <p>Warning: These APIs are experimental and subject to change without notice
2627
*/
2728
@Experimental
28-
public class TrainingModule implements Closeable {
29+
public class TrainingModule {
2930

3031
static {
3132
if (!NativeLoader.isInitialized()) {
@@ -36,7 +37,6 @@ public class TrainingModule implements Closeable {
3637
}
3738

3839
private final HybridData mHybridData;
39-
private boolean mDestroyed = false;
4040

4141
@DoNotStrip
4242
private static native HybridData initHybrid(String moduleAbsolutePath, String dataAbsolutePath);
@@ -45,10 +45,6 @@ private TrainingModule(String moduleAbsolutePath, String dataAbsolutePath) {
4545
mHybridData = initHybrid(moduleAbsolutePath, dataAbsolutePath);
4646
}
4747

48-
private void checkNotDestroyed() {
49-
if (mDestroyed) throw new IllegalStateException("TrainingModule has been destroyed");
50-
}
51-
5248
/**
5349
* Loads a serialized ExecuTorch Training Module from the specified path on the disk.
5450
*
@@ -82,33 +78,35 @@ public static TrainingModule load(final String modelPath) {
8278
* @return return value(s) from the method.
8379
*/
8480
public EValue[] executeForwardBackward(String methodName, EValue... inputs) {
85-
checkNotDestroyed();
81+
if (!mHybridData.isValid()) {
82+
Log.e("ExecuTorch", "Attempt to use a destroyed module");
83+
return new EValue[0];
84+
}
8685
return executeForwardBackwardNative(methodName, inputs);
8786
}
8887

8988
@DoNotStrip
9089
private native EValue[] executeForwardBackwardNative(String methodName, EValue... inputs);
9190

9291
public Map<String, Tensor> namedParameters(String methodName) {
93-
checkNotDestroyed();
92+
if (!mHybridData.isValid()) {
93+
Log.e("ExecuTorch", "Attempt to use a destroyed module");
94+
return new HashMap<String, Tensor>();
95+
}
9496
return namedParametersNative(methodName);
9597
}
9698

9799
@DoNotStrip
98100
private native Map<String, Tensor> namedParametersNative(String methodName);
99101

100102
public Map<String, Tensor> namedGradients(String methodName) {
101-
checkNotDestroyed();
103+
if (!mHybridData.isValid()) {
104+
Log.e("ExecuTorch", "Attempt to use a destroyed module");
105+
return new HashMap<String, Tensor>();
106+
}
102107
return namedGradientsNative(methodName);
103108
}
104109

105110
@DoNotStrip
106111
private native Map<String, Tensor> namedGradientsNative(String methodName);
107-
108-
@Override
109-
public void close() {
110-
if (mDestroyed) return;
111-
mDestroyed = true;
112-
mHybridData.resetNative();
113-
}
114112
}

extension/android/jni/jni_layer.cpp

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -284,18 +284,8 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
284284
#else
285285
auto etdump_gen = nullptr;
286286
#endif
287-
try {
288-
module_ = std::make_unique<Module>(
289-
modelPath->toStdString(), load_mode, std::move(etdump_gen));
290-
} catch (const std::exception& e) {
291-
executorch::jni_helper::throwExecutorchException(
292-
static_cast<uint32_t>(Error::Internal),
293-
std::string("Failed to create Module: ") + e.what());
294-
} catch (...) {
295-
executorch::jni_helper::throwExecutorchException(
296-
static_cast<uint32_t>(Error::Internal),
297-
"Failed to create Module: unknown native error");
298-
}
287+
module_ = std::make_unique<Module>(
288+
modelPath->toStdString(), load_mode, std::move(etdump_gen));
299289

300290
#ifdef ET_USE_THREADPOOL
301291
// Default to using cores/2 threadpool threads. The long-term plan is to

0 commit comments

Comments
 (0)