Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions paddle/phi/core/tensor_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -994,7 +994,7 @@ template phi::dtype::complex<double> GetValue(const DenseTensor* x);
template <typename T>
std::vector<T> GetVectorFromTensor(const DenseTensor* x) {
std::vector<T> vec_new_data;
if (phi::TransToProtoVarType(x->dtype()) == ProtoDataType::INT32) {
if (x->dtype() == DataType::INT32) {
auto* data = x->data<int>();
DenseTensor cpu_attr_tensor;
if (x->place().GetType() != phi::AllocationType::CPU) {
Expand All @@ -1004,7 +1004,7 @@ std::vector<T> GetVectorFromTensor(const DenseTensor* x) {
data = cpu_attr_tensor.data<int>();
}
vec_new_data = std::vector<T>(data, data + x->numel());
} else if (phi::TransToProtoVarType(x->dtype()) == ProtoDataType::INT64) {
} else if (x->dtype() == DataType::INT64) {
auto* data = x->data<int64_t>();
DenseTensor cpu_attr_tensor;
if (x->place().GetType() != phi::AllocationType::CPU) {
Expand All @@ -1018,7 +1018,7 @@ std::vector<T> GetVectorFromTensor(const DenseTensor* x) {
} else {
PADDLE_THROW(common::errors::InvalidArgument(
"The dtype of Tensor must be int32 or int64, but received: %s",
phi::TransToProtoVarType(x->dtype())));
x->dtype()));
}
return vec_new_data;
}
Expand Down Expand Up @@ -1046,20 +1046,18 @@ std::vector<T> _GetVectorFromTensor(const DenseTensor* x) {

template <>
std::vector<float> GetVectorFromTensor<float>(const DenseTensor* x) {
if (phi::TransToProtoVarType(x->dtype()) != ProtoDataType::FP32) {
if (x->dtype() != DataType::FLOAT32) {
PADDLE_THROW(common::errors::InvalidArgument(
"The dtype of Tensor must be float32, but received: %s",
phi::TransToProtoVarType(x->dtype())));
"The dtype of Tensor must be float32, but received: %s", x->dtype()));
}
return _GetVectorFromTensor<float>(x);
}

template <>
std::vector<double> GetVectorFromTensor<double>(const DenseTensor* x) {
if (phi::TransToProtoVarType(x->dtype()) != ProtoDataType::FP64) {
if (x->dtype() != DataType::FLOAT64) {
PADDLE_THROW(common::errors::InvalidArgument(
"The dtype of Tensor must be float64, but received: %s",
phi::TransToProtoVarType(x->dtype())));
"The dtype of Tensor must be float64, but received: %s", x->dtype()));
}
return _GetVectorFromTensor<double>(x);
}
Expand Down
7 changes: 4 additions & 3 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2849,11 +2849,12 @@ void FusionGroupInferMeta(const std::vector<const MetaTensor*>& ins,
}

for (size_t j = 0; j < num_outs; ++j) {
if (outs_dtype[j] == phi::TransToProtoVarType(DataType::FLOAT16)) {
DataType out_dtype = TransToPhiDataType(outs_dtype[j]);
if (out_dtype == DataType::FLOAT16) {
outs[j]->set_dtype(DataType::FLOAT16);
} else if (outs_dtype[j] == phi::TransToProtoVarType(DataType::FLOAT32)) {
} else if (out_dtype == DataType::FLOAT32) {
outs[j]->set_dtype(DataType::FLOAT32);
} else if (outs_dtype[j] == phi::TransToProtoVarType(DataType::FLOAT64)) {
} else if (out_dtype == DataType::FLOAT64) {
outs[j]->set_dtype(DataType::FLOAT64);
}
}
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/kernels/cpu/pyramid_hash_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,8 @@ void CPUPyramidHashOPKernel(const Context& dev_ctx,
if (iter != iter_end) {
exit(1);
}
auto weight_type = TransToProtoVarType(_blobs_0->dtype());
if (_is_training == 0 && weight_type != ProtoDataType::INT8) {
auto weight_type = _blobs_0->dtype();
if (_is_training == 0 && weight_type != DataType::INT8) {
funcs::axpy_noadd(
top_data, top_data, top->dims()[0] * top->dims()[1], _drop_out_percent);
}
Expand Down
28 changes: 12 additions & 16 deletions paddle/phi/kernels/cpu/tdm_sampler_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,9 @@ void TDMSamplerKernel(const Context &dev_ctx,
DenseTensor *out,
DenseTensor *labels,
DenseTensor *mask) {
const auto &input_type = TransToProtoVarType(x.dtype());
const auto &input_type = x.dtype();
bool input_type_match =
input_type == ProtoDataType::INT32 || input_type == ProtoDataType::INT64;
input_type == DataType::INT32 || input_type == DataType::INT64;
PADDLE_ENFORCE_EQ(input_type_match,
true,
common::errors::InvalidArgument(
Expand All @@ -274,9 +274,9 @@ void TDMSamplerKernel(const Context &dev_ctx,
DataTypeToString(DataType::INT32),
DataTypeToString(DataType::INT64)));

const auto &travel_type = TransToProtoVarType(travel.dtype());
bool travel_type_match = travel_type == ProtoDataType::INT32 ||
travel_type == ProtoDataType::INT64;
const auto &travel_type = travel.dtype();
bool travel_type_match =
travel_type == DataType::INT32 || travel_type == DataType::INT64;
PADDLE_ENFORCE_EQ(travel_type_match,
true,
common::errors::InvalidArgument(
Expand All @@ -286,9 +286,9 @@ void TDMSamplerKernel(const Context &dev_ctx,
DataTypeToString(DataType::INT32),
DataTypeToString(DataType::INT64)));

const auto &layer_type = TransToProtoVarType(layer.dtype());
const auto &layer_type = layer.dtype();
bool layer_type_match =
layer_type == ProtoDataType::INT32 || layer_type == ProtoDataType::INT64;
layer_type == DataType::INT32 || layer_type == DataType::INT64;
PADDLE_ENFORCE_EQ(layer_type_match,
true,
common::errors::InvalidArgument(
Expand All @@ -305,10 +305,9 @@ void TDMSamplerKernel(const Context &dev_ctx,
DataTypeToString(travel.dtype()),
DataTypeToString(layer.dtype())));

auto output_type = static_cast<ProtoDataType>(dtype);
auto output_type = TransToPhiDataType(dtype);

if (travel_type == ProtoDataType::INT32 &&
output_type == ProtoDataType::INT32) {
if (travel_type == DataType::INT32 && output_type == DataType::INT32) {
TDMSamplerInner<T, Context, int, int>(dev_ctx,
x,
travel,
Expand All @@ -320,8 +319,7 @@ void TDMSamplerKernel(const Context &dev_ctx,
out,
labels,
mask);
} else if (travel_type == ProtoDataType::INT64 &&
output_type == ProtoDataType::INT32) {
} else if (travel_type == DataType::INT64 && output_type == DataType::INT32) {
TDMSamplerInner<T, Context, int64_t, int>(dev_ctx,
x,
travel,
Expand All @@ -333,8 +331,7 @@ void TDMSamplerKernel(const Context &dev_ctx,
out,
labels,
mask);
} else if (travel_type == ProtoDataType::INT32 &&
output_type == ProtoDataType::INT64) {
} else if (travel_type == DataType::INT32 && output_type == DataType::INT64) {
TDMSamplerInner<T, Context, int, int64_t>(dev_ctx,
x,
travel,
Expand All @@ -346,8 +343,7 @@ void TDMSamplerKernel(const Context &dev_ctx,
out,
labels,
mask);
} else if (travel_type == ProtoDataType::INT64 &&
output_type == ProtoDataType::INT64) {
} else if (travel_type == DataType::INT64 && output_type == DataType::INT64) {
TDMSamplerInner<T, Context, int64_t, int64_t>(dev_ctx,
x,
travel,
Expand Down
35 changes: 16 additions & 19 deletions paddle/phi/kernels/fusion/gpu/fusion_group_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,14 @@ static void MutableMultiTypeData(std::vector<DenseTensor*>* var,
const std::vector<int>& data_type,
const DeviceContext& dev_ctx) {
for (size_t i = 0; i < var->size(); i++) {
if (data_type[i] == phi::TransToProtoVarType(phi::DataType::FLOAT32)) {
DataType dtype = TransToPhiDataType(data_type[i]);
if (dtype == DataType::FLOAT32) {
dev_ctx.template Alloc<float>((*var)[i],
(*var)[i]->numel() * sizeof(float));
} else if (data_type[i] ==
phi::TransToProtoVarType(phi::DataType::FLOAT16)) {
dev_ctx.template Alloc<phi::float16>(
(*var)[i], (*var)[i]->numel() * sizeof(phi::float16));
} else if (data_type[i] ==
phi::TransToProtoVarType(phi::DataType::FLOAT64)) {
} else if (dtype == DataType::FLOAT16) {
dev_ctx.template Alloc<float16>((*var)[i],
(*var)[i]->numel() * sizeof(float16));
} else if (dtype == DataType::FLOAT64) {
dev_ctx.template Alloc<double>((*var)[i],
(*var)[i]->numel() * sizeof(double));
}
Expand Down Expand Up @@ -66,25 +65,23 @@ void FusionGroupKernel(const Context& dev_ctx,
args.push_back(&n);
std::vector<const void*> ptrs(num_ins + num_outs);
for (size_t i = 0; i < num_ins; ++i) {
if (inputs_dtype[i] == phi::TransToProtoVarType(phi::DataType::FLOAT16)) {
ptrs[i] = ins[i]->data<phi::float16>();
} else if (inputs_dtype[i] ==
phi::TransToProtoVarType(phi::DataType::FLOAT32)) {
DataType input_dtype = TransToPhiDataType(inputs_dtype[i]);
if (input_dtype == DataType::FLOAT16) {
ptrs[i] = ins[i]->data<float16>();
} else if (input_dtype == DataType::FLOAT32) {
ptrs[i] = ins[i]->data<float>();
} else if (inputs_dtype[i] ==
phi::TransToProtoVarType(phi::DataType::FLOAT64)) {
} else if (input_dtype == DataType::FLOAT64) {
ptrs[i] = ins[i]->data<double>();
}
args.push_back(&ptrs[i]);
}
for (size_t j = 0; j < num_outs; ++j) {
if (outs_dtype[j] == phi::TransToProtoVarType(phi::DataType::FLOAT16)) {
ptrs[num_ins + j] = outs[j]->data<phi::float16>();
} else if (outs_dtype[j] ==
phi::TransToProtoVarType(phi::DataType::FLOAT32)) {
DataType out_dtype = TransToPhiDataType(outs_dtype[j]);
if (out_dtype == DataType::FLOAT16) {
ptrs[num_ins + j] = outs[j]->data<float16>();
} else if (out_dtype == DataType::FLOAT32) {
ptrs[num_ins + j] = outs[j]->data<float>();
} else if (outs_dtype[j] ==
phi::TransToProtoVarType(phi::DataType::FLOAT64)) {
} else if (out_dtype == DataType::FLOAT64) {
ptrs[num_ins + j] = outs[j]->data<double>();
}
args.push_back(&ptrs[num_ins + j]);
Expand Down
5 changes: 2 additions & 3 deletions paddle/phi/kernels/fusion/onednn/fusion_gru_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -489,14 +489,13 @@ void RunKernel(const phi::OneDNNContext& dev_ctx,
std::shared_ptr<dnnl::memory> h0_memory_p, weight_h_memory_p,
weight_x_memory_p;

if (phi::TransToProtoVarType(weight_h.dtype()) == phi::ProtoDataType::FP32) {
if (weight_h.dtype() == DataType::FLOAT32) {
h0_memory_p = handler.template AcquireH0Memory<float>(h0.get_ptr());
weight_x_memory_p =
handler.template AcquireWeightXMemory<float>(&weight_x, origin_mode);
weight_h_memory_p =
handler.template AcquireWeightHMemory<float>(&weight_h, origin_mode);
} else if (phi::TransToProtoVarType(weight_h.dtype()) ==
phi::ProtoDataType::BF16) {
} else if (weight_h.dtype() == DataType::BFLOAT16) {
h0_memory_p = handler.template AcquireH0Memory<phi::bfloat16>(h0.get_ptr());
weight_x_memory_p = handler.template AcquireWeightXMemory<phi::bfloat16>(
&weight_x, origin_mode);
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/kernels/onednn/matmul_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -449,9 +449,9 @@ std::shared_ptr<MulPrimitiveFactory<XT, YT, OT>> GetPrimitiveFactory(
const DenseTensor *input_y,
const engine &onednn_engine) {
std::string key = funcs::CreateKey(dev_ctx,
phi::TransToProtoVarType(input_x->dtype()),
TransToProtoVarType(input_x->dtype()),
vectorize(input_x->dims()),
phi::TransToProtoVarType(input_y->dtype()),
TransToProtoVarType(input_y->dtype()),
vectorize(input_y->dims()),
dev_ctx.GetOutputsName("Out")[0]);
key = funcs::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
Expand Down
Loading