File tree 3 files changed +6
-8
lines changed
mmseg/models/decode_heads
3 files changed +6
-8
lines changed Original file line number Diff line number Diff line change 8
8
from mmengine .model import BaseModule
9
9
from torch import Tensor
10
10
11
+ from mmseg .registry import MODELS
11
12
from mmseg .structures import build_pixel_sampler
12
13
from mmseg .utils import ConfigType , SampleList
13
- from ..builder import build_loss
14
14
from ..losses import accuracy
15
15
from ..utils import resize
16
16
@@ -140,11 +140,11 @@ def __init__(self,
140
140
self .threshold = threshold
141
141
142
142
if isinstance (loss_decode , dict ):
143
- self .loss_decode = build_loss (loss_decode )
143
+ self .loss_decode = MODELS . build (loss_decode )
144
144
elif isinstance (loss_decode , (list , tuple )):
145
145
self .loss_decode = nn .ModuleList ()
146
146
for loss in loss_decode :
147
- self .loss_decode .append (build_loss (loss ))
147
+ self .loss_decode .append (MODELS . build (loss ))
148
148
else :
149
149
raise TypeError (f'loss_decode must be a dict or sequence of dict,\
150
150
but got { type (loss_decode )} ' )
Original file line number Diff line number Diff line change 9
9
10
10
from mmseg .registry import MODELS
11
11
from mmseg .utils import ConfigType , SampleList
12
- from ..builder import build_loss
13
12
from ..utils import Encoding , resize
14
13
from .decode_head import BaseDecodeHead
15
14
@@ -128,7 +127,7 @@ def __init__(self,
128
127
norm_cfg = self .norm_cfg ,
129
128
act_cfg = self .act_cfg )
130
129
if self .use_se_loss :
131
- self .loss_se_decode = build_loss (loss_se_decode )
130
+ self .loss_se_decode = MODELS . build (loss_se_decode )
132
131
self .se_layer = nn .Linear (self .channels , self .num_classes )
133
132
134
133
def forward (self , inputs ):
Original file line number Diff line number Diff line change 10
10
11
11
from mmseg .registry import MODELS
12
12
from mmseg .utils import SampleList
13
- from ..builder import build_loss
14
13
from ..utils import resize
15
14
from .decode_head import BaseDecodeHead
16
15
@@ -184,11 +183,11 @@ def __init__(
184
183
185
184
# build loss
186
185
if isinstance (loss_decode , dict ):
187
- self .loss_decode = build_loss (loss_decode )
186
+ self .loss_decode = MODELS . build (loss_decode )
188
187
elif isinstance (loss_decode , (list , tuple )):
189
188
self .loss_decode = nn .ModuleList ()
190
189
for loss in loss_decode :
191
- self .loss_decode .append (build_loss (loss ))
190
+ self .loss_decode .append (MODELS . build (loss ))
192
191
else :
193
192
raise TypeError (f'loss_decode must be a dict or sequence of dict,\
194
193
but got { type (loss_decode )} ' )
You can’t perform that action at this time.
0 commit comments