Skip to content

Commit ba14edb

Browse files
committed
Overload math functions
1 parent f0c4375 commit ba14edb

File tree

10 files changed

+135
-166
lines changed

10 files changed

+135
-166
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ struct Point {
9797
}
9898

9999
length(p: Point) -> f32 {
100-
sqrtf(p.x * p.x + p.y * p.y)
100+
sqrt(p.x * p.x + p.y * p.y)
101101
}
102102

103103
main {
@@ -300,8 +300,8 @@ struct Biquad {
300300

301301
lpf(fc: f32, fs: f32, q: f32) -> Biquad {
302302
var w0 = 2.0 * 3.14159265 * fc / fs
303-
var alpha = sinf(w0) / (2.0 * q)
304-
var cs = cosf(w0)
303+
var alpha = sin(w0) / (2.0 * q)
304+
var cs = cos(w0)
305305
var a0 = 1.0 + alpha
306306
var inv = 1.0 / a0
307307

benchmark/biquad.lyte

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ struct Biquad {
1111
// Design a 2nd-order lowpass (RBJ cookbook)
1212
lpf(fc: f32, fs: f32, q: f32) -> Biquad {
1313
var w0 = 2.0 * 3.14159265 * fc / fs
14-
var alpha = sinf(w0) / (2.0 * q)
15-
var cs = cosf(w0)
14+
var alpha = sin(w0) / (2.0 * q)
15+
var cs = cos(w0)
1616

1717
var a0 = 1.0 + alpha
1818
var inv = 1.0 / a0
@@ -41,7 +41,7 @@ main {
4141

4242
// Process a 440 Hz sine wave through the filter
4343
for i in 0 .. n {
44-
var x = sinf(phase * two_pi)
44+
var x = sin(phase * two_pi)
4545
phase = phase + freq
4646
if phase > 1.0 { phase = phase - 1.0 }
4747

src/compiler.rs

Lines changed: 54 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -61,100 +61,62 @@ fn builtin_decls() -> Vec<Decl> {
6161
}),
6262
];
6363

64-
// Math builtins: unary f32 and f64 variants.
65-
let unary_math = [
66-
("sinf", "sind"),
67-
("cosf", "cosd"),
68-
("tanf", "tand"),
69-
("lnf", "lnd"),
70-
("expf", "expd"),
71-
("sqrtf", "sqrtd"),
72-
("absf", "absd"),
73-
("floorf", "floord"),
74-
("ceilf", "ceild"),
75-
];
76-
for (f32_name, f64_name) in unary_math {
77-
decls.push(Decl::Func(FuncDecl {
78-
name: Name::new(f32_name.into()),
79-
typevars: vec![],
80-
size_vars: vec![],
81-
params: vec![Param {
82-
name: Name::new("x".into()),
83-
ty: Some(mk_type(Type::Float32)),
84-
}],
85-
body: None,
86-
ret: mk_type(Type::Float32),
87-
constraints: vec![],
88-
loc: test_loc(),
89-
arena: ExprArena::new(),
90-
types: vec![],
91-
closure_vars: vec![],
92-
}));
93-
decls.push(Decl::Func(FuncDecl {
94-
name: Name::new(f64_name.into()),
95-
typevars: vec![],
96-
size_vars: vec![],
97-
params: vec![Param {
98-
name: Name::new("x".into()),
99-
ty: Some(mk_type(Type::Float64)),
100-
}],
101-
body: None,
102-
ret: mk_type(Type::Float64),
103-
constraints: vec![],
104-
loc: test_loc(),
105-
arena: ExprArena::new(),
106-
types: vec![],
107-
closure_vars: vec![],
108-
}));
64+
// Math builtins: unary, overloaded for f32 and f64.
65+
let unary_math = ["sin", "cos", "tan", "ln", "exp", "sqrt", "abs", "floor", "ceil"];
66+
for name in unary_math {
67+
for (ty, ret_ty) in [
68+
(mk_type(Type::Float32), mk_type(Type::Float32)),
69+
(mk_type(Type::Float64), mk_type(Type::Float64)),
70+
] {
71+
decls.push(Decl::Func(FuncDecl {
72+
name: Name::new(name.into()),
73+
typevars: vec![],
74+
size_vars: vec![],
75+
params: vec![Param {
76+
name: Name::new("x".into()),
77+
ty: Some(ty),
78+
}],
79+
body: None,
80+
ret: ret_ty,
81+
constraints: vec![],
82+
loc: test_loc(),
83+
arena: ExprArena::new(),
84+
types: vec![],
85+
closure_vars: vec![],
86+
}));
87+
}
10988
}
11089

111-
// Math builtins: binary f32 and f64 variants.
112-
let binary_math = [("powf", "powd"), ("atan2f", "atan2d")];
113-
for (f32_name, f64_name) in binary_math {
114-
decls.push(Decl::Func(FuncDecl {
115-
name: Name::new(f32_name.into()),
116-
typevars: vec![],
117-
size_vars: vec![],
118-
params: vec![
119-
Param {
120-
name: Name::new("x".into()),
121-
ty: Some(mk_type(Type::Float32)),
122-
},
123-
Param {
124-
name: Name::new("y".into()),
125-
ty: Some(mk_type(Type::Float32)),
126-
},
127-
],
128-
body: None,
129-
ret: mk_type(Type::Float32),
130-
constraints: vec![],
131-
loc: test_loc(),
132-
arena: ExprArena::new(),
133-
types: vec![],
134-
closure_vars: vec![],
135-
}));
136-
decls.push(Decl::Func(FuncDecl {
137-
name: Name::new(f64_name.into()),
138-
typevars: vec![],
139-
size_vars: vec![],
140-
params: vec![
141-
Param {
142-
name: Name::new("x".into()),
143-
ty: Some(mk_type(Type::Float64)),
144-
},
145-
Param {
146-
name: Name::new("y".into()),
147-
ty: Some(mk_type(Type::Float64)),
148-
},
149-
],
150-
body: None,
151-
ret: mk_type(Type::Float64),
152-
constraints: vec![],
153-
loc: test_loc(),
154-
arena: ExprArena::new(),
155-
types: vec![],
156-
closure_vars: vec![],
157-
}));
90+
// Math builtins: binary, overloaded for f32 and f64.
91+
let binary_math = ["pow", "atan2"];
92+
for name in binary_math {
93+
for (ty, ret_ty) in [
94+
(mk_type(Type::Float32), mk_type(Type::Float32)),
95+
(mk_type(Type::Float64), mk_type(Type::Float64)),
96+
] {
97+
decls.push(Decl::Func(FuncDecl {
98+
name: Name::new(name.into()),
99+
typevars: vec![],
100+
size_vars: vec![],
101+
params: vec![
102+
Param {
103+
name: Name::new("x".into()),
104+
ty: Some(ty),
105+
},
106+
Param {
107+
name: Name::new("y".into()),
108+
ty: Some(ty),
109+
},
110+
],
111+
body: None,
112+
ret: ret_ty,
113+
constraints: vec![],
114+
loc: test_loc(),
115+
arena: ExprArena::new(),
116+
types: vec![],
117+
closure_vars: vec![],
118+
}));
119+
}
158120
}
159121

160122
decls

src/jit.rs

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,11 @@ impl JIT {
210210
panic!()
211211
};
212212

213+
// Skip builtins — they have no body and are called via raw function pointers.
214+
if decl.body.is_none() {
215+
continue;
216+
}
217+
213218
self.compile_function(decls, decl)?;
214219
}
215220

@@ -2070,9 +2075,11 @@ extern "C" fn lyte_atan2d(x: f64, y: f64) -> f64 {
20702075
}
20712076

20722077
const BUILTIN_NAMES: &[&str] = &[
2073-
"assert", "print", "putc", "sinf", "sind", "cosf", "cosd", "tanf", "tand", "lnf", "lnd",
2074-
"expf", "expd", "sqrtf", "sqrtd", "absf", "absd", "floorf", "floord", "ceilf", "ceild", "powf",
2075-
"powd", "atan2f", "atan2d",
2078+
"assert", "print", "putc",
2079+
"sin$f32", "sin$f64", "cos$f32", "cos$f64", "tan$f32", "tan$f64",
2080+
"ln$f32", "ln$f64", "exp$f32", "exp$f64", "sqrt$f32", "sqrt$f64",
2081+
"abs$f32", "abs$f64", "floor$f32", "floor$f64", "ceil$f32", "ceil$f64",
2082+
"pow$f32$f32", "pow$f64$f64", "atan2$f32$f32", "atan2$f64$f64",
20762083
];
20772084

20782085
fn is_builtin_name(name: &Name) -> bool {
@@ -2081,28 +2088,28 @@ fn is_builtin_name(name: &Name) -> bool {
20812088

20822089
fn math_builtin_ptr(name: &Name) -> Option<i64> {
20832090
let pairs: &[(&str, i64)] = &[
2084-
("sinf", lyte_sinf as *const () as i64),
2085-
("sind", lyte_sind as *const () as i64),
2086-
("cosf", lyte_cosf as *const () as i64),
2087-
("cosd", lyte_cosd as *const () as i64),
2088-
("tanf", lyte_tanf as *const () as i64),
2089-
("tand", lyte_tand as *const () as i64),
2090-
("lnf", lyte_lnf as *const () as i64),
2091-
("lnd", lyte_lnd as *const () as i64),
2092-
("expf", lyte_expf as *const () as i64),
2093-
("expd", lyte_expd as *const () as i64),
2094-
("sqrtf", lyte_sqrtf as *const () as i64),
2095-
("sqrtd", lyte_sqrtd as *const () as i64),
2096-
("absf", lyte_absf as *const () as i64),
2097-
("absd", lyte_absd as *const () as i64),
2098-
("floorf", lyte_floorf as *const () as i64),
2099-
("floord", lyte_floord as *const () as i64),
2100-
("ceilf", lyte_ceilf as *const () as i64),
2101-
("ceild", lyte_ceild as *const () as i64),
2102-
("powf", lyte_powf as *const () as i64),
2103-
("powd", lyte_powd as *const () as i64),
2104-
("atan2f", lyte_atan2f as *const () as i64),
2105-
("atan2d", lyte_atan2d as *const () as i64),
2091+
("sin$f32", lyte_sinf as *const () as i64),
2092+
("sin$f64", lyte_sind as *const () as i64),
2093+
("cos$f32", lyte_cosf as *const () as i64),
2094+
("cos$f64", lyte_cosd as *const () as i64),
2095+
("tan$f32", lyte_tanf as *const () as i64),
2096+
("tan$f64", lyte_tand as *const () as i64),
2097+
("ln$f32", lyte_lnf as *const () as i64),
2098+
("ln$f64", lyte_lnd as *const () as i64),
2099+
("exp$f32", lyte_expf as *const () as i64),
2100+
("exp$f64", lyte_expd as *const () as i64),
2101+
("sqrt$f32", lyte_sqrtf as *const () as i64),
2102+
("sqrt$f64", lyte_sqrtd as *const () as i64),
2103+
("abs$f32", lyte_absf as *const () as i64),
2104+
("abs$f64", lyte_absd as *const () as i64),
2105+
("floor$f32", lyte_floorf as *const () as i64),
2106+
("floor$f64", lyte_floord as *const () as i64),
2107+
("ceil$f32", lyte_ceilf as *const () as i64),
2108+
("ceil$f64", lyte_ceild as *const () as i64),
2109+
("pow$f32$f32", lyte_powf as *const () as i64),
2110+
("pow$f64$f64", lyte_powd as *const () as i64),
2111+
("atan2$f32$f32", lyte_atan2f as *const () as i64),
2112+
("atan2$f64$f64", lyte_atan2d as *const () as i64),
21062113
];
21072114
for &(n, ptr) in pairs {
21082115
if *name == Name::str(n) {

src/vm_codegen.rs

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1585,26 +1585,26 @@ impl<'a> FunctionTranslator<'a> {
15851585

15861586
// Unary math builtins (f32 and f64).
15871587
let unary_math_f32: &[(&str, fn(Reg, Reg) -> Opcode)] = &[
1588-
("sinf", |dst, src| Opcode::SinF32 { dst, src }),
1589-
("cosf", |dst, src| Opcode::CosF32 { dst, src }),
1590-
("tanf", |dst, src| Opcode::TanF32 { dst, src }),
1591-
("lnf", |dst, src| Opcode::LnF32 { dst, src }),
1592-
("expf", |dst, src| Opcode::ExpF32 { dst, src }),
1593-
("sqrtf", |dst, src| Opcode::SqrtF32 { dst, src }),
1594-
("absf", |dst, src| Opcode::AbsF32 { dst, src }),
1595-
("floorf", |dst, src| Opcode::FloorF32 { dst, src }),
1596-
("ceilf", |dst, src| Opcode::CeilF32 { dst, src }),
1588+
("sin$f32", |dst, src| Opcode::SinF32 { dst, src }),
1589+
("cos$f32", |dst, src| Opcode::CosF32 { dst, src }),
1590+
("tan$f32", |dst, src| Opcode::TanF32 { dst, src }),
1591+
("ln$f32", |dst, src| Opcode::LnF32 { dst, src }),
1592+
("exp$f32", |dst, src| Opcode::ExpF32 { dst, src }),
1593+
("sqrt$f32", |dst, src| Opcode::SqrtF32 { dst, src }),
1594+
("abs$f32", |dst, src| Opcode::AbsF32 { dst, src }),
1595+
("floor$f32", |dst, src| Opcode::FloorF32 { dst, src }),
1596+
("ceil$f32", |dst, src| Opcode::CeilF32 { dst, src }),
15971597
];
15981598
let unary_math_f64: &[(&str, fn(Reg, Reg) -> Opcode)] = &[
1599-
("sind", |dst, src| Opcode::SinF64 { dst, src }),
1600-
("cosd", |dst, src| Opcode::CosF64 { dst, src }),
1601-
("tand", |dst, src| Opcode::TanF64 { dst, src }),
1602-
("lnd", |dst, src| Opcode::LnF64 { dst, src }),
1603-
("expd", |dst, src| Opcode::ExpF64 { dst, src }),
1604-
("sqrtd", |dst, src| Opcode::SqrtF64 { dst, src }),
1605-
("absd", |dst, src| Opcode::AbsF64 { dst, src }),
1606-
("floord", |dst, src| Opcode::FloorF64 { dst, src }),
1607-
("ceild", |dst, src| Opcode::CeilF64 { dst, src }),
1599+
("sin$f64", |dst, src| Opcode::SinF64 { dst, src }),
1600+
("cos$f64", |dst, src| Opcode::CosF64 { dst, src }),
1601+
("tan$f64", |dst, src| Opcode::TanF64 { dst, src }),
1602+
("ln$f64", |dst, src| Opcode::LnF64 { dst, src }),
1603+
("exp$f64", |dst, src| Opcode::ExpF64 { dst, src }),
1604+
("sqrt$f64", |dst, src| Opcode::SqrtF64 { dst, src }),
1605+
("abs$f64", |dst, src| Opcode::AbsF64 { dst, src }),
1606+
("floor$f64", |dst, src| Opcode::FloorF64 { dst, src }),
1607+
("ceil$f64", |dst, src| Opcode::CeilF64 { dst, src }),
16081608
];
16091609
for (n, mk_op) in unary_math_f32.iter().chain(unary_math_f64.iter()) {
16101610
if **name == *n {
@@ -1617,10 +1617,10 @@ impl<'a> FunctionTranslator<'a> {
16171617

16181618
// Binary math builtins (f32 and f64).
16191619
let binary_math: &[(&str, fn(Reg, Reg, Reg) -> Opcode)] = &[
1620-
("powf", |dst, a, b| Opcode::PowF32 { dst, a, b }),
1621-
("powd", |dst, a, b| Opcode::PowF64 { dst, a, b }),
1622-
("atan2f", |dst, a, b| Opcode::Atan2F32 { dst, a, b }),
1623-
("atan2d", |dst, a, b| Opcode::Atan2F64 { dst, a, b }),
1620+
("pow$f32$f32", |dst, a, b| Opcode::PowF32 { dst, a, b }),
1621+
("pow$f64$f64", |dst, a, b| Opcode::PowF64 { dst, a, b }),
1622+
("atan2$f32$f32", |dst, a, b| Opcode::Atan2F32 { dst, a, b }),
1623+
("atan2$f64$f64", |dst, a, b| Opcode::Atan2F64 { dst, a, b }),
16241624
];
16251625
for (n, mk_op) in binary_math {
16261626
if **name == *n {

tests/cases/biquad_bytecode.lyte

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,8 @@ struct Biquad {
144144

145145
lpf(fc: f32, fs: f32, q: f32) -> Biquad {
146146
var w0 = 2.0 * 3.14159265 * fc / fs
147-
var alpha = sinf(w0) / (2.0 * q)
148-
var cs = cosf(w0)
147+
var alpha = sin(w0) / (2.0 * q)
148+
var cs = cos(w0)
149149

150150
var a0 = 1.0 + alpha
151151
var inv = 1.0 / a0
@@ -173,7 +173,7 @@ main {
173173
var two_pi = 2.0 * 3.14159265
174174

175175
for i in 0 .. n {
176-
var x = sinf(phase * two_pi)
176+
var x = sin(phase * two_pi)
177177
phase = phase + freq
178178
if phase > 1.0 { phase = phase - 1.0 }
179179

tests/cases/examples/biquad.lyte

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ struct Biquad {
1616

1717
lpf(fc: f32, fs: f32, q: f32) -> Biquad {
1818
var w0 = 2.0 * 3.14159265 * fc / fs
19-
var alpha = sinf(w0) / (2.0 * q)
20-
var cs = cosf(w0)
19+
var alpha = sin(w0) / (2.0 * q)
20+
var cs = cos(w0)
2121
var a0 = 1.0 + alpha
2222
var inv = 1.0 / a0
2323

tests/cases/examples/structs.lyte

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ struct Point {
1313
}
1414

1515
length(p: Point) -> f32 {
16-
sqrtf(p.x * p.x + p.y * p.y)
16+
sqrt(p.x * p.x + p.y * p.y)
1717
}
1818

1919
main {

tests/cases/iir_filter.lyte

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ struct Biquad {
2020
// Design a 2nd-order lowpass (RBJ cookbook).
2121
lpf(fc: f32, fs: f32, q: f32) -> Biquad {
2222
var w0 = 2.0 * 3.14159265 * fc / fs
23-
var alpha = sinf(w0) / (2.0 * q)
24-
var cs = cosf(w0)
23+
var alpha = sin(w0) / (2.0 * q)
24+
var cs = cos(w0)
2525

2626
var a0 = 1.0 + alpha
2727
var inv = 1.0 / a0

0 commit comments

Comments
 (0)