Skip to content

Commit 8b54ad1

Browse files
authored
Supports other domain for light API (#54)
* ut * first sketch * finalize other domain epxressions * docuemntation * extend the support of translate to other domain * documentation
1 parent 06a15a9 commit 8b54ad1

File tree

12 files changed

+333
-10
lines changed

12 files changed

+333
-10
lines changed

Diff for: _doc/api/light_api.rst

+17-6
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,15 @@ translate
1919
Classes for the Light API
2020
=========================
2121

22-
ProtoType
23-
+++++++++
22+
domain
23+
++++++
2424

25-
.. autoclass:: onnx_array_api.light_api.model.ProtoType
25+
..autofunction:: onnx_array_api.light_api.domain
26+
27+
BaseVar
28+
+++++++
29+
30+
.. autoclass:: onnx_array_api.light_api.var.BaseVar
2631
:members:
2732

2833
OnnxGraph
@@ -31,10 +36,16 @@ OnnxGraph
3136
.. autoclass:: onnx_array_api.light_api.OnnxGraph
3237
:members:
3338

34-
BaseVar
35-
+++++++
39+
ProtoType
40+
+++++++++
3641

37-
.. autoclass:: onnx_array_api.light_api.var.BaseVar
42+
.. autoclass:: onnx_array_api.light_api.model.ProtoType
43+
:members:
44+
45+
SubDomain
46+
+++++++++
47+
48+
.. autoclass:: onnx_array_api.light_api.var.SubDomain
3849
:members:
3950

4051
Var

Diff for: _doc/tutorial/light_api.rst

+29
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,32 @@ operator `+` to be available as well and that the case. They are
7676
defined in class :class:`Var <onnx_array_api.light_api.Var>` or
7777
:class:`Vars <onnx_array_api.light_api.Vars>` depending on the number of
7878
inputs they require. Their name starts with a lower letter.
79+
80+
Other domains
81+
=============
82+
83+
The following example uses operator *Normalizer* from domain
84+
*ai.onnx.ml*. The operator name is called with the syntax
85+
`<domain>.<operator name>`. The domain may have dots in its name
86+
but it must follow the python definition of a variable.
87+
The operator *Normalizer* becomes `ai.onnx.ml.Normalizer`.
88+
89+
.. runpython::
90+
:showcode:
91+
92+
import numpy as np
93+
from onnx_array_api.light_api import start
94+
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
95+
96+
model = (
97+
start(opset=19, opsets={"ai.onnx.ml": 3})
98+
.vin("X")
99+
.reshape((-1, 1))
100+
.rename("USE")
101+
.ai.onnx.ml.Normalizer(norm="MAX")
102+
.rename("Y")
103+
.vout()
104+
.to_onnx()
105+
)
106+
107+
print(onnx_simple_text_plot(model))

Diff for: _unittests/ut_light_api/test_light_api.py

+39-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import inspect
12
import unittest
23
from typing import Callable, Optional
34
import numpy as np
@@ -12,6 +13,7 @@
1213
from onnx.reference import ReferenceEvaluator
1314
from onnx_array_api.ext_test_case import ExtTestCase, skipif_ci_windows
1415
from onnx_array_api.light_api import start, OnnxGraph, Var, g
16+
from onnx_array_api.light_api.var import SubDomain
1517
from onnx_array_api.light_api._op_var import OpsVar
1618
from onnx_array_api.light_api._op_vars import OpsVars
1719

@@ -472,7 +474,43 @@ def test_if(self):
472474
got = ref.run(None, {"X": -x})
473475
self.assertEqualArray(np.array([0], dtype=np.int64), got[0])
474476

477+
def test_domain(self):
478+
onx = start(opsets={"ai.onnx.ml": 3}).vin("X").reshape((-1, 1)).rename("USE")
479+
480+
class A:
481+
def g(self):
482+
return True
483+
484+
def ah(self):
485+
return True
486+
487+
setattr(A, "h", ah)
488+
489+
self.assertTrue(A().h())
490+
self.assertIn("(self)", str(inspect.signature(A.h)))
491+
self.assertTrue(issubclass(onx._ai, SubDomain))
492+
self.assertIsInstance(onx.ai, SubDomain)
493+
self.assertIsInstance(onx.ai.parent, Var)
494+
self.assertTrue(issubclass(onx._ai._onnx, SubDomain))
495+
self.assertIsInstance(onx.ai.onnx, SubDomain)
496+
self.assertIsInstance(onx.ai.onnx.parent, Var)
497+
self.assertTrue(issubclass(onx._ai._onnx._ml, SubDomain))
498+
self.assertIsInstance(onx.ai.onnx.ml, SubDomain)
499+
self.assertIsInstance(onx.ai.onnx.ml.parent, Var)
500+
self.assertIn("(self,", str(inspect.signature(onx._ai._onnx._ml.Normalizer)))
501+
onx = onx.ai.onnx.ml.Normalizer(norm="MAX")
502+
onx = onx.rename("Y").vout().to_onnx()
503+
self.assertIsInstance(onx, ModelProto)
504+
self.assertIn("Normalizer", str(onx))
505+
self.assertIn('domain: "ai.onnx.ml"', str(onx))
506+
self.assertIn('input: "USE"', str(onx))
507+
ref = ReferenceEvaluator(onx)
508+
a = np.arange(10).astype(np.float32)
509+
got = ref.run(None, {"X": a})[0]
510+
expected = (a > 0).astype(int).astype(np.float32).reshape((-1, 1))
511+
self.assertEqualArray(expected, got)
512+
475513

476514
if __name__ == "__main__":
477-
TestLightApi().test_if()
515+
TestLightApi().test_domain()
478516
unittest.main(verbosity=2)

Diff for: _unittests/ut_light_api/test_translate.py

+33
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,39 @@ def test_export_if(self):
185185
self.maxDiff = None
186186
self.assertEqual(expected, code)
187187

188+
def test_aionnxml(self):
189+
onx = (
190+
start(opset=19, opsets={"ai.onnx.ml": 3})
191+
.vin("X")
192+
.reshape((-1, 1))
193+
.rename("USE")
194+
.ai.onnx.ml.Normalizer(norm="MAX")
195+
.rename("Y")
196+
.vout()
197+
.to_onnx()
198+
)
199+
code = translate(onx)
200+
expected = dedent(
201+
"""
202+
(
203+
start(opset=19, opsets={'ai.onnx.ml': 3})
204+
.cst(np.array([-1, 1], dtype=np.int64))
205+
.rename('r')
206+
.vin('X', elem_type=TensorProto.FLOAT)
207+
.bring('X', 'r')
208+
.Reshape()
209+
.rename('USE')
210+
.bring('USE')
211+
.ai.onnx.ml.Normalizer(norm='MAX')
212+
.rename('Y')
213+
.bring('Y')
214+
.vout(elem_type=TensorProto.FLOAT)
215+
.to_onnx()
216+
)"""
217+
).strip("\n")
218+
self.maxDiff = None
219+
self.assertEqual(expected, code)
220+
188221

189222
if __name__ == "__main__":
190223
TestTranslate().test_export_if()

Diff for: _unittests/ut_light_api/test_translate_classic.py

+66
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,72 @@ def test_fft(self):
252252
)
253253
raise AssertionError(f"ERROR {e}\n{new_code}")
254254

255+
def test_aionnxml(self):
256+
onx = (
257+
start(opset=19, opsets={"ai.onnx.ml": 3})
258+
.vin("X")
259+
.reshape((-1, 1))
260+
.rename("USE")
261+
.ai.onnx.ml.Normalizer(norm="MAX")
262+
.rename("Y")
263+
.vout()
264+
.to_onnx()
265+
)
266+
code = translate(onx, api="onnx")
267+
print(code)
268+
expected = dedent(
269+
"""
270+
opset_imports = [
271+
make_opsetid('', 19),
272+
make_opsetid('ai.onnx.ml', 3),
273+
]
274+
inputs = []
275+
outputs = []
276+
nodes = []
277+
initializers = []
278+
sparse_initializers = []
279+
functions = []
280+
initializers.append(
281+
from_array(
282+
np.array([-1, 1], dtype=np.int64),
283+
name='r'
284+
)
285+
)
286+
inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[]))
287+
nodes.append(
288+
make_node(
289+
'Reshape',
290+
['X', 'r'],
291+
['USE']
292+
)
293+
)
294+
nodes.append(
295+
make_node(
296+
'Normalizer',
297+
['USE'],
298+
['Y'],
299+
domain='ai.onnx.ml',
300+
norm='MAX'
301+
)
302+
)
303+
outputs.append(make_tensor_value_info('Y', TensorProto.FLOAT, shape=[]))
304+
graph = make_graph(
305+
nodes,
306+
'light_api',
307+
inputs,
308+
outputs,
309+
initializers,
310+
sparse_initializer=sparse_initializers,
311+
)
312+
model = make_model(
313+
graph,
314+
functions=functions,
315+
opset_imports=opset_imports
316+
)"""
317+
).strip("\n")
318+
self.maxDiff = None
319+
self.assertEqual(expected, code)
320+
255321

256322
if __name__ == "__main__":
257323
# TestLightApi().test_topk()

Diff for: onnx_array_api/light_api/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Dict, Optional
22
from onnx import ModelProto
3+
from .annotations import domain
34
from .model import OnnxGraph, ProtoType
45
from .translate import Translater
56
from .var import Var, Vars

Diff for: onnx_array_api/light_api/_op_var.py

+5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import List, Optional, Union
2+
from .annotations import AI_ONNX_ML, domain
23

34

45
class OpsVar:
@@ -319,6 +320,10 @@ def Transpose(self, perm: Optional[List[int]] = None) -> "Var":
319320
perm = perm or []
320321
return self.make_node("Transpose", self, perm=perm)
321322

323+
@domain(AI_ONNX_ML)
324+
def Normalizer(self, norm: str = "MAX"):
325+
return self.make_node("Normalizer", self, norm=norm, domain=AI_ONNX_ML)
326+
322327

323328
def _complete():
324329
ops_to_add = [

Diff for: onnx_array_api/light_api/annotations.py

+36-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Tuple, Union
1+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
22
import numpy as np
33
from onnx import FunctionProto, GraphProto, ModelProto, TensorProto, TensorShapeProto
44
from onnx.helper import np_dtype_to_tensor_dtype
@@ -9,12 +9,47 @@
99
VAR_CONSTANT_TYPE = Union["Var", TensorProto, np.ndarray]
1010
GRAPH_PROTO = Union[FunctionProto, GraphProto, ModelProto]
1111

12+
AI_ONNX_ML = "ai.onnx.ml"
13+
1214
ELEMENT_TYPE_NAME = {
1315
getattr(TensorProto, k): k
1416
for k in dir(TensorProto)
1517
if isinstance(getattr(TensorProto, k), int) and "_" not in k
1618
}
1719

20+
21+
class SubDomain:
22+
pass
23+
24+
25+
def domain(domain: str, op_type: Optional[str] = None) -> Callable:
26+
"""
27+
Registers one operator into a sub domain. It should be used as a
28+
decorator. One example:
29+
30+
.. code-block:: python
31+
32+
@domain("ai.onnx.ml")
33+
def Normalizer(self, norm: str = "MAX"):
34+
return self.make_node("Normalizer", self, norm=norm, domain="ai.onnx.ml")
35+
"""
36+
names = [op_type]
37+
38+
def decorate(op_method: Callable) -> Callable:
39+
if names[0] is None:
40+
names[0] = op_method.__name__
41+
42+
def wrapper(self, *args: List[Any], **kwargs: Dict[str, Any]) -> Any:
43+
return op_method(self.parent, *args, **kwargs)
44+
45+
wrapper.__qual__name__ = f"[{domain}]{names[0]}"
46+
wrapper.__name__ = f"[{domain}]{names[0]}"
47+
wrapper.__domain__ = domain
48+
return wrapper
49+
50+
return decorate
51+
52+
1853
_type_numpy = {
1954
np.float32: TensorProto.FLOAT,
2055
np.float64: TensorProto.DOUBLE,

Diff for: onnx_array_api/light_api/emitter.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]:
241241
outputs = kwargs["outputs"]
242242
if kwargs.get("domain", "") != "":
243243
domain = kwargs["domain"]
244-
raise NotImplementedError(f"domain={domain!r} not supported yet.")
244+
op_type = f"{domain}.{op_type}"
245245
atts = kwargs.get("atts", {})
246246
args = []
247247
for k, v in atts.items():

Diff for: onnx_array_api/light_api/inner_emitter.py

-1
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,6 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]:
120120
outputs = kwargs["outputs"]
121121
if kwargs.get("domain", "") != "":
122122
domain = kwargs["domain"]
123-
raise NotImplementedError(f"domain={domain!r} not supported yet.")
124123

125124
before_lines = []
126125
lines = [

Diff for: onnx_array_api/light_api/model.py

+3
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,9 @@ def make_node(
248248

249249
node = make_node(op_type, input_names, output_names, domain=domain, **kwargs)
250250
self.nodes.append(node)
251+
if domain != "":
252+
if not self.opsets or domain not in self.opsets:
253+
raise RuntimeError(f"No opset value was given for domain {domain!r}.")
251254
return node
252255

253256
def cst(self, value: np.ndarray, name: Optional[str] = None) -> "Var":

0 commit comments

Comments
 (0)