Advertisement
Guest User

Untitled

a guest
Jun 26th, 2017
64
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.10 KB | None | 0 0
  1. package org.bubblecloud.logi.analysis
  2.  
  3. import org.nd4j.linalg.api.ndarray.INDArray
  4. import org.nd4j.linalg.dataset.api.iterator.DataSetIterator
  5. import java.io.File
  6. import javax.imageio.ImageIO
  7. import java.awt.image.BufferedImage
  8. import java.awt.image.DataBufferByte
  9.  
  10. /**
  11. * Visualize data as images stored to given directory.
  12. */
  13. fun visualize(imageDirectory: File, dataSetIterator: DataSetIterator, maxFeatureValue: Float, maxLabelValue: Float,featureImageWidth: Int, labelImageWidth: Int) : Unit {
  14. if (!imageDirectory.exists()) {
  15. imageDirectory.mkdir()
  16. }
  17. if (!imageDirectory.isDirectory) {
  18. throw IllegalArgumentException("Given path is not imageDirectory.")
  19. }
  20.  
  21. var imageIndex = 0
  22. while (dataSetIterator.hasNext()) {
  23. val dataSet = dataSetIterator.next()
  24.  
  25. val features = dataSet.features
  26. val labels = dataSet.labels
  27.  
  28. val rowCount = features.rows()
  29.  
  30. for (r in 0..rowCount - 1) {
  31. features.getRow(r)
  32. val imageIndexLabel = imageIndex.toString().padStart(10, '0')
  33. visualize(File("${imageDirectory.path}/${imageIndexLabel}F.png"),features.getRow(r), maxFeatureValue, featureImageWidth)
  34. visualize(File("${imageDirectory.path}/${imageIndexLabel}L.png"),labels.getRow(r), maxLabelValue, labelImageWidth)
  35. imageIndex++
  36. }
  37.  
  38. imageIndex++
  39. }
  40.  
  41. }
  42.  
  43. /**
  44. * Visualize data as image stored to given file.
  45. */
  46. fun visualize(imageFile: File, intArray: INDArray, maxValue: Float, maxWidth: Int) : Unit {
  47. val height = intArray.columns() / maxWidth
  48. val bufferedImage = BufferedImage(maxWidth, height, BufferedImage.TYPE_BYTE_GRAY)
  49. val a = (bufferedImage.raster.dataBuffer as DataBufferByte).data
  50. val data = ByteArray(intArray.columns())
  51.  
  52. for (i in 0..intArray.columns() - 1) {
  53. var value = intArray.getFloat(i) / maxValue
  54. if (value < 0f) {
  55. value = 0f
  56. }
  57. if (value > 1.0f) {
  58. value = 1.0f
  59. }
  60. data[i] = (255.0 * value).toByte()
  61. }
  62.  
  63. System.arraycopy(data, 0, a, 0, data.size)
  64. ImageIO.write(bufferedImage, "png", imageFile)
  65. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement