Skip to content

Commit 28549e3

Browse files
committed
Reduce KV-cache memory size halve
Signed-off-by: Min, Byungil <byungil.min@intel.com>
1 parent 7b8fe75 commit 28549e3

File tree

7 files changed

+49
-21
lines changed

7 files changed

+49
-21
lines changed

src/plugins/intel_gpu/include/intel_gpu/plugin/multi_tensor_variable_state.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ class VariableStateIndirectKVCacheCompressed : public VariableStateIndirectKVCac
5555
const std::vector<cldnn::layout>& output_layouts,
5656
size_t beam_idx,
5757
size_t concat_idx,
58-
bool has_zp_state);
58+
bool has_zp_state,
59+
bool is_4bit_kv_cache = false);
5960
using Ptr = std::shared_ptr<VariableStateIndirectKVCacheCompressed>;
6061

6162
void set_state(const ov::SoPtr<ov::ITensor>& state) override;
@@ -70,5 +71,6 @@ class VariableStateIndirectKVCacheCompressed : public VariableStateIndirectKVCac
7071

7172
private:
7273
bool m_has_zp_state = false;
74+
bool m_is_4bit_kv_cache = false;
7375
};
7476
} // namespace ov::intel_gpu

src/plugins/intel_gpu/include/intel_gpu/plugin/variable_state.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ class VariableState : public VariableStateBase {
7272
return m_initial_layout;
7373
}
7474

75+
void set_alloc_inner_dim_divisor(size_t divisor) { m_alloc_inner_dim_divisor = divisor; }
76+
7577
ov::element::Type get_user_specified_type() const;
7678

7779
protected:
@@ -82,6 +84,7 @@ class VariableState : public VariableStateBase {
8284
cldnn::memory::ptr m_memory = nullptr;
8385
bool m_transpose_required = false;
8486
size_t actual_size = 0;
87+
size_t m_alloc_inner_dim_divisor = 1;
8588

8689
const cldnn::layout m_initial_layout;
8790

src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa_opt.cl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -264,10 +264,10 @@ KERNEL(sdpa_opt)(
264264
#if IS_INT4_COMPRESSED && !defined(BEAM_TABLE_TYPE)
265265
#ifdef INPUT1_DIMS_ORDER
266266
const uint key_base_p0 = FUNC_CALL(get_input1_index)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, 0, 0);
267-
const uint key_packed_pitch_p0 = (FUNC_CALL(get_input1_index)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, 1, 0) - key_base_p0) / 2;
267+
const uint key_packed_pitch_p0 = FUNC_CALL(get_input1_index)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, 1, 0) - key_base_p0;
268268
#else
269269
const uint key_base_p0 = INPUT1_GET_INDEX(b0_idx, b1_idx, 0, 0);
270-
const uint key_packed_pitch_p0 = K_HEAD_SIZE / 2;
270+
const uint key_packed_pitch_p0 = K_HEAD_SIZE;
271271
#endif
272272
#endif
273273
for (uint seq_len = sgid; seq_len < partition_seq_len; seq_len += SUBGROUPS_PER_WG) {
@@ -713,13 +713,13 @@ KERNEL(sdpa_opt)(
713713
uint value_offset = FUNC_CALL(get_input2_index)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, 0, 0);
714714
uint value_offset_next_seq = FUNC_CALL(get_input2_index)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, 1, 0);
715715
#if IS_INT4_COMPRESSED
716-
const uint value_pitch = (value_offset_next_seq - value_offset) / 2;
716+
const uint value_pitch = value_offset_next_seq - value_offset;
717717
#else
718718
const uint value_pitch = value_offset_next_seq - value_offset;
719719
#endif
720720
#else
721721
#if IS_INT4_COMPRESSED
722-
const uint value_pitch = V_HEAD_SIZE / 2;
722+
const uint value_pitch = V_HEAD_SIZE;
723723
#else
724724
const uint value_pitch = V_HEAD_SIZE;
725725
#endif
@@ -1296,10 +1296,10 @@ KERNEL(sdpa_opt)(
12961296
#if IS_INT4_COMPRESSED && !defined(IS_PAGED_ATTENTION) && !defined(BEAM_TABLE_TYPE)
12971297
#ifdef INPUT1_DIMS_ORDER
12981298
const uint key_base_s1 = FUNC_CALL(get_input1_index)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, 0, 0);
1299-
const uint key_packed_pitch_s1 = (FUNC_CALL(get_input1_index)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, 1, 0) - key_base_s1) / 2;
1299+
const uint key_packed_pitch_s1 = FUNC_CALL(get_input1_index)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, 1, 0) - key_base_s1;
13001300
#else
13011301
const uint key_base_s1 = INPUT1_GET_INDEX(b0_idx, b1_idx, 0, 0);
1302-
const uint key_packed_pitch_s1 = K_HEAD_SIZE / 2;
1302+
const uint key_packed_pitch_s1 = K_HEAD_SIZE;
13031303
#endif
13041304
#endif
13051305

@@ -1373,7 +1373,7 @@ KERNEL(sdpa_opt)(
13731373
// INT4: process 2*SUBGROUP_SIZE logical head dims per iteration (one packed byte per lane per token row)
13741374
#define KEY_BLOCK_READ(ptr, offset) BLOCK_READN(INPUT1_TYPE, 1, ptr, offset);
13751375
#define QUERY_VEC MAKE_VECTOR_TYPE(INPUT0_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE)
1376-
const uint key_pitch_int4 = K_HEAD_SIZE / 2;
1376+
const uint key_pitch_int4 = K_HEAD_SIZE;
13771377
for (uint hi = 0; hi < K_HEAD_SIZE; hi += 2 * SUBGROUP_SIZE) {
13781378
QUERY_VEC qvec_lo, qvec_hi;
13791379
uint qlo = hi * TARGET_SEQ_LEN_BLOCK_SIZE + sglid;
@@ -1470,7 +1470,7 @@ KERNEL(sdpa_opt)(
14701470
// INT4 partial block: process 2*SUBGROUP_SIZE logical head dims per iteration
14711471
#define KEY_BLOCK_READ(ptr, offset) BLOCK_READN(INPUT1_TYPE, 1, ptr, offset)
14721472
#define QUERY_VEC_TYPE MAKE_VECTOR_TYPE(INPUT0_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE)
1473-
const uint key_pitch_int4 = K_HEAD_SIZE / 2;
1473+
const uint key_pitch_int4 = K_HEAD_SIZE;
14741474
for (uint hi = 0; hi < K_HEAD_SIZE; hi += 2 * SUBGROUP_SIZE) {
14751475
QUERY_VEC_TYPE qvec_lo, qvec_hi;
14761476
uint qlo = hi * TARGET_SEQ_LEN_BLOCK_SIZE + sglid;
@@ -1856,13 +1856,13 @@ KERNEL(sdpa_opt)(
18561856
uint value_offset_base = FUNC_CALL(get_input2_index)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, 0, 0);
18571857
uint value_offset_next_seq = FUNC_CALL(get_input2_index)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, 1, 0);
18581858
#if IS_INT4_COMPRESSED
1859-
const uint value_pitch = (value_offset_next_seq - value_offset_base) / 2;
1859+
const uint value_pitch = value_offset_next_seq - value_offset_base;
18601860
#else
18611861
const uint value_pitch = value_offset_next_seq - value_offset_base;
18621862
#endif
18631863
#else
18641864
#if IS_INT4_COMPRESSED
1865-
const uint value_pitch = V_HEAD_SIZE / 2;
1865+
const uint value_pitch = V_HEAD_SIZE;
18661866
#else
18671867
const uint value_pitch = V_HEAD_SIZE;
18681868
#endif

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/dynamic_quantize_gpu_kv_cache.cl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,10 @@ KERNEL(dynamic_quantize_gpu_kv_cache)(
104104
ACCUMULATOR_TYPE scale_tmp = (ACCUMULATOR_TYPE)((UINT4_RANGE) / diff_value);
105105
ACCUMULATOR_TYPE zp_tmp = (ACCUMULATOR_TYPE)(-min_value * scale_tmp); // maps min -> 0, max -> UINT4_RANGE
106106

107-
// INT4 packed buffer: the output layout uses i8 with full head_size shape,
108-
// so divide by 2 to get the correct packed byte offset (2 INT4 values per byte).
109-
const uint output_offset = OUTPUT_GET_INDEX(b, f, y, x) / 2;
107+
// INT4 packed buffer: the output layout uses i8 with full head_size shape.
108+
// Use element-level offset directly (same stride as layout) so that SDPA
109+
// can address rows with the standard GET_INDEX pitch.
110+
const uint output_offset = OUTPUT_GET_INDEX(b, f, y, x);
110111
// Pairs of consecutive SUBGROUP_SIZE blocks are packed together.
111112
unroll_for (uint i = 0; i < INNERMOST_DIM_VALUE / SUBGROUP_SIZE; i += 2) {
112113
uchar q0 = (uchar)clamp(convert_int_rte((float)val[i] * scale_tmp + zp_tmp), 0, UINT4_RANGE);

src/plugins/intel_gpu/src/plugin/multi_tensor_variable_state.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,11 @@ VariableStateIndirectKVCacheCompressed::VariableStateIndirectKVCacheCompressed(
165165
const std::vector<cldnn::layout>& output_layouts,
166166
size_t beam_idx,
167167
size_t concat_idx,
168-
bool has_zp_state = false)
168+
bool has_zp_state,
169+
bool is_4bit_kv_cache)
169170
: VariableStateIndirectKVCache(info, context, shape_predictor, beam_idx, concat_idx),
170-
m_has_zp_state(has_zp_state) {
171+
m_has_zp_state(has_zp_state),
172+
m_is_4bit_kv_cache(is_4bit_kv_cache) {
171173
OPENVINO_ASSERT((has_zp_state && output_layouts.size() == 3) ||
172174
(!has_zp_state && output_layouts.size() == 2),
173175
"[GPU] Unexpected number of output layouts for VariableStateIndirectKVCacheCompressed");
@@ -185,6 +187,12 @@ VariableStateIndirectKVCacheCompressed::VariableStateIndirectKVCacheCompressed(
185187
OPENVINO_ASSERT((!m_has_zp_state && m_hidden_states.size() == 3) || (m_has_zp_state && m_hidden_states.size() == 4),
186188
"[GPU] VariableStateIndirectKVCacheCompressed expects 3 or 4 internal states to be initialized, "
187189
"actual number is ", m_hidden_states.size());
190+
191+
// For 4-bit KV-cache, two INT4 values are packed per byte.
192+
// Halve the innermost dim of the allocation to reduce physical memory usage.
193+
if (m_is_4bit_kv_cache) {
194+
m_hidden_states[0]->set_alloc_inner_dim_divisor(2);
195+
}
188196
}
189197

190198
VariableState::Ptr VariableStateIndirectKVCacheCompressed::get_compression_scale_state() const {

src/plugins/intel_gpu/src/plugin/sync_infer_request.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -667,13 +667,16 @@ void SyncInferRequest::allocate_states() {
667667
}
668668

669669
if (compressed) {
670+
const auto kv_precision = m_graph->get_config().get_kv_cache_precision();
671+
const bool is_4bit_kv_cache = ov::element::Type(kv_precision).bitwidth() == 4;
670672
m_variables.emplace(vi.first, std::make_shared<VariableStateIndirectKVCacheCompressed>(vi.second,
671673
m_context,
672674
m_shape_predictor,
673675
states_layouts,
674676
beam_axis,
675677
concat_axis,
676-
has_zp_state));
678+
has_zp_state,
679+
is_4bit_kv_cache));
677680
} else if (indirect_kv_cache) {
678681
m_variables.emplace(vi.first, std::make_shared<VariableStateIndirectKVCache>(vi.second,
679682
m_context,

src/plugins/intel_gpu/src/plugin/variable_state.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,21 @@ void VariableState::update_device_buffer() {
119119
const auto alloc_type = m_context->get_engine().use_unified_shared_memory() ? cldnn::allocation_type::usm_device : cldnn::allocation_type::cl_mem;
120120
const auto current_buf_size = m_layout.get_padded_dims();
121121
ov::Shape current_shape(current_buf_size.begin(), current_buf_size.end());
122-
const auto alloc_shape = predict_shape(m_name, cldnn::layout(current_shape, m_layout.data_type, m_layout.format), *m_shape_predictor);
123-
const auto alloc_layout = cldnn::layout(alloc_shape, m_layout.data_type, m_layout.format);
124-
m_memory = m_context->get_engine().allocate_memory(alloc_layout, alloc_type, false);
125-
actual_size = std::max(actual_size, alloc_layout.bytes_count());
122+
auto alloc_shape = predict_shape(m_name, cldnn::layout(current_shape, m_layout.data_type, m_layout.format), *m_shape_predictor);
123+
124+
// For INT4 packed KV-cache, halve the innermost dim to reduce physical allocation.
125+
// actual_size tracks LOGICAL capacity (un-halved) for correct max_pad calculations.
126+
if (m_alloc_inner_dim_divisor > 1 && !alloc_shape.empty()) {
127+
auto logical_alloc_shape = alloc_shape;
128+
alloc_shape.back() /= m_alloc_inner_dim_divisor;
129+
const auto alloc_layout = cldnn::layout(alloc_shape, m_layout.data_type, m_layout.format);
130+
m_memory = m_context->get_engine().allocate_memory(alloc_layout, alloc_type, false);
131+
actual_size = std::max(actual_size, cldnn::layout(logical_alloc_shape, m_layout.data_type, m_layout.format).bytes_count());
132+
} else {
133+
const auto alloc_layout = cldnn::layout(alloc_shape, m_layout.data_type, m_layout.format);
134+
m_memory = m_context->get_engine().allocate_memory(alloc_layout, alloc_type, false);
135+
actual_size = std::max(actual_size, alloc_layout.bytes_count());
136+
}
126137
}
127138

128139
OPENVINO_ASSERT(m_memory != nullptr, "m_memory is nullptr!!!");

0 commit comments

Comments
 (0)