Skip to content

Commit 1c6b998

Browse files
ivanmkcsasha-gitg
andauthored
feat: Added create_training_pipeline_custom_job_sample and create_training_pipeline_custom_training_managed_dataset_sample and fixed create_training_pipeline_image_classification_sample (#343)
* Added create_and_import_dataset_tabular_gcs_sample.py * Added create_and_import_dataset_tabular_bigquery_sample.py * Added create_training_pipeline_custom_job_sample.py and tweaked other tests * Added create_training_pipeline_custom_training_managed_dataset_sample and fixed unmanaged sample * Fixed args * Deleted duplicated samples * Added more args to samples * Ran black * Ran linter Co-authored-by: sasha-gitg <44654632+sasha-gitg@users.noreply.github.com>
1 parent c39eab2 commit 1c6b998

8 files changed

+301
-2
lines changed

Diff for: samples/model-builder/conftest.py

+13
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,19 @@ def mock_run_automl_image_training_job():
152152
yield mock
153153

154154

155+
@pytest.fixture
156+
def mock_init_custom_training_job():
157+
with patch.object(aiplatform.training_jobs.CustomTrainingJob, "__init__") as mock:
158+
mock.return_value = None
159+
yield mock
160+
161+
162+
@pytest.fixture
163+
def mock_run_custom_training_job():
164+
with patch.object(aiplatform.training_jobs.CustomTrainingJob, "run") as mock:
165+
yield mock
166+
167+
155168
"""
156169
----------------------------------------------------------------------------
157170
Model Fixtures
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright 2021 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import List, Optional, Union
16+
17+
from google.cloud import aiplatform
18+
19+
20+
# [START aiplatform_sdk_create_training_pipeline_custom_job_sample]
21+
def create_training_pipeline_custom_job_sample(
22+
project: str,
23+
location: str,
24+
display_name: str,
25+
script_path: str,
26+
container_uri: str,
27+
model_serving_container_image_uri: str,
28+
model_display_name: Optional[str] = None,
29+
args: Optional[List[Union[str, float, int]]] = None,
30+
replica_count: int = 0,
31+
machine_type: str = "n1-standard-4",
32+
accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED",
33+
accelerator_count: int = 0,
34+
training_fraction_split: float = 0.8,
35+
validation_fraction_split: float = 0.1,
36+
test_fraction_split: float = 0.1,
37+
sync: bool = True,
38+
):
39+
aiplatform.init(project=project, location=location)
40+
41+
job = aiplatform.CustomTrainingJob(
42+
display_name=display_name,
43+
script_path=script_path,
44+
container_uri=container_uri,
45+
model_serving_container_image_uri=model_serving_container_image_uri,
46+
)
47+
48+
model = job.run(
49+
model_display_name=model_display_name,
50+
args=args,
51+
replica_count=replica_count,
52+
machine_type=machine_type,
53+
accelerator_type=accelerator_type,
54+
accelerator_count=accelerator_count,
55+
training_fraction_split=training_fraction_split,
56+
validation_fraction_split=validation_fraction_split,
57+
test_fraction_split=test_fraction_split,
58+
sync=sync,
59+
)
60+
61+
model.wait()
62+
63+
print(model.display_name)
64+
print(model.resource_name)
65+
print(model.uri)
66+
return model
67+
68+
69+
# [END aiplatform_sdk_create_training_pipeline_custom_job_sample]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright 2021 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import create_training_pipeline_custom_job_sample
17+
import test_constants as constants
18+
19+
20+
def test_create_training_pipeline_custom_job_sample(
21+
mock_sdk_init, mock_init_custom_training_job, mock_run_custom_training_job,
22+
):
23+
24+
create_training_pipeline_custom_job_sample.create_training_pipeline_custom_job_sample(
25+
project=constants.PROJECT,
26+
location=constants.LOCATION,
27+
display_name=constants.DISPLAY_NAME,
28+
args=constants.ARGS,
29+
script_path=constants.SCRIPT_PATH,
30+
container_uri=constants.CONTAINER_URI,
31+
model_serving_container_image_uri=constants.CONTAINER_URI,
32+
model_display_name=constants.DISPLAY_NAME_2,
33+
replica_count=constants.REPLICA_COUNT,
34+
machine_type=constants.MACHINE_TYPE,
35+
accelerator_type=constants.ACCELERATOR_TYPE,
36+
accelerator_count=constants.ACCELERATOR_COUNT,
37+
training_fraction_split=constants.TRAINING_FRACTION_SPLIT,
38+
validation_fraction_split=constants.VALIDATION_FRACTION_SPLIT,
39+
test_fraction_split=constants.TEST_FRACTION_SPLIT,
40+
)
41+
42+
mock_sdk_init.assert_called_once_with(
43+
project=constants.PROJECT, location=constants.LOCATION
44+
)
45+
mock_init_custom_training_job.assert_called_once_with(
46+
display_name=constants.DISPLAY_NAME,
47+
script_path=constants.SCRIPT_PATH,
48+
container_uri=constants.CONTAINER_URI,
49+
model_serving_container_image_uri=constants.CONTAINER_URI,
50+
)
51+
mock_run_custom_training_job.assert_called_once_with(
52+
model_display_name=constants.DISPLAY_NAME_2,
53+
replica_count=constants.REPLICA_COUNT,
54+
machine_type=constants.MACHINE_TYPE,
55+
accelerator_type=constants.ACCELERATOR_TYPE,
56+
accelerator_count=constants.ACCELERATOR_COUNT,
57+
args=constants.ARGS,
58+
training_fraction_split=constants.TRAINING_FRACTION_SPLIT,
59+
validation_fraction_split=constants.VALIDATION_FRACTION_SPLIT,
60+
test_fraction_split=constants.TEST_FRACTION_SPLIT,
61+
sync=True,
62+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright 2021 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import List, Optional, Union
16+
17+
from google.cloud import aiplatform
18+
19+
20+
# [START aiplatform_sdk_create_training_pipeline_custom_job_sample]
21+
def create_training_pipeline_custom_training_managed_dataset_sample(
22+
project: str,
23+
location: str,
24+
display_name: str,
25+
script_path: str,
26+
container_uri: str,
27+
model_serving_container_image_uri: str,
28+
dataset_id: int,
29+
model_display_name: Optional[str] = None,
30+
args: Optional[List[Union[str, float, int]]] = None,
31+
replica_count: int = 0,
32+
machine_type: str = "n1-standard-4",
33+
accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED",
34+
accelerator_count: int = 0,
35+
training_fraction_split: float = 0.8,
36+
validation_fraction_split: float = 0.1,
37+
test_fraction_split: float = 0.1,
38+
sync: bool = True,
39+
):
40+
aiplatform.init(project=project, location=location)
41+
42+
job = aiplatform.CustomTrainingJob(
43+
display_name=display_name,
44+
script_path=script_path,
45+
container_uri=container_uri,
46+
model_serving_container_image_uri=model_serving_container_image_uri,
47+
)
48+
49+
my_image_ds = aiplatform.ImageDataset(dataset_id)
50+
51+
model = job.run(
52+
dataset=my_image_ds,
53+
model_display_name=model_display_name,
54+
args=args,
55+
replica_count=replica_count,
56+
machine_type=machine_type,
57+
accelerator_type=accelerator_type,
58+
accelerator_count=accelerator_count,
59+
training_fraction_split=training_fraction_split,
60+
validation_fraction_split=validation_fraction_split,
61+
test_fraction_split=test_fraction_split,
62+
sync=sync,
63+
)
64+
65+
model.wait()
66+
67+
print(model.display_name)
68+
print(model.resource_name)
69+
print(model.uri)
70+
return model
71+
72+
73+
# [END aiplatform_sdk_create_training_pipeline_custom_job_sample]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright 2021 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import create_training_pipeline_custom_training_managed_dataset_sample
17+
import test_constants as constants
18+
19+
20+
def test_create_training_pipeline_custom_job_sample(
21+
mock_sdk_init,
22+
mock_image_dataset,
23+
mock_init_custom_training_job,
24+
mock_run_custom_training_job,
25+
mock_get_image_dataset,
26+
):
27+
28+
create_training_pipeline_custom_training_managed_dataset_sample.create_training_pipeline_custom_training_managed_dataset_sample(
29+
project=constants.PROJECT,
30+
location=constants.LOCATION,
31+
display_name=constants.DISPLAY_NAME,
32+
args=constants.ARGS,
33+
script_path=constants.SCRIPT_PATH,
34+
container_uri=constants.CONTAINER_URI,
35+
model_serving_container_image_uri=constants.CONTAINER_URI,
36+
dataset_id=constants.RESOURCE_ID,
37+
model_display_name=constants.DISPLAY_NAME_2,
38+
replica_count=constants.REPLICA_COUNT,
39+
machine_type=constants.MACHINE_TYPE,
40+
accelerator_type=constants.ACCELERATOR_TYPE,
41+
accelerator_count=constants.ACCELERATOR_COUNT,
42+
training_fraction_split=constants.TRAINING_FRACTION_SPLIT,
43+
validation_fraction_split=constants.VALIDATION_FRACTION_SPLIT,
44+
test_fraction_split=constants.TEST_FRACTION_SPLIT,
45+
)
46+
47+
mock_get_image_dataset.assert_called_once_with(constants.RESOURCE_ID)
48+
49+
mock_sdk_init.assert_called_once_with(
50+
project=constants.PROJECT, location=constants.LOCATION
51+
)
52+
mock_init_custom_training_job.assert_called_once_with(
53+
display_name=constants.DISPLAY_NAME,
54+
script_path=constants.SCRIPT_PATH,
55+
container_uri=constants.CONTAINER_URI,
56+
model_serving_container_image_uri=constants.CONTAINER_URI,
57+
)
58+
mock_run_custom_training_job.assert_called_once_with(
59+
dataset=mock_image_dataset,
60+
model_display_name=constants.DISPLAY_NAME_2,
61+
args=constants.ARGS,
62+
replica_count=constants.REPLICA_COUNT,
63+
machine_type=constants.MACHINE_TYPE,
64+
accelerator_type=constants.ACCELERATOR_TYPE,
65+
accelerator_count=constants.ACCELERATOR_COUNT,
66+
training_fraction_split=constants.TRAINING_FRACTION_SPLIT,
67+
validation_fraction_split=constants.VALIDATION_FRACTION_SPLIT,
68+
test_fraction_split=constants.TEST_FRACTION_SPLIT,
69+
sync=True,
70+
)

Diff for: samples/model-builder/create_training_pipeline_image_classification_sample.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,18 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from typing import Optional
16+
1517
from google.cloud import aiplatform
1618

1719

1820
# [START aiplatform_sdk_create_training_pipeline_image_classification_sample]
1921
def create_training_pipeline_image_classification_sample(
2022
project: str,
23+
location: str,
2124
display_name: str,
2225
dataset_id: int,
23-
location: str = "us-central1",
24-
model_display_name: str = None,
26+
model_display_name: Optional[str] = None,
2527
training_fraction_split: float = 0.8,
2628
validation_fraction_split: float = 0.1,
2729
test_fraction_split: float = 0.1,

Diff for: samples/model-builder/create_training_pipeline_image_classification_sample_test.py

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def test_create_training_pipeline_image_classification_sample(
2727

2828
create_training_pipeline_image_classification_sample.create_training_pipeline_image_classification_sample(
2929
project=constants.PROJECT,
30+
location=constants.LOCATION,
3031
display_name=constants.DISPLAY_NAME,
3132
dataset_id=constants.RESOURCE_ID,
3233
model_display_name=constants.DISPLAY_NAME_2,

Diff for: samples/model-builder/test_constants.py

+9
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
TRAINING_JOB_NAME = f"{PARENT}/trainingJobs/{RESOURCE_ID}"
4141

4242
GCS_SOURCES = ["gs://bucket1/source1.jsonl", "gs://bucket7/source4.jsonl"]
43+
BIGQUERY_SOURCE = "bq://bigquery-public-data.ml_datasets.iris"
4344
GCS_DESTINATION = "gs://bucket3/output-dir/"
4445

4546
TRAINING_FRACTION_SPLIT = 0.7
@@ -51,3 +52,11 @@
5152
ENCRYPTION_SPEC_KEY_NAME = f"{PARENT}/keyRings/{RESOURCE_ID}/cryptoKeys/{RESOURCE_ID_2}"
5253

5354
PREDICTION_TEXT_INSTANCE = "This is some text for testing NLP prediction output"
55+
56+
SCRIPT_PATH = "task.py"
57+
CONTAINER_URI = "gcr.io/my_project/my_image:latest"
58+
ARGS = ["--tfds", "tf_flowers:3.*.*"]
59+
REPLICA_COUNT = 0
60+
MACHINE_TYPE = "n1-standard-4"
61+
ACCELERATOR_TYPE = "ACCELERATOR_TYPE_UNSPECIFIED"
62+
ACCELERATOR_COUNT = 0

0 commit comments

Comments
 (0)