Skip to content

Commit 427bf35

Browse files
committed
add superinstructions
1 parent 0749152 commit 427bf35

File tree

1 file changed

+196
-30
lines changed

1 file changed

+196
-30
lines changed

src/vm.rs

Lines changed: 196 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ pub enum Opcode {
6868
/// Integer negate: dst = -src
6969
INeg { dst: Reg, src: Reg },
7070

71+
/// Integer add immediate: dst = src + imm
72+
IAddImm { dst: Reg, src: Reg, imm: i32 },
73+
7174
// ============ Floating Point Arithmetic ============
7275

7376
/// Float32 add: dst = a + b
@@ -85,6 +88,12 @@ pub enum Opcode {
8588
/// Float32 negate: dst = -src
8689
FNeg { dst: Reg, src: Reg },
8790

91+
/// Float32 fused multiply-add: dst = a * b + c
92+
FMulAdd { dst: Reg, a: Reg, b: Reg, c: Reg },
93+
94+
/// Float32 fused multiply-subtract: dst = a * b - c
95+
FMulSub { dst: Reg, a: Reg, b: Reg, c: Reg },
96+
8897
// ============ Float64 Arithmetic ============
8998

9099
/// Float64 add: dst = a + b
@@ -102,6 +111,12 @@ pub enum Opcode {
102111
/// Float64 negate: dst = -src
103112
DNeg { dst: Reg, src: Reg },
104113

114+
/// Float64 fused multiply-add: dst = a * b + c
115+
DMulAdd { dst: Reg, a: Reg, b: Reg, c: Reg },
116+
117+
/// Float64 fused multiply-subtract: dst = a * b - c
118+
DMulSub { dst: Reg, a: Reg, b: Reg, c: Reg },
119+
105120
// ============ Bitwise Operations ============
106121

107122
/// Bitwise AND: dst = a & b
@@ -244,6 +259,14 @@ pub enum Opcode {
244259
/// Jump if register is non-zero (true)
245260
JumpIfNotZero { cond: Reg, offset: Offset },
246261

262+
// ============ Superinstructions: Compare and Branch ============
263+
264+
/// Jump if a < b (signed): if !(a < b) jump
265+
ILtJump { a: Reg, b: Reg, offset: Offset },
266+
267+
/// Jump if a >= b (signed): if !(a >= b) jump
268+
IGeJump { a: Reg, b: Reg, offset: Offset },
269+
247270
/// Call function by index, args in registers starting at `args_start`
248271
/// Result (if any) goes in register 0
249272
Call { func: FuncIdx, args_start: Reg, arg_count: u8 },
@@ -326,7 +349,9 @@ impl VMFunction {
326349
match &mut self.code[idx] {
327350
Opcode::Jump { offset: off } |
328351
Opcode::JumpIfZero { offset: off, .. } |
329-
Opcode::JumpIfNotZero { offset: off, .. } => {
352+
Opcode::JumpIfNotZero { offset: off, .. } |
353+
Opcode::ILtJump { offset: off, .. } |
354+
Opcode::IGeJump { offset: off, .. } => {
330355
*off = offset;
331356
}
332357
_ => panic!("patch_jump called on non-jump instruction"),
@@ -588,6 +613,10 @@ impl VM {
588613
self.set_i64(dst, -self.get_i64(src));
589614
}
590615

616+
Opcode::IAddImm { dst, src, imm } => {
617+
self.set_i64(dst, self.get_i64(src).wrapping_add(imm as i64));
618+
}
619+
591620
// Float32 arithmetic
592621
Opcode::FAdd { dst, a, b } => {
593622
self.set_f32(dst, self.get_f32(a) + self.get_f32(b));
@@ -609,6 +638,14 @@ impl VM {
609638
self.set_f32(dst, -self.get_f32(src));
610639
}
611640

641+
Opcode::FMulAdd { dst, a, b, c } => {
642+
self.set_f32(dst, self.get_f32(a).mul_add(self.get_f32(b), self.get_f32(c)));
643+
}
644+
645+
Opcode::FMulSub { dst, a, b, c } => {
646+
self.set_f32(dst, self.get_f32(a).mul_add(self.get_f32(b), -self.get_f32(c)));
647+
}
648+
612649
// Float64 arithmetic
613650
Opcode::DAdd { dst, a, b } => {
614651
self.set_f64(dst, self.get_f64(a) + self.get_f64(b));
@@ -630,6 +667,14 @@ impl VM {
630667
self.set_f64(dst, -self.get_f64(src));
631668
}
632669

670+
Opcode::DMulAdd { dst, a, b, c } => {
671+
self.set_f64(dst, self.get_f64(a).mul_add(self.get_f64(b), self.get_f64(c)));
672+
}
673+
674+
Opcode::DMulSub { dst, a, b, c } => {
675+
self.set_f64(dst, self.get_f64(a).mul_add(self.get_f64(b), -self.get_f64(c)));
676+
}
677+
633678
// Bitwise operations
634679
Opcode::And { dst, a, b } => {
635680
self.set_u64(dst, self.get_u64(a) & self.get_u64(b));
@@ -830,6 +875,19 @@ impl VM {
830875
}
831876
}
832877

878+
// Superinstructions: compare and branch
879+
Opcode::ILtJump { a, b, offset } => {
880+
if self.get_i64(a) >= self.get_i64(b) {
881+
self.ip = (self.ip as i32 + offset) as usize;
882+
}
883+
}
884+
885+
Opcode::IGeJump { a, b, offset } => {
886+
if self.get_i64(a) < self.get_i64(b) {
887+
self.ip = (self.ip as i32 + offset) as usize;
888+
}
889+
}
890+
833891
Opcode::Call { func, args_start, arg_count } => {
834892
// Save current frame
835893
let frame = CallFrame {
@@ -1022,45 +1080,37 @@ pub fn create_biquad_program() -> VMProgram {
10221080
biquad.emit(Opcode::Load32Off { dst: 13, base: 2, offset: 12 }); // r13 = a1
10231081
biquad.emit(Opcode::Load32Off { dst: 14, base: 2, offset: 16 }); // r14 = a2
10241082

1025-
// Compute output:
1083+
// Compute output using FMulAdd/FMulSub superinstructions:
10261084
// y0 = b0*x0 + b1*x1 + b2*x2 - a1*y1 - a2*y2
10271085

10281086
// r20 = b0 * x0
10291087
biquad.emit(Opcode::FMul { dst: 20, a: 10, b: 0 });
10301088

1031-
// r21 = b1 * x1
1032-
biquad.emit(Opcode::FMul { dst: 21, a: 11, b: 3 });
1033-
1034-
// r22 = b2 * x2
1035-
biquad.emit(Opcode::FMul { dst: 22, a: 12, b: 4 });
1036-
1037-
// r23 = a1 * y1
1038-
biquad.emit(Opcode::FMul { dst: 23, a: 13, b: 5 });
1089+
// r20 = b1 * x1 + r20 (b0*x0 + b1*x1)
1090+
biquad.emit(Opcode::FMulAdd { dst: 20, a: 11, b: 3, c: 20 });
10391091

1040-
// r24 = a2 * y2
1041-
biquad.emit(Opcode::FMul { dst: 24, a: 14, b: 6 });
1092+
// r20 = b2 * x2 + r20 (b0*x0 + b1*x1 + b2*x2)
1093+
biquad.emit(Opcode::FMulAdd { dst: 20, a: 12, b: 4, c: 20 });
10421094

1043-
// r25 = r20 + r21 (b0*x0 + b1*x1)
1044-
biquad.emit(Opcode::FAdd { dst: 25, a: 20, b: 21 });
1095+
// r20 = r20 - a1 * y1 (using FMulSub: r20 = a1*y1 - r20, then negate... or use different approach)
1096+
// Actually FMulSub is: dst = a * b - c, so we need: r20 - a1*y1
1097+
// Let's compute a1*y1 first, then subtract
1098+
biquad.emit(Opcode::FMul { dst: 21, a: 13, b: 5 }); // r21 = a1 * y1
1099+
biquad.emit(Opcode::FSub { dst: 20, a: 20, b: 21 }); // r20 = r20 - a1*y1
10451100

1046-
// r25 = r25 + r22 (+ b2*x2)
1047-
biquad.emit(Opcode::FAdd { dst: 25, a: 25, b: 22 });
1048-
1049-
// r25 = r25 - r23 (- a1*y1)
1050-
biquad.emit(Opcode::FSub { dst: 25, a: 25, b: 23 });
1051-
1052-
// r25 = r25 - r24 (- a2*y2) -> this is y0
1053-
biquad.emit(Opcode::FSub { dst: 25, a: 25, b: 24 });
1101+
// r20 = r20 - a2 * y2
1102+
biquad.emit(Opcode::FMul { dst: 21, a: 14, b: 6 }); // r21 = a2 * y2
1103+
biquad.emit(Opcode::FSub { dst: 20, a: 20, b: 21 }); // r20 = r20 - a2*y2 -> this is y0
10541104

10551105
// Update state:
10561106
// x2 = x1, x1 = x0, y2 = y1, y1 = y0
10571107
biquad.emit(Opcode::Store32Off { base: 1, offset: 4, src: 3 }); // x2 = x1
10581108
biquad.emit(Opcode::Store32Off { base: 1, offset: 0, src: 0 }); // x1 = x0
10591109
biquad.emit(Opcode::Store32Off { base: 1, offset: 12, src: 5 }); // y2 = y1
1060-
biquad.emit(Opcode::Store32Off { base: 1, offset: 8, src: 25 }); // y1 = y0
1110+
biquad.emit(Opcode::Store32Off { base: 1, offset: 8, src: 20 }); // y1 = y0
10611111

10621112
// Return y0
1063-
biquad.emit(Opcode::ReturnReg { src: 25 });
1113+
biquad.emit(Opcode::ReturnReg { src: 20 });
10641114

10651115
let biquad_idx = program.add_function(biquad);
10661116

@@ -1101,9 +1151,8 @@ pub fn create_biquad_program() -> VMProgram {
11011151
// Loop start (instruction index for jump target)
11021152
let loop_start = main.code.len();
11031153

1104-
// Check i < N
1105-
main.emit(Opcode::ILt { dst: 22, a: 20, b: 21 });
1106-
let jump_end = main.emit(Opcode::JumpIfZero { cond: 22, offset: 0 }); // patched later
1154+
// Check i < N using superinstruction (jumps if NOT i < N)
1155+
let jump_end = main.emit(Opcode::ILtJump { a: 20, b: 21, offset: 0 }); // patched later
11071156

11081157
// Generate input: simple sawtooth wave
11091158
// input = (i % 100) / 100.0 - 0.5
@@ -1123,9 +1172,8 @@ pub fn create_biquad_program() -> VMProgram {
11231172
// Accumulate output
11241173
main.emit(Opcode::FAdd { dst: 30, a: 30, b: 0 }); // sum += output
11251174

1126-
// i++
1127-
main.emit(Opcode::LoadImm { dst: 23, value: 1 });
1128-
main.emit(Opcode::IAdd { dst: 20, a: 20, b: 23 });
1175+
// i++ using superinstruction
1176+
main.emit(Opcode::IAddImm { dst: 20, src: 20, imm: 1 });
11291177

11301178
// Jump back to loop start
11311179
let loop_end = main.code.len();
@@ -1373,4 +1421,122 @@ mod tests {
13731421
let result = vm.run(&program);
13741422
assert_eq!(result, 42);
13751423
}
1424+
1425+
#[test]
1426+
fn test_fmul_add() {
1427+
let mut func = VMFunction::new("test");
1428+
func.emit(Opcode::LoadF32 { dst: 0, value: 2.0 });
1429+
func.emit(Opcode::LoadF32 { dst: 1, value: 3.0 });
1430+
func.emit(Opcode::LoadF32 { dst: 2, value: 4.0 });
1431+
// dst = 2.0 * 3.0 + 4.0 = 10.0
1432+
func.emit(Opcode::FMulAdd { dst: 0, a: 0, b: 1, c: 2 });
1433+
func.emit(Opcode::Return);
1434+
1435+
let mut program = VMProgram::new();
1436+
program.entry = program.add_function(func);
1437+
1438+
let mut vm = VM::new();
1439+
let result = vm.run_f32(&program);
1440+
assert!((result - 10.0).abs() < 0.0001);
1441+
}
1442+
1443+
#[test]
1444+
fn test_fmul_sub() {
1445+
let mut func = VMFunction::new("test");
1446+
func.emit(Opcode::LoadF32 { dst: 0, value: 2.0 });
1447+
func.emit(Opcode::LoadF32 { dst: 1, value: 3.0 });
1448+
func.emit(Opcode::LoadF32 { dst: 2, value: 4.0 });
1449+
// dst = 2.0 * 3.0 - 4.0 = 2.0
1450+
func.emit(Opcode::FMulSub { dst: 0, a: 0, b: 1, c: 2 });
1451+
func.emit(Opcode::Return);
1452+
1453+
let mut program = VMProgram::new();
1454+
program.entry = program.add_function(func);
1455+
1456+
let mut vm = VM::new();
1457+
let result = vm.run_f32(&program);
1458+
assert!((result - 2.0).abs() < 0.0001);
1459+
}
1460+
1461+
#[test]
1462+
fn test_iadd_imm() {
1463+
let mut func = VMFunction::new("test");
1464+
func.emit(Opcode::LoadImm { dst: 0, value: 40 });
1465+
func.emit(Opcode::IAddImm { dst: 0, src: 0, imm: 2 });
1466+
func.emit(Opcode::Return);
1467+
1468+
let mut program = VMProgram::new();
1469+
program.entry = program.add_function(func);
1470+
1471+
let mut vm = VM::new();
1472+
let result = vm.run(&program);
1473+
assert_eq!(result, 42);
1474+
}
1475+
1476+
#[test]
1477+
fn test_ilt_jump() {
1478+
// Test the fused compare-and-branch: ILtJump jumps if NOT (a < b)
1479+
let mut func = VMFunction::new("test");
1480+
func.emit(Opcode::LoadImm { dst: 0, value: 5 });
1481+
func.emit(Opcode::LoadImm { dst: 1, value: 10 });
1482+
// 5 < 10 is true, so we should NOT jump
1483+
func.emit(Opcode::ILtJump { a: 0, b: 1, offset: 2 });
1484+
func.emit(Opcode::LoadImm { dst: 0, value: 100 }); // should execute
1485+
func.emit(Opcode::Jump { offset: 1 });
1486+
func.emit(Opcode::LoadImm { dst: 0, value: 200 }); // should skip
1487+
func.emit(Opcode::Return);
1488+
1489+
let mut program = VMProgram::new();
1490+
program.entry = program.add_function(func);
1491+
1492+
let mut vm = VM::new();
1493+
let result = vm.run(&program);
1494+
assert_eq!(result, 100);
1495+
}
1496+
1497+
#[test]
1498+
fn test_ilt_jump_taken() {
1499+
// Test when jump IS taken: a >= b
1500+
let mut func = VMFunction::new("test");
1501+
func.emit(Opcode::LoadImm { dst: 0, value: 10 });
1502+
func.emit(Opcode::LoadImm { dst: 1, value: 5 });
1503+
// 10 < 5 is false, so we SHOULD jump
1504+
func.emit(Opcode::ILtJump { a: 0, b: 1, offset: 2 });
1505+
func.emit(Opcode::LoadImm { dst: 0, value: 100 }); // should skip
1506+
func.emit(Opcode::Jump { offset: 1 });
1507+
func.emit(Opcode::LoadImm { dst: 0, value: 200 }); // should execute
1508+
func.emit(Opcode::Return);
1509+
1510+
let mut program = VMProgram::new();
1511+
program.entry = program.add_function(func);
1512+
1513+
let mut vm = VM::new();
1514+
let result = vm.run(&program);
1515+
assert_eq!(result, 200);
1516+
}
1517+
1518+
#[test]
1519+
fn test_loop_with_superinstructions() {
1520+
// Sum 1..10 using IAddImm and ILtJump
1521+
let mut func = VMFunction::new("test");
1522+
func.emit(Opcode::LoadImm { dst: 0, value: 0 }); // sum = 0
1523+
func.emit(Opcode::LoadImm { dst: 1, value: 1 }); // i = 1
1524+
func.emit(Opcode::LoadImm { dst: 2, value: 11 }); // limit = 11
1525+
1526+
// Loop start (index 3)
1527+
// ILtJump: if !(i < 11) jump to end
1528+
func.emit(Opcode::ILtJump { a: 1, b: 2, offset: 3 });
1529+
func.emit(Opcode::IAdd { dst: 0, a: 0, b: 1 }); // sum += i
1530+
func.emit(Opcode::IAddImm { dst: 1, src: 1, imm: 1 }); // i++
1531+
func.emit(Opcode::Jump { offset: -4 }); // back to loop start
1532+
1533+
func.emit(Opcode::Return);
1534+
1535+
let mut program = VMProgram::new();
1536+
program.entry = program.add_function(func);
1537+
1538+
let mut vm = VM::new();
1539+
let result = vm.run(&program);
1540+
assert_eq!(result, 55); // 1+2+3+...+10 = 55
1541+
}
13761542
}

0 commit comments

Comments
 (0)