main.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. package main
  2. import (
  3. "demos/ops/ml"
  4. "flag"
  5. "flow"
  6. "fmt"
  7. "log"
  8. "os"
  9. "runtime/pprof"
  10. "gonum.org/v1/gonum/mat"
  11. )
  12. var cpuprofile = flag.String("cpuprofile", "", "write cpu profile to file")
  13. var memprofile = flag.String("memprofile", "", "write mem profile to file")
  14. func main() {
  15. flag.Parse()
  16. if *cpuprofile != "" {
  17. f, err := os.Create(*cpuprofile)
  18. if err != nil {
  19. log.Fatal(err)
  20. }
  21. pprof.StartCPUProfile(f)
  22. defer pprof.StopCPUProfile()
  23. }
  24. if *memprofile != "" {
  25. f, err := os.Create(*memprofile)
  26. if err != nil {
  27. log.Fatal(err)
  28. }
  29. defer pprof.WriteHeapProfile(f)
  30. }
  31. // Registry for machine learning
  32. r := ml.New()
  33. f := flow.New()
  34. f.UseRegistry(r)
  35. samples := []float64{
  36. 0, 0,
  37. 0, 1,
  38. 1, 0,
  39. 1, 1,
  40. }
  41. labels := []float64{
  42. 0,
  43. 1,
  44. 1,
  45. 0,
  46. }
  47. learningRate := float64(0.3)
  48. nInputs := 2
  49. nHidden := 5
  50. nOutput := 1
  51. nSamples := 4
  52. matSamples := mat.NewDense(nSamples, 2, samples)
  53. matLabels := mat.NewDense(nSamples, 1, labels)
  54. // Define input
  55. // Make a matrix out of the input and output
  56. x := f.In(0)
  57. y := f.In(1)
  58. wHidden := f.Var("wHidden", f.Op("matNewRand", nInputs, nHidden))
  59. wOut := f.Var("wOut", f.Op("matNewRand", nHidden, nOutput))
  60. // Forward process
  61. hiddenLayerInput := f.Op("matMul", x, wHidden)
  62. hiddenLayerActivations := f.Op("matSigmoid", hiddenLayerInput)
  63. outputLayerInput := f.Op("matMul", hiddenLayerActivations, wOut)
  64. output := f.Op("matSigmoid", outputLayerInput)
  65. // Back propagation
  66. // output weights
  67. networkError := f.Op("matSub", y, output)
  68. slopeOutputLayer := f.Op("matSigmoidPrime", output)
  69. dOutput := f.Op("matMulElem", networkError, slopeOutputLayer)
  70. wOutAdj := f.Op("matMul", f.Op("matTranspose", hiddenLayerActivations), dOutput)
  71. wOutAdj = f.Op("matScale", learningRate, wOutAdj)
  72. // hidden weights
  73. errorAtHiddenLayer := f.Op("matMul", dOutput, f.Op("matTranspose", wOut))
  74. slopeHiddenLayer := f.Op("matSigmoidPrime", hiddenLayerActivations)
  75. dHiddenLayer := f.Op("matMulElem", errorAtHiddenLayer, slopeHiddenLayer)
  76. wHiddenAdj := f.Op("matMul", f.Op("matTranspose", x), dHiddenLayer)
  77. wHiddenAdj = f.Op("matScale", learningRate, wHiddenAdj)
  78. // Adjust the parameters
  79. setwOut := f.SetVar("wOut", f.Op("matAdd", wOut, wOutAdj))
  80. setwHidden := f.SetVar("wHidden", f.Op("matAdd", wHidden, wHiddenAdj))
  81. // Training
  82. for i := 0; i < 5000; i++ {
  83. sess := f.NewSession()
  84. sess.Inputs(matSamples, matLabels)
  85. _, err := sess.Run(setwOut, setwHidden)
  86. if err != nil {
  87. log.Fatal(err)
  88. }
  89. }
  90. // Same as above because its simple
  91. testSamples := matSamples
  92. testLabels := matLabels
  93. res, err := output.Process(testSamples)
  94. if err != nil {
  95. log.Fatal(err)
  96. }
  97. predictions := res.(mat.Matrix)
  98. log.Println("Predictions", predictions)
  99. var rights int
  100. numPreds, _ := predictions.Dims()
  101. log.Println("Number of predictions:", numPreds)
  102. for i := 0; i < numPreds; i++ {
  103. if predictions.At(i, 0) > 0.5 && testLabels.At(i, 0) == 1.0 ||
  104. predictions.At(i, 0) < 0.5 && testLabels.At(i, 0) == 0 {
  105. rights++
  106. }
  107. }
  108. accuracy := float64(rights) / float64(numPreds)
  109. fmt.Printf("\nAccuracy = %0.2f\n\n", accuracy)
  110. }