123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164 |
- package flow_test
- import (
- "bytes"
- "encoding/json"
- "log"
- "reflect"
- "testing"
- "dev.hexasoftware.com/x/flow"
- vecasm "github.com/gohxs/vec-benchmark/asm"
- )
- func TestSerialize(t *testing.T) {
- f := flow.New()
- var1 := f.Variable([]float32{4, 4, 4})
- c1 := f.Const([]float32{1, 2, 3})
- c2 := f.Const([]float32{2, 2, 2})
- op1 := f.Op("vecmul", // op:0 - expected: [12,16,20,24]
- f.Variable([]float32{4, 4, 4, 4}),
- f.Op("vecadd", // op:1 - expected: [3,4,5,6]
- f.Const([]float32{1, 2, 3, 4}),
- f.Const([]float32{2, 2, 2, 2}),
- ),
- )
- mul1 := f.Op("vecmul", c1, op1) // op:2 - expected 12, 32, 60, 0
- mul2 := f.Op("vecmul", mul1, var1) // op:3 - expected 48, 128, 240, 0
- mul3 := f.Op("vecmul", c2, mul2) // op:4 - expected 96, 256, 480, 0
- mul4 := f.Op("vecmul", mul3, f.In(0)) // op:5 - expected 96, 512, 1440,0
- t.Log(f.Analyse([]float32{1, 2, 3, 4}))
- res := mul4.Process([]float32{1, 2, 3, 4})
- t.Log("Res:", res)
- t.Log("Flow:\n", f)
- ret := bytes.NewBuffer(nil)
- e := json.NewEncoder(ret)
- e.Encode(f)
- t.Log("Flow:", ret)
- }
- func TestFlow(t *testing.T) {
- f := flow.New()
- a := f.Const([]float32{1, 2, 3})
- b := f.Const([]float32{2, 2, 2})
- f.Op("mul", a, b)
- log.Println(f)
- }
- func TestBuild(t *testing.T) {
- var err error
- err = flow.Register("vecmul", vecmul)
- if err != nil {
- t.Fatal(err)
- }
- err = flow.Register("vecadd", vecadd)
- if err != nil {
- t.Fatal(err)
- }
- f := flow.New()
- a := f.Const([]float32{1, 2, 3})
- b := f.Const([]float32{2, 2, 2})
- m := f.Op("vecmul", []float32{1, 2, 3}, b)
- lastOp := f.Op("vecadd", a, m)
- {
- res, err := f.Run(m, lastOp)
- errcheck(t, err)
- test := []interface{}{
- []float32{2, 4, 6},
- []float32{3, 6, 9},
- }
- if !reflect.DeepEqual(test, res) {
- t.Fatal("Arrays does not match:", test, res)
- }
- t.Log("Result:", res)
- }
- }
- func TestImplementation(t *testing.T) {
- f := flow.New()
- in := f.Variable([]float32{2, 2, 2, 2})
- v1 := f.Const([]float32{1, 2, 3, 4})
- op1 := f.Op("vecmul", v1, v1) // 1 2 9 16
- op2 := f.Op("vecmul", in, op1) // 2 4 18 32
- res, err := f.Run(op1, op2)
- errcheck(t, err)
- t.Log("Res:", res)
- test := []interface{}{
- []float32{1, 4, 9, 16},
- []float32{2, 8, 18, 32},
- }
- if !reflect.DeepEqual(test, res) {
- t.Fatal("Arrays does not match:", test, res)
- }
- }
- func TestVariable(t *testing.T) {
- f := flow.New()
- v := f.Variable(1)
- res, _ := f.Run(v)
- t.Log("res", res)
- if !reflect.DeepEqual(res, []interface{}{1}) {
- t.Fatal("Result mismatch")
- }
- v.Set(2)
- res, _ = f.Run(v)
- t.Log("res", res)
- if !reflect.DeepEqual(res, []interface{}{2}) {
- t.Fatal("Result mismatch")
- }
- }
- func init() {
- flow.Register("vecmul", vecmul)
- flow.Register("vecadd", vecadd)
- }
- // Some funcs
- func vecmul(a, b []float32) []float32 {
- sz := min(len(a), len(b))
- out := make([]float32, sz)
- vecasm.VecMulf32x8(a, b, out)
- return out
- }
- func vecadd(a, b []float32) []float32 {
- sz := min(len(a), len(b))
- out := make([]float32, sz)
- for i := 0; i < sz; i++ {
- out[i] = a[i] + b[i]
- }
- return out
- }
- func min(p ...int) int {
- min := p[0]
- for _, v := range p[1:] {
- if min < v {
- min = v
- }
- }
- return min
- }
- func errcheck(t *testing.T, err error) {
- if err != nil {
- t.Fatal(err)
- }
- }
|