Skip to content

Commit 427f6b6

Browse files
committed
Added model validation
1 parent 0cc1c50 commit 427f6b6

File tree

2 files changed

+20
-8
lines changed

2 files changed

+20
-8
lines changed

swiftannotate/image/captioning/qwen.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import json
33
from typing import Tuple, List, Dict
44
from qwen_vl_utils import process_vision_info
5-
from transformers import AutoProcessor, AutoModelForImageTextToText
5+
from transformers import AutoProcessor, AutoModelForImageTextToText, Qwen2VLForConditionalGeneration, Qwen2VLProcessor
66
from swiftannotate.image.base import BaseImageCaptioning
77
from swiftannotate.constants import BASE_IMAGE_CAPTION_VALIDATION_PROMPT, BASE_IMAGE_CAPTION_PROMPT
88

@@ -13,8 +13,8 @@ class ImageCaptioningQwen2VL(BaseImageCaptioning):
1313
"""
1414
def __init__(
1515
self,
16-
model: AutoModelForImageTextToText,
17-
processor: AutoProcessor,
16+
model: AutoModelForImageTextToText | Qwen2VLForConditionalGeneration,
17+
processor: AutoProcessor | Qwen2VLProcessor,
1818
caption_prompt: str = BASE_IMAGE_CAPTION_PROMPT,
1919
validation: bool = True,
2020
validation_prompt: str = BASE_IMAGE_CAPTION_VALIDATION_PROMPT,
@@ -59,7 +59,13 @@ def __init__(
5959
Notes:
6060
`validation_prompt` should specify the rules for validating the caption and the range of validation score to be generated example (0-1).
6161
Your `validation_threshold` should be within this specified range.
62-
"""
62+
"""
63+
64+
if not isinstance(model, Qwen2VLForConditionalGeneration):
65+
raise ValueError("Model should be an instance of Qwen2VLForConditionalGeneration.")
66+
if not isinstance(processor, Qwen2VLProcessor):
67+
raise ValueError("Processor should be an instance of Qwen2VLProcessor.")
68+
6369
self.model = model
6470
self.processor = processor
6571

swiftannotate/image/classification/qwen.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import json
33
from typing import Tuple, List, Dict
44
from qwen_vl_utils import process_vision_info
5-
from transformers import AutoProcessor, AutoModelForImageTextToText
5+
from transformers import AutoProcessor, AutoModelForImageTextToText, Qwen2VLForConditionalGeneration, Qwen2VLProcessor
66
from swiftannotate.image.base import BaseImageClassification
77
from swiftannotate.constants import BASE_IMAGE_CLASSIFICATION_VALIDATION_PROMPT, BASE_IMAGE_CLASSIFICATION_PROMPT
88

@@ -13,8 +13,8 @@ class ImageClassificationQwen2VL(BaseImageClassification):
1313
"""
1414
def __init__(
1515
self,
16-
model: AutoModelForImageTextToText,
17-
processor: AutoProcessor,
16+
model: AutoModelForImageTextToText | Qwen2VLForConditionalGeneration,
17+
processor: AutoProcessor | Qwen2VLProcessor,
1818
classification_labels: List[str],
1919
classification_prompt: str = BASE_IMAGE_CLASSIFICATION_PROMPT,
2020
validation: bool = True,
@@ -62,7 +62,13 @@ def __init__(
6262
Notes:
6363
`validation_prompt` should specify the rules for validating the class label and the range of validation score to be generated example (0-1).
6464
Your `validation_threshold` should be within this specified range.
65-
"""
65+
"""
66+
67+
if not isinstance(model, Qwen2VLForConditionalGeneration):
68+
raise ValueError("Model should be an instance of Qwen2VLForConditionalGeneration.")
69+
if not isinstance(processor, Qwen2VLProcessor):
70+
raise ValueError("Processor should be an instance of Qwen2VLProcessor.")
71+
6672
self.model = model
6773
self.processor = processor
6874

0 commit comments

Comments
 (0)