Skip to content

Commit 4a27b5c

Browse files
authored
[Fix] Add ctx to the original ndarray and revise the usage of context to ctx (apache#16819)
* try to fix warning try to fix warning try to fix all warnings use ctx * try to fix warnings * try fo fix warnings
1 parent b972406 commit 4a27b5c

File tree

4 files changed

+40
-22
lines changed

4 files changed

+40
-22
lines changed

python/mxnet/gluon/block.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def _gather_type_ctx_info(args):
117117
Context of the first appeared NDArray (for backward-compatibility)
118118
"""
119119
if isinstance(args, NDArray):
120-
return False, True, {args.context}, args.context
120+
return False, True, {args.ctx}, args.ctx
121121
elif isinstance(args, Symbol):
122122
return True, False, set(), None
123123
elif isinstance(args, (list, tuple)):
@@ -1141,7 +1141,7 @@ def forward(self, x, *args):
11411141
if len(ctx_set) > 1:
11421142
raise ValueError('Find multiple contexts in the input, '
11431143
'After hybridized, the HybridBlock only supports one input '
1144-
'context. You can print the ele.context in the '
1144+
'context. You can print the ele.ctx in the '
11451145
'input arguments to inspect their contexts. '
11461146
'Find all contexts = {}'.format(ctx_set))
11471147
with ctx:
@@ -1324,7 +1324,7 @@ def __init__(self, outputs, inputs, params=None):
13241324

13251325
def forward(self, x, *args):
13261326
if isinstance(x, NDArray):
1327-
with x.context:
1327+
with x.ctx:
13281328
return self._call_cached_op(x, *args)
13291329

13301330
assert isinstance(x, Symbol), \

python/mxnet/gluon/parameter.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -369,10 +369,10 @@ def _init_grad(self):
369369
if self._grad_stype != 'default':
370370
raise ValueError("mxnet.numpy.zeros does not support stype = {}"
371371
.format(self._grad_stype))
372-
self._grad = [_mx_np.zeros(shape=i.shape, dtype=i.dtype, ctx=i.context)
372+
self._grad = [_mx_np.zeros(shape=i.shape, dtype=i.dtype, ctx=i.ctx)
373373
for i in self._data]
374374
else:
375-
self._grad = [ndarray.zeros(shape=i.shape, dtype=i.dtype, ctx=i.context,
375+
self._grad = [ndarray.zeros(shape=i.shape, dtype=i.dtype, ctx=i.ctx,
376376
stype=self._grad_stype) for i in self._data]
377377

378378
autograd.mark_variables(self._check_and_get(self._data, list),
@@ -522,7 +522,7 @@ def row_sparse_data(self, row_id):
522522
raise RuntimeError("Cannot return a copy of Parameter %s via row_sparse_data() " \
523523
"because its storage type is %s. Please use data() instead." \
524524
%(self.name, self._stype))
525-
return self._get_row_sparse(self._data, row_id.context, row_id)
525+
return self._get_row_sparse(self._data, row_id.ctx, row_id)
526526

527527
def list_row_sparse_data(self, row_id):
528528
"""Returns copies of the 'row_sparse' parameter on all contexts, in the same order
@@ -897,7 +897,7 @@ def zero_grad(self):
897897
if g.stype == 'row_sparse':
898898
ndarray.zeros_like(g, out=g)
899899
else:
900-
arrays[g.context].append(g)
900+
arrays[g.ctx].append(g)
901901

902902
if len(arrays) == 0:
903903
return

python/mxnet/ndarray/ndarray.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def __repr__(self):
245245
shape_info = 'x'.join(['%d' % x for x in self.shape])
246246
return '\n%s\n<%s %s @%s>' % (str(self.asnumpy()),
247247
self.__class__.__name__,
248-
shape_info, self.context)
248+
shape_info, self.ctx)
249249

250250
def __reduce__(self):
251251
return NDArray, (None,), self.__getstate__()
@@ -729,14 +729,14 @@ def _prepare_value_nd(self, value, bcast_shape, squeeze_axes=None):
729729
`squeeze_axes`: a sequence of axes to squeeze in the value array.
730730
"""
731731
if isinstance(value, numeric_types):
732-
value_nd = full(bcast_shape, value, ctx=self.context, dtype=self.dtype)
732+
value_nd = full(bcast_shape, value, ctx=self.ctx, dtype=self.dtype)
733733
elif type(value) == self.__class__: # pylint: disable=unidiomatic-typecheck
734-
value_nd = value.as_in_context(self.context)
734+
value_nd = value.as_in_context(self.ctx)
735735
if value_nd.dtype != self.dtype:
736736
value_nd = value_nd.astype(self.dtype)
737737
else:
738738
try:
739-
value_nd = array(value, ctx=self.context, dtype=self.dtype)
739+
value_nd = array(value, ctx=self.ctx, dtype=self.dtype)
740740
except:
741741
raise TypeError('{} does not support assignment with non-array-like '
742742
'object {} of type {}'.format(self.__class__, value, type(value)))
@@ -1220,7 +1220,7 @@ def _get_index_nd(self, key):
12201220

12211221
shape_nd_permut = tuple(self.shape[ax] for ax in axs_nd_permut)
12221222
converted_idcs_short = [
1223-
self._advanced_index_to_array(idx, ax_len, self.context)
1223+
self._advanced_index_to_array(idx, ax_len, self.ctx)
12241224
for idx, ax_len in zip(idcs_permut_short, shape_nd_permut)
12251225
]
12261226
bcast_idcs_permut_short = self._broadcast_advanced_indices(
@@ -1229,7 +1229,7 @@ def _get_index_nd(self, key):
12291229

12301230
# Get the ndim of advanced indexing subspace
12311231
converted_advanced_idcs = [
1232-
self._advanced_index_to_array(idx, ax_len, self.context)
1232+
self._advanced_index_to_array(idx, ax_len, self.ctx)
12331233
for idx, ax_len in zip(adv_idcs_nd, [self.shape[ax] for ax in adv_axs_nd])
12341234
]
12351235
bcast_advanced_shape = _broadcast_shapes(converted_advanced_idcs)
@@ -2433,6 +2433,23 @@ def context(self):
24332433
self.handle, ctypes.byref(dev_typeid), ctypes.byref(dev_id)))
24342434
return Context(Context.devtype2str[dev_typeid.value], dev_id.value)
24352435

2436+
@property
2437+
def ctx(self):
2438+
"""Device context of the array. Has the same meaning as context.
2439+
2440+
Examples
2441+
--------
2442+
>>> x = mx.nd.array([1, 2, 3, 4])
2443+
>>> x.ctx
2444+
cpu(0)
2445+
>>> type(x.ctx)
2446+
<class 'mxnet.context.Context'>
2447+
>>> y = mx.nd.zeros((2,3), mx.gpu(0))
2448+
>>> y.ctx
2449+
gpu(0)
2450+
"""
2451+
return self.context
2452+
24362453
@property
24372454
def dtype(self):
24382455
"""Data-type of the array's elements.
@@ -2580,7 +2597,7 @@ def astype(self, dtype, copy=True):
25802597
if not copy and np.dtype(dtype) == self.dtype:
25812598
return self
25822599

2583-
res = empty(self.shape, ctx=self.context, dtype=dtype)
2600+
res = empty(self.shape, ctx=self.ctx, dtype=dtype)
25842601
self.copyto(res)
25852602
return res
25862603

@@ -2646,7 +2663,7 @@ def copy(self):
26462663
array([[ 1., 1., 1.],
26472664
[ 1., 1., 1.]], dtype=float32)
26482665
"""
2649-
return self.copyto(self.context)
2666+
return self.copyto(self.ctx)
26502667

26512668
def slice_assign_scalar(self, value, begin, end, step):
26522669
"""
@@ -2904,7 +2921,7 @@ def _full(self, value):
29042921
"""
29052922
This is added as an NDArray class method in order to support polymorphism in NDArray and numpy.ndarray indexing
29062923
"""
2907-
return _internal._full(self.shape, value=value, ctx=self.context, dtype=self.dtype, out=self)
2924+
return _internal._full(self.shape, value=value, ctx=self.ctx, dtype=self.dtype, out=self)
29082925

29092926
def _scatter_set_nd(self, value_nd, indices):
29102927
"""
@@ -4542,7 +4559,7 @@ def concatenate(arrays, axis=0, always_copy=True):
45424559
assert shape_rest2 == arr.shape[axis+1:]
45434560
assert dtype == arr.dtype
45444561
ret_shape = shape_rest1 + (shape_axis,) + shape_rest2
4545-
ret = empty(ret_shape, ctx=arrays[0].context, dtype=dtype)
4562+
ret = empty(ret_shape, ctx=arrays[0].ctx, dtype=dtype)
45464563

45474564
idx = 0
45484565
begin = [0 for _ in ret_shape]

python/mxnet/numpy/multiarray.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -921,15 +921,15 @@ def __repr__(self):
921921
elif dtype not in (_np.float32, _np.bool_):
922922
array_str = array_str[:-1] + ', dtype={})'.format(dtype)
923923

924-
context = self.context
924+
context = self.ctx
925925
if context.device_type == 'cpu':
926926
return array_str
927927
return array_str[:-1] + ', ctx={})'.format(str(context))
928928

929929
def __str__(self):
930930
"""Returns a string representation of the array."""
931931
array_str = self.asnumpy().__str__()
932-
context = self.context
932+
context = self.ctx
933933
if context.device_type == 'cpu' or self.ndim == 0:
934934
return array_str
935935
return '{array} @{ctx}'.format(array=array_str, ctx=context)
@@ -994,7 +994,7 @@ def astype(self, dtype, **kwargs): # pylint: disable=arguments-differ,unused-ar
994994
if not copy and _np.dtype(dtype) == self.dtype:
995995
return self
996996

997-
res = empty(self.shape, dtype=dtype, ctx=self.context)
997+
res = empty(self.shape, dtype=dtype, ctx=self.ctx)
998998
self.copyto(res)
999999
return res
10001000

@@ -1051,7 +1051,8 @@ def argmax(self, axis=None, out=None): # pylint: disable=arguments-differ
10511051

10521052
def as_in_context(self, context):
10531053
"""This function has been deprecated. Please refer to ``ndarray.as_in_ctx``."""
1054-
warnings.warn('ndarray.context has been renamed to ndarray.ctx', DeprecationWarning)
1054+
warnings.warn('ndarray.as_in_context has been renamed to'
1055+
' ndarray.as_in_ctx', DeprecationWarning)
10551056
return self.as_nd_ndarray().as_in_context(context).as_np_ndarray()
10561057

10571058
def as_in_ctx(self, ctx):
@@ -1864,7 +1865,7 @@ def _full(self, value):
18641865
Currently for internal use only. Implemented for __setitem__.
18651866
Assign to self an array of self's same shape and type, filled with value.
18661867
"""
1867-
return _mx_nd_np.full(self.shape, value, ctx=self.context, dtype=self.dtype, out=self)
1868+
return _mx_nd_np.full(self.shape, value, ctx=self.ctx, dtype=self.dtype, out=self)
18681869

18691870
# pylint: disable=redefined-outer-name
18701871
def _scatter_set_nd(self, value_nd, indices):

0 commit comments

Comments
 (0)