Advertisement
Guest User

Untitled

a guest
Jun 15th, 2019
55
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.48 KB | None | 0 0
  1. // MARK: - Some differentiable array manipulation functions used in the algorithms.
  2.  
  3. extension Array where Element: Differentiable {
  4. @differentiable(vjp: _vjpSwappedAt)
  5. func swappedAt(_ i: Int, _ j: Int) -> Array {
  6. var tmp = self
  7. tmp.swapAt(i, j)
  8. return tmp
  9. }
  10.  
  11. func _vjpSwappedAt(_ i: Int, _ j: Int) -> (Array, (TangentVector) -> TangentVector) {
  12. return (swappedAt(i, j), { TangentVector($0.base.swappedAt(i, j)) })
  13. }
  14.  
  15. @differentiable(vjp: _vjpDroppedFirst)
  16. func droppedFirst() -> Array {
  17. return Array(self.dropFirst())
  18. }
  19.  
  20. func _vjpDroppedFirst() -> (Array, (TangentVector) -> TangentVector) {
  21. return (droppedFirst(), { TangentVector([Element.TangentVector.zero] + $0.base) })
  22. }
  23.  
  24. @differentiable(vjp: _vjpAppending)
  25. func appending(_ element: Element) -> Array {
  26. var tmp = self
  27. tmp.append(element)
  28. return tmp
  29. }
  30.  
  31. func _vjpAppending(_ element: Element) -> ([Element], (TangentVector) -> (TangentVector, Element.TangentVector)) {
  32. func pb(_ v: TangentVector) -> (TangentVector, Element.TangentVector) {
  33. return (TangentVector(Array<Element.TangentVector>(v.base.dropLast())), v.base[v.base.count - 1])
  34. }
  35. return (appending(element), pb)
  36. }
  37.  
  38. @differentiable(vjp: _vjpMakeSingle)
  39. static func makeSingle(_ element: Element) -> Array {
  40. return [element]
  41. }
  42.  
  43. static func _vjpMakeSingle(_ element: Element) -> (Array, (TangentVector) -> Element.TangentVector) {
  44. return ([element], { v in
  45. precondition(v.base.count == 1)
  46. return v.base[0]
  47. })
  48. }
  49. }
  50.  
  51. // MARK: - Custom VJP for stdlib sort.
  52.  
  53. @differentiable(vjp: _vjpSorted)
  54. func sorted(_ array: [Double]) -> [Double] {
  55. return array.sorted()
  56. }
  57.  
  58. func _vjpSorted(_ array: [Double]) -> ([Double], (Array<Double>.DifferentiableView) -> Array<Double>.DifferentiableView) {
  59. let sort = array.enumerated().sorted(by: { $0.element < $1.element })
  60. let sorted = sort.map { $0.element }
  61. let permutation = sort.map { $0.offset }
  62. return (sorted, { v in
  63. var result = Array(repeating: 0.0, count: v.base.count)
  64. for (i, j) in permutation.enumerated() {
  65. result[j] = v.base[i]
  66. }
  67. return Array<Double>.DifferentiableView(result)
  68. })
  69. }
  70.  
  71. let arrayToSort: [Double] = [7, 2, 4, 1, 8, 3, 0, 9]
  72. var vectorsToPullBack: [[Double]] = []
  73. for i in 0..<arrayToSort.count {
  74. var v = Array(repeating: 0.0, count: arrayToSort.count)
  75. v[i] = 1
  76. vectorsToPullBack.append(v)
  77. }
  78.  
  79. let (value, pb) = valueWithPullback(at: arrayToSort, in: sorted)
  80. print("USING CUSTOM DERIVATIVE FOR SORT")
  81. print(value)
  82. for v in vectorsToPullBack {
  83. print(pb(Array.DifferentiableView(v)))
  84. }
  85. print("")
  86.  
  87. // MARK: - Selection sort.
  88.  
  89. func argMax(_ array: [Double]) -> Int {
  90. var result: Int = 0
  91. var max: Double = array[0]
  92. for (index, val) in array.enumerated() {
  93. if val > max {
  94. result = index
  95. max = val
  96. }
  97. }
  98. return result
  99. }
  100.  
  101. func selectionSort(_ array: [Double]) -> [Double] {
  102. if array.count <= 1 {
  103. return array
  104. } else {
  105. let next = array.swappedAt(0, argMax(array.withoutDerivative()))
  106. return selectionSort(next.droppedFirst()).appending(next[0])
  107. }
  108. }
  109.  
  110. let (value2, pb2) = valueWithPullback(at: arrayToSort, in: selectionSort)
  111. print("USING AUTOMATICALLY COMPUTED DERIVATIVE OF SELECTION SORT")
  112. print(value2)
  113. if value2 != value {
  114. print(" oh no, that one is wrong")
  115. }
  116. for v in vectorsToPullBack {
  117. print(pb2(Array.DifferentiableView(v)))
  118. if pb2(Array.DifferentiableView(v)) != pb(Array.DifferentiableView(v)) {
  119. print(" oh no, that one is wrong")
  120. }
  121. }
  122. print("")
  123.  
  124. // MARK: - Quicksort.
  125.  
  126. extension Array where Element : Differentiable {
  127. func filter(_ predicate: (Element) -> Bool, _ start: Int) -> Array {
  128. if start == count {
  129. return []
  130. }
  131. if predicate(self[start]) {
  132. return filter(predicate, start + 1).appending(self[start])
  133. } else {
  134. return filter(predicate, start + 1)
  135. }
  136. }
  137. }
  138.  
  139. func qsort(_ array: [Double]) -> [Double] {
  140. if array.count <= 1 {
  141. return array
  142. }
  143. let pivot = array[0]
  144. let pivotWD = pivot.withoutDerivative()
  145. let l = array.filter({ $0 < pivotWD }, 1)
  146. let r = array.filter({ $0 >= pivotWD }, 1)
  147. return qsort(l) + Array.makeSingle(pivot) + qsort(r)
  148. }
  149.  
  150. let (value3, pb3) = valueWithPullback(at: arrayToSort, in: qsort)
  151. print("USING AUTOMATICALLY COMPUTED DERIVATIVE OF QUICK SORT")
  152. print(value3)
  153. if value3 != value {
  154. print(" oh no, that one is wrong")
  155. }
  156. for v in vectorsToPullBack {
  157. print(pb3(Array.DifferentiableView(v)))
  158. if pb3(Array.DifferentiableView(v)) != pb(Array.DifferentiableView(v)) {
  159. print(" oh no, that one is wrong")
  160. }
  161. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement