flow_test.go 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. package flow_test
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "log"
  6. "reflect"
  7. "testing"
  8. "dev.hexasoftware.com/x/flow"
  9. vecasm "github.com/gohxs/vec-benchmark/asm"
  10. )
  11. func TestSerialize(t *testing.T) {
  12. f := flow.New()
  13. var1 := f.Variable([]float32{4, 4, 4})
  14. c1 := f.Const([]float32{1, 2, 3})
  15. c2 := f.Const([]float32{2, 2, 2})
  16. op1 := f.Op("vecmul", // op:0 - expected: [12,16,20,24]
  17. f.Variable([]float32{4, 4, 4, 4}),
  18. f.Op("vecadd", // op:1 - expected: [3,4,5,6]
  19. f.Const([]float32{1, 2, 3, 4}),
  20. f.Const([]float32{2, 2, 2, 2}),
  21. ),
  22. )
  23. mul1 := f.Op("vecmul", c1, op1) // op:2 - expected 12, 32, 60, 0
  24. mul2 := f.Op("vecmul", mul1, var1) // op:3 - expected 48, 128, 240, 0
  25. mul3 := f.Op("vecmul", c2, mul2) // op:4 - expected 96, 256, 480, 0
  26. mul4 := f.Op("vecmul", mul3, f.In(0)) // op:5 - expected 96, 512, 1440,0
  27. t.Log(f.Analyse([]float32{1, 2, 3, 4}))
  28. res := mul4.Process([]float32{1, 2, 3, 4})
  29. t.Log("Res:", res)
  30. t.Log("Flow:\n", f)
  31. ret := bytes.NewBuffer(nil)
  32. e := json.NewEncoder(ret)
  33. e.Encode(f)
  34. t.Log("Flow:", ret)
  35. }
  36. func TestFlow(t *testing.T) {
  37. f := flow.New()
  38. a := f.Const([]float32{1, 2, 3})
  39. b := f.Const([]float32{2, 2, 2})
  40. f.Op("mul", a, b)
  41. log.Println(f)
  42. }
  43. func TestBuild(t *testing.T) {
  44. var err error
  45. err = flow.Register("vecmul", vecmul)
  46. if err != nil {
  47. t.Fatal(err)
  48. }
  49. err = flow.Register("vecadd", vecadd)
  50. if err != nil {
  51. t.Fatal(err)
  52. }
  53. f := flow.New()
  54. a := f.Const([]float32{1, 2, 3})
  55. b := f.Const([]float32{2, 2, 2})
  56. m := f.Op("vecmul", []float32{1, 2, 3}, b)
  57. lastOp := f.Op("vecadd", a, m)
  58. {
  59. res, err := f.Run(m, lastOp)
  60. errcheck(t, err)
  61. test := []interface{}{
  62. []float32{2, 4, 6},
  63. []float32{3, 6, 9},
  64. }
  65. if !reflect.DeepEqual(test, res) {
  66. t.Fatal("Arrays does not match:", test, res)
  67. }
  68. t.Log("Result:", res)
  69. }
  70. }
  71. func TestImplementation(t *testing.T) {
  72. f := flow.New()
  73. in := f.Variable([]float32{2, 2, 2, 2})
  74. v1 := f.Const([]float32{1, 2, 3, 4})
  75. op1 := f.Op("vecmul", v1, v1) // 1 2 9 16
  76. op2 := f.Op("vecmul", in, op1) // 2 4 18 32
  77. res, err := f.Run(op1, op2)
  78. errcheck(t, err)
  79. t.Log("Res:", res)
  80. test := []interface{}{
  81. []float32{1, 4, 9, 16},
  82. []float32{2, 8, 18, 32},
  83. }
  84. if !reflect.DeepEqual(test, res) {
  85. t.Fatal("Arrays does not match:", test, res)
  86. }
  87. }
  88. func TestVariable(t *testing.T) {
  89. f := flow.New()
  90. v := f.Variable(1)
  91. res, _ := f.Run(v)
  92. t.Log("res", res)
  93. if !reflect.DeepEqual(res, []interface{}{1}) {
  94. t.Fatal("Result mismatch")
  95. }
  96. v.Set(2)
  97. res, _ = f.Run(v)
  98. t.Log("res", res)
  99. if !reflect.DeepEqual(res, []interface{}{2}) {
  100. t.Fatal("Result mismatch")
  101. }
  102. }
  103. func init() {
  104. flow.Register("vecmul", vecmul)
  105. flow.Register("vecadd", vecadd)
  106. }
  107. // Some funcs
  108. func vecmul(a, b []float32) []float32 {
  109. sz := min(len(a), len(b))
  110. out := make([]float32, sz)
  111. vecasm.VecMulf32x8(a, b, out)
  112. return out
  113. }
  114. func vecadd(a, b []float32) []float32 {
  115. sz := min(len(a), len(b))
  116. out := make([]float32, sz)
  117. for i := 0; i < sz; i++ {
  118. out[i] = a[i] + b[i]
  119. }
  120. return out
  121. }
  122. func min(p ...int) int {
  123. min := p[0]
  124. for _, v := range p[1:] {
  125. if min < v {
  126. min = v
  127. }
  128. }
  129. return min
  130. }
  131. func errcheck(t *testing.T, err error) {
  132. if err != nil {
  133. t.Fatal(err)
  134. }
  135. }