import torch import torch.nn as nn import math class Embeddings(nn.Module): def __init__(self): super(Embeddings, self).__init__() self.activation = nn.LeakyReLU(0.2, True) self.en_layer1_1 = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, padding=1), self.activation, ) self.en_layer1_2 = nn.Sequential( nn.Conv2d(64, 64, kernel_size=3, padding=1), self.activation, nn.Conv2d(64, 64, kernel_size=3, padding=1)) self.en_layer1_3 = nn.Sequential( nn.Conv2d(64, 64, kernel_size=3, padding=1), self.activation, nn.Conv2d(64, 64, kernel_size=3, padding=1)) self.en_layer1_4 = nn.Sequential( nn.Conv2d(64, 64, kernel_size=3, padding=1), self.activation, nn.Conv2d(64, 64, kernel_size=3, padding=1)) self.en_layer2_1 = nn.Sequential( nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), self.activation, ) self.en_layer2_2 = nn.Sequential( nn.Conv2d(128, 128, kernel_size=3, padding=1), self.activation, nn.Conv2d(128, 128, kernel_size=3, padding=1)) self.en_layer2_3 = nn.Sequential( nn.Conv2d(128, 128, kernel_size=3, padding=1), self.activation, nn.Conv2d(128, 128, kernel_size=3, padding=1)) self.en_layer2_4 = nn.Sequential( nn.Conv2d(128, 128, kernel_size=3, padding=1), self.activation, nn.Conv2d(128, 128, kernel_size=3, padding=1)) self.en_layer3_1 = nn.Sequential( nn.Conv2d(128, 320, kernel_size=3, stride=2, padding=1), self.activation, ) def forward(self, x): hx = self.en_layer1_1(x) hx = self.activation(self.en_layer1_2(hx) + hx) hx = self.activation(self.en_layer1_3(hx) + hx) hx = self.activation(self.en_layer1_4(hx) + hx) residual_1 = hx hx = self.en_layer2_1(hx) hx = self.activation(self.en_layer2_2(hx) + hx) hx = self.activation(self.en_layer2_3(hx) + hx) hx = self.activation(self.en_layer2_4(hx) + hx) residual_2 = hx hx = self.en_layer3_1(hx) return hx, residual_1, residual_2 class Embeddings_output(nn.Module): def __init__(self): super(Embeddings_output, self).__init__() self.activation = nn.LeakyReLU(0.2, True) self.de_layer3_1 = nn.Sequential( nn.ConvTranspose2d(320, 192, kernel_size=4, stride=2, padding=1), self.activation, ) head_num = 3 dim = 192 self.de_layer2_2 = nn.Sequential( nn.Conv2d(192+128, 192, kernel_size=1, padding=0), self.activation, ) self.de_block_1 = Intra_SA(dim, head_num) self.de_block_2 = Inter_SA(dim, head_num) self.de_block_3 = Intra_SA(dim, head_num) self.de_block_4 = Inter_SA(dim, head_num) self.de_block_5 = Intra_SA(dim, head_num) self.de_block_6 = Inter_SA(dim, head_num) self.de_layer2_1 = nn.Sequential( nn.ConvTranspose2d(192, 64, kernel_size=4, stride=2, padding=1), self.activation, ) self.de_layer1_3 = nn.Sequential( nn.Conv2d(128, 64, kernel_size=1, padding=0), self.activation, nn.Conv2d(64, 64, kernel_size=3, padding=1)) self.de_layer1_2 = nn.Sequential( nn.Conv2d(64, 64, kernel_size=3, padding=1), self.activation, nn.Conv2d(64, 64, kernel_size=3, padding=1)) self.de_layer1_1 = nn.Sequential( nn.Conv2d(64, 3, kernel_size=3, padding=1), self.activation ) def forward(self, x, residual_1, residual_2): hx = self.de_layer3_1(x) hx = self.de_layer2_2(torch.cat((hx, residual_2), dim = 1)) hx = self.de_block_1(hx) hx = self.de_block_2(hx) hx = self.de_block_3(hx) hx = self.de_block_4(hx) hx = self.de_block_5(hx) hx = self.de_block_6(hx) hx = self.de_layer2_1(hx) hx = self.activation(self.de_layer1_3(torch.cat((hx, residual_1), dim = 1)) + hx) hx = self.activation(self.de_layer1_2(hx) + hx) hx = self.de_layer1_1(hx) return hx class Attention(nn.Module): def __init__(self, head_num): super(Attention, self).__init__() self.num_attention_heads = head_num self.softmax = nn.Softmax(dim=-1) def transpose_for_scores(self, x): B, N, C = x.size() attention_head_size = int(C / self.num_attention_heads) new_x_shape = x.size()[:-1] + (self.num_attention_heads, attention_head_size) x = x.view(*new_x_shape) return x.permute(0, 2, 1, 3).contiguous() def forward(self, query_layer, key_layer, value_layer): B, N, C = query_layer.size() query_layer = self.transpose_for_scores(query_layer) key_layer = self.transpose_for_scores(key_layer) value_layer = self.transpose_for_scores(value_layer) attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) _, _, _, d = query_layer.size() attention_scores = attention_scores / math.sqrt(d) attention_probs = self.softmax(attention_scores) context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (C,) attention_out = context_layer.view(*new_context_layer_shape) return attention_out class Mlp(nn.Module): def __init__(self, hidden_size): super(Mlp, self).__init__() self.fc1 = nn.Linear(hidden_size, 4*hidden_size) self.fc2 = nn.Linear(4*hidden_size, hidden_size) self.act_fn = torch.nn.functional.gelu self._init_weights() def _init_weights(self): nn.init.xavier_uniform_(self.fc1.weight) nn.init.xavier_uniform_(self.fc2.weight) nn.init.normal_(self.fc1.bias, std=1e-6) nn.init.normal_(self.fc2.bias, std=1e-6) def forward(self, x): x = self.fc1(x) x = self.act_fn(x) x = self.fc2(x) return x # CPE (Conditional Positional Embedding) class PEG(nn.Module): def __init__(self, hidden_size): super(PEG, self).__init__() self.PEG = nn.Conv2d(hidden_size, hidden_size, kernel_size=3, padding=1, groups=hidden_size) def forward(self, x): x = self.PEG(x) + x return x class Intra_SA(nn.Module): def __init__(self, dim, head_num): super(Intra_SA, self).__init__() self.hidden_size = dim // 2 self.head_num = head_num self.attention_norm = nn.LayerNorm(dim) self.conv_input = nn.Conv2d(dim, dim, kernel_size=1, padding=0) self.qkv_local_h = nn.Linear(self.hidden_size, self.hidden_size * 3) # qkv_h self.qkv_local_v = nn.Linear(self.hidden_size, self.hidden_size * 3) # qkv_v self.fuse_out = nn.Conv2d(dim, dim, kernel_size=1, padding=0) self.ffn_norm = nn.LayerNorm(dim) self.ffn = Mlp(dim) self.attn = Attention(head_num=self.head_num) self.PEG = PEG(dim) def forward(self, x): h = x B, C, H, W = x.size() x = x.view(B, C, H*W).permute(0, 2, 1).contiguous() x = self.attention_norm(x).permute(0, 2, 1).contiguous() x = x.view(B, C, H, W) x_input = torch.chunk(self.conv_input(x), 2, dim=1) feature_h = (x_input[0]).permute(0, 2, 3, 1).contiguous() feature_h = feature_h.view(B * H, W, C//2) feature_v = (x_input[1]).permute(0, 3, 2, 1).contiguous() feature_v = feature_v.view(B * W, H, C//2) qkv_h = torch.chunk(self.qkv_local_h(feature_h), 3, dim=2) qkv_v = torch.chunk(self.qkv_local_v(feature_v), 3, dim=2) q_h, k_h, v_h = qkv_h[0], qkv_h[1], qkv_h[2] q_v, k_v, v_v = qkv_v[0], qkv_v[1], qkv_v[2] if H == W: query = torch.cat((q_h, q_v), dim=0) key = torch.cat((k_h, k_v), dim=0) value = torch.cat((v_h, v_v), dim=0) attention_output = self.attn(query, key, value) attention_output = torch.chunk(attention_output, 2, dim=0) attention_output_h = attention_output[0] attention_output_v = attention_output[1] attention_output_h = attention_output_h.view(B, H, W, C//2).permute(0, 3, 1, 2).contiguous() attention_output_v = attention_output_v.view(B, W, H, C//2).permute(0, 3, 2, 1).contiguous() attn_out = self.fuse_out(torch.cat((attention_output_h, attention_output_v), dim=1)) else: attention_output_h = self.attn(q_h, k_h, v_h) attention_output_v = self.attn(q_v, k_v, v_v) attention_output_h = attention_output_h.view(B, H, W, C//2).permute(0, 3, 1, 2).contiguous() attention_output_v = attention_output_v.view(B, W, H, C//2).permute(0, 3, 2, 1).contiguous() attn_out = self.fuse_out(torch.cat((attention_output_h, attention_output_v), dim=1)) x = attn_out + h x = x.view(B, C, H*W).permute(0, 2, 1).contiguous() h = x x = self.ffn_norm(x) x = self.ffn(x) x = x + h x = x.permute(0, 2, 1).contiguous() x = x.view(B, C, H, W) x = self.PEG(x) return x class Inter_SA(nn.Module): def __init__(self,dim, head_num): super(Inter_SA, self).__init__() self.hidden_size = dim self.head_num = head_num self.attention_norm = nn.LayerNorm(self.hidden_size) self.conv_input = nn.Conv2d(self.hidden_size, self.hidden_size, kernel_size=1, padding=0) self.conv_h = nn.Conv2d(self.hidden_size//2, 3 * (self.hidden_size//2), kernel_size=1, padding=0) # qkv_h self.conv_v = nn.Conv2d(self.hidden_size//2, 3 * (self.hidden_size//2), kernel_size=1, padding=0) # qkv_v self.ffn_norm = nn.LayerNorm(self.hidden_size) self.ffn = Mlp(self.hidden_size) self.fuse_out = nn.Conv2d(self.hidden_size, self.hidden_size, kernel_size=1, padding=0) self.attn = Attention(head_num=self.head_num) self.PEG = PEG(dim) def forward(self, x): h = x B, C, H, W = x.size() x = x.view(B, C, H*W).permute(0, 2, 1).contiguous() x = self.attention_norm(x).permute(0, 2, 1).contiguous() x = x.view(B, C, H, W) x_input = torch.chunk(self.conv_input(x), 2, dim=1) feature_h = torch.chunk(self.conv_h(x_input[0]), 3, dim=1) feature_v = torch.chunk(self.conv_v(x_input[1]), 3, dim=1) query_h, key_h, value_h = feature_h[0], feature_h[1], feature_h[2] query_v, key_v, value_v = feature_v[0], feature_v[1], feature_v[2] horizontal_groups = torch.cat((query_h, key_h, value_h), dim=0) horizontal_groups = horizontal_groups.permute(0, 2, 1, 3).contiguous() horizontal_groups = horizontal_groups.view(3*B, H, -1) horizontal_groups = torch.chunk(horizontal_groups, 3, dim=0) query_h, key_h, value_h = horizontal_groups[0], horizontal_groups[1], horizontal_groups[2] vertical_groups = torch.cat((query_v, key_v, value_v), dim=0) vertical_groups = vertical_groups.permute(0, 3, 1, 2).contiguous() vertical_groups = vertical_groups.view(3*B, W, -1) vertical_groups = torch.chunk(vertical_groups, 3, dim=0) query_v, key_v, value_v = vertical_groups[0], vertical_groups[1], vertical_groups[2] if H == W: query = torch.cat((query_h, query_v), dim=0) key = torch.cat((key_h, key_v), dim=0) value = torch.cat((value_h, value_v), dim=0) attention_output = self.attn(query, key, value) attention_output = torch.chunk(attention_output, 2, dim=0) attention_output_h = attention_output[0] attention_output_v = attention_output[1] attention_output_h = attention_output_h.view(B, H, C//2, W).permute(0, 2, 1, 3).contiguous() attention_output_v = attention_output_v.view(B, W, C//2, H).permute(0, 2, 3, 1).contiguous() attn_out = self.fuse_out(torch.cat((attention_output_h, attention_output_v), dim=1)) else: attention_output_h = self.attn(query_h, key_h, value_h) attention_output_v = self.attn(query_v, key_v, value_v) attention_output_h = attention_output_h.view(B, H, C//2, W).permute(0, 2, 1, 3).contiguous() attention_output_v = attention_output_v.view(B, W, C//2, H).permute(0, 2, 3, 1).contiguous() attn_out = self.fuse_out(torch.cat((attention_output_h, attention_output_v), dim=1)) x = attn_out + h x = x.view(B, C, H*W).permute(0, 2, 1).contiguous() h = x x = self.ffn_norm(x) x = self.ffn(x) x = x + h x = x.permute(0, 2, 1).contiguous() x = x.view(B, C, H, W) x = self.PEG(x) return x class Stripformer(nn.Module): def __init__(self): super(Stripformer, self).__init__() self.encoder = Embeddings() head_num = 5 dim = 320 self.Trans_block_1 = Intra_SA(dim, head_num) self.Trans_block_2 = Inter_SA(dim, head_num) self.Trans_block_3 = Intra_SA(dim, head_num) self.Trans_block_4 = Inter_SA(dim, head_num) self.Trans_block_5 = Intra_SA(dim, head_num) self.Trans_block_6 = Inter_SA(dim, head_num) self.Trans_block_7 = Intra_SA(dim, head_num) self.Trans_block_8 = Inter_SA(dim, head_num) self.Trans_block_9 = Intra_SA(dim, head_num) self.Trans_block_10 = Inter_SA(dim, head_num) self.Trans_block_11 = Intra_SA(dim, head_num) self.Trans_block_12 = Inter_SA(dim, head_num) self.decoder = Embeddings_output() def forward(self, x): hx, residual_1, residual_2 = self.encoder(x) hx = self.Trans_block_1(hx) hx = self.Trans_block_2(hx) hx = self.Trans_block_3(hx) hx = self.Trans_block_4(hx) hx = self.Trans_block_5(hx) hx = self.Trans_block_6(hx) hx = self.Trans_block_7(hx) hx = self.Trans_block_8(hx) hx = self.Trans_block_9(hx) hx = self.Trans_block_10(hx) hx = self.Trans_block_11(hx) hx = self.Trans_block_12(hx) hx = self.decoder(hx, residual_1, residual_2) return hx + x