2
2
import json
3
3
from typing import Tuple , List , Dict
4
4
from qwen_vl_utils import process_vision_info
5
- from transformers import AutoProcessor , AutoModelForImageTextToText
5
+ from transformers import AutoProcessor , AutoModelForImageTextToText , Qwen2VLForConditionalGeneration , Qwen2VLProcessor
6
6
from swiftannotate .image .base import BaseImageClassification
7
7
from swiftannotate .constants import BASE_IMAGE_CLASSIFICATION_VALIDATION_PROMPT , BASE_IMAGE_CLASSIFICATION_PROMPT
8
8
@@ -13,8 +13,8 @@ class ImageClassificationQwen2VL(BaseImageClassification):
13
13
"""
14
14
def __init__ (
15
15
self ,
16
- model : AutoModelForImageTextToText ,
17
- processor : AutoProcessor ,
16
+ model : AutoModelForImageTextToText | Qwen2VLForConditionalGeneration ,
17
+ processor : AutoProcessor | Qwen2VLProcessor ,
18
18
classification_labels : List [str ],
19
19
classification_prompt : str = BASE_IMAGE_CLASSIFICATION_PROMPT ,
20
20
validation : bool = True ,
@@ -62,7 +62,13 @@ def __init__(
62
62
Notes:
63
63
`validation_prompt` should specify the rules for validating the class label and the range of validation score to be generated example (0-1).
64
64
Your `validation_threshold` should be within this specified range.
65
- """
65
+ """
66
+
67
+ if not isinstance (model , Qwen2VLForConditionalGeneration ):
68
+ raise ValueError ("Model should be an instance of Qwen2VLForConditionalGeneration." )
69
+ if not isinstance (processor , Qwen2VLProcessor ):
70
+ raise ValueError ("Processor should be an instance of Qwen2VLProcessor." )
71
+
66
72
self .model = model
67
73
self .processor = processor
68
74
0 commit comments