Skip to content
This repository was archived by the owner on Dec 18, 2024. It is now read-only.

Trying to understand the Dense Prediction Transformer architecture #78

@tholmb

Description

@tholmb

This is not basically an issue but I'm trying to understand the DPT architecture. I'm quite new with transformers but I have previous knowledge with CNN's.

Let's take an example where I would like to generate dpt large depth model with patch size of 16 and image size 384. I can see that pretrained weights for the original model variants are loaded from timm but in this example I would like to generate the model from the scratch.

I gathered all the essential functions for the below code snippet (hopefully). I don't totally understand that in where and how I should apply functions forward_flex() or _resize_pos_embed()? The second question is that if I would like to calculate multiscale-loss, for which layers it should be done? Thanks in advance!

import torch
import torch.nn as nn
import torch.nn.functional as F

class Transpose(nn.Module):
    def __init__(self, dim0, dim1):
        super(Transpose, self).__init__()
        self.dim0 = dim0
        self.dim1 = dim1

    def forward(self, x):
        x = x.transpose(self.dim0, self.dim1)
        return x

class Interpolate(nn.Module):
    def __init__(self, scale_factor):
        super(Interpolate, self).__init__()
        self.scale_factor = scale_factor

    def forward(self, x):
        x = nn.functional.interpolate(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=True)
        return x

class ResidualConvUnit(nn.Module):
    def __init__(self, features):
        super().__init__()
        self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True)
        self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        out = self.relu(x)
        out = self.conv1(out)
        out = self.relu(out)
        out = self.conv2(out)
        return out + x

class FeatureFusionBlock(nn.Module):
    def __init__(self, features):
        super(FeatureFusionBlock, self).__init__()
        self.resConfUnit1 = ResidualConvUnit(features)
        self.resConfUnit2 = ResidualConvUnit(features)

    def forward(self, *xs):
        output = xs[0]
        if len(xs) == 2:
            output += self.resConfUnit1(xs[1])
        output = self.resConfUnit2(output)
        output = nn.functional.interpolate(output, scale_factor=2, mode="bilinear", align_corners=True)
        return output

class PatchEmbed(nn.Module):
    def __init__(self, img_size=384, patch_size=16, in_chans=3, embed_dim=1024):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (img_size // patch_size, img_size // patch_size)
        self.num_patches = self.grid_size[0] * self.grid_size[1]

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        B, C, H, W = x.shape
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)
        return x

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class LayerScale(nn.Module):
    def __init__(self, dim, init_values=1e-5, inplace=False):
        super().__init__()
        self.inplace = inplace
        self.gamma = nn.Parameter(init_values * torch.ones(dim))

    def forward(self, x):
        return x.mul_(self.gamma) if self.inplace else x * self.gamma

class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features

        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
        self.act = act_layer()
        self.drop1 = nn.Dropout(drop)
        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
        self.drop2 = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x

class Block(nn.Module):

    def __init__(self, embed_dim=1024, num_heads=16, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(embed_dim)
        self.attn = Attention(embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)

        self.norm2 = norm_layer(embed_dim)
        self.mlp = Mlp(in_features=embed_dim, hidden_features=int(embed_dim * mlp_ratio), act_layer=act_layer, drop=drop)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class ProjectReadout(nn.Module):
    def __init__(self, in_features, start_index=1):
        super(ProjectReadout, self).__init__()
        self.start_index = start_index

        self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())

    def forward(self, x):
        readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
        features = torch.cat((x[:, self.start_index :], readout), -1)

        return self.project(features)

class Encoder(nn.Module):
    def __init__(self, embed_dim=1024, num_heads=16, norm_layer=nn.LayerNorm):
        super(Encoder, self).__init__()

        self.patch_embed = PatchEmbed()
        self.num_patches = self.patch_embed.num_patches
        self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches, embed_dim) * .02)
        self.pos_drop = nn.Dropout(p=0.1)

        self.encoder_block = Block(embed_dim=embed_dim, num_heads=num_heads)

        self.norm = norm_layer(embed_dim)

    def forward(self, x):
        x = self.patch_embed(x)

        x = self.pos_embed(x)
        x = self.pos_drop(x)

        for i in range(6):       
            x = self.encoder_block(x)

        layer1 = x

        for i in range(6):       
            x = self.encoder_block(x)

        layer2 = x

        for i in range(6):       
            x = self.encoder_block(x)

        layer3 = x

        for i in range(6):       
            x = self.encoder_block(x)

        layer4 = self.norm(x)
        return layer1, layer2, layer3, layer4

class DPT_VITL_16_384(nn.Module):
    def __init__(self, img_size=384, features=256, embed_dim=1024, in_shape=[256, 512, 1024, 1024]):
        super(DPT_VITL_16_384, self).__init__()

        self.encoder = Encoder(embed_dim=embed_dim)

        self.refinenet1 = FeatureFusionBlock(features)
        self.refinenet2 = FeatureFusionBlock(features)
        self.refinenet3 = FeatureFusionBlock(features)
        self.refinenet4 = FeatureFusionBlock(features)

        self.layer1_rn = nn.Conv2d(in_shape[0], features, kernel_size=3, stride=1, padding=1, bias=False)
        self.layer2_rn = nn.Conv2d(in_shape[1], features, kernel_size=3, stride=1, padding=1, bias=False)
        self.layer3_rn = nn.Conv2d(in_shape[2], features, kernel_size=3, stride=1, padding=1, bias=False)
        self.layer4_rn = nn.Conv2d(in_shape[3], features, kernel_size=3, stride=1, padding=1, bias=False)

        self.readout_oper0 = ProjectReadout(embed_dim, start_index=1)
        self.readout_oper1 = ProjectReadout(embed_dim, start_index=1)
        self.readout_oper2 = ProjectReadout(embed_dim, start_index=1)
        self.readout_oper3 = ProjectReadout(embed_dim, start_index=1)

        self.act_postprocess1 = nn.Sequential(self.readout_oper0, Transpose(1, 2), nn.Unflatten(2, torch.Size([img_size // 16, img_size // 16])),
                                              nn.Conv2d(in_channels=embed_dim, out_channels=in_shape[0], kernel_size=1, stride=1, padding=0),
                                              nn.ConvTranspose2d(in_channels=in_shape[0], out_channels=in_shape[0], kernel_size=4, stride=4, padding=0, bias=True, dilation=1, groups=1))

        self.act_postprocess2 = nn.Sequential(self.readout_oper1, Transpose(1, 2), nn.Unflatten(2, torch.Size([img_size // 16, img_size // 16])),
                                              nn.Conv2d(in_channels=embed_dim, out_channels=in_shape[1],kernel_size=1,stride=1,padding=0),
                                              nn.ConvTranspose2d(in_channels=in_shape[1], out_channels=in_shape[1], kernel_size=2, stride=2, padding=0, bias=True, dilation=1, groups=1))

        self.act_postprocess3 = nn.Sequential(self.readout_oper2, Transpose(1, 2), nn.Unflatten(2, torch.Size([img_size // 16, img_size // 16])),
                                              nn.Conv2d(in_channels=embed_dim, out_channels=in_shape[2], kernel_size=1, stride=1, padding=0))

        self.act_postprocess4 = nn.Sequential(self.readout_oper3, Transpose(1, 2), nn.Unflatten(2, torch.Size([img_size // 16, img_size // 16])),
                                              nn.Conv2d(in_channels=embed_dim, out_channels=in_shape[3], kernel_size=1, stride=1, padding=0),
                                              nn.Conv2d(in_channels=in_shape[3], out_channels=in_shape[3], kernel_size=3, stride=2, padding=1))


        self.head = nn.Sequential(
            nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
            Interpolate(scale_factor=2),
            nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(True),
            nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
            nn.ReLU(True)
        )

    def forward(self, x):

        layer_1, layer_2, layer_3, layer_4 = self.encoder(x)

        layer_1 = self.act_postprocess1(layer_1)
        layer_2 = self.act_postprocess1(layer_2)
        layer_3 = self.act_postprocess1(layer_3)
        layer_4 = self.act_postprocess1(layer_4)

        layer_1_rn = self.layer1_rn(layer_1)
        layer_2_rn = self.layer2_rn(layer_2)
        layer_3_rn = self.layer3_rn(layer_3)
        layer_4_rn = self.layer4_rn(layer_4)

        path_4 = self.refinenet4(layer_4_rn)
        path_3 = self.refinenet3(path_4, layer_3_rn)
        path_2 = self.refinenet2(path_3, layer_2_rn)
        path_1 = self.refinenet1(path_2, layer_1_rn)

        out = self.head(path_1)

        return out

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
model = DPT_VITL_16_384().to(device)
print(model)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions