Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- public struct TensorToArray<Element: AdditiveArithmetic>
- {
- var array: [Element]
- var dimensions: [Int]
- init(tensor: Tensor)
- {
- dimensions = tensor.shape.dimensions
- array = tensor.data.toArray(type: Element.self)
- }
- private func flatIndex(_ index: [Int]) -> Int
- {
- guard index.count == dimensions.count
- else
- {
- fatalError("Invalid index: got \(index.count) index(es) for \(dimensions.count) index(es).")
- }
- var result = 0
- for i in 0..<dimensions.count
- {
- guard dimensions[i] > index[i]
- else
- {
- fatalError("Invalid index: \(index[i]) is bigger than \(dimensions[i])")
- }
- result = dimensions[i] * result + index[i]
- }
- return result
- }
- subscript(_ index: Int...) -> Element
- {
- get { return array[flatIndex(index)] }
- set(newValue) { array[flatIndex(index)] = newValue }
- }
- public func getMaximumIndex()-> Int
- {
- var index = 0
- var maxValue: Float = 0
- var res = 0
- while index < array.count
- {
- let temp = array[index] as! Float
- if temp > maxValue
- {
- maxValue = temp
- res = index
- }
- index += 1
- }
- return res
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement