the img2col function for 2d arrays in swift

The following is an implementation of img2col function that works with 2D arrays:

main.swift
func img2col(_ input: [[Double]], kernelSize: Int, stride: Int, padding: Int) -> [[Double]] {
    let paddedInput = pad(input, padding: padding)
    let (height, width) = (paddedInput.count, paddedInput[0].count)
    var output = [[Double]]()
    for row in stride(from: 0, to: height-kernelSize+1, by: stride) {
        for col in stride(from: 0, to: width-kernelSize+1, by: stride) {
            var colArr = [Double]()
            for kRow in 0..<kernelSize {
                for kCol in 0..<kernelSize {
                    let pixelRow = row + kRow
                    let pixelCol = col + kCol
                    let pixelVal = paddedInput[pixelRow][pixelCol]
                    colArr.append(pixelVal)
                }
            }
            output.append(colArr)
        }
    }
    return output
}

func pad(_ input: [[Double]], padding: Int) -> [[Double]] {
    let padWidth = input[0].count + 2*padding
    let padHeight = input.count + 2*padding
    
    var padded = Array(repeating: Array(repeating: 0.0, count: padWidth), count: padHeight)
    
    for row in 0..<input.count {
        for col in 0..<input[0].count {
            padded[row+padding][col+padding] = input[row][col]
        }
    }
    
    return padded
}
1243 chars
36 lines

The img2col function takes in an input 2D array, kernel size, stride, and padding as input arguments. It pads the input 2D array, and then for each kernel in the input array, it extracts the values from the padded array into an output column vector. Finally, it returns an array of column vectors.

The pad function pads the input array with zeros based on the provided padding argument.

Example usage:

main.swift
let input: [[Double]] = [[1, 2, 3],
                         [4, 5, 6],
                         [7, 8, 9]]

let output = img2col(input, kernelSize: 2, stride: 1, padding: 0)

print(output)
// [[1.0, 2.0, 4.0, 5.0], [2.0, 3.0, 5.0, 6.0], [4.0, 5.0, 7.0, 8.0], [5.0, 6.0, 8.0, 9.0]]
282 chars
9 lines

In the example above, we provide an input array with 3x3 values. The kernelSize is 2x2, and we set stride to 1 and padding to 0. As a result, the output is an array of column vectors with 4x1 dimensions.

gistlibby LogSnag