create a deformable convolutional layer in matlab

The deformable convolutional layer is typically used in object detection and instance segmentation tasks. It is an extension of the standard convolutional layer where the kernel positions are adjusted according to a learnable offset for each spatial location.

To create a deformable convolutional layer in MATLAB, you can use the Deep Learning Toolbox with the custom layer API. Here's an example implementation of the Deformable Convolutional Networks (DCN) layer based on the paper Deformable Convolutional Networks.

main.m
classdef DeformableConvolutionalLayer < nnet.layer.Layer

    properties
        NumFilters
        KernelSize
        Stride
        DilationFactor
        OffsetNumChannels
        PaddingMode
        PaddingSize
        Weights
        Bias
    end

    methods
        function layer = DeformableConvolutionalLayer(numFilters, kernelSize, stride, dilationFactor, offsetNumChannels, paddingMode, paddingSize, name)
            layer.Name = name;
            layer.Description = "Deformable Convolutional Layer";
            layer.NumFilters = numFilters;
            layer.KernelSize = kernelSize;
            layer.Stride = stride;
            layer.DilationFactor = dilationFactor;
            layer.OffsetNumChannels = offsetNumChannels;
            layer.PaddingMode = paddingMode;
            layer.PaddingSize = paddingSize;
        end
        
        function Z = predict(layer, X, offset)
            [O, P, Q, C] = size(X);
            H = layer.KernelSize(1);
            W = layer.KernelSize(2);
            layer.Weights(:,:,:,1:layer.OffsetNumChannels,:) = 0;
            for q = 1:layer.NumFilters
                for c = layer.OffsetNumChannels+1:C
                    % Compute deformable offset
                    offset_ = squeeze(offset(:,:,:,layer.OffsetNumChannels*(q-1)+(c-layer.OffsetNumChannels)+1,:));
                    offset_(isnan(offset_)) = 0;
                    [R, S, ~] = size(offset_);
                    [U, V] = meshgrid(1:Q, 1:P);
                    U = repmat(U, [1, 1, H*W]);
                    V = repmat(V, [1, 1, H*W]);
                    U_ = reshape(U-dilationFactor*(offset_(:,:,1)+1), [O*P, H*W]);
                    V_ = reshape(V-dilationFactor*(offset_(:,:,2)+1), [O*P, H*W]);
                    U_ = min(max(round(U_), 1), Q);
                    V_ = min(max(round(V_), 1), P);
                    X_ = permute(X, [4, 3, 2, 1]);
                    X_ = reshape(X_, [C, H*W, O*P]);
                    index = sub2ind([P, Q], V_, U_);
                    for h = 1:H
                        for w = 1:W
                            ind_start = (h-1)*W + w;
                            ind_end = (h-1)*W + w + layer.OffsetNumChannels*(H*W);
                            layer.Weights(h, w, :, ind_start:ind_end, q) = X_(:, ind_start*ones(1, layer.OffsetNumChannels) + index-1);
                        end
                    end
                end
            end
            X = permute(X, [4, 3, 2, 1]);
            Z = vl_nnconv(X, layer.Weights, layer.Bias, 'pad', layer.PaddingSize, 'stride', layer.Stride);
            Z = permute(Z, [4, 3, 2, 1]);
        end

        function [dLdX, dLdWeights, dLdBias, dLdOffset] = backward(layer, X, offset, dLdZ, ~)
            [O, P, Q, C] = size(X);
            H = layer.KernelSize(1);
            W = layer.KernelSize(2);
            dLdX = zeros(O, P, Q, C, 'like', X);
            dLdWeights = zeros(H, W, C, (H*W+layer.OffsetNumChannels*H*W*(layer.NumFilters-1)), layer.NumFilters, 'like', X);
            dLdBias = zeros(1, 1, layer.NumFilters, 'like', X);
            dLdZ = permute(dLdZ, [4, 3, 2, 1]);
            [dLdX_, dLdWeights_, dLdBias_] = vl_nnconv(X, layer.Weights, layer.Bias, dLdZ, 'pad', layer.PaddingSize, 'stride', layer.Stride);
            for q = 1:layer.NumFilters
                for c = layer.OffsetNumChannels+1:C
                    % Compute deformable offset
                    offset_ = squeeze(offset(:,:,:,layer.OffsetNumChannels*(q-1)+(c-layer.OffsetNumChannels)+1,:));
                    offset_(isnan(offset_)) = 0;
                    [R, S, ~] = size(offset_);
                    [U, V] = meshgrid(1:Q, 1:P);
                    U = repmat(U, [1, 1, H*W]);
                    V = repmat(V, [1, 1, H*W]);
                    U_ = reshape(U-dilationFactor*(offset_(:,:,1)+1), [O*P, H*W]);
                    V_ = reshape(V-dilationFactor*(offset_(:,:,2)+1), [O*P, H*W]);
                    U_ = min(max(round(U_), 1), Q);
                    V_ = min(max(round(V_), 1), P);
                    index = sub2ind([P, Q], V_, U_);
                    dLdX__ = permute(dLdX_(:,:,:,q), [3, 2, 1]);
                    dLdX__ = reshape(dLdX__, [1, 1, O*P, H*W]);
                    for h = 1:H
                        for w = 1:W
                            ind_start = (h-1)*W + w;
                            ind_end = (h-1)*W + w + layer.OffsetNumChannels*(H*W);
                            dLdWeights(h, w, :, ind_start:ind_end, q) = sum(bsxfun(@times, dLdX__(1, 1, index'+(ind_start-1)*ones(size(index)), :), bsxfun(@minus, permute(layer.Weights(h, w, :, ind_start:ind_end, q), [4, 3, 1, 2])), 4);
                        end
                    end
                end
            end
            dLdX_ = reshape(dLdX_, [O, P, Q, C, layer.NumFilters]);
            for q = 1:layer.NumFilters
                tmp = dLdX_(:,:,:,:,q);
                tmp = permute(tmp, [4, 3, 2, 1]);
                dLdX(:,:,:,layer.OffsetNumChannels*(q-1)+1:layer.OffsetNumChannels*q) = sum(tmp, 5);
            end
            dLdBias = sum(sum(sum(dLdZ, 1), 2), 4);
            dLdOffset = zeros(O, P, Q, layer.OffsetNumChannels*layer.NumFilters, 'like', X);
            for q = 1:layer.NumFilters
                for c = layer.OffsetNumChannels+1:C
                    % Compute deformable offset
                    offset_ = squeeze(offset(:,:,:,layer.OffsetNumChannels*(q-1)+(c-layer.OffsetNumChannels)+1,:));
                    offset_(isnan(offset_)) = 0;
                    [R, S, ~] = size(offset_);
                    [U, V] = meshgrid(1:Q, 1:P);
                    U = repmat(U, [1, 1, H*W]);
                    V = repmat(V, [1, 1, H*W]);
                    U_ = reshape(U-dilationFactor*(offset_(:,:,1)+1), [O*P, H*W]);
                    V_ = reshape(V-dilationFactor*(offset_(:,:,2)+1), [O*P, H*W]);
                    U_ = min(max(round(U_), 1), Q);
                    V_ = min(max(round(V_), 1), P);
                    index = sub2ind([P, Q], V_, U_);
                    dLdX_ = permute(dLdX(:,:,:,layer.OffsetNumChannels*(q-1)+1:layer.OffsetNumChannels*q), [4, 3, 2, 1]);
                    dLdX_ = reshape(dLdX_, [layer.OffsetNumChannels, 1, H*W, O*P]);
                    for h = 1:H
                        for w = 1:W
                            ind_start = (h-1)*W + w;
                            ind_end = (h-1)*W + w + layer.OffsetNumChannels*(H*W);
                            dLdOffset_ = bsxfun(@times, permute(layer.Weights(h,w,:,ind_start:ind_end,q), [3, 1, 2]), dLdX_);
                            dLdOffset_ = sum(reshape(dLdOffset_, [layer.OffsetNumChannels, H*W, O*P]), 2);
                            dLdOffset_ = permute(dLdOffset_, [2, 1, 3]);
                            dLdOffset_ = reshape(dLdOffset_, [P, Q, layer.OffsetNumChannels]);
                            tmp = zeros(P, Q, 2, 'like', X);
                            tmp(:,:,1) = bsxfun(@times, dLdOffset_(:,:,1), (h-1-W/2));  % dU
                            tmp(:,:,2) = bsxfun(@times, dLdOffset_(:,:,2), (w-1-H/2));  % dV
                            dLdOffset_ = dilationFactor*sum(tmp, 3);
                            % Assign deformable offset
                            dLdOffset(:,:,:,layer.OffsetNumChannels*(q-1)+(c-layer.OffsetNumChannels)+1) = dLdOffset(:,:,:,layer.OffsetNumChannels*(q-1)+(c-layer.OffsetNumChannels)+1) + dLdOffset_;
                        end
                    end
                end
            end
        end
    end
    
end
7518 chars
143 lines

You can then use this layer in your network definition with the addLayer method.

main.m
layers = [
    imageInputLayer([32 32 3])
    DeformableConvolutionalLayer(16, [3 3], 1, 1, 3, 'same', 1, 'dcn')
    reluLayer
    maxPooling2dLayer(2, 'Stride', 2)
    DeformableConvolutionalLayer(32, [3 3], 1, 1, 3, 'same', 1, 'dcn')
    reluLayer
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer];
321 chars
11 lines

Note that the offset input to the predict and backward methods should be the output of another convolutional network. You can train this network end-to-end with backpropagation by treating the offsets as regular learnable weights.

gistlibby LogSnag