strassen algorithm in swift

Here's an implementation of the Strassen Algorithm in Swift:

main.swift
func strassen(_ a: [[Int]], _ b: [[Int]]) -> [[Int]] {
    let n = a.count
    if n == 1 {
        return [[a[0][0] * b[0][0]]]
    }
    
    let m = n / 2
    
    var a11 = [[Int]](), a12 = [[Int]](), a21 = [[Int]](), a22 = [[Int]]()
    var b11 = [[Int]](), b12 = [[Int]](), b21 = [[Int]](), b22 = [[Int]]()
    for i in 0..<m {
        var row = [Int]()
        for j in 0..<m {
            row.append(a[i][j])
        }
        a11.append(row)
        
        row.removeAll()
        for j in m..<n {
            row.append(a[i][j])
        }
        a12.append(row)
        
        row.removeAll()
        for j in 0..<m {
            row.append(a[i + m][j])
        }
        a21.append(row)
        
        row.removeAll()
        for j in m..<n {
            row.append(a[i + m][j])
        }
        a22.append(row)
    }
    for i in 0..<m {
        var row = [Int]()
        for j in 0..<m {
            row.append(b[i][j])
        }
        b11.append(row)
        
        row.removeAll()
        for j in m..<n {
            row.append(b[i][j])
        }
        b12.append(row)
        
        row.removeAll()
        for j in 0..<m {
            row.append(b[i + m][j])
        }
        b21.append(row)
        
        row.removeAll()
        for j in m..<n {
            row.append(b[i + m][j])
        }
        b22.append(row)
    }
    
    let p1 = strassen(a11, sub(b12, b22))
    let p2 = strassen(add(a11, a12), b22)
    let p3 = strassen(add(a21, a22), b11)
    let p4 = strassen(a22, sub(b21, b11))
    let p5 = strassen(add(a11, a22), add(b11, b22))
    let p6 = strassen(sub(a12, a22), add(b21, b22))
    let p7 = strassen(sub(a11, a21), add(b11, b12))
    
    var c11 = add(sub(add(p5, p4), p2), p6)
    var c12 = add(p1, p2)
    var c21 = add(p3, p4)
    var c22 = sub(sub(add(p5, p1), p3), p7)
    
    var result = [[Int]](repeating: [Int](repeating: 0, count: n), count: n)
    for i in 0..<m {
        for j in 0..<m {
            result[i][j] = c11[i][j]
            result[i][j + m] = c12[i][j]
            result[i + m][j] = c21[i][j]
            result[i + m][j + m] = c22[i][j]
        }
    }
    return result
}

func add(_ a: [[Int]], _ b: [[Int]]) -> [[Int]] {
    let n = a.count
    var c = [[Int]](repeating: [Int](repeating: 0, count: n), count: n)
    for i in 0..<n {
        for j in 0..<n {
            c[i][j] = a[i][j] + b[i][j]
        }
    }
    return c
}

func sub(_ a: [[Int]], _ b: [[Int]]) -> [[Int]] {
    let n = a.count
    var c = [[Int]](repeating: [Int](repeating: 0, count: n), count: n)
    for i in 0..<n {
        for j in 0..<n {
            c[i][j] = a[i][j] - b[i][j]
        }
    }
    return c
}
2682 chars
108 lines

This implementation assumes that the matrices passed in are square and have a size that is a power of 2. This implementation uses recursion to break the matrices down into smaller matrices until they are 1x1 matrices, which it can then multiply together. The Strassen algorithm is used to multiply the larger matrices together.

gistlibby LogSnag