@@ -10,10 +10,11 @@ import { optimizeCallback } from '../../utils/optimizeCallback.js'
1010
1111const name = 'DenseMatrix'
1212const dependencies = [
13- 'Matrix'
13+ 'Matrix' ,
14+ 'config'
1415]
1516
16- export const createDenseMatrixClass = /* #__PURE__ */ factory ( name , dependencies , ( { Matrix } ) => {
17+ export const createDenseMatrixClass = /* #__PURE__ */ factory ( name , dependencies , ( { Matrix, config } ) => {
1718 /**
1819 * Dense Matrix implementation. A regular, dense matrix, supporting multi-dimensional matrices. This is the default matrix type.
1920 * @class DenseMatrix
@@ -218,7 +219,9 @@ export const createDenseMatrixClass = /* #__PURE__ */ factory(name, dependencies
218219 throw new TypeError ( 'Invalid index' )
219220 }
220221
221- const isScalar = index . isScalar ( )
222+ const isScalar = config . legacySubset
223+ ? index . size ( ) . every ( idx => idx === 1 )
224+ : index . isScalar ( )
222225 if ( isScalar ) {
223226 // return a scalar
224227 return matrix . get ( index . min ( ) )
@@ -243,7 +246,7 @@ export const createDenseMatrixClass = /* #__PURE__ */ factory(name, dependencies
243246 returnMatrix . _size = submatrix . size
244247 returnMatrix . _datatype = matrix . _datatype
245248 returnMatrix . _data = submatrix . data
246- return returnMatrix
249+ return config . legacySubset ? returnMatrix . reshape ( index . size ( ) ) : returnMatrix
247250 }
248251 }
249252
@@ -262,27 +265,27 @@ export const createDenseMatrixClass = /* #__PURE__ */ factory(name, dependencies
262265 return { data : getSubmatrixRecursive ( data ) , size : size . filter ( x => x !== null ) }
263266
264267 function getSubmatrixRecursive ( data , depth = 0 ) {
265- const ranges = index . dimension ( depth )
266- function _mapIndex ( range , callback ) {
268+ const dims = index . dimension ( depth )
269+ function _mapIndex ( dim , callback ) {
267270 // applies a callback for when the index is a Number or a Matrix
268- if ( isNumber ( range ) ) return callback ( range )
269- else return range . map ( callback ) . valueOf ( )
271+ if ( isNumber ( dim ) ) return callback ( dim )
272+ else return dim . map ( callback ) . valueOf ( )
270273 }
271274
272- if ( isNumber ( ranges ) ) {
275+ if ( isNumber ( dims ) ) {
273276 size [ depth ] = null
274277 } else {
275- size [ depth ] = ranges . size ( ) [ 0 ]
278+ size [ depth ] = dims . size ( ) [ 0 ]
276279 }
277280 if ( depth < maxDepth ) {
278- return _mapIndex ( ranges , rangeIndex => {
279- validateIndex ( rangeIndex , data . length )
280- return getSubmatrixRecursive ( data [ rangeIndex ] , depth + 1 )
281+ return _mapIndex ( dims , dimIndex => {
282+ validateIndex ( dimIndex , data . length )
283+ return getSubmatrixRecursive ( data [ dimIndex ] , depth + 1 )
281284 } )
282285 } else {
283- return _mapIndex ( ranges , rangeIndex => {
284- validateIndex ( rangeIndex , data . length )
285- return data [ rangeIndex ]
286+ return _mapIndex ( dims , dimIndex => {
287+ validateIndex ( dimIndex , data . length )
288+ return data [ dimIndex ]
286289 } )
287290 }
288291 }
0 commit comments