@@ -12,12 +12,20 @@ struct IndexConstraint {
1212 pub max : Option < i64 > ,
1313}
1414
15- #[ derive( Copy , Clone , Debug , Default ) ]
15+ #[ derive( Copy , Clone , Debug , Default , PartialEq , Eq ) ]
1616struct IndexInterval {
1717 pub min : i64 ,
1818 pub max : i64 ,
1919}
2020
21+ impl std:: ops:: Add < IndexInterval > for IndexInterval {
22+ type Output = IndexInterval ;
23+
24+ fn add ( self , rhs : IndexInterval ) -> IndexInterval {
25+ IndexInterval { min : self . min + rhs. min , max : self . max + rhs. max }
26+ }
27+ }
28+
2129fn enclose ( a : IndexInterval , b : IndexInterval ) -> IndexInterval {
2230 IndexInterval {
2331 min : a. min . min ( b. min ) ,
@@ -55,6 +63,11 @@ impl ArrayChecker {
5563 self . constraints . push ( IndexConstraint { name, min, max } )
5664 }
5765
66+ fn replace ( & mut self , name : Name , min : Option < i64 > , max : Option < i64 > ) {
67+ self . constraints . retain ( |c| c. name != name) ;
68+ self . add ( name, min, max) ;
69+ }
70+
5871 fn find ( & self , name : Name ) -> Option < IndexConstraint > {
5972 self . constraints . iter ( ) . find ( |c| c. name == name) . cloned ( )
6073 }
@@ -195,12 +208,30 @@ impl ArrayChecker {
195208 IndexInterval :: default ( )
196209 }
197210 Expr :: While ( cond, body) => {
198- let initial_constraint_count = self . constraints . len ( ) ;
211+ let saved_constraints = self . constraints . clone ( ) ;
199212 self . match_expr ( * cond, decl, decls) ;
200213
201214 self . check_expr ( * body, decl, decls) ;
202- while self . constraints . len ( ) > initial_constraint_count {
203- self . constraints . pop ( ) ;
215+ self . constraints = saved_constraints;
216+
217+ IndexInterval :: default ( )
218+ }
219+ Expr :: Binop ( op, lhs, rhs) => {
220+
221+ if * op == Binop :: Plus {
222+ let lhs_range = self . check_expr ( * lhs, decl, decls) ;
223+ let rhs_range = self . check_expr ( * rhs, decl, decls) ;
224+ return lhs_range + rhs_range
225+ }
226+
227+ if * op == Binop :: Assign {
228+ let rhs_range = self . check_expr ( * rhs, decl, decls) ;
229+
230+ if rhs_range != IndexInterval :: default ( ) {
231+ if let Expr :: Id ( name) = & decl. arena [ * lhs] {
232+ self . replace ( * name, Some ( rhs_range. min ) , Some ( rhs_range. max ) ) ;
233+ }
234+ }
204235 }
205236
206237 IndexInterval :: default ( )
@@ -329,4 +360,22 @@ mod tests {
329360 let errors = check ( s) ;
330361 assert ! ( errors. is_empty( ) ) ;
331362 }
363+
364+ #[ test]
365+ pub fn test_while_mutate ( ) {
366+ let s = "
367+ f {
368+ var i: u32
369+ var a: [i32; 50]
370+ while i < 50u {
371+ a[i]
372+ i = i + 1u
373+ a[i]
374+ }
375+ }
376+ " ;
377+
378+ let errors = check ( s) ;
379+ assert_eq ! ( errors. len( ) , 1 ) ;
380+ }
332381}
0 commit comments