Skip to content

Commit 1d38d85

Browse files
authored
Added Min/Max Between (#117)
* Fixed #90 * Added MinBetween and MaxBetween engine def * Added code to generate Min/MaxBetween * Moved example out from the generated file * Generated MinBetween and MaxBetween methods for StdEng * Added some compile time assertions * Added API for Min/Max between * Added better prep for min/max between of engine
1 parent da5e1e2 commit 1d38d85

13 files changed

+3738
-22
lines changed

api_minmax.go

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
package tensor
2+
3+
import "github.com/pkg/errors"
4+
5+
func MinBetween(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) {
6+
var minbetweener MinBetweener
7+
var oe standardEngine
8+
var ok bool
9+
switch at := a.(type) {
10+
case Tensor:
11+
oe = at.standardEngine()
12+
switch bt := b.(type) {
13+
case Tensor:
14+
if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor addition
15+
if oe != nil {
16+
return oe.MinBetween(at, bt, opts...)
17+
}
18+
if oe = bt.standardEngine(); oe != nil {
19+
return oe.MinBetween(at, bt, opts...)
20+
}
21+
if minbetweener, ok = at.Engine().(MinBetweener); ok {
22+
return minbetweener.MinBetween(at, bt, opts...)
23+
}
24+
if minbetweener, ok = bt.Engine().(MinBetweener); ok {
25+
return minbetweener.MinBetween(at, bt, opts...)
26+
}
27+
return nil, errors.New("Neither engines of either operand support MinBetween")
28+
29+
} else { // at least one of the operands is a scalar
30+
var leftTensor bool
31+
if !bt.Shape().IsScalar() {
32+
leftTensor = false // a Scalar-Tensor * b Tensor
33+
tmp := at
34+
at = bt
35+
bt = tmp
36+
} else {
37+
leftTensor = true // a Tensor * b Scalar-Tensor
38+
}
39+
40+
if oe != nil {
41+
return oe.MinBetweenScalar(at, bt, leftTensor, opts...)
42+
}
43+
if oe = bt.standardEngine(); oe != nil {
44+
return oe.MinBetweenScalar(at, bt, leftTensor, opts...)
45+
}
46+
if minbetweener, ok = at.Engine().(MinBetweener); ok {
47+
return minbetweener.MinBetweenScalar(at, bt, leftTensor, opts...)
48+
}
49+
if minbetweener, ok = bt.Engine().(MinBetweener); ok {
50+
return minbetweener.MinBetweenScalar(at, bt, leftTensor, opts...)
51+
}
52+
return nil, errors.New("Neither engines of either operand support MinBetween")
53+
}
54+
55+
default:
56+
if oe != nil {
57+
return oe.MinBetweenScalar(at, bt, true, opts...)
58+
}
59+
if minbetweener, ok = at.Engine().(MinBetweener); ok {
60+
return minbetweener.MinBetweenScalar(at, bt, true, opts...)
61+
}
62+
return nil, errors.New("Operand A's engine does not support MinBetween")
63+
}
64+
default:
65+
switch bt := b.(type) {
66+
case Tensor:
67+
if oe = bt.standardEngine(); oe != nil {
68+
return oe.MinBetweenScalar(bt, at, false, opts...)
69+
}
70+
if minbetweener, ok = bt.Engine().(MinBetweener); ok {
71+
return minbetweener.MinBetweenScalar(bt, at, false, opts...)
72+
}
73+
return nil, errors.New("Operand B's engine does not support MinBetween")
74+
default:
75+
return nil, errors.Errorf("Cannot perform MinBetween of %T and %T", a, b)
76+
}
77+
}
78+
panic("Unreachable")
79+
}
80+
81+
func MaxBetween(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) {
82+
var maxbetweener MaxBetweener
83+
var oe standardEngine
84+
var ok bool
85+
switch at := a.(type) {
86+
case Tensor:
87+
oe = at.standardEngine()
88+
switch bt := b.(type) {
89+
case Tensor:
90+
if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor addition
91+
if oe != nil {
92+
return oe.MaxBetween(at, bt, opts...)
93+
}
94+
if oe = bt.standardEngine(); oe != nil {
95+
return oe.MaxBetween(at, bt, opts...)
96+
}
97+
if maxbetweener, ok = at.Engine().(MaxBetweener); ok {
98+
return maxbetweener.MaxBetween(at, bt, opts...)
99+
}
100+
if maxbetweener, ok = bt.Engine().(MaxBetweener); ok {
101+
return maxbetweener.MaxBetween(at, bt, opts...)
102+
}
103+
return nil, errors.New("Neither engines of either operand support MaxBetween")
104+
105+
} else { // at least one of the operands is a scalar
106+
var leftTensor bool
107+
if !bt.Shape().IsScalar() {
108+
leftTensor = false // a Scalar-Tensor * b Tensor
109+
tmp := at
110+
at = bt
111+
bt = tmp
112+
} else {
113+
leftTensor = true // a Tensor * b Scalar-Tensor
114+
}
115+
116+
if oe != nil {
117+
return oe.MaxBetweenScalar(at, bt, leftTensor, opts...)
118+
}
119+
if oe = bt.standardEngine(); oe != nil {
120+
return oe.MaxBetweenScalar(at, bt, leftTensor, opts...)
121+
}
122+
if maxbetweener, ok = at.Engine().(MaxBetweener); ok {
123+
return maxbetweener.MaxBetweenScalar(at, bt, leftTensor, opts...)
124+
}
125+
if maxbetweener, ok = bt.Engine().(MaxBetweener); ok {
126+
return maxbetweener.MaxBetweenScalar(at, bt, leftTensor, opts...)
127+
}
128+
return nil, errors.New("Neither engines of either operand support MaxBetween")
129+
}
130+
131+
default:
132+
if oe != nil {
133+
return oe.MaxBetweenScalar(at, bt, true, opts...)
134+
}
135+
if maxbetweener, ok = at.Engine().(MaxBetweener); ok {
136+
return maxbetweener.MaxBetweenScalar(at, bt, true, opts...)
137+
}
138+
return nil, errors.New("Operand A's engine does not support MaxBetween")
139+
}
140+
default:
141+
switch bt := b.(type) {
142+
case Tensor:
143+
if oe = bt.standardEngine(); oe != nil {
144+
return oe.MaxBetweenScalar(bt, at, false, opts...)
145+
}
146+
if maxbetweener, ok = bt.Engine().(MaxBetweener); ok {
147+
return maxbetweener.MaxBetweenScalar(bt, at, false, opts...)
148+
}
149+
return nil, errors.New("Operand B's engine does not support MaxBetween")
150+
default:
151+
return nil, errors.Errorf("Cannot perform MaxBetween of %T and %T", a, b)
152+
}
153+
}
154+
panic("Unreachable")
155+
}

0 commit comments

Comments
 (0)