Skip to content

Commit 5102370

Browse files
heyihongHyukjinKwon
authored andcommitted
[SPARK-51774][CONNECT] Add GRPC Status code to Python Connect GRPC Exception
### What changes were proposed in this pull request? - Add GRPC Status code to Python Connect GRPC Exception. The default status code is UNKNOWN. ### Why are the changes needed? - Users can use grpc status code to differentiate/handle errors ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing tests (e.g. test_connect_errors_conversion.py) ### Was this patch authored or co-authored using generative AI tooling? No Closes #50564 from heyihong/SPARK-51774. Authored-by: Yihong He <heyihong.cn@gmail.com> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
1 parent ece4bc0 commit 5102370

File tree

3 files changed

+40
-7
lines changed

3 files changed

+40
-7
lines changed

Diff for: python/pyspark/errors/exceptions/connect.py

+19-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
#
17+
import grpc
1718
import json
19+
from grpc import StatusCode
1820
from typing import Dict, List, Optional, TYPE_CHECKING
1921

2022
from pyspark.errors.exceptions.base import (
@@ -57,8 +59,11 @@ def convert_exception(
5759
truncated_message: str,
5860
resp: Optional["pb2.FetchErrorDetailsResponse"],
5961
display_server_stacktrace: bool = False,
62+
grpc_status_code: grpc.StatusCode = StatusCode.UNKNOWN,
6063
) -> SparkConnectException:
61-
converted = _convert_exception(info, truncated_message, resp, display_server_stacktrace)
64+
converted = _convert_exception(
65+
info, truncated_message, resp, display_server_stacktrace, grpc_status_code
66+
)
6267
return recover_python_exception(converted)
6368

6469

@@ -67,6 +72,7 @@ def _convert_exception(
6772
truncated_message: str,
6873
resp: Optional["pb2.FetchErrorDetailsResponse"],
6974
display_server_stacktrace: bool = False,
75+
grpc_status_code: grpc.StatusCode = StatusCode.UNKNOWN,
7076
) -> SparkConnectException:
7177
import pyspark.sql.connect.proto as pb2
7278

@@ -102,8 +108,9 @@ def _convert_exception(
102108

103109
if "org.apache.spark.api.python.PythonException" in classes:
104110
return PythonException(
105-
"\n An exception was thrown from the Python worker. "
106-
"Please see the stack trace below.\n%s" % message
111+
message="\n An exception was thrown from the Python worker. "
112+
"Please see the stack trace below.\n%s" % message,
113+
grpc_status_code=grpc_status_code,
107114
)
108115

109116
# Return exception based on class mapping
@@ -126,6 +133,7 @@ def _convert_exception(
126133
server_stacktrace=stacktrace,
127134
display_server_stacktrace=display_server_stacktrace,
128135
contexts=contexts,
136+
grpc_status_code=grpc_status_code,
129137
)
130138

131139
# Return UnknownException if there is no matched exception class
@@ -138,6 +146,7 @@ def _convert_exception(
138146
server_stacktrace=stacktrace,
139147
display_server_stacktrace=display_server_stacktrace,
140148
contexts=contexts,
149+
grpc_status_code=grpc_status_code,
141150
)
142151

143152

@@ -183,6 +192,7 @@ def __init__(
183192
server_stacktrace: Optional[str] = None,
184193
display_server_stacktrace: bool = False,
185194
contexts: Optional[List[BaseQueryContext]] = None,
195+
grpc_status_code: grpc.StatusCode = StatusCode.UNKNOWN,
186196
) -> None:
187197
if contexts is None:
188198
contexts = []
@@ -210,6 +220,7 @@ def __init__(
210220
self._stacktrace: Optional[str] = server_stacktrace
211221
self._display_stacktrace: bool = display_server_stacktrace
212222
self._contexts: List[BaseQueryContext] = contexts
223+
self._grpc_status_code = grpc_status_code
213224
self._log_exception()
214225

215226
def getSqlState(self) -> Optional[str]:
@@ -227,6 +238,9 @@ def getMessage(self) -> str:
227238
desc += "\n\nJVM stacktrace:\n%s" % self._stacktrace
228239
return desc
229240

241+
def getGrpcStatusCode(self) -> grpc.StatusCode:
242+
return self._grpc_status_code
243+
230244
def __str__(self) -> str:
231245
return self.getMessage()
232246

@@ -248,6 +262,7 @@ def __init__(
248262
server_stacktrace: Optional[str] = None,
249263
display_server_stacktrace: bool = False,
250264
contexts: Optional[List[BaseQueryContext]] = None,
265+
grpc_status_code: grpc.StatusCode = StatusCode.UNKNOWN,
251266
) -> None:
252267
super().__init__(
253268
message=message,
@@ -258,6 +273,7 @@ def __init__(
258273
server_stacktrace=server_stacktrace,
259274
display_server_stacktrace=display_server_stacktrace,
260275
contexts=contexts,
276+
grpc_status_code=grpc_status_code,
261277
)
262278

263279

Diff for: python/pyspark/errors/tests/test_connect_errors_conversion.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
)
2727
from pyspark.sql.connect.proto import FetchErrorDetailsResponse as pb2
2828
from google.rpc.error_details_pb2 import ErrorInfo
29+
from grpc import StatusCode
2930

3031

3132
class ConnectErrorsTest(unittest.TestCase):
@@ -42,13 +43,17 @@ def test_convert_exception_known_class(self):
4243
}
4344
truncated_message = "Analysis error occurred"
4445
exception = convert_exception(
45-
info=ErrorInfo(**info), truncated_message=truncated_message, resp=None
46+
info=ErrorInfo(**info),
47+
truncated_message=truncated_message,
48+
resp=None,
49+
grpc_status_code=StatusCode.INTERNAL,
4650
)
4751

4852
self.assertIsInstance(exception, AnalysisException)
4953
self.assertEqual(exception.getSqlState(), "42000")
5054
self.assertEqual(exception._errorClass, "ANALYSIS.ERROR")
5155
self.assertEqual(exception._messageParameters, {"param1": "value1"})
56+
self.assertEqual(exception.getGrpcStatusCode(), StatusCode.INTERNAL)
5257

5358
def test_convert_exception_python_exception(self):
5459
# Mock ErrorInfo for PythonException
@@ -60,11 +65,15 @@ def test_convert_exception_python_exception(self):
6065
}
6166
truncated_message = "Python worker error occurred"
6267
exception = convert_exception(
63-
info=ErrorInfo(**info), truncated_message=truncated_message, resp=None
68+
info=ErrorInfo(**info),
69+
truncated_message=truncated_message,
70+
resp=None,
71+
grpc_status_code=StatusCode.INTERNAL,
6472
)
6573

6674
self.assertIsInstance(exception, PythonException)
6775
self.assertIn("An exception was thrown from the Python worker", exception.getMessage())
76+
self.assertEqual(exception.getGrpcStatusCode(), StatusCode.INTERNAL)
6877

6978
def test_convert_exception_unknown_class(self):
7079
# Mock ErrorInfo with an unknown error class
@@ -74,13 +83,17 @@ def test_convert_exception_unknown_class(self):
7483
}
7584
truncated_message = "Unknown error occurred"
7685
exception = convert_exception(
77-
info=ErrorInfo(**info), truncated_message=truncated_message, resp=None
86+
info=ErrorInfo(**info),
87+
truncated_message=truncated_message,
88+
resp=None,
89+
grpc_status_code=StatusCode.INTERNAL,
7890
)
7991

8092
self.assertIsInstance(exception, SparkConnectGrpcException)
8193
self.assertEqual(
8294
exception.getMessage(), "(org.apache.spark.UnknownException) Unknown error occurred"
8395
)
96+
self.assertEqual(exception.getGrpcStatusCode(), StatusCode.INTERNAL)
8497

8598
def test_exception_class_mapping(self):
8699
# Ensure that all keys in EXCEPTION_CLASS_MAPPING are valid
@@ -154,6 +167,7 @@ def test_convert_exception_fallback(self):
154167
self.assertEqual(
155168
exception.getMessage(), "(org.apache.spark.UnknownReason) Fallback error occurred"
156169
)
170+
self.assertEqual(exception.getGrpcStatusCode(), StatusCode.UNKNOWN)
157171

158172

159173
if __name__ == "__main__":

Diff for: python/pyspark/sql/connect/client/core.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1895,9 +1895,12 @@ def _handle_rpc_error(self, rpc_error: grpc.RpcError) -> NoReturn:
18951895
status.message,
18961896
self._fetch_enriched_error(info),
18971897
self._display_server_stack_trace(),
1898+
status.code,
18981899
) from None
18991900

1900-
raise SparkConnectGrpcException(status.message) from None
1901+
raise SparkConnectGrpcException(
1902+
message=status.message, grpc_status_code=status.code
1903+
) from None
19011904
else:
19021905
raise SparkConnectGrpcException(str(rpc_error)) from None
19031906

0 commit comments

Comments
 (0)