diff --git a/models/face_fed.py b/models/face_fed.py index 0952bd2..66c5f8f 100644 --- a/models/face_fed.py +++ b/models/face_fed.py @@ -592,7 +592,7 @@ class scale_kernel_conf(nn.Module): x4=self.conv3(x3) x4 = self.trans_block4(x4) #print(x4.size()) - residual = self.sig(self.conv_refin(self.sig(F.avg_pool2d(x4,16)))) + residual = self.sig(self.conv_refin(self.sig(F.adaptive_avg_pool2d(x4,(1,1))))) #print(residual) residual = F.interpolate(residual, size=x.size()[2:]) #print(residual.size())