The deformable convolutional layer is typically used in object detection and instance segmentation tasks. It is an extension of the standard convolutional layer where the kernel positions are adjusted according to a learnable offset for each spatial location.
To create a deformable convolutional layer in MATLAB, you can use the Deep Learning Toolbox with the custom layer API. Here's an example implementation of the Deformable Convolutional Networks (DCN) layer based on the paper Deformable Convolutional Networks.
classdef DeformableConvolutionalLayer < nnet.layer.Layer
properties
NumFilters
KernelSize
Stride
DilationFactor
OffsetNumChannels
PaddingMode
PaddingSize
Weights
Bias
end
methods
function layer = DeformableConvolutionalLayer(numFilters, kernelSize, stride, dilationFactor, offsetNumChannels, paddingMode, paddingSize, name)
layer.Name = name;
layer.Description = "Deformable Convolutional Layer";
layer.NumFilters = numFilters;
layer.KernelSize = kernelSize;
layer.Stride = stride;
layer.DilationFactor = dilationFactor;
layer.OffsetNumChannels = offsetNumChannels;
layer.PaddingMode = paddingMode;
layer.PaddingSize = paddingSize;
end
function Z = predict(layer, X, offset)
[O, P, Q, C] = size(X);
H = layer.KernelSize(1);
W = layer.KernelSize(2);
layer.Weights(:,:,:,1:layer.OffsetNumChannels,:) = 0;
for q = 1:layer.NumFilters
for c = layer.OffsetNumChannels+1:C
% Compute deformable offset
offset_ = squeeze(offset(:,:,:,layer.OffsetNumChannels*(q-1)+(c-layer.OffsetNumChannels)+1,:));
offset_(isnan(offset_)) = 0;
[R, S, ~] = size(offset_);
[U, V] = meshgrid(1:Q, 1:P);
U = repmat(U, [1, 1, H*W]);
V = repmat(V, [1, 1, H*W]);
U_ = reshape(U-dilationFactor*(offset_(:,:,1)+1), [O*P, H*W]);
V_ = reshape(V-dilationFactor*(offset_(:,:,2)+1), [O*P, H*W]);
U_ = min(max(round(U_), 1), Q);
V_ = min(max(round(V_), 1), P);
X_ = permute(X, [4, 3, 2, 1]);
X_ = reshape(X_, [C, H*W, O*P]);
index = sub2ind([P, Q], V_, U_);
for h = 1:H
for w = 1:W
ind_start = (h-1)*W + w;
ind_end = (h-1)*W + w + layer.OffsetNumChannels*(H*W);
layer.Weights(h, w, :, ind_start:ind_end, q) = X_(:, ind_start*ones(1, layer.OffsetNumChannels) + index-1);
end
end
end
end
X = permute(X, [4, 3, 2, 1]);
Z = vl_nnconv(X, layer.Weights, layer.Bias, 'pad', layer.PaddingSize, 'stride', layer.Stride);
Z = permute(Z, [4, 3, 2, 1]);
end
function [dLdX, dLdWeights, dLdBias, dLdOffset] = backward(layer, X, offset, dLdZ, ~)
[O, P, Q, C] = size(X);
H = layer.KernelSize(1);
W = layer.KernelSize(2);
dLdX = zeros(O, P, Q, C, 'like', X);
dLdWeights = zeros(H, W, C, (H*W+layer.OffsetNumChannels*H*W*(layer.NumFilters-1)), layer.NumFilters, 'like', X);
dLdBias = zeros(1, 1, layer.NumFilters, 'like', X);
dLdZ = permute(dLdZ, [4, 3, 2, 1]);
[dLdX_, dLdWeights_, dLdBias_] = vl_nnconv(X, layer.Weights, layer.Bias, dLdZ, 'pad', layer.PaddingSize, 'stride', layer.Stride);
for q = 1:layer.NumFilters
for c = layer.OffsetNumChannels+1:C
% Compute deformable offset
offset_ = squeeze(offset(:,:,:,layer.OffsetNumChannels*(q-1)+(c-layer.OffsetNumChannels)+1,:));
offset_(isnan(offset_)) = 0;
[R, S, ~] = size(offset_);
[U, V] = meshgrid(1:Q, 1:P);
U = repmat(U, [1, 1, H*W]);
V = repmat(V, [1, 1, H*W]);
U_ = reshape(U-dilationFactor*(offset_(:,:,1)+1), [O*P, H*W]);
V_ = reshape(V-dilationFactor*(offset_(:,:,2)+1), [O*P, H*W]);
U_ = min(max(round(U_), 1), Q);
V_ = min(max(round(V_), 1), P);
index = sub2ind([P, Q], V_, U_);
dLdX__ = permute(dLdX_(:,:,:,q), [3, 2, 1]);
dLdX__ = reshape(dLdX__, [1, 1, O*P, H*W]);
for h = 1:H
for w = 1:W
ind_start = (h-1)*W + w;
ind_end = (h-1)*W + w + layer.OffsetNumChannels*(H*W);
dLdWeights(h, w, :, ind_start:ind_end, q) = sum(bsxfun(@times, dLdX__(1, 1, index'+(ind_start-1)*ones(size(index)), :), bsxfun(@minus, permute(layer.Weights(h, w, :, ind_start:ind_end, q), [4, 3, 1, 2])), 4);
end
end
end
end
dLdX_ = reshape(dLdX_, [O, P, Q, C, layer.NumFilters]);
for q = 1:layer.NumFilters
tmp = dLdX_(:,:,:,:,q);
tmp = permute(tmp, [4, 3, 2, 1]);
dLdX(:,:,:,layer.OffsetNumChannels*(q-1)+1:layer.OffsetNumChannels*q) = sum(tmp, 5);
end
dLdBias = sum(sum(sum(dLdZ, 1), 2), 4);
dLdOffset = zeros(O, P, Q, layer.OffsetNumChannels*layer.NumFilters, 'like', X);
for q = 1:layer.NumFilters
for c = layer.OffsetNumChannels+1:C
% Compute deformable offset
offset_ = squeeze(offset(:,:,:,layer.OffsetNumChannels*(q-1)+(c-layer.OffsetNumChannels)+1,:));
offset_(isnan(offset_)) = 0;
[R, S, ~] = size(offset_);
[U, V] = meshgrid(1:Q, 1:P);
U = repmat(U, [1, 1, H*W]);
V = repmat(V, [1, 1, H*W]);
U_ = reshape(U-dilationFactor*(offset_(:,:,1)+1), [O*P, H*W]);
V_ = reshape(V-dilationFactor*(offset_(:,:,2)+1), [O*P, H*W]);
U_ = min(max(round(U_), 1), Q);
V_ = min(max(round(V_), 1), P);
index = sub2ind([P, Q], V_, U_);
dLdX_ = permute(dLdX(:,:,:,layer.OffsetNumChannels*(q-1)+1:layer.OffsetNumChannels*q), [4, 3, 2, 1]);
dLdX_ = reshape(dLdX_, [layer.OffsetNumChannels, 1, H*W, O*P]);
for h = 1:H
for w = 1:W
ind_start = (h-1)*W + w;
ind_end = (h-1)*W + w + layer.OffsetNumChannels*(H*W);
dLdOffset_ = bsxfun(@times, permute(layer.Weights(h,w,:,ind_start:ind_end,q), [3, 1, 2]), dLdX_);
dLdOffset_ = sum(reshape(dLdOffset_, [layer.OffsetNumChannels, H*W, O*P]), 2);
dLdOffset_ = permute(dLdOffset_, [2, 1, 3]);
dLdOffset_ = reshape(dLdOffset_, [P, Q, layer.OffsetNumChannels]);
tmp = zeros(P, Q, 2, 'like', X);
tmp(:,:,1) = bsxfun(@times, dLdOffset_(:,:,1), (h-1-W/2)); % dU
tmp(:,:,2) = bsxfun(@times, dLdOffset_(:,:,2), (w-1-H/2)); % dV
dLdOffset_ = dilationFactor*sum(tmp, 3);
% Assign deformable offset
dLdOffset(:,:,:,layer.OffsetNumChannels*(q-1)+(c-layer.OffsetNumChannels)+1) = dLdOffset(:,:,:,layer.OffsetNumChannels*(q-1)+(c-layer.OffsetNumChannels)+1) + dLdOffset_;
end
end
end
end
end
end
end