Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

【DONT MERGE】 test softmax speed #1326

Open
wants to merge 24 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions cinn/auto_schedule/cost_model/feature_extractor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,12 @@ VisitCountMemberPattern(Alloc, mem_alloc);
VisitCountMemberPattern(Free, mem_free);
VisitCountMemberPattern(Load, mem_read);
VisitCountMemberPattern(Store, mem_write);
VisitCountMemberPattern(LocalTemp, bool_op);
VisitCountMemberPattern(Sqrt, bool_op);
VisitCountMemberPattern(LoadIndex, bool_op);
VisitCountMemberPattern(ReduceMax, bool_op);
VisitCountMemberPattern(BlockLoad, bool_op);
VisitCountMemberPattern(BlockStore, bool_op);

/* Visit for loops */

Expand Down
119 changes: 89 additions & 30 deletions cinn/backends/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,40 @@ void CodeGenC::Visit(const ir::Not *op) {
IrPrinter::Print(op->v());
os() << ")";
}
void CodeGenC::Visit(const ir::Cast *op) { PrintCastExpr(op->type(), op->v()); }
void CodeGenC::Visit(const ir::LocalTemp* op)
{ IrPrinter::Visit( op ); }

void CodeGenC::Visit(const ir::Sqrt* op)
{ IrPrinter::Visit( op ); }

void CodeGenC::Visit(const ir::LoadIndex* op)
{ IrPrinter::Visit( op ); }

void CodeGenC::Visit(const ir::ReduceMax* op)
{ IrPrinter::Visit( op ); }

void CodeGenC::Visit(const ir::BlockLoad* op)
{ IrPrinter::Visit( op ); }

void CodeGenC::Visit(const ir::BlockStore* op)
{ IrPrinter::Visit( op ); }



void CodeGenC::Visit(const ir::Cast *op) {
// PrintCastExpr(op->type(), op->v());
//IrPrinter::Visit( op );
os() << "static_cast<half>(";
Print( op->v() ) ;

// auto v = op->v().As<ir::Load>();

// Visit(v);

os() << ")";
}
void CodeGenC::Visit(const ir::For *op) {
// std::cerr << "visit loop" << std::endl;
Expr extent = op->extent;
Expr min = op->min;
int num_task = 1;
Expand All @@ -209,8 +241,13 @@ void CodeGenC::Visit(const ir::For *op) {
extent = (task_id + 1) * n_per_task;
DoIndent();
}
if( op->is_unrolled() )
{
os() << "#pragma unroll" << std::endl;
}
os() << "for (";
os() << GetTypeRepr(Int(32));
// os() << GetTypeRepr(Int(32));
os() << "int";
os() << " " << op->loop_var->name;
os() << " = ";
Print(min);
Expand Down Expand Up @@ -321,7 +358,7 @@ void CodeGenC::Visit(const ir::Block *op) {
if (op->stmts.size() >= 1) {
DoIndent();
Print(op->stmts.back());
os() << ";";
os() << ";\n";
}

DecIndent();
Expand Down Expand Up @@ -438,33 +475,47 @@ void CodeGenC::Visit(const ir::_Module_ *op) { CINN_NOT_IMPLEMENTED }
void CodeGenC::Visit(const ir::_Var_ *op) { os() << op->name; }

void CodeGenC::Visit(const ir::Load *op) {
Expr dense_strided_ramp = detail::StridedRampBase(op->index(), 1);
if (dense_strided_ramp.defined()) { // Loading a continuous Ramp address.
CHECK(op->type().is_vector());
PrintStackVecType(op->type().ElementOf(), op->index().type().lanes());
os() << "::"
<< "Load(";
os() << op->tensor.As<ir::_Tensor_>()->name;
os() << ",";
Print(dense_strided_ramp);
os() << ")";
} else if (op->index().type().is_vector()) {
// gather
CHECK(op->type().is_vector());
PrintStackVecType(op->type().ElementOf(), op->index().type().lanes());
os() << "::Load(";
os() << op->tensor.As<ir::_Tensor_>()->name;
os() << ",";
Print(op->index());
os() << ")";
} else if (op->is_addr_tensor()) {
auto *tensor = op->tensor.As<ir::_Tensor_>();
auto *tensor = op->tensor.As<ir::_Tensor_>();
os() << tensor->name << "[";
Print(op->index());
// Print(op->index());
os() << op->indices.front();
for( int i = 1; i < op->indices.size(); ++i )
{
os() << "][" << op->indices[i];
}
os() << "]";
} else {
IrPrinter::Visit(op);
}
// Expr dense_strided_ramp = detail::StridedRampBase(op->index(), 1);
// if (dense_strided_ramp.defined()) { // Loading a continuous Ramp address.
// CHECK(op->type().is_vector());
// PrintStackVecType(op->type().ElementOf(), op->index().type().lanes());
// os() << "::"
// << "Load(";
// os() << op->tensor.As<ir::_Tensor_>()->name;
// os() << ",";
// Print(dense_strided_ramp);
// os() << ")";
// } else if (op->index().type().is_vector()) {
// // gather
// CHECK(op->type().is_vector());
// PrintStackVecType(op->type().ElementOf(), op->index().type().lanes());
// os() << "::Load(";
// os() << op->tensor.As<ir::_Tensor_>()->name;
// os() << ",";
// Print(op->index());
// os() << ")";
// } else if (op->is_addr_tensor()) {
// auto *tensor = op->tensor.As<ir::_Tensor_>();
// os() << tensor->name << "[";
// // Print(op->index());
// os() << op->indices.front();
// for( int i = 1; i < op->indices.size(); ++i )
// {
// os() << "][" << op->indices[i];
// }
// os() << "]";
// } else {
// IrPrinter::Visit(op);
// }
}

void CodeGenC::Visit(const ir::Store *op) {
Expand All @@ -473,7 +524,12 @@ void CodeGenC::Visit(const ir::Store *op) {
auto *tensor = op->tensor.As<ir::_Tensor_>();
CHECK(tensor);
os() << tensor->name << "[";
Print(op->index());
// Print(op->index());
os() << op->indices.front();
for( int i = 1; i < op->indices.size(); ++i )
{
os() << "][" << op->indices[i];
}
os() << "]";
os() << " = ";
Print(op->value);
Expand Down Expand Up @@ -508,7 +564,10 @@ void CodeGenC::Visit(const ir::Let *op) {
os() << "auto";
is_vec = true;
} else {
os() << GetTypeRepr(op->type());
if ( op->with_dtype )
{
os() << GetTypeRepr(op->type());
}
}

os() << " ";
Expand Down
4 changes: 2 additions & 2 deletions cinn/backends/codegen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class CodeGenC : public ir::IrPrinter {
//! Disable inline the builtin codes(too large) for simpler string comparation.
void SetInlineBuiltinCodes(bool x = true) { inline_builtin_codes_ = x; }

protected:
public:
std::string Compile(const ir::LoweredFunc& function);
std::string Compile(const ir::Buffer& buffer);

Expand Down Expand Up @@ -111,7 +111,7 @@ class CodeGenC : public ir::IrPrinter {

friend class ExternFunctionEmitter;

protected:
public:
Target target_;
std::stringstream ss_;
bool inline_builtin_codes_{true};
Expand Down
13 changes: 10 additions & 3 deletions cinn/backends/codegen_cuda_dev.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ const std::string CodeGenCUDA_Dev::source_header_ =
#include "float16.h"
using cinn::common::float16;

#include <curand_kernel.h>
#include "cinn_cuda_runtime_source.cuh"
)";

Expand Down Expand Up @@ -74,6 +75,9 @@ void CodeGenCUDA_Dev::Compile(const ir::Module &module, const Outputs &outputs)
}

std::string CodeGenCUDA_Dev::Compile(const ir::LoweredFunc &func) {
// std::cerr << "fun " << func << std::endl;
std::cerr << "!!!==============\n";
std::cerr << func << std::endl;
Print(Expr(func));
return ss_.str();
}
Expand Down Expand Up @@ -223,7 +227,7 @@ std::string CodeGenCUDA_Dev::Compile(const ir::Module &module, CodeGenC::OutputK

PrintBuiltinCodes();

for (auto &func : module.functions()) {
for (auto &func : module.functions()) {
Compile(func);
}
} else {
Expand Down Expand Up @@ -317,8 +321,11 @@ void CodeGenCUDA_Dev::Visit(const ir::Let *op) {
// with customized_type::kcuda_builtin_vector_t prefix, and save their names
if (op->type().is_customized() &&
utils::Startswith(op->type().customized_type(), common::customized_type::kcuda_builtin_vector_t)) {
os() << GetTypeRepr(op->type());
os() << " ";
if( op->with_dtype )
{
os() << GetTypeRepr(op->type());
os() << " ";
}
Print(op->symbol);
vectorized_tensor_names_.insert(utils::GetStreamCnt(op->symbol));
os() << " = ";
Expand Down
3 changes: 2 additions & 1 deletion cinn/backends/codegen_cuda_dev.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ class CodeGenCUDA_Dev : public CodeGenC {

const std::string& GetSourceHeader() const;

protected:
public:

void Visit(const ir::_Var_* op) override;
void Visit(const ir::_LoweredFunc_* op) override;
void Visit(const ir::Min* op) override;
Expand Down
1 change: 1 addition & 0 deletions cinn/backends/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ void Compiler::CompileCudaModule(const Module& module, const std::string& code)
VLOG(3) << "[CUDA] device module:\n" << device_module;
CodeGenCUDA_Dev codegen(target_);
auto source_code = codegen.Compile(device_module);
std::cerr << "source code here " << source_code << std::endl;

VLOG(3) << "[CUDA] C:\n" << source_code;
if (!code.empty()) source_code = code;
Expand Down
36 changes: 36 additions & 0 deletions cinn/backends/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,42 @@ llvm::Value *CodeGenLLVM::Visit(const ir::IntImm *op) {
return llvm::ConstantInt::get(type, op->value, true);
}

llvm::Value *CodeGenLLVM::Visit(const ir::LocalTemp *op) {
std::cerr << "not impl in llvm gen";
auto *type = b_->getIntNTy(op->type().bits());
return llvm::ConstantInt::get(type, 1, true);
}

llvm::Value *CodeGenLLVM::Visit(const ir::Sqrt *op) {
std::cerr << "not impl in llvm gen";
auto *type = b_->getIntNTy(op->type().bits());
return llvm::ConstantInt::get(type, 1, true);
}

llvm::Value *CodeGenLLVM::Visit(const ir::LoadIndex *op) {
std::cerr << "not impl in llvm gen";
auto *type = b_->getIntNTy(op->type().bits());
return llvm::ConstantInt::get(type, op->reduce_block, true);
}

llvm::Value *CodeGenLLVM::Visit(const ir::ReduceMax *op) {
std::cerr << "not impl in reduceMax llvm gen";
auto *type = b_->getIntNTy(op->type().bits());
return llvm::ConstantInt::get(type, op->axis, true);
}

llvm::Value *CodeGenLLVM::Visit(const ir::BlockLoad *op) {
std::cerr << "not impl in block load llvm gen";
auto *type = b_->getIntNTy(op->type().bits());
return llvm::ConstantInt::get(type, 1, true);
}

llvm::Value *CodeGenLLVM::Visit(const ir::BlockStore *op) {
std::cerr << "not impl in block store llvm gen";
auto *type = b_->getIntNTy(op->type().bits());
return llvm::ConstantInt::get(type, 1, true);
}

llvm::Value *CodeGenLLVM::Visit(const ir::UIntImm *op) {
if (op->type().is_bool()) {
auto *type = b_->getInt1Ty();
Expand Down
6 changes: 3 additions & 3 deletions cinn/backends/llvm/execution_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,9 @@ void ExecutionEngine::Link(const ir::Module &module) {
LLVMModuleOptimizer optimize(machine.get(), 3, {}, true);
optimize(m.get());
CHECK(!llvm::verifyModule(*m, &llvm::errs())) << "Invalid optimized module detected";
for (auto &f : *m) {
VLOG(5) << "function: " << DumpToString(f);
}
// for (auto &f : *m) {
// VLOG(5) << "function: " << DumpToString(f);
// }

llvm::raw_svector_ostream rawstream(buffer_);
llvm::legacy::PassManager pass_manager;
Expand Down
6 changes: 4 additions & 2 deletions cinn/common/type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,8 @@ Type Str2Type(const std::string &type) {
std::string Type2Str(const Type &type) {
switch (type.type()) {
case Type::type_t::Int:
return "int" + std::to_string(type.bits());
return "int";
//return "int" + std::to_string(type.bits());

case Type::type_t::UInt:
if (type.bits() == 1) {
Expand All @@ -490,7 +491,8 @@ std::string Type2Str(const Type &type) {
}

case Type::type_t::Float:
return "float" + std::to_string(type.bits());
return "float";
//return "float" + std::to_string(type.bits());

case Type::type_t::Void:
return "void";
Expand Down
20 changes: 20 additions & 0 deletions cinn/frontend/net_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,26 @@ void NetBuilder::InferShape(Instruction instr) const {
}
}

Variable NetBuilder::Load( const Variable& input, const Variable& slice )
{
Instruction instr("load", {input, slice});

auto& outs = instr->outputs;
outs.resize( {1} );
outs[0]->shape = slice->shape;

AppendInstruction(instr);
return instr.GetOutput(0);
}

Variable NetBuilder::Store( const Variable& output, const Variable& slice, const Variable& update_value)
{
Instruction instr("store", {output, slice, update_value});

AppendInstruction(instr);
return instr.GetOutput(0);
}

const std::vector<Variable>& NetBuilder::CustomInstr(const std::string& type,
const std::vector<Variable>& inputs,
const AttributeMap& attrs) {
Expand Down
2 changes: 2 additions & 0 deletions cinn/frontend/net_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ class NetBuilder {
const std::vector<Variable>& inputs,
const AttributeMap& attrs);

Variable Load( const Variable& input, const Variable& slice );
Variable Store( const Variable& output, const Variable& slice, const Variable& update_value);
protected:
/**
* @brief Helper function of UnaryOp.
Expand Down
Loading