implement fullyconnectedlayer() from scratch in matlab

Here is how to implement a fully connected layer in Matlab using matrix multiplication:

main.m
classdef FullyConnectedLayer
    properties
        inputSize
        outputSize
        weights
        bias
        dLdW
        dLdB
    end
    
    methods
        function obj = FullyConnectedLayer(inputSize, outputSize)
            obj.inputSize = inputSize;
            obj.outputSize = outputSize;
            
            % Initialize weights with Xavier initialization
            variance = 2/(inputSize+outputSize);
            obj.weights = sqrt(variance)*randn(outputSize, inputSize);
            
            % Initialize bias to zero
            obj.bias = zeros(outputSize, 1);
        end
        
        function output = forward(obj, input)
            % Compute output as matrix product of weights and input
            output = obj.weights*input + obj.bias;
        end
        
        function [dLdInput, obj] = backward(obj, input, gradOutput, learningRate)
            % Compute gradients of loss with respect to weights and bias
            obj.dLdW = gradOutput*input';
            obj.dLdB = sum(gradOutput, 2);
            
            % Compute gradient of loss with respect to input
            dLdInput = obj.weights'*gradOutput;
            
            % Update weights and bias using gradient descent
            obj.weights = obj.weights - learningRate*obj.dLdW;
            obj.bias = obj.bias - learningRate*obj.dLdB;
        end
    end
end
1383 chars
43 lines

The FullyConnectedLayer class has a constructor that takes the layer's input size and output size as arguments. Upon initialization, the layer's weights are randomly initialized using Xavier initialization, and the bias is set to zero.

Forward propagation is implemented using matrix multiplication of the layer's weights and input, along with the bias term.

Backward propagation updates the weights and bias using gradient descent, computes the gradient of the loss with respect to input, and returns the gradient of the loss with respect to input.

gistlibby LogSnag