Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -384,18 +384,19 @@ static VectorDims makeDummyInputDims(const Shape& inShape, const Shape& wShape)
}

static VectorDims makeDummyOutputDims(const VectorDims& inShape, const VectorDims& wShape, const size_t out_rank) {
size_t activationRank = inShape.size();
size_t channelRank = wShape.size() - 1;
const auto wShape2D = reshapeDownToRank<2>(wShape);
const size_t activationRank = inShape.size();
const size_t channelRank = 1;
// activation weight output_shape
// NCHW CoCHW NCo
// TNC CoC TNCo
// NC CoC NCo
VectorDims outputShape(out_rank, 1);
// set Co
outputShape.back() = wShape[0];
outputShape.back() = wShape2D[0];
// set batch dims
size_t batchRank = activationRank - channelRank;
size_t startIdx = out_rank - batchRank - 1;
const size_t batchRank = activationRank - channelRank;
const size_t startIdx = out_rank - batchRank - 1;
for (size_t i = 0; i < batchRank; i++) {
outputShape[i + startIdx] = inShape[i];
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ const std::vector<ExecutorImplementation<FCAttrs>>& getImplementations() {
fcMappingNotation);
},
AcceptsAnyShape<FCAttrs>,
CreateDnnlDefault<DnnlFCPrimitive, FCAttrs>{false, true}
CreateDnnlDefault<DnnlFCPrimitive, FCAttrs>{true, true}
)
};

Expand Down
2 changes: 1 addition & 1 deletion src/plugins/intel_cpu/src/nodes/fullyconnected.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -652,8 +652,8 @@ void FullyConnected::needSplitMemoryForTensorParallel() {
}
memory[ARG_BIAS] = tp_cfg.cached_splited_bias;
// dst
memory[ARG_DST] = getDstMemoryAtPort(0);
tp_cfg.cached_dst = split_horizontal(context->getEngine(), dst, -1, tp_cfg.w_rank, tp_cfg.w_size, false);
memory[ARG_DST] = tp_cfg.cached_dst;

if (auto it = memory.find(ARG_DST | ARG_ATTR_SCALES); it != memory.end()) {
it->second = split_horizontal(context->getEngine(), it->second, 0, tp_cfg.w_rank, tp_cfg.w_size);
Expand Down
Loading