Skip to content

Commit 639cf10

Browse files
matthew29tangcopybara-github
authored andcommitted
feat: Add serializer.register_custom_command()
PiperOrigin-RevId: 570206347
1 parent a36daa7 commit 639cf10

File tree

4 files changed

+46
-11
lines changed

4 files changed

+46
-11
lines changed

Diff for: tests/unit/vertexai/test_remote_training.py

+16-8
Original file line numberDiff line numberDiff line change
@@ -492,25 +492,29 @@ def mock_any_serializer_serialize_sklearn():
492492
{
493493
serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [
494494
f"scikit-learn=={sklearn.__version__}"
495-
]
495+
],
496+
serializers_base.SERIALIZATION_METADATA_CUSTOM_COMMANDS_KEY: [],
496497
},
497498
{
498499
serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [
499500
f"numpy=={np.__version__}",
500501
f"cloudpickle=={cloudpickle.__version__}",
501-
]
502+
],
503+
serializers_base.SERIALIZATION_METADATA_CUSTOM_COMMANDS_KEY: [],
502504
},
503505
{
504506
serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [
505507
f"numpy=={np.__version__}",
506508
f"cloudpickle=={cloudpickle.__version__}",
507-
]
509+
],
510+
serializers_base.SERIALIZATION_METADATA_CUSTOM_COMMANDS_KEY: [],
508511
},
509512
{
510513
serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [
511514
f"numpy=={np.__version__}",
512515
f"cloudpickle=={cloudpickle.__version__}",
513-
]
516+
],
517+
serializers_base.SERIALIZATION_METADATA_CUSTOM_COMMANDS_KEY: [],
514518
},
515519
],
516520
) as mock_any_serializer_serialize:
@@ -575,25 +579,29 @@ def mock_any_serializer_serialize_keras():
575579
{
576580
serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [
577581
f"tensorflow=={tf.__version__}"
578-
]
582+
],
583+
serializers_base.SERIALIZATION_METADATA_CUSTOM_COMMANDS_KEY: [],
579584
},
580585
{
581586
serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [
582587
f"numpy=={np.__version__}",
583588
f"cloudpickle=={cloudpickle.__version__}",
584-
]
589+
],
590+
serializers_base.SERIALIZATION_METADATA_CUSTOM_COMMANDS_KEY: [],
585591
},
586592
{
587593
serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [
588594
f"numpy=={np.__version__}",
589595
f"cloudpickle=={cloudpickle.__version__}",
590-
]
596+
],
597+
serializers_base.SERIALIZATION_METADATA_CUSTOM_COMMANDS_KEY: [],
591598
},
592599
{
593600
serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [
594601
f"numpy=={np.__version__}",
595602
f"cloudpickle=={cloudpickle.__version__}",
596-
]
603+
],
604+
serializers_base.SERIALIZATION_METADATA_CUSTOM_COMMANDS_KEY: [],
597605
},
598606
],
599607
) as mock_any_serializer_serialize:

Diff for: vertexai/preview/_workflow/executor/training.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,7 @@ def remote_training(invokable: shared._Invokable, rewrapper: Any):
513513
]
514514

515515
requirements = []
516+
custom_commands = []
516517

517518
enable_cuda = config.enable_cuda
518519

@@ -641,8 +642,16 @@ def remote_training(invokable: shared._Invokable, rewrapper: Any):
641642

642643
requirements = _add_indirect_dependency_versions(requirements)
643644
command = ["export PIP_ROOT_USER_ACTION=ignore &&"]
644-
if config.custom_commands:
645-
custom_commands = [f"{command} &&" for command in config.custom_commands]
645+
646+
# Combine user custom_commands and serializer custom_commands
647+
custom_commands += serialization_metadata[
648+
serializers_base.SERIALIZATION_METADATA_CUSTOM_COMMANDS_KEY
649+
]
650+
custom_commands += config.custom_commands
651+
custom_commands = list(dict.fromkeys(custom_commands))
652+
653+
if custom_commands:
654+
custom_commands = [f"{command} &&" for command in custom_commands]
646655
command.extend(custom_commands)
647656
if requirements:
648657
command.append("pip install --upgrade pip &&")

Diff for: vertexai/preview/_workflow/serialization_engine/serializers.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1119,7 +1119,11 @@ def serialize(
11191119
) -> str:
11201120
# All bigframe serializers will be identical (bigframes.dataframe.DataFrame --> parquet)
11211121
# Record the framework in metadata for deserialization
1122-
BigframeSerializer._metadata.framework = kwargs.get("framework")
1122+
detected_framework = kwargs.get("framework")
1123+
BigframeSerializer._metadata.framework = detected_framework
1124+
if detected_framework == "torch":
1125+
self.register_custom_command("pip install torchdata")
1126+
self.register_custom_command("pip install torcharrow")
11231127
if not _is_valid_gcs_path(gcs_path):
11241128
raise ValueError(f"Invalid gcs path: {gcs_path}")
11251129

Diff for: vertexai/preview/_workflow/serialization_engine/serializers_base.py

+14
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@
105105
SERIALIZATION_METADATA_FILENAME = "serialization_metadata"
106106
SERIALIZATION_METADATA_SERIALIZER_KEY = "serializer"
107107
SERIALIZATION_METADATA_DEPENDENCIES_KEY = "dependencies"
108+
SERIALIZATION_METADATA_CUSTOM_COMMANDS_KEY = "custom_commands"
108109

109110

110111
@dataclasses.dataclass
@@ -133,11 +134,13 @@ def deserialize(self, gcs_path):
133134

134135
serializer: Optional[str] = None
135136
dependencies: List[str] = dataclasses.field(default_factory=list)
137+
custom_commands: List[str] = dataclasses.field(default_factory=list)
136138

137139
def to_dict(self):
138140
return {
139141
SERIALIZATION_METADATA_SERIALIZER_KEY: self.serializer,
140142
SERIALIZATION_METADATA_DEPENDENCIES_KEY: self.dependencies,
143+
SERIALIZATION_METADATA_CUSTOM_COMMANDS_KEY: self.custom_commands,
141144
}
142145

143146

@@ -322,6 +325,12 @@ def _dedupe_deps(cls):
322325
# the version if version is not specified.
323326
cls._metadata.dependencies = list(dict.fromkeys(cls._metadata.dependencies))
324327

328+
@classmethod
329+
def _dedupe_custom_commands(cls):
330+
cls._metadata.custom_commands = list(
331+
dict.fromkeys(cls._metadata.custom_commands)
332+
)
333+
325334
@classmethod
326335
def register_requirement(cls, required_package: str):
327336
# TODO(b/280648121) Consider allowing the user to register the
@@ -334,3 +343,8 @@ def register_requirement(cls, required_package: str):
334343
def register_requirements(cls, requirements: List[str]):
335344
cls._metadata.dependencies.extend(requirements)
336345
cls._dedupe_deps()
346+
347+
@classmethod
348+
def register_custom_command(cls, custom_command: str):
349+
cls._metadata.custom_commands.append(custom_command)
350+
cls._dedupe_custom_commands()

0 commit comments

Comments
 (0)