Skip to content

Commit 8c436c2

Browse files
committed
Change subset
1 parent cefff87 commit 8c436c2

File tree

13 files changed

+76
-85
lines changed

13 files changed

+76
-85
lines changed

src/function/algebra/sylvester.js

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ export const createSylvester = /* #__PURE__ */ factory(name, dependencies, (
110110

111111
for (let k = 0; k < n; k++) {
112112
if (k < (n - 1) && abs(subset(G, index(k + 1, k))) > 1e-5) {
113-
let RHS = vc(subset(D, index(all, k)), subset(D, index(all, k + 1)))
113+
let RHS = vc(subset(D, index(all, [k])), subset(D, index(all, [k + 1])))
114114
for (let j = 0; j < k; j++) {
115115
RHS = add(RHS,
116116
vc(multiply(y[j], subset(G, index(j, k))), multiply(y[j], subset(G, index(j, k + 1))))
@@ -125,11 +125,11 @@ export const createSylvester = /* #__PURE__ */ factory(name, dependencies, (
125125
hc(gkm, add(F, gmm))
126126
)
127127
const yAux = lusolve(LHS, RHS)
128-
y[k] = yAux.subset(index(range(0, m), 0))
129-
y[k + 1] = yAux.subset(index(range(m, 2 * m), 0))
128+
y[k] = yAux.subset(index(range(0, m), [0]))
129+
y[k + 1] = yAux.subset(index(range(m, 2 * m), [0]))
130130
k++
131131
} else {
132-
let RHS = subset(D, index(all, k))
132+
let RHS = subset(D, index(all, [k]))
133133
for (let j = 0; j < k; j++) { RHS = add(RHS, multiply(y[j], subset(G, index(j, k)))) }
134134
const gkk = subset(G, index(k, k))
135135
const LHS = subtract(F, multiply(gkk, identity(m)))

src/function/matrix/column.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ export const createColumn = /* #__PURE__ */ factory(name, dependencies, ({ typed
5151
validateIndex(column, value.size()[1])
5252

5353
const rowRange = range(0, value.size()[0])
54-
const index = new Index(rowRange, column)
54+
const index = new Index(rowRange, [column])
5555
const result = value.subset(index)
5656
return isMatrix(result)
5757
? result

src/function/matrix/row.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ export const createRow = /* #__PURE__ */ factory(name, dependencies, ({ typed, I
5151
validateIndex(row, value.size()[0])
5252

5353
const columnRange = range(0, value.size()[1])
54-
const index = new Index(row, columnRange)
54+
const index = new Index([row], columnRange)
5555
const result = value.subset(index)
5656
return isMatrix(result)
5757
? result

src/function/matrix/subset.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ export const createSubset = /* #__PURE__ */ factory(name, dependencies, ({ typed
115115
if (typeof replacement === 'string') {
116116
throw new Error('can\'t boradcast a string')
117117
}
118-
if (index._isScalar) {
118+
if (index.isScalar()) {
119119
return replacement
120120
}
121121

src/type/matrix/DenseMatrix.js

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ export const createDenseMatrixClass = /* #__PURE__ */ factory(name, dependencies
238238
}
239239

240240
// retrieve submatrix
241-
const returnMatrix = new DenseMatrix([])
241+
const returnMatrix = new DenseMatrix()
242242
const submatrix = _getSubmatrix(matrix._data, index)
243243
returnMatrix._size = submatrix.size
244244
returnMatrix._datatype = matrix._datatype
@@ -259,11 +259,10 @@ export const createDenseMatrixClass = /* #__PURE__ */ factory(name, dependencies
259259
function _getSubmatrix (data, index) {
260260
const maxDepth = index.size().length - 1
261261
const size = Array(maxDepth)
262-
return { data: getSubmatrixRecursive(data), size }
262+
return { data: getSubmatrixRecursive(data), size: size.filter(x => x !== null) }
263263

264264
function getSubmatrixRecursive (data, depth = 0) {
265265
const ranges = index.dimension(depth)
266-
if (ranges === null) { console.log('null range') }
267266
function callback (rangeIndex) {
268267
validateIndex(rangeIndex, data.length)
269268
return getSubmatrixRecursive(data[rangeIndex], depth + 1)
@@ -272,16 +271,16 @@ export const createDenseMatrixClass = /* #__PURE__ */ factory(name, dependencies
272271
validateIndex(rangeIndex, data.length)
273272
return data[rangeIndex]
274273
}
275-
if (Number.isInteger(ranges)) {
276-
size[depth] = 1
274+
if (isNumber(ranges)) {
275+
size[depth] = null
277276
} else {
278277
size[depth] = ranges.size()[0]
279278
}
280279
if (depth < maxDepth) {
281-
if (Number.isInteger(ranges)) return [callback(ranges)]
280+
if (isNumber(ranges)) return callback(ranges)
282281
else return ranges.map(callback).valueOf()
283282
} else {
284-
if (Number.isInteger(ranges)) return [finalCallback(ranges)]
283+
if (isNumber(ranges)) return finalCallback(ranges)
285284
else return ranges.map(finalCallback).valueOf()
286285
}
287286
}
@@ -309,34 +308,35 @@ export const createDenseMatrixClass = /* #__PURE__ */ factory(name, dependencies
309308
const isScalar = index.isScalar()
310309

311310
// calculate the size of the submatrix, and convert it into an Array if needed
312-
let sSize
311+
let submatrixSize
313312
if (isMatrix(submatrix)) {
314-
sSize = submatrix.size()
313+
submatrixSize = submatrix.size()
315314
submatrix = submatrix.valueOf()
316315
} else {
317-
sSize = arraySize(submatrix)
316+
submatrixSize = arraySize(submatrix)
318317
}
319318

320319
if (isScalar) {
321320
// set a scalar
322321

323322
// check whether submatrix is a scalar
324-
if (sSize.length !== 0) {
323+
if (submatrixSize.length !== 0) {
325324
throw new TypeError('Scalar expected')
326325
}
327326
matrix.set(index.min(), submatrix, defaultValue)
328327
} else {
329328
// set a submatrix
330329

331330
// broadcast submatrix
332-
if (!deepStrictEqual(sSize, iSize)) {
331+
if (!deepStrictEqual(submatrixSize, iSize)) {
332+
// TODO: remove try catch if possible
333333
try {
334-
if (sSize.length === 0) {
334+
if (submatrixSize.length === 0) {
335335
submatrix = broadcastTo([submatrix], iSize)
336336
} else {
337337
submatrix = broadcastTo(submatrix, iSize)
338338
}
339-
sSize = arraySize(submatrix)
339+
submatrixSize = arraySize(submatrix)
340340
} catch {
341341
}
342342
}
@@ -346,11 +346,11 @@ export const createDenseMatrixClass = /* #__PURE__ */ factory(name, dependencies
346346
throw new DimensionError(iSize.length, matrix._size.length, '<')
347347
}
348348

349-
if (sSize.length < iSize.length) {
349+
if (submatrixSize.length < iSize.length) {
350350
// calculate number of missing outer dimensions
351351
let i = 0
352352
let outer = 0
353-
while (iSize[i] === 1 && sSize[i] === 1) {
353+
while (iSize[i] === 1 && submatrixSize[i] === 1) {
354354
i++
355355
}
356356
while (iSize[i] === 1) {
@@ -359,12 +359,12 @@ export const createDenseMatrixClass = /* #__PURE__ */ factory(name, dependencies
359359
}
360360

361361
// unsqueeze both outer and inner dimensions
362-
submatrix = unsqueeze(submatrix, iSize.length, outer, sSize)
362+
submatrix = unsqueeze(submatrix, iSize.length, outer, submatrixSize)
363363
}
364364

365365
// check whether the size of the submatrix matches the index size
366-
if (!deepStrictEqual(iSize, sSize)) {
367-
throw new DimensionError(iSize, sSize, '>')
366+
if (!deepStrictEqual(iSize, submatrixSize)) {
367+
throw new DimensionError(iSize, submatrixSize, '>')
368368
}
369369

370370
// enlarge matrix when needed
@@ -405,10 +405,10 @@ export const createDenseMatrixClass = /* #__PURE__ */ factory(name, dependencies
405405
}
406406

407407
if (depth < maxDepth) {
408-
if (Number.isInteger(range)) recursiveCallback(range, [0])
408+
if (isNumber(range)) recursiveCallback(range, [0])
409409
else range.forEach(recursiveCallback)
410410
} else {
411-
if (Number.isInteger(range)) finalCallback(range, [0])
411+
if (isNumber(range)) finalCallback(range, [0])
412412
else range.forEach(finalCallback)
413413
}
414414
}

src/type/matrix/MatrixIndex.js

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { isArray, isMatrix, isRange } from '../../utils/is.js'
1+
import { isArray, isMatrix, isRange, isNumber, isString } from '../../utils/is.js'
22
import { clone } from '../../utils/object.js'
33
import { isInteger } from '../../utils/number.js'
44
import { factory } from '../../utils/factory.js'
@@ -36,7 +36,6 @@ export const createIndexClass = /* #__PURE__ */ factory(name, dependencies, ({ I
3636

3737
this._dimensions = []
3838
this._sourceSize = []
39-
this._isScalar = true
4039

4140
for (let i = 0, ii = ranges.length; i < ii; i++) {
4241
const arg = ranges[i]
@@ -46,7 +45,6 @@ export const createIndexClass = /* #__PURE__ */ factory(name, dependencies, ({ I
4645
let sourceSize = null
4746
if (isRange(arg)) {
4847
this._dimensions.push(arg)
49-
this._isScalar = false
5048
} else if (argIsArray || argIsMatrix) {
5149
// create matrix
5250
let m
@@ -60,12 +58,6 @@ export const createIndexClass = /* #__PURE__ */ factory(name, dependencies, ({ I
6058
}
6159

6260
this._dimensions.push(m)
63-
// size
64-
const size = m.size()
65-
// scalar
66-
if (size.length !== 1 || size[0] !== 1 || sourceSize !== null) {
67-
this._isScalar = false
68-
}
6961
} else if (argType === 'number') {
7062
this._dimensions.push(arg)
7163
} else if (argType === 'bigint') {
@@ -109,7 +101,6 @@ export const createIndexClass = /* #__PURE__ */ factory(name, dependencies, ({ I
109101
Index.prototype.clone = function () {
110102
const index = new Index()
111103
index._dimensions = clone(this._dimensions)
112-
index._isScalar = this._isScalar
113104
index._sourceSize = this._sourceSize
114105
return index
115106
}
@@ -228,7 +219,7 @@ export const createIndexClass = /* #__PURE__ */ factory(name, dependencies, ({ I
228219
* @return {boolean} isScalar
229220
*/
230221
Index.prototype.isScalar = function () {
231-
return this._isScalar
222+
return this._dimensions.every(dim => isNumber(dim) || isString(dim))
232223
}
233224

234225
/**

test/unit-tests/expression/node/AccessorNode.test.js

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ describe('AccessorNode', function () {
7373
const scope = {
7474
a: [[1, 2], [3, 4]]
7575
}
76-
assert.deepStrictEqual(expr.evaluate(scope), [[3, 4]])
76+
assert.deepStrictEqual(expr.evaluate(scope), [3, 4])
7777
})
7878

7979
it('should compile a AccessorNode with "end" in an expression', function () {
@@ -185,7 +185,7 @@ describe('AccessorNode', function () {
185185
const scope = {
186186
a: [[1, 2], [3, 4]]
187187
}
188-
assert.deepStrictEqual(expr.evaluate(scope), [[4, 3]])
188+
assert.deepStrictEqual(expr.evaluate(scope), [4, 3])
189189
})
190190

191191
it('should compile a AccessorNode with "end" both as value and in a range', function () {
@@ -203,7 +203,7 @@ describe('AccessorNode', function () {
203203
const scope = {
204204
a: [[1, 2], [3, 4]]
205205
}
206-
assert.deepStrictEqual(expr.evaluate(scope), [[3, 4]])
206+
assert.deepStrictEqual(expr.evaluate(scope), [3, 4])
207207
})
208208

209209
it('should use the inner context when using "end" in a nested index', function () {

test/unit-tests/expression/parse.test.js

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -697,18 +697,18 @@ describe('parse', function () {
697697
[7, 8, 9]
698698
])
699699
}
700-
assert.deepStrictEqual(parseAndEval('a[2, :]', scope), math.matrix([[4, 5, 6]]))
701-
assert.deepStrictEqual(parseAndEval('a[2, :2]', scope), math.matrix([[4, 5]]))
702-
assert.deepStrictEqual(parseAndEval('a[2, :end-1]', scope), math.matrix([[4, 5]]))
703-
assert.deepStrictEqual(parseAndEval('a[2, 2:]', scope), math.matrix([[5, 6]]))
704-
assert.deepStrictEqual(parseAndEval('a[2, 2:3]', scope), math.matrix([[5, 6]]))
705-
assert.deepStrictEqual(parseAndEval('a[2, 1:2:3]', scope), math.matrix([[4, 6]]))
706-
assert.deepStrictEqual(parseAndEval('a[:, 2]', scope), math.matrix([[2], [5], [8]]))
707-
assert.deepStrictEqual(parseAndEval('a[:2, 2]', scope), math.matrix([[2], [5]]))
708-
assert.deepStrictEqual(parseAndEval('a[:end-1, 2]', scope), math.matrix([[2], [5]]))
709-
assert.deepStrictEqual(parseAndEval('a[2:, 2]', scope), math.matrix([[5], [8]]))
710-
assert.deepStrictEqual(parseAndEval('a[2:3, 2]', scope), math.matrix([[5], [8]]))
711-
assert.deepStrictEqual(parseAndEval('a[1:2:3, 2]', scope), math.matrix([[2], [8]]))
700+
assert.deepStrictEqual(parseAndEval('a[2, :]', scope), math.matrix([4, 5, 6]))
701+
assert.deepStrictEqual(parseAndEval('a[2, :2]', scope), math.matrix([4, 5]))
702+
assert.deepStrictEqual(parseAndEval('a[2, :end-1]', scope), math.matrix([4, 5]))
703+
assert.deepStrictEqual(parseAndEval('a[2, 2:]', scope), math.matrix([5, 6]))
704+
assert.deepStrictEqual(parseAndEval('a[2, 2:3]', scope), math.matrix([5, 6]))
705+
assert.deepStrictEqual(parseAndEval('a[2, 1:2:3]', scope), math.matrix([4, 6]))
706+
assert.deepStrictEqual(parseAndEval('a[:, 2]', scope), math.matrix([2, 5, 8]))
707+
assert.deepStrictEqual(parseAndEval('a[:2, [2]]', scope), math.matrix([[2], [5]]))
708+
assert.deepStrictEqual(parseAndEval('a[:end-1, [2]]', scope), math.matrix([[2], [5]]))
709+
assert.deepStrictEqual(parseAndEval('a[2:, [2]]', scope), math.matrix([[5], [8]]))
710+
assert.deepStrictEqual(parseAndEval('a[2:3, [2]]', scope), math.matrix([[5], [8]]))
711+
assert.deepStrictEqual(parseAndEval('a[1:2:3, [2]]', scope), math.matrix([[2], [8]]))
712712
})
713713

714714
it('should get a matrix subset of a matrix subset', function () {
@@ -719,7 +719,7 @@ describe('parse', function () {
719719
[7, 8, 9]
720720
])
721721
}
722-
assert.deepStrictEqual(parseAndEval('a[2, :][1,1]', scope), 4)
722+
assert.deepStrictEqual(parseAndEval('a[[2], :][1,1]', scope), 4)
723723
})
724724

725725
it('should get BigNumber value from an array', function () {
@@ -768,7 +768,7 @@ describe('parse', function () {
768768
assert.deepStrictEqual(parseAndEval('a[1:3,1:2]', scope), math.matrix([[100, 2], [3, 10], [0, 12]]))
769769

770770
scope.b = [[1, 2], [3, 4]]
771-
assert.deepStrictEqual(parseAndEval('b[1,:]', scope), [[1, 2]])
771+
assert.deepStrictEqual(parseAndEval('b[1,:]', scope), [1, 2])
772772
})
773773

774774
it('should get/set the matrix correctly for 3d matrices', function () {
@@ -789,10 +789,10 @@ describe('parse', function () {
789789
]))
790790

791791
assert.deepStrictEqual(parseAndEval('size(f)', scope), math.matrix([2, 2, 2], 'dense', 'number'))
792-
assert.deepStrictEqual(parseAndEval('f[:,:,1]', scope), math.matrix([[[1], [2]], [[3], [4]]]))
793-
assert.deepStrictEqual(parseAndEval('f[:,:,2]', scope), math.matrix([[[5], [6]], [[7], [8]]]))
794-
assert.deepStrictEqual(parseAndEval('f[:,2,:]', scope), math.matrix([[[2, 6]], [[4, 8]]]))
795-
assert.deepStrictEqual(parseAndEval('f[2,:,:]', scope), math.matrix([[[3, 7], [4, 8]]]))
792+
assert.deepStrictEqual(parseAndEval('f[:,:,1]', scope), math.matrix([[1, 2], [3, 4]]))
793+
assert.deepStrictEqual(parseAndEval('f[:,:,2]', scope), math.matrix([[5, 6], [7, 8]]))
794+
assert.deepStrictEqual(parseAndEval('f[:,2,:]', scope), math.matrix([[2, 6], [4, 8]]))
795+
assert.deepStrictEqual(parseAndEval('f[2,:,:]', scope), math.matrix([[3, 7], [4, 8]]))
796796

797797
parseAndEval('a=diag([1,2,3,4])', scope)
798798
assert.deepStrictEqual(parseAndEval('a[3:end, 3:end]', scope), math.matrix([[3, 0], [0, 4]]))

test/unit-tests/function/matrix/map.test.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ describe('map', function () {
281281
it('should operate from the parser with multiple inputs that need broadcasting and one based indices and the broadcasted arrays', function () {
282282
// this is a convoluted way of calculating f(a,b,idx) = 2a+2b+index
283283
// 2(1) + 2([3,4]) + [1, 2] # yields [9, 12]
284-
const arr2 = math.evaluate('map([1],[3,4], f(a,b,idx,A,B)= a + A[idx] + b + B[idx] + idx[1])')
284+
const arr2 = math.evaluate('map([1],[3,4], f(a,b,idx,A,B)= a + A[idx[1]] + b + B[idx[1]] + idx[1])')
285285
const expected = math.matrix([9, 12])
286286
assert.deepStrictEqual(arr2, expected)
287287
})

test/unit-tests/function/matrix/subset.test.js

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@ describe('subset', function () {
1111
const b = math.matrix(a)
1212

1313
it('should get the right subset of an array', function () {
14-
assert.deepStrictEqual(subset(a, index(new Range(0, 2), 1)), [[2], [4]])
14+
assert.deepStrictEqual(subset(a, index(new Range(0, 2), 1)), [2, 4])
1515
assert.deepStrictEqual(subset(a, index(1, 0)), 3)
1616
assert.deepStrictEqual(subset([math.bignumber(2)], index(0)), math.bignumber(2))
1717
})
1818

1919
it('should get the right subset of an array of booleans', function () {
20-
assert.deepStrictEqual(subset(a, index([true, true], 1)), [[2], [4]])
20+
assert.deepStrictEqual(subset(a, index([true, true], [1])), [[2], [4]])
2121
assert.deepStrictEqual(subset(a, index([false, true], [true, false])), [[3]])
2222
assert.deepStrictEqual(subset([math.bignumber(2)], index([true])), [math.bignumber(2)])
2323
})
@@ -32,7 +32,7 @@ describe('subset', function () {
3232
})
3333

3434
it('should get the right subset of an array of booleans in the parser', function () {
35-
assert.deepStrictEqual(math.evaluate('a[[true, true], 2]', { a }), [[2], [4]])
35+
assert.deepStrictEqual(math.evaluate('a[[true, true], 2]', { a }), [2, 4])
3636
assert.deepStrictEqual(math.evaluate('a[[false, true], [true, false]]', { a }), [[3]])
3737
assert.deepStrictEqual(math.evaluate('[bignumber(2)][[true]]'), math.matrix([math.bignumber(2)]))
3838
})
@@ -75,7 +75,7 @@ describe('subset', function () {
7575
})
7676

7777
it('should get the right subset of a matrix', function () {
78-
assert.deepStrictEqual(subset(b, index(new Range(0, 2), 1)), matrix([[2], [4]]))
78+
assert.deepStrictEqual(subset(b, index(new Range(0, 2), 1)), matrix([2, 4]))
7979
assert.deepStrictEqual(subset(b, index(1, 0)), 3)
8080
})
8181

0 commit comments

Comments
 (0)