Skip to content

Commit

Permalink
polish fused cpu and gpu op
Browse files Browse the repository at this point in the history
  • Loading branch information
cjld committed Mar 22, 2022
1 parent 1987728 commit 5062b2d
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 35 deletions.
5 changes: 5 additions & 0 deletions python/jittor/src/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,11 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
var->alloc(cpu_allocator);
}
}
} else {
for (Var* v : op->inputs()) {
if (!v->allocator->is_cuda())
migrate_to_gpu(v, allocator);
}
}
#endif
#ifdef NODE_MEMCHECK
Expand Down
25 changes: 16 additions & 9 deletions python/jittor/src/fused_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "op_compiler.h"
#include "profiler/profiler.h"
#include "misc/fast_shared_ptr.h"
#include "misc/cuda_flags.h"

namespace jittor {

Expand Down Expand Up @@ -42,6 +43,7 @@ void FusedOp::update_ops() {
loop_options_tuned.clear();
loop_options = loop_options_origin = nullptr;

_inputs.clear();
_outputs.clear();
for (Op* op : ops) {
for (Var* o : op->outputs()) {
Expand Down Expand Up @@ -101,6 +103,7 @@ void FusedOp::update_ops() {
if (!(c&2)) {
c += 2 + vars.size()*4;
vars.push_back({i, 0});
_inputs.emplace_back((Node*)i);
}
}
for (Var* o : opi->outputs()) {
Expand Down Expand Up @@ -135,6 +138,7 @@ FusedOp::FusedOp(const FusedOp& other) {
}

FusedOp::~FusedOp() {
_inputs.clear();
_outputs.clear();
Op::number_of_lived_ops++;
}
Expand All @@ -159,20 +163,15 @@ void FusedOp::statistics(uint64_t& in, uint64_t& out, uint64_t& compute) {

void FusedOp::do_jit_prepare(JK& jk) {
jk.clear();
int8 flags = 3;
for (uint i=0; i<ops.size(); i++) {
Op* op = ops[i];
jk << "[opkey" << i << JK::val;
op->do_jit_prepare(jk);
jk << op->name();
op->jit_prepare(jk);
jk << JK::end;
if (op->flags.get(NodeFlags::_cpu))
flags &= 1; // only cpu
else
flags &= 2; // only gpu
}
ASSERT(flags) << "FusedOp cannot contain both cpu and cuda ops.";
jk << _CS("[JIT:1]");
if (flags==1) {
if (!use_cuda) {
// only cpu
jk << _CS("[JIT_cpu:1]");
this->flags.set(NodeFlags::_cuda, 0);
Expand All @@ -189,9 +188,17 @@ void FusedOp::do_jit_prepare(JK& jk) {
jk << JK::hex2(i) << JK::hex1(j) << JK::hex2(k) << JK::hex1(l) << ',';
}
jk << _CS("][var_info:") << JK::val;
for (auto& vi : vars)
bool use_int64_t = false;
for (auto& vi : vars) {
jk << JK::hex1(vi.type) << JK::hex1(vi.var->shape.size());
if (vi.type != 1 && vi.var->num >= std::numeric_limits<int32_t>::max())
use_int64_t = true;
}
jk << JK::end;
if (use_int64_t)
jk << _CS("[index_t:int64]");
else
jk << _CS("[index_t:int32]");
if (loop_options->size()) {
if (get_loop_option("compile_shapes")) {
jk << _CS("[shapes:");
Expand Down
21 changes: 1 addition & 20 deletions python/jittor/src/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,43 +123,24 @@ void Op::do_jit_prepare(JK& jk) {
if (has_cuda && has_cpu && !use_cuda)
flags.set(NodeFlags::_cuda, 0);
} else {
// check use int64_t as index_t if array is too big
int in_id=0, out_id=0;
bool use_int64_t = false;
// TODO: fused op do not have inputs,
// check use_cuda_op from outputs may not be enough
bool use_cuda_op = use_cuda;
for (Var* var : inputs()) {
if (var->mem_ptr) {
/* jit key don't include here, because
parallel compiler don't known
jk << JK::key << "alloc_i" << JK::hex1(in_id)
<< JK::hex1(var->allocator->flags()) << JK::end;
*/
use_cuda_op &= var->allocator->is_cuda();
}
if (var->num >= std::numeric_limits<int32_t>::max())
use_int64_t = true;
in_id ++;
}
for (Var* var : outputs()) {
if (var->mem_ptr) {
/*
jk << JK::key << "alloc_o" << JK::hex1(in_id)
<< JK::hex1(var->allocator->flags()) << JK::end;
*/
use_cuda_op &= var->allocator->is_cuda();
}
if (var->num >= std::numeric_limits<int32_t>::max())
use_int64_t = true;
out_id ++;
}
jk << _CS("[JIT:1]");
if (use_cuda_op && flags.get(NodeFlags::_cuda)) {
jk << _CS("[JIT_cuda:1]");
flags.set(NodeFlags::_cpu, 0);
// TODO: 64bit index in CUDA
use_int64_t = false;
// use_int64_t = false;
} else {
if (use_cuda==2) {
if (flags.get(NodeFlags::_cuda))
Expand Down
5 changes: 0 additions & 5 deletions python/jittor/src/ops/copy_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,7 @@ void CopyOp::run() {
auto y_ptr = outputs().front()->mem_ptr;
#ifdef HAS_CUDA
if (flags.get(NodeFlags::_cuda)) {
// TODO: check why cpu allocator in x
#ifdef IS_CUDA
checkCudaErrors(cudaMemcpyAsync(y_ptr, x_ptr, size, cudaMemcpyDefault, 0));
#else
checkCudaErrors(cudaMemcpyAsync(y_ptr, x_ptr, size, cudaMemcpyDeviceToDevice, 0));
#endif
} else
#endif
{
Expand Down
6 changes: 6 additions & 0 deletions python/jittor/test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ def test_cuda_custom_op(self):
assert a.shape == [3,4,5] and a.dtype == 'float'
assert (-na.flatten() == range(3*4*5)).all(), na

def test_cuda_fused_op(self):
a = jt.array([1,2,3])
a.sync()
with jt.flag_scope(use_cuda=1):
((a+a)*2).data


@unittest.skipIf(jt.compiler.has_cuda, "Only test without CUDA")
class TestNoCuda(unittest.TestCase):
Expand Down
1 change: 0 additions & 1 deletion python/jittor/test/test_fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,6 @@ def check_share():
}
}
kernel<<<1024,16*16>>>(in0_p, out0_p);
LOGir << "aaa";
""").sync()
jt.sync_all(True)
# print(a[0]+1)
Expand Down

0 comments on commit 5062b2d

Please sign in to comment.