Advertisement
Guest User

TensorToArray

a guest
Apr 11th, 2022
207
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Swift 1.61 KB | None | 0 0
  1. public struct TensorToArray<Element: AdditiveArithmetic>
  2.     {
  3.         var array: [Element]
  4.         var dimensions: [Int]
  5.  
  6.         init(tensor: Tensor)
  7.         {
  8.             dimensions = tensor.shape.dimensions
  9.             array = tensor.data.toArray(type: Element.self)
  10.         }
  11.  
  12.         private func flatIndex(_ index: [Int]) -> Int
  13.         {
  14.             guard index.count == dimensions.count
  15.             else
  16.             {
  17.                 fatalError("Invalid index: got \(index.count) index(es) for \(dimensions.count) index(es).")
  18.             }
  19.  
  20.             var result = 0
  21.             for i in 0..<dimensions.count
  22.             {
  23.                 guard dimensions[i] > index[i]
  24.                 else
  25.                 {
  26.                     fatalError("Invalid index: \(index[i]) is bigger than \(dimensions[i])")
  27.                 }
  28.                 result = dimensions[i] * result + index[i]
  29.             }
  30.             return result
  31.         }
  32.  
  33.         subscript(_ index: Int...) -> Element
  34.         {
  35.             get { return array[flatIndex(index)] }
  36.             set(newValue) { array[flatIndex(index)] = newValue }
  37.         }
  38.        
  39.         public func getMaximumIndex()-> Int
  40.         {
  41.             var index = 0
  42.             var maxValue: Float = 0
  43.             var res = 0
  44.             while index < array.count
  45.             {
  46.                 let temp = array[index] as! Float
  47.                 if temp > maxValue
  48.                 {
  49.                     maxValue = temp
  50.                     res = index
  51.                 }
  52.                 index += 1
  53.             }
  54.             return res
  55.         }
  56.     }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement