Guest User

Untitled

a guest
Jul 16th, 2018
81
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.21 KB | None | 0 0
  1. // RUN: %target-swift-frontend -Xllvm -tf-dump-intermediates -O -emit-sil -verify %s | %FileCheck %s
  2.  
  3. import TensorFlow
  4.  
  5. // This test is intended to verify that all of the operations end up in the
  6. // graph: that there are no host/accelerator copies generated. This tests a
  7. // combination of the partitioning pass being able to recognize various forms,
  8. // but also checks that certain ops implementations are promotable as well.
  9.  
  10. // Please keep it so no errors or warnings are generated by functions in this
  11. // file.
  12.  
  13. /// b/76222306
  14. struct Classifier {
  15. // Parameters
  16. var w1 = Tensor<Float>(randomUniform: [784, 30])
  17. var w2 = Tensor<Float>(randomUniform: [30, 10])
  18. var b1 = Tensor<Float>(zeros: [1, 30])
  19. var b2 = Tensor<Float>(zeros: [1, 10])
  20.  
  21. mutating func train(images: Tensor<Float>, labels: Tensor<Float>,
  22. learningRate: Float, epochCount: Int) -> Float {
  23. var loss: Float
  24. var epochCount = epochCount
  25. repeat {
  26. // Forward pass
  27. let z1 = images • w1 + b1
  28.  
  29. // Code below from original test is commented for minimal reproducer.
  30. /*
  31. let h1 = sigmoid(z1)
  32. let z2 = h1 • w2 + b2
  33. let pred = sigmoid(z2)
  34.  
  35. // Backward pass
  36. let dz2 = pred - labels
  37. let dw2 = h1.transposed(withPermutations: 1, 0) • dz2
  38. let db2 = dz2.sum(squeezingAxes: 0)
  39. let dz1 = matmul(dz2, w2.transposed(withPermutations: 1, 0)) * h1 * (1 - h1)
  40. let dw1 = images.transposed(withPermutations: 1, 0) • dz1
  41. let db1 = dz1.sum(squeezingAxes: 0)
  42.  
  43. // Gradient descent
  44. w1 -= dw1 * learningRate
  45. b1 -= db1 * learningRate
  46. w2 -= dw2 * learningRate
  47. b2 -= db2 * learningRate
  48.  
  49. loss = dz2.squared().mean(squeezingAxes: 1, 0).scalarized()
  50. */
  51. loss = 1
  52.  
  53. epochCount -= 1
  54. } while epochCount > 0
  55.  
  56. return loss
  57. }
  58. }
  59.  
  60. public func mnist() {
  61. // Training data
  62. // expected-warning @+1 {{'Tensor<Float>' implicitly copied to the accelerator, use .toAccelerator}}
  63. let images = Tensor<Float>(randomNormal: [10, 784])
  64. let labels = Tensor<Float>(randomNormal: [10, 10])
  65. var classifier = Classifier()
  66. let loss = classifier.train(images: images, labels: labels,
  67. learningRate: 0.3, epochCount: 100)
  68. print(loss)
  69. }
Add Comment
Please, Sign In to add comment