You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

375 lines
14 KiB

import torch
import torch.nn as nn
import math
class Embeddings(nn.Module):
def __init__(self):
super(Embeddings, self).__init__()
self.activation = nn.LeakyReLU(0.2, True)
self.en_layer1_1 = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1),
self.activation,
)
self.en_layer1_2 = nn.Sequential(
nn.Conv2d(64, 64, kernel_size=3, padding=1),
self.activation,
nn.Conv2d(64, 64, kernel_size=3, padding=1))
self.en_layer1_3 = nn.Sequential(
nn.Conv2d(64, 64, kernel_size=3, padding=1),
self.activation,
nn.Conv2d(64, 64, kernel_size=3, padding=1))
self.en_layer1_4 = nn.Sequential(
nn.Conv2d(64, 64, kernel_size=3, padding=1),
self.activation,
nn.Conv2d(64, 64, kernel_size=3, padding=1))
self.en_layer2_1 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
self.activation,
)
self.en_layer2_2 = nn.Sequential(
nn.Conv2d(128, 128, kernel_size=3, padding=1),
self.activation,
nn.Conv2d(128, 128, kernel_size=3, padding=1))
self.en_layer2_3 = nn.Sequential(
nn.Conv2d(128, 128, kernel_size=3, padding=1),
self.activation,
nn.Conv2d(128, 128, kernel_size=3, padding=1))
self.en_layer2_4 = nn.Sequential(
nn.Conv2d(128, 128, kernel_size=3, padding=1),
self.activation,
nn.Conv2d(128, 128, kernel_size=3, padding=1))
self.en_layer3_1 = nn.Sequential(
nn.Conv2d(128, 320, kernel_size=3, stride=2, padding=1),
self.activation,
)
def forward(self, x):
hx = self.en_layer1_1(x)
hx = self.activation(self.en_layer1_2(hx) + hx)
hx = self.activation(self.en_layer1_3(hx) + hx)
hx = self.activation(self.en_layer1_4(hx) + hx)
residual_1 = hx
hx = self.en_layer2_1(hx)
hx = self.activation(self.en_layer2_2(hx) + hx)
hx = self.activation(self.en_layer2_3(hx) + hx)
hx = self.activation(self.en_layer2_4(hx) + hx)
residual_2 = hx
hx = self.en_layer3_1(hx)
return hx, residual_1, residual_2
class Embeddings_output(nn.Module):
def __init__(self):
super(Embeddings_output, self).__init__()
self.activation = nn.LeakyReLU(0.2, True)
self.de_layer3_1 = nn.Sequential(
nn.ConvTranspose2d(320, 192, kernel_size=4, stride=2, padding=1),
self.activation,
)
head_num = 3
dim = 192
self.de_layer2_2 = nn.Sequential(
nn.Conv2d(192+128, 192, kernel_size=1, padding=0),
self.activation,
)
self.de_block_1 = Intra_SA(dim, head_num)
self.de_block_2 = Inter_SA(dim, head_num)
self.de_block_3 = Intra_SA(dim, head_num)
self.de_block_4 = Inter_SA(dim, head_num)
self.de_block_5 = Intra_SA(dim, head_num)
self.de_block_6 = Inter_SA(dim, head_num)
self.de_layer2_1 = nn.Sequential(
nn.ConvTranspose2d(192, 64, kernel_size=4, stride=2, padding=1),
self.activation,
)
self.de_layer1_3 = nn.Sequential(
nn.Conv2d(128, 64, kernel_size=1, padding=0),
self.activation,
nn.Conv2d(64, 64, kernel_size=3, padding=1))
self.de_layer1_2 = nn.Sequential(
nn.Conv2d(64, 64, kernel_size=3, padding=1),
self.activation,
nn.Conv2d(64, 64, kernel_size=3, padding=1))
self.de_layer1_1 = nn.Sequential(
nn.Conv2d(64, 3, kernel_size=3, padding=1),
self.activation
)
def forward(self, x, residual_1, residual_2):
hx = self.de_layer3_1(x)
hx = self.de_layer2_2(torch.cat((hx, residual_2), dim = 1))
hx = self.de_block_1(hx)
hx = self.de_block_2(hx)
hx = self.de_block_3(hx)
hx = self.de_block_4(hx)
hx = self.de_block_5(hx)
hx = self.de_block_6(hx)
hx = self.de_layer2_1(hx)
hx = self.activation(self.de_layer1_3(torch.cat((hx, residual_1), dim = 1)) + hx)
hx = self.activation(self.de_layer1_2(hx) + hx)
hx = self.de_layer1_1(hx)
return hx
class Attention(nn.Module):
def __init__(self, head_num):
super(Attention, self).__init__()
self.num_attention_heads = head_num
self.softmax = nn.Softmax(dim=-1)
def transpose_for_scores(self, x):
B, N, C = x.size()
attention_head_size = int(C / self.num_attention_heads)
new_x_shape = x.size()[:-1] + (self.num_attention_heads, attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3).contiguous()
def forward(self, query_layer, key_layer, value_layer):
B, N, C = query_layer.size()
query_layer = self.transpose_for_scores(query_layer)
key_layer = self.transpose_for_scores(key_layer)
value_layer = self.transpose_for_scores(value_layer)
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
_, _, _, d = query_layer.size()
attention_scores = attention_scores / math.sqrt(d)
attention_probs = self.softmax(attention_scores)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (C,)
attention_out = context_layer.view(*new_context_layer_shape)
return attention_out
class Mlp(nn.Module):
def __init__(self, hidden_size):
super(Mlp, self).__init__()
self.fc1 = nn.Linear(hidden_size, 4*hidden_size)
self.fc2 = nn.Linear(4*hidden_size, hidden_size)
self.act_fn = torch.nn.functional.gelu
self._init_weights()
def _init_weights(self):
nn.init.xavier_uniform_(self.fc1.weight)
nn.init.xavier_uniform_(self.fc2.weight)
nn.init.normal_(self.fc1.bias, std=1e-6)
nn.init.normal_(self.fc2.bias, std=1e-6)
def forward(self, x):
x = self.fc1(x)
x = self.act_fn(x)
x = self.fc2(x)
return x
# CPE (Conditional Positional Embedding)
class PEG(nn.Module):
def __init__(self, hidden_size):
super(PEG, self).__init__()
self.PEG = nn.Conv2d(hidden_size, hidden_size, kernel_size=3, padding=1, groups=hidden_size)
def forward(self, x):
x = self.PEG(x) + x
return x
class Intra_SA(nn.Module):
def __init__(self, dim, head_num):
super(Intra_SA, self).__init__()
self.hidden_size = dim // 2
self.head_num = head_num
self.attention_norm = nn.LayerNorm(dim)
self.conv_input = nn.Conv2d(dim, dim, kernel_size=1, padding=0)
self.qkv_local_h = nn.Linear(self.hidden_size, self.hidden_size * 3) # qkv_h
self.qkv_local_v = nn.Linear(self.hidden_size, self.hidden_size * 3) # qkv_v
self.fuse_out = nn.Conv2d(dim, dim, kernel_size=1, padding=0)
self.ffn_norm = nn.LayerNorm(dim)
self.ffn = Mlp(dim)
self.attn = Attention(head_num=self.head_num)
self.PEG = PEG(dim)
def forward(self, x):
h = x
B, C, H, W = x.size()
x = x.view(B, C, H*W).permute(0, 2, 1).contiguous()
x = self.attention_norm(x).permute(0, 2, 1).contiguous()
x = x.view(B, C, H, W)
x_input = torch.chunk(self.conv_input(x), 2, dim=1)
feature_h = (x_input[0]).permute(0, 2, 3, 1).contiguous()
feature_h = feature_h.view(B * H, W, C//2)
feature_v = (x_input[1]).permute(0, 3, 2, 1).contiguous()
feature_v = feature_v.view(B * W, H, C//2)
qkv_h = torch.chunk(self.qkv_local_h(feature_h), 3, dim=2)
qkv_v = torch.chunk(self.qkv_local_v(feature_v), 3, dim=2)
q_h, k_h, v_h = qkv_h[0], qkv_h[1], qkv_h[2]
q_v, k_v, v_v = qkv_v[0], qkv_v[1], qkv_v[2]
if H == W:
query = torch.cat((q_h, q_v), dim=0)
key = torch.cat((k_h, k_v), dim=0)
value = torch.cat((v_h, v_v), dim=0)
attention_output = self.attn(query, key, value)
attention_output = torch.chunk(attention_output, 2, dim=0)
attention_output_h = attention_output[0]
attention_output_v = attention_output[1]
attention_output_h = attention_output_h.view(B, H, W, C//2).permute(0, 3, 1, 2).contiguous()
attention_output_v = attention_output_v.view(B, W, H, C//2).permute(0, 3, 2, 1).contiguous()
attn_out = self.fuse_out(torch.cat((attention_output_h, attention_output_v), dim=1))
else:
attention_output_h = self.attn(q_h, k_h, v_h)
attention_output_v = self.attn(q_v, k_v, v_v)
attention_output_h = attention_output_h.view(B, H, W, C//2).permute(0, 3, 1, 2).contiguous()
attention_output_v = attention_output_v.view(B, W, H, C//2).permute(0, 3, 2, 1).contiguous()
attn_out = self.fuse_out(torch.cat((attention_output_h, attention_output_v), dim=1))
x = attn_out + h
x = x.view(B, C, H*W).permute(0, 2, 1).contiguous()
h = x
x = self.ffn_norm(x)
x = self.ffn(x)
x = x + h
x = x.permute(0, 2, 1).contiguous()
x = x.view(B, C, H, W)
x = self.PEG(x)
return x
class Inter_SA(nn.Module):
def __init__(self,dim, head_num):
super(Inter_SA, self).__init__()
self.hidden_size = dim
self.head_num = head_num
self.attention_norm = nn.LayerNorm(self.hidden_size)
self.conv_input = nn.Conv2d(self.hidden_size, self.hidden_size, kernel_size=1, padding=0)
self.conv_h = nn.Conv2d(self.hidden_size//2, 3 * (self.hidden_size//2), kernel_size=1, padding=0) # qkv_h
self.conv_v = nn.Conv2d(self.hidden_size//2, 3 * (self.hidden_size//2), kernel_size=1, padding=0) # qkv_v
self.ffn_norm = nn.LayerNorm(self.hidden_size)
self.ffn = Mlp(self.hidden_size)
self.fuse_out = nn.Conv2d(self.hidden_size, self.hidden_size, kernel_size=1, padding=0)
self.attn = Attention(head_num=self.head_num)
self.PEG = PEG(dim)
def forward(self, x):
h = x
B, C, H, W = x.size()
x = x.view(B, C, H*W).permute(0, 2, 1).contiguous()
x = self.attention_norm(x).permute(0, 2, 1).contiguous()
x = x.view(B, C, H, W)
x_input = torch.chunk(self.conv_input(x), 2, dim=1)
feature_h = torch.chunk(self.conv_h(x_input[0]), 3, dim=1)
feature_v = torch.chunk(self.conv_v(x_input[1]), 3, dim=1)
query_h, key_h, value_h = feature_h[0], feature_h[1], feature_h[2]
query_v, key_v, value_v = feature_v[0], feature_v[1], feature_v[2]
horizontal_groups = torch.cat((query_h, key_h, value_h), dim=0)
horizontal_groups = horizontal_groups.permute(0, 2, 1, 3).contiguous()
horizontal_groups = horizontal_groups.view(3*B, H, -1)
horizontal_groups = torch.chunk(horizontal_groups, 3, dim=0)
query_h, key_h, value_h = horizontal_groups[0], horizontal_groups[1], horizontal_groups[2]
vertical_groups = torch.cat((query_v, key_v, value_v), dim=0)
vertical_groups = vertical_groups.permute(0, 3, 1, 2).contiguous()
vertical_groups = vertical_groups.view(3*B, W, -1)
vertical_groups = torch.chunk(vertical_groups, 3, dim=0)
query_v, key_v, value_v = vertical_groups[0], vertical_groups[1], vertical_groups[2]
if H == W:
query = torch.cat((query_h, query_v), dim=0)
key = torch.cat((key_h, key_v), dim=0)
value = torch.cat((value_h, value_v), dim=0)
attention_output = self.attn(query, key, value)
attention_output = torch.chunk(attention_output, 2, dim=0)
attention_output_h = attention_output[0]
attention_output_v = attention_output[1]
attention_output_h = attention_output_h.view(B, H, C//2, W).permute(0, 2, 1, 3).contiguous()
attention_output_v = attention_output_v.view(B, W, C//2, H).permute(0, 2, 3, 1).contiguous()
attn_out = self.fuse_out(torch.cat((attention_output_h, attention_output_v), dim=1))
else:
attention_output_h = self.attn(query_h, key_h, value_h)
attention_output_v = self.attn(query_v, key_v, value_v)
attention_output_h = attention_output_h.view(B, H, C//2, W).permute(0, 2, 1, 3).contiguous()
attention_output_v = attention_output_v.view(B, W, C//2, H).permute(0, 2, 3, 1).contiguous()
attn_out = self.fuse_out(torch.cat((attention_output_h, attention_output_v), dim=1))
x = attn_out + h
x = x.view(B, C, H*W).permute(0, 2, 1).contiguous()
h = x
x = self.ffn_norm(x)
x = self.ffn(x)
x = x + h
x = x.permute(0, 2, 1).contiguous()
x = x.view(B, C, H, W)
x = self.PEG(x)
return x
class Stripformer(nn.Module):
def __init__(self):
super(Stripformer, self).__init__()
self.encoder = Embeddings()
head_num = 5
dim = 320
self.Trans_block_1 = Intra_SA(dim, head_num)
self.Trans_block_2 = Inter_SA(dim, head_num)
self.Trans_block_3 = Intra_SA(dim, head_num)
self.Trans_block_4 = Inter_SA(dim, head_num)
self.Trans_block_5 = Intra_SA(dim, head_num)
self.Trans_block_6 = Inter_SA(dim, head_num)
self.Trans_block_7 = Intra_SA(dim, head_num)
self.Trans_block_8 = Inter_SA(dim, head_num)
self.Trans_block_9 = Intra_SA(dim, head_num)
self.Trans_block_10 = Inter_SA(dim, head_num)
self.Trans_block_11 = Intra_SA(dim, head_num)
self.Trans_block_12 = Inter_SA(dim, head_num)
self.decoder = Embeddings_output()
def forward(self, x):
hx, residual_1, residual_2 = self.encoder(x)
hx = self.Trans_block_1(hx)
hx = self.Trans_block_2(hx)
hx = self.Trans_block_3(hx)
hx = self.Trans_block_4(hx)
hx = self.Trans_block_5(hx)
hx = self.Trans_block_6(hx)
hx = self.Trans_block_7(hx)
hx = self.Trans_block_8(hx)
hx = self.Trans_block_9(hx)
hx = self.Trans_block_10(hx)
hx = self.Trans_block_11(hx)
hx = self.Trans_block_12(hx)
hx = self.decoder(hx, residual_1, residual_2)
return hx + x