Skip to content

Commit 5465118

Browse files
authored
[Fix] update build loss api (#3587)
## Motivation Use `MODELS.build` instead of `build_loss` ## Modification Please briefly describe what modification is made in this PR.
1 parent be687fc commit 5465118

File tree

3 files changed

+6
-8
lines changed

3 files changed

+6
-8
lines changed

mmseg/models/decode_heads/decode_head.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
from mmengine.model import BaseModule
99
from torch import Tensor
1010

11+
from mmseg.registry import MODELS
1112
from mmseg.structures import build_pixel_sampler
1213
from mmseg.utils import ConfigType, SampleList
13-
from ..builder import build_loss
1414
from ..losses import accuracy
1515
from ..utils import resize
1616

@@ -140,11 +140,11 @@ def __init__(self,
140140
self.threshold = threshold
141141

142142
if isinstance(loss_decode, dict):
143-
self.loss_decode = build_loss(loss_decode)
143+
self.loss_decode = MODELS.build(loss_decode)
144144
elif isinstance(loss_decode, (list, tuple)):
145145
self.loss_decode = nn.ModuleList()
146146
for loss in loss_decode:
147-
self.loss_decode.append(build_loss(loss))
147+
self.loss_decode.append(MODELS.build(loss))
148148
else:
149149
raise TypeError(f'loss_decode must be a dict or sequence of dict,\
150150
but got {type(loss_decode)}')

mmseg/models/decode_heads/enc_head.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
from mmseg.registry import MODELS
1111
from mmseg.utils import ConfigType, SampleList
12-
from ..builder import build_loss
1312
from ..utils import Encoding, resize
1413
from .decode_head import BaseDecodeHead
1514

@@ -128,7 +127,7 @@ def __init__(self,
128127
norm_cfg=self.norm_cfg,
129128
act_cfg=self.act_cfg)
130129
if self.use_se_loss:
131-
self.loss_se_decode = build_loss(loss_se_decode)
130+
self.loss_se_decode = MODELS.build(loss_se_decode)
132131
self.se_layer = nn.Linear(self.channels, self.num_classes)
133132

134133
def forward(self, inputs):

mmseg/models/decode_heads/vpd_depth_head.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
from mmseg.registry import MODELS
1212
from mmseg.utils import SampleList
13-
from ..builder import build_loss
1413
from ..utils import resize
1514
from .decode_head import BaseDecodeHead
1615

@@ -184,11 +183,11 @@ def __init__(
184183

185184
# build loss
186185
if isinstance(loss_decode, dict):
187-
self.loss_decode = build_loss(loss_decode)
186+
self.loss_decode = MODELS.build(loss_decode)
188187
elif isinstance(loss_decode, (list, tuple)):
189188
self.loss_decode = nn.ModuleList()
190189
for loss in loss_decode:
191-
self.loss_decode.append(build_loss(loss))
190+
self.loss_decode.append(MODELS.build(loss))
192191
else:
193192
raise TypeError(f'loss_decode must be a dict or sequence of dict,\
194193
but got {type(loss_decode)}')

0 commit comments

Comments
 (0)