package flow_test import ( "bytes" "encoding/json" "log" "reflect" "testing" "dev.hexasoftware.com/x/flow" vecasm "github.com/gohxs/vec-benchmark/asm" ) func TestSerialize(t *testing.T) { f := flow.New() var1 := f.Variable([]float32{4, 4, 4}) c1 := f.Const([]float32{1, 2, 3}) c2 := f.Const([]float32{2, 2, 2}) op1 := f.Op("vecmul", // op:0 - expected: [12,16,20,24] f.Variable([]float32{4, 4, 4, 4}), f.Op("vecadd", // op:1 - expected: [3,4,5,6] f.Const([]float32{1, 2, 3, 4}), f.Const([]float32{2, 2, 2, 2}), ), ) mul1 := f.Op("vecmul", c1, op1) // op:2 - expected 12, 32, 60, 0 mul2 := f.Op("vecmul", mul1, var1) // op:3 - expected 48, 128, 240, 0 mul3 := f.Op("vecmul", c2, mul2) // op:4 - expected 96, 256, 480, 0 mul4 := f.Op("vecmul", mul3, f.In(0)) // op:5 - expected 96, 512, 1440,0 t.Log(f.Analyse([]float32{1, 2, 3, 4})) res := mul4.Process([]float32{1, 2, 3, 4}) t.Log("Res:", res) t.Log("Flow:\n", f) ret := bytes.NewBuffer(nil) e := json.NewEncoder(ret) e.Encode(f) t.Log("Flow:", ret) } func TestFlow(t *testing.T) { f := flow.New() a := f.Const([]float32{1, 2, 3}) b := f.Const([]float32{2, 2, 2}) f.Op("mul", a, b) log.Println(f) } func TestBuild(t *testing.T) { var err error err = flow.Register("vecmul", vecmul) if err != nil { t.Fatal(err) } err = flow.Register("vecadd", vecadd) if err != nil { t.Fatal(err) } f := flow.New() a := f.Const([]float32{1, 2, 3}) b := f.Const([]float32{2, 2, 2}) m := f.Op("vecmul", []float32{1, 2, 3}, b) lastOp := f.Op("vecadd", a, m) { res, err := f.Run(m, lastOp) errcheck(t, err) test := []interface{}{ []float32{2, 4, 6}, []float32{3, 6, 9}, } if !reflect.DeepEqual(test, res) { t.Fatal("Arrays does not match:", test, res) } t.Log("Result:", res) } } func TestImplementation(t *testing.T) { f := flow.New() in := f.Variable([]float32{2, 2, 2, 2}) v1 := f.Const([]float32{1, 2, 3, 4}) op1 := f.Op("vecmul", v1, v1) // 1 2 9 16 op2 := f.Op("vecmul", in, op1) // 2 4 18 32 res, err := f.Run(op1, op2) errcheck(t, err) t.Log("Res:", res) test := []interface{}{ []float32{1, 4, 9, 16}, []float32{2, 8, 18, 32}, } if !reflect.DeepEqual(test, res) { t.Fatal("Arrays does not match:", test, res) } } func TestVariable(t *testing.T) { f := flow.New() v := f.Variable(1) res, _ := f.Run(v) t.Log("res", res) if !reflect.DeepEqual(res, []interface{}{1}) { t.Fatal("Result mismatch") } v.Set(2) res, _ = f.Run(v) t.Log("res", res) if !reflect.DeepEqual(res, []interface{}{2}) { t.Fatal("Result mismatch") } } func init() { flow.Register("vecmul", vecmul) flow.Register("vecadd", vecadd) } // Some funcs func vecmul(a, b []float32) []float32 { sz := min(len(a), len(b)) out := make([]float32, sz) vecasm.VecMulf32x8(a, b, out) return out } func vecadd(a, b []float32) []float32 { sz := min(len(a), len(b)) out := make([]float32, sz) for i := 0; i < sz; i++ { out[i] = a[i] + b[i] } return out } func min(p ...int) int { min := p[0] for _, v := range p[1:] { if min < v { min = v } } return min } func errcheck(t *testing.T, err error) { if err != nil { t.Fatal(err) } }