Skip to content

Commit

Permalink
Rollup merge of #130820 - 91khr:fix-coroutine-unit-arg, r=compiler-er…
Browse files Browse the repository at this point in the history
…rors

Fix diagnostics for coroutines with () as input.

This may be a more real-life example to trigger the diagnostic:

```rust
#![features(try_blocks, coroutine_trait, coroutines)]

use std::ops::Coroutine;

struct Request;
struct Response;
fn get_args() -> Result<String, String> { todo!() }
fn build_request(_arg: String) -> Request { todo!() }
fn work() -> impl Coroutine<Option<Response>, Yield = Request> {
    #[coroutine]
    |_| {
        let r: Result<(), String> = try {
            let req = get_args()?;
            yield build_request(req)
        };
        if let Err(msg) = r {
            eprintln!("Error: {msg}");
        }
    }
}
```
  • Loading branch information
GuillaumeGomez authored Sep 26, 2024
2 parents a4a591a + 986e20d commit 0acddf5
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2635,49 +2635,47 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
// This shouldn't be common unless manually implementing one of the
// traits manually, but don't make it more confusing when it does
// happen.
Ok(
if Some(expected_trait_ref.def_id) != self.tcx.lang_items().coroutine_trait()
&& not_tupled
{
self.report_and_explain_type_error(
TypeTrace::trait_refs(
&obligation.cause,
true,
expected_trait_ref,
found_trait_ref,
),
ty::error::TypeError::Mismatch,
)
} else if found.len() == expected.len() {
self.report_closure_arg_mismatch(
span,
found_span,
found_trait_ref,
expected_trait_ref,
obligation.cause.code(),
found_node,
obligation.param_env,
)
} else {
let (closure_span, closure_arg_span, found) = found_did
.and_then(|did| {
let node = self.tcx.hir().get_if_local(did)?;
let (found_span, closure_arg_span, found) =
self.get_fn_like_arguments(node)?;
Some((Some(found_span), closure_arg_span, found))
})
.unwrap_or((found_span, None, found));

self.report_arg_count_mismatch(
if Some(expected_trait_ref.def_id) != self.tcx.lang_items().coroutine_trait() && not_tupled
{
return Ok(self.report_and_explain_type_error(
TypeTrace::trait_refs(&obligation.cause, true, expected_trait_ref, found_trait_ref),
ty::error::TypeError::Mismatch,
));
}
if found.len() != expected.len() {
let (closure_span, closure_arg_span, found) = found_did
.and_then(|did| {
let node = self.tcx.hir().get_if_local(did)?;
let (found_span, closure_arg_span, found) = self.get_fn_like_arguments(node)?;
Some((Some(found_span), closure_arg_span, found))
})
.unwrap_or((found_span, None, found));

// If the coroutine take a single () as its argument,
// the trait argument would found the coroutine take 0 arguments,
// but get_fn_like_arguments would give 1 argument.
// This would result in "Expected to take 1 argument, but it takes 1 argument".
// Check again to avoid this.
if found.len() != expected.len() {
return Ok(self.report_arg_count_mismatch(
span,
closure_span,
expected,
found,
found_trait_ty.is_closure(),
closure_arg_span,
)
},
)
));
}
}
Ok(self.report_closure_arg_mismatch(
span,
found_span,
found_trait_ref,
expected_trait_ref,
obligation.cause.code(),
found_node,
obligation.param_env,
))
}

/// Given some node representing a fn-like thing in the HIR map,
Expand Down
11 changes: 11 additions & 0 deletions tests/ui/coroutine/arg-count-mismatch-on-unit-input.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#![feature(coroutines, coroutine_trait, stmt_expr_attributes)]

use std::ops::Coroutine;

fn foo() -> impl Coroutine<u8> {
//~^ ERROR type mismatch in coroutine arguments
#[coroutine]
|_: ()| {}
}

fn main() { }
15 changes: 15 additions & 0 deletions tests/ui/coroutine/arg-count-mismatch-on-unit-input.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
error[E0631]: type mismatch in coroutine arguments
--> $DIR/arg-count-mismatch-on-unit-input.rs:5:13
|
LL | fn foo() -> impl Coroutine<u8> {
| ^^^^^^^^^^^^^^^^^^ expected due to this
...
LL | |_: ()| {}
| ------- found signature defined here
|
= note: expected coroutine signature `fn(u8) -> _`
found coroutine signature `fn(()) -> _`

error: aborting due to 1 previous error

For more information about this error, try `rustc --explain E0631`.

0 comments on commit 0acddf5

Please sign in to comment.