main.go 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. package main
  2. import (
  3. "demos/ops/ml"
  4. "flag"
  5. "fmt"
  6. "log"
  7. "os"
  8. "runtime/pprof"
  9. "github.com/hexasoftware/flow"
  10. "gonum.org/v1/gonum/mat"
  11. )
  12. var cpuprofile = flag.String("cpuprofile", "", "write cpu profile to file")
  13. var memprofile = flag.String("memprofile", "", "write mem profile to file")
  14. func main() {
  15. flag.Parse()
  16. if *cpuprofile != "" {
  17. f, err := os.Create(*cpuprofile)
  18. if err != nil {
  19. log.Fatal(err)
  20. }
  21. pprof.StartCPUProfile(f)
  22. defer pprof.StopCPUProfile()
  23. }
  24. if *memprofile != "" {
  25. f, err := os.Create(*memprofile)
  26. if err != nil {
  27. log.Fatal(err)
  28. }
  29. defer pprof.WriteHeapProfile(f)
  30. }
  31. // Registry for machine learning
  32. r := ml.New()
  33. f := flow.New()
  34. f.UseRegistry(r)
  35. samples := []float64{
  36. 0, 0,
  37. 0, 1,
  38. 1, 0,
  39. 1, 1,
  40. }
  41. labels := []float64{
  42. 0,
  43. 1,
  44. 1,
  45. 0,
  46. }
  47. learningRate := float64(0.3)
  48. nInputs := 2
  49. nHidden := 5
  50. nOutput := 1
  51. nSamples := 4
  52. matSamples := mat.NewDense(nSamples, 2, samples)
  53. matLabels := mat.NewDense(nSamples, 1, labels)
  54. // Define input
  55. // Make a matrix out of the input and output
  56. x := f.In(0)
  57. y := f.In(1)
  58. // [ 1, 2, 3, 4, 5]
  59. // [ 1, 2, 3, 4, 5]
  60. wHidden := f.Var("wHidden", f.Op("matNewRand", nInputs, nHidden))
  61. // [ 1 ]
  62. // [ 2 ]
  63. // [ 3 ]
  64. // [ 4 ]
  65. // [ 5 ]
  66. wOut := f.Var("wOut", f.Op("matNewRand", nHidden, nOutput))
  67. // Forward process
  68. hiddenLayerInput := f.Op("matMul", x, wHidden)
  69. hiddenLayerActivations := f.Op("matSigmoid", hiddenLayerInput)
  70. outputLayerInput := f.Op("matMul", hiddenLayerActivations, wOut)
  71. // Activations
  72. output := f.Op("matSigmoid", outputLayerInput)
  73. // Back propagation
  74. // output weights
  75. networkError := f.Op("matSub", y, output)
  76. slopeOutputLayer := f.Op("matSigmoidPrime", output)
  77. dOutput := f.Op("matMulElem", networkError, slopeOutputLayer)
  78. wOutAdj := f.Op("matScale",
  79. learningRate,
  80. f.Op("matMul", f.Op("matTranspose", hiddenLayerActivations), dOutput),
  81. )
  82. // hidden weights
  83. errorAtHiddenLayer := f.Op("matMul", dOutput, f.Op("matTranspose", wOut))
  84. slopeHiddenLayer := f.Op("matSigmoidPrime", hiddenLayerActivations)
  85. dHiddenLayer := f.Op("matMulElem", errorAtHiddenLayer, slopeHiddenLayer)
  86. wHiddenAdj := f.Op("matScale",
  87. learningRate,
  88. f.Op("matMul", f.Op("matTranspose", x), dHiddenLayer),
  89. )
  90. // Adjust the parameters
  91. setwOut := f.SetVar("wOut", f.Op("matAdd", wOut, wOutAdj))
  92. setwHidden := f.SetVar("wHidden", f.Op("matAdd", wHidden, wHiddenAdj))
  93. // Training
  94. for i := 0; i < 5000; i++ {
  95. sess := f.NewSession()
  96. sess.Inputs(matSamples, matLabels)
  97. _, err := sess.Run(setwOut, setwHidden)
  98. if err != nil {
  99. log.Fatal(err)
  100. }
  101. }
  102. // Same as above because its simple
  103. testSamples := matSamples
  104. testLabels := matLabels
  105. res, err := output.Process(testSamples)
  106. if err != nil {
  107. log.Fatal(err)
  108. }
  109. predictions := res.(mat.Matrix)
  110. log.Println("Predictions", predictions)
  111. var rights int
  112. numPreds, _ := predictions.Dims()
  113. log.Println("Number of predictions:", numPreds)
  114. for i := 0; i < numPreds; i++ {
  115. if predictions.At(i, 0) > 0.5 && testLabels.At(i, 0) == 1.0 ||
  116. predictions.At(i, 0) < 0.5 && testLabels.At(i, 0) == 0 {
  117. rights++
  118. }
  119. }
  120. accuracy := float64(rights) / float64(numPreds)
  121. fmt.Printf("\nAccuracy = %0.2f\n\n", accuracy)
  122. }