Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- // MARK: - Some differentiable array manipulation functions used in the algorithms.
- extension Array where Element: Differentiable {
- @differentiable(vjp: _vjpSwappedAt)
- func swappedAt(_ i: Int, _ j: Int) -> Array {
- var tmp = self
- tmp.swapAt(i, j)
- return tmp
- }
- func _vjpSwappedAt(_ i: Int, _ j: Int) -> (Array, (TangentVector) -> TangentVector) {
- return (swappedAt(i, j), { TangentVector($0.base.swappedAt(i, j)) })
- }
- @differentiable(vjp: _vjpDroppedFirst)
- func droppedFirst() -> Array {
- return Array(self.dropFirst())
- }
- func _vjpDroppedFirst() -> (Array, (TangentVector) -> TangentVector) {
- return (droppedFirst(), { TangentVector([Element.TangentVector.zero] + $0.base) })
- }
- @differentiable(vjp: _vjpAppending)
- func appending(_ element: Element) -> Array {
- var tmp = self
- tmp.append(element)
- return tmp
- }
- func _vjpAppending(_ element: Element) -> ([Element], (TangentVector) -> (TangentVector, Element.TangentVector)) {
- func pb(_ v: TangentVector) -> (TangentVector, Element.TangentVector) {
- return (TangentVector(Array<Element.TangentVector>(v.base.dropLast())), v.base[v.base.count - 1])
- }
- return (appending(element), pb)
- }
- @differentiable(vjp: _vjpMakeSingle)
- static func makeSingle(_ element: Element) -> Array {
- return [element]
- }
- static func _vjpMakeSingle(_ element: Element) -> (Array, (TangentVector) -> Element.TangentVector) {
- return ([element], { v in
- precondition(v.base.count == 1)
- return v.base[0]
- })
- }
- }
- // MARK: - Custom VJP for stdlib sort.
- @differentiable(vjp: _vjpSorted)
- func sorted(_ array: [Double]) -> [Double] {
- return array.sorted()
- }
- func _vjpSorted(_ array: [Double]) -> ([Double], (Array<Double>.DifferentiableView) -> Array<Double>.DifferentiableView) {
- let sort = array.enumerated().sorted(by: { $0.element < $1.element })
- let sorted = sort.map { $0.element }
- let permutation = sort.map { $0.offset }
- return (sorted, { v in
- var result = Array(repeating: 0.0, count: v.base.count)
- for (i, j) in permutation.enumerated() {
- result[j] = v.base[i]
- }
- return Array<Double>.DifferentiableView(result)
- })
- }
- let arrayToSort: [Double] = [7, 2, 4, 1, 8, 3, 0, 9]
- var vectorsToPullBack: [[Double]] = []
- for i in 0..<arrayToSort.count {
- var v = Array(repeating: 0.0, count: arrayToSort.count)
- v[i] = 1
- vectorsToPullBack.append(v)
- }
- let (value, pb) = valueWithPullback(at: arrayToSort, in: sorted)
- print("USING CUSTOM DERIVATIVE FOR SORT")
- print(value)
- for v in vectorsToPullBack {
- print(pb(Array.DifferentiableView(v)))
- }
- print("")
- // MARK: - Selection sort.
- func argMax(_ array: [Double]) -> Int {
- var result: Int = 0
- var max: Double = array[0]
- for (index, val) in array.enumerated() {
- if val > max {
- result = index
- max = val
- }
- }
- return result
- }
- func selectionSort(_ array: [Double]) -> [Double] {
- if array.count <= 1 {
- return array
- } else {
- let next = array.swappedAt(0, argMax(array.withoutDerivative()))
- return selectionSort(next.droppedFirst()).appending(next[0])
- }
- }
- let (value2, pb2) = valueWithPullback(at: arrayToSort, in: selectionSort)
- print("USING AUTOMATICALLY COMPUTED DERIVATIVE OF SELECTION SORT")
- print(value2)
- if value2 != value {
- print(" oh no, that one is wrong")
- }
- for v in vectorsToPullBack {
- print(pb2(Array.DifferentiableView(v)))
- if pb2(Array.DifferentiableView(v)) != pb(Array.DifferentiableView(v)) {
- print(" oh no, that one is wrong")
- }
- }
- print("")
- // MARK: - Quicksort.
- extension Array where Element : Differentiable {
- func filter(_ predicate: (Element) -> Bool, _ start: Int) -> Array {
- if start == count {
- return []
- }
- if predicate(self[start]) {
- return filter(predicate, start + 1).appending(self[start])
- } else {
- return filter(predicate, start + 1)
- }
- }
- }
- func qsort(_ array: [Double]) -> [Double] {
- if array.count <= 1 {
- return array
- }
- let pivot = array[0]
- let pivotWD = pivot.withoutDerivative()
- let l = array.filter({ $0 < pivotWD }, 1)
- let r = array.filter({ $0 >= pivotWD }, 1)
- return qsort(l) + Array.makeSingle(pivot) + qsort(r)
- }
- let (value3, pb3) = valueWithPullback(at: arrayToSort, in: qsort)
- print("USING AUTOMATICALLY COMPUTED DERIVATIVE OF QUICK SORT")
- print(value3)
- if value3 != value {
- print(" oh no, that one is wrong")
- }
- for v in vectorsToPullBack {
- print(pb3(Array.DifferentiableView(v)))
- if pb3(Array.DifferentiableView(v)) != pb(Array.DifferentiableView(v)) {
- print(" oh no, that one is wrong")
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement