14
14
# See the License for the specific language governing permissions and
15
15
# limitations under the License.
16
16
#
17
+ import grpc
17
18
import json
19
+ from grpc import StatusCode
18
20
from typing import Dict , List , Optional , TYPE_CHECKING
19
21
20
22
from pyspark .errors .exceptions .base import (
@@ -57,8 +59,11 @@ def convert_exception(
57
59
truncated_message : str ,
58
60
resp : Optional ["pb2.FetchErrorDetailsResponse" ],
59
61
display_server_stacktrace : bool = False ,
62
+ grpc_status_code : grpc .StatusCode = StatusCode .UNKNOWN ,
60
63
) -> SparkConnectException :
61
- converted = _convert_exception (info , truncated_message , resp , display_server_stacktrace )
64
+ converted = _convert_exception (
65
+ info , truncated_message , resp , display_server_stacktrace , grpc_status_code
66
+ )
62
67
return recover_python_exception (converted )
63
68
64
69
@@ -67,6 +72,7 @@ def _convert_exception(
67
72
truncated_message : str ,
68
73
resp : Optional ["pb2.FetchErrorDetailsResponse" ],
69
74
display_server_stacktrace : bool = False ,
75
+ grpc_status_code : grpc .StatusCode = StatusCode .UNKNOWN ,
70
76
) -> SparkConnectException :
71
77
import pyspark .sql .connect .proto as pb2
72
78
@@ -102,8 +108,9 @@ def _convert_exception(
102
108
103
109
if "org.apache.spark.api.python.PythonException" in classes :
104
110
return PythonException (
105
- "\n An exception was thrown from the Python worker. "
106
- "Please see the stack trace below.\n %s" % message
111
+ message = "\n An exception was thrown from the Python worker. "
112
+ "Please see the stack trace below.\n %s" % message ,
113
+ grpc_status_code = grpc_status_code ,
107
114
)
108
115
109
116
# Return exception based on class mapping
@@ -126,6 +133,7 @@ def _convert_exception(
126
133
server_stacktrace = stacktrace ,
127
134
display_server_stacktrace = display_server_stacktrace ,
128
135
contexts = contexts ,
136
+ grpc_status_code = grpc_status_code ,
129
137
)
130
138
131
139
# Return UnknownException if there is no matched exception class
@@ -138,6 +146,7 @@ def _convert_exception(
138
146
server_stacktrace = stacktrace ,
139
147
display_server_stacktrace = display_server_stacktrace ,
140
148
contexts = contexts ,
149
+ grpc_status_code = grpc_status_code ,
141
150
)
142
151
143
152
@@ -183,6 +192,7 @@ def __init__(
183
192
server_stacktrace : Optional [str ] = None ,
184
193
display_server_stacktrace : bool = False ,
185
194
contexts : Optional [List [BaseQueryContext ]] = None ,
195
+ grpc_status_code : grpc .StatusCode = StatusCode .UNKNOWN ,
186
196
) -> None :
187
197
if contexts is None :
188
198
contexts = []
@@ -210,6 +220,7 @@ def __init__(
210
220
self ._stacktrace : Optional [str ] = server_stacktrace
211
221
self ._display_stacktrace : bool = display_server_stacktrace
212
222
self ._contexts : List [BaseQueryContext ] = contexts
223
+ self ._grpc_status_code = grpc_status_code
213
224
self ._log_exception ()
214
225
215
226
def getSqlState (self ) -> Optional [str ]:
@@ -227,6 +238,9 @@ def getMessage(self) -> str:
227
238
desc += "\n \n JVM stacktrace:\n %s" % self ._stacktrace
228
239
return desc
229
240
241
+ def getGrpcStatusCode (self ) -> grpc .StatusCode :
242
+ return self ._grpc_status_code
243
+
230
244
def __str__ (self ) -> str :
231
245
return self .getMessage ()
232
246
@@ -248,6 +262,7 @@ def __init__(
248
262
server_stacktrace : Optional [str ] = None ,
249
263
display_server_stacktrace : bool = False ,
250
264
contexts : Optional [List [BaseQueryContext ]] = None ,
265
+ grpc_status_code : grpc .StatusCode = StatusCode .UNKNOWN ,
251
266
) -> None :
252
267
super ().__init__ (
253
268
message = message ,
@@ -258,6 +273,7 @@ def __init__(
258
273
server_stacktrace = server_stacktrace ,
259
274
display_server_stacktrace = display_server_stacktrace ,
260
275
contexts = contexts ,
276
+ grpc_status_code = grpc_status_code ,
261
277
)
262
278
263
279
0 commit comments