diff --git a/cpp2rust/converter/converter.cpp b/cpp2rust/converter/converter.cpp index 0e06acec..923d1bd8 100644 --- a/cpp2rust/converter/converter.cpp +++ b/cpp2rust/converter/converter.cpp @@ -665,6 +665,12 @@ bool Converter::RecordDerivesDefault(const clang::RecordDecl *decl) { } bool Converter::RecordDerivesCopy(const clang::RecordDecl *decl) { + auto *derives = Mapper::MappedDerives(ctx_.getCanonicalTagType(decl)); + return derives && + std::find(derives->begin(), derives->end(), "Copy") != derives->end(); +} + +bool Converter::RecordHasCopyableFields(const clang::RecordDecl *decl) { for (auto f : decl->fields()) { // Records that contain std::vector, std::array, std::string or anything // that is translated to Vec<>, do not derive Copy @@ -751,8 +757,11 @@ void Converter::EmitRustStructOrUnion(clang::RecordDecl *decl) { if (EmitsReprCForRecords()) { StrCat("#[repr(C)]"); } + auto attrs = GetStructAttributes(decl); + Mapper::SetDerives(ctx_.getCanonicalTagType(decl), + std::vector(attrs.begin(), attrs.end())); StrCat("#[derive("); - for (auto *attr : GetStructAttributes(decl)) { + for (auto *attr : attrs) { StrCat(attr, ','); } StrCat(")]"); @@ -3107,6 +3116,8 @@ bool Converter::VisitEnumDecl(clang::EnumDecl *decl) { return false; } Mapper::AddRuleForUserDefinedType(decl); + Mapper::SetDerives(ctx_.getCanonicalTagType(decl), + {"Clone", "Copy", "PartialEq", "Debug", "Default"}); StrCat("#[derive(Clone, Copy, PartialEq, Debug, Default)]"); StrCat(std::format("enum {}", GetRecordName(decl))); StrCat('{'); @@ -3324,6 +3335,12 @@ std::string Converter::GetArrayDefaultAsString(clang::QualType qual_type) { auto size_as_string = GetNumAsString(array_type->getSize()); auto element_type = array_type->getElementType(); auto element_type_as_string = GetDefaultAsString(element_type); + if (auto *rec = element_type->getAsRecordDecl()) { + if (!RecordDerivesCopy(rec)) { + return std::format("std::array::from_fn::<_, {}, _>(|_| {})", + size_as_string.c_str(), element_type_as_string); + } + } return std::format("[{}; {}]", element_type_as_string, size_as_string.c_str()); } @@ -3487,7 +3504,7 @@ Converter::GetStructAttributes(const clang::RecordDecl *decl) { std::vector struct_attrs; - if (RecordDerivesCopy(decl)) { + if (RecordHasCopyableFields(decl)) { struct_attrs.emplace_back("Copy"); } diff --git a/cpp2rust/converter/converter.h b/cpp2rust/converter/converter.h index a8d41078..63de20e4 100644 --- a/cpp2rust/converter/converter.h +++ b/cpp2rust/converter/converter.h @@ -587,6 +587,8 @@ class Converter : public clang::RecursiveASTVisitor { bool RecordDerivesCopy(const clang::RecordDecl *decl); + bool RecordHasCopyableFields(const clang::RecordDecl *decl); + bool ShouldReplaceWithMappedBody(clang::DeclRefExpr *expr) const; std::string *rs_code_; diff --git a/cpp2rust/converter/mapper.cpp b/cpp2rust/converter/mapper.cpp index cd664f1b..c9f6ba0b 100644 --- a/cpp2rust/converter/mapper.cpp +++ b/cpp2rust/converter/mapper.cpp @@ -460,6 +460,12 @@ void addBuiltinTypes(Model model) { const std::string &initializer = {}) { auto plain = TranslationRule::TypeRule::Plain(rust); plain.initializer = initializer; + std::vector derives = {"Copy", "Clone", "Default", + "Debug", "PartialEq", "PartialOrd"}; + if (!(rust == "f32" || rust == "f64")) { + derives.insert(derives.end(), {"Eq", "Ord", "Hash"}); + } + plain.type_info.derives = std::move(derives); AddTypeRule(cxx, TranslationRule::TypeRule(plain)); AddTypeRule("const " + cxx, std::move(plain)); @@ -696,6 +702,16 @@ bool MapsToRefcountPointer(clang::QualType qual_type) { return rule && rule->type_info.is_refcount_pointer; } +const std::vector *MappedDerives(clang::QualType qual_type) { + auto rule = search(qual_type).first; + return rule ? &rule->type_info.derives : nullptr; +} + +void SetDerives(clang::QualType qual_type, std::vector derives) { + if (auto *rule = search(qual_type).first) { + rule->type_info.derives = std::move(derives); + } +} bool ReturnsPointer(const clang::Expr *expr) { auto rule = search(expr); return rule && rule->return_type.is_pointer(); diff --git a/cpp2rust/converter/mapper.h b/cpp2rust/converter/mapper.h index a967870e..34f9afd6 100644 --- a/cpp2rust/converter/mapper.h +++ b/cpp2rust/converter/mapper.h @@ -37,6 +37,8 @@ std::string GetParamType(const clang::Expr *expr, unsigned index); bool ParamIsPointer(const clang::Expr *expr, unsigned index); bool MapsToPointer(clang::QualType qual_type); bool MapsToRefcountPointer(clang::QualType qual_type); +const std::vector *MappedDerives(clang::QualType qual_type); +void SetDerives(clang::QualType qual_type, std::vector derives); enum class ScalarSugar { kDesugar, diff --git a/cpp2rust/converter/translation_rule.cpp b/cpp2rust/converter/translation_rule.cpp index 3f24e59a..65fcaa3f 100644 --- a/cpp2rust/converter/translation_rule.cpp +++ b/cpp2rust/converter/translation_rule.cpp @@ -25,6 +25,12 @@ TypeInfo ParseTypeInfoJSON(const llvm::json::Object &obj) { if (auto v = obj.getBoolean("is_unsafe_pointer")) info.is_unsafe_pointer = *v; assert(!(info.is_refcount_pointer && info.is_unsafe_pointer)); + if (auto *arr = obj.getArray("derives")) { + for (const auto &elem : *arr) { + if (auto s = elem.getAsString()) + info.derives.emplace_back(s->str()); + } + } return info; } @@ -329,6 +335,8 @@ void TypeInfo::dump() const { log() << " [rc_ptr]"; if (is_unsafe_pointer) log() << " [unsafe_ptr]"; + for (const auto &d : derives) + log() << " +" << d; } void TypeRule::dump() const { diff --git a/cpp2rust/converter/translation_rule.h b/cpp2rust/converter/translation_rule.h index a2962cec..906cc7f3 100644 --- a/cpp2rust/converter/translation_rule.h +++ b/cpp2rust/converter/translation_rule.h @@ -54,6 +54,7 @@ struct MethodCallFragment { }; struct TypeInfo { + std::vector derives; std::string type; bool is_refcount_pointer = false; bool is_unsafe_pointer = false; @@ -83,13 +84,13 @@ struct TypeRule { void dump() const; static TypeRule Plain(std::string type) { - return {{}, {}, {std::move(type), false, false}}; + return {{}, {}, {{}, std::move(type), false, false}}; } static TypeRule RefcountPtr(std::string type) { - return {{}, {}, {std::move(type), true, false}}; + return {{}, {}, {{}, std::move(type), true, false}}; } static TypeRule UnsafePtr(std::string type) { - return {{}, {}, {std::move(type), false, true}}; + return {{}, {}, {{}, std::move(type), false, true}}; } }; diff --git a/rule-preprocessor/src/ir.rs b/rule-preprocessor/src/ir.rs index 4697724d..97a4ce0c 100644 --- a/rule-preprocessor/src/ir.rs +++ b/rule-preprocessor/src/ir.rs @@ -44,6 +44,8 @@ pub struct TypeInfo { pub is_refcount_pointer: bool, #[serde(default, skip_serializing_if = "std::ops::Not::not")] pub is_unsafe_pointer: bool, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub derives: Vec, } #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/rule-preprocessor/src/main.rs b/rule-preprocessor/src/main.rs index bd0cca15..afd92a36 100644 --- a/rule-preprocessor/src/main.rs +++ b/rule-preprocessor/src/main.rs @@ -5,9 +5,11 @@ extern crate rustc_driver; extern crate rustc_hir; +extern crate rustc_infer; extern crate rustc_interface; extern crate rustc_middle; extern crate rustc_span; +extern crate rustc_trait_selection; mod ir; mod semantic; diff --git a/rule-preprocessor/src/semantic.rs b/rule-preprocessor/src/semantic.rs index 5f2dabb4..77a12f5c 100644 --- a/rule-preprocessor/src/semantic.rs +++ b/rule-preprocessor/src/semantic.rs @@ -90,6 +90,7 @@ fn find_rlib(deps_dir: &Path, crate_name: &str) -> Option { struct FnDecl<'tcx> { source_file: String, name: String, + def_id: rustc_span::def_id::DefId, body: &'tcx rustc_hir::Body<'tcx>, } @@ -124,11 +125,17 @@ struct MethodResolver { } impl MethodResolver { - fn resolve_fn_decl<'tcx>(&mut self, tcx: rustc_middle::ty::TyCtxt<'tcx>, f: &FnDecl<'tcx>) { - if let Some(file_ir) = self.ir.all_ir.get_mut(&f.source_file) - && let Some(RuleIr::Fn(fn_ir)) = file_ir.get_mut(&f.name) - { - f.resolve_unknowns(tcx, fn_ir); + fn resolve_rule<'tcx>(&mut self, tcx: rustc_middle::ty::TyCtxt<'tcx>, f: &FnDecl<'tcx>) { + let Some(file_ir) = self.ir.all_ir.get_mut(&f.source_file) else { + return; + }; + match file_ir.get_mut(&f.name) { + Some(RuleIr::Fn(fn_ir)) => f.resolve_unknowns(tcx, fn_ir), + Some(RuleIr::Type(type_ir)) => { + let ret_ty = tcx.fn_sig(f.def_id).skip_binder().output().skip_binder(); + type_ir.type_info.derives = type_derives(tcx, ret_ty); + } + None => {} } } @@ -154,7 +161,7 @@ impl rustc_driver::Callbacks for MethodResolver { tcx: rustc_middle::ty::TyCtxt<'_>, ) -> rustc_driver::Compilation { for f in iter_fn_decls(tcx) { - self.resolve_fn_decl(tcx, &f); + self.resolve_rule(tcx, &f); } rustc_driver::Compilation::Stop @@ -181,6 +188,7 @@ fn iter_fn_decls<'tcx>(tcx: rustc_middle::ty::TyCtxt<'tcx>) -> Vec> result.push(FnDecl { source_file, name: ident.name.as_str().to_string(), + def_id: decl_id.owner_id.to_def_id(), body: tcx.hir_body(body_id), }); } @@ -205,6 +213,44 @@ fn decl_source_file( ) } +fn type_derives<'tcx>( + tcx: rustc_middle::ty::TyCtxt<'tcx>, + ty: rustc_middle::ty::Ty<'tcx>, +) -> Vec { + use rustc_infer::infer::TyCtxtInferExt; + use rustc_span::sym; + use rustc_trait_selection::infer::InferCtxtExt; + + let lang = tcx.lang_items(); + let derivable = [ + lang.copy_trait(), + lang.clone_trait(), + tcx.get_diagnostic_item(sym::Debug), + tcx.get_diagnostic_item(sym::Default), + tcx.get_diagnostic_item(sym::PartialEq), + tcx.get_diagnostic_item(sym::Eq), + tcx.get_diagnostic_item(sym::PartialOrd), + tcx.get_diagnostic_item(sym::Ord), + tcx.get_diagnostic_item(sym::Hash), + ]; + + let infcx = tcx + .infer_ctxt() + .build(rustc_middle::ty::TypingMode::non_body_analysis()); + + derivable + .into_iter() + .flatten() + .filter(|&trait_def_id| { + let args = vec![ty; tcx.generics_of(trait_def_id).count()]; + infcx + .type_implements_trait(trait_def_id, args, rustc_middle::ty::ParamEnv::empty()) + .must_apply_modulo_regions() + }) + .map(|trait_def_id| tcx.item_name(trait_def_id).to_string()) + .collect() +} + struct AstVisitor<'a, 'tcx> { tcx: rustc_middle::ty::TyCtxt<'tcx>, param_names: Vec, diff --git a/rule-preprocessor/src/syntactic.rs b/rule-preprocessor/src/syntactic.rs index 3e448636..aa3d8871 100644 --- a/rule-preprocessor/src/syntactic.rs +++ b/rule-preprocessor/src/syntactic.rs @@ -322,6 +322,7 @@ impl<'a> FnIrBuilder<'a> { ty: ty_str, is_refcount_pointer, is_unsafe_pointer, + derives: Vec::new(), }) } } @@ -482,6 +483,7 @@ impl<'a> FnIrBuilder<'a> { ty: p.ty.clone(), is_refcount_pointer: p.is_refcount_pointer, is_unsafe_pointer: p.is_unsafe_pointer, + derives: Vec::new(), }, ) }) @@ -560,6 +562,7 @@ impl<'a> TypeIrBuilder<'a> { ty: ty.syntax().text().to_string(), is_refcount_pointer, is_unsafe_pointer, + derives: Vec::new(), }, } } diff --git a/rules/socket/src.c b/rules/socket/src.c index 25ee0d0a..ca290b1d 100644 --- a/rules/socket/src.c +++ b/rules/socket/src.c @@ -2,6 +2,8 @@ #include #include +typedef struct sockaddr t1; + int f1() { return MSG_NOSIGNAL; } diff --git a/rules/socket/tgt_unsafe.rs b/rules/socket/tgt_unsafe.rs index 0889064c..f263d8c5 100644 --- a/rules/socket/tgt_unsafe.rs +++ b/rules/socket/tgt_unsafe.rs @@ -1,3 +1,7 @@ +fn t1() -> libc::sockaddr { + unsafe { std::mem::zeroed() } +} + unsafe fn f1() -> i32 { libc::MSG_NOSIGNAL } diff --git a/tests/unit/array_of_noncopy_struct.cpp b/tests/unit/array_of_noncopy_struct.cpp new file mode 100644 index 00000000..39abb492 --- /dev/null +++ b/tests/unit/array_of_noncopy_struct.cpp @@ -0,0 +1,19 @@ +#include +#include + +struct NonCopy { + std::vector data; + int tag = 0; +}; + +int main() { + NonCopy arr[3]; + arr[0].tag = 7; + arr[1].data.push_back(42); + assert(arr[0].tag == 7); + assert(arr[1].data.size() == 1); + assert(arr[1].data[0] == 42); + assert(arr[2].tag == 0); + assert(arr[2].data.size() == 0); + return 0; +} diff --git a/tests/unit/out/refcount/array_of_noncopy_struct.rs b/tests/unit/out/refcount/array_of_noncopy_struct.rs new file mode 100644 index 00000000..4af56349 --- /dev/null +++ b/tests/unit/out/refcount/array_of_noncopy_struct.rs @@ -0,0 +1,46 @@ +extern crate libcc2rs; +use libcc2rs::*; +use std::cell::RefCell; +use std::collections::BTreeMap; +use std::io::prelude::*; +use std::io::{Read, Seek, Write}; +use std::os::fd::AsFd; +use std::rc::{Rc, Weak}; +#[derive(Default)] +pub struct NonCopy { + pub data: Value>, + pub tag: Value, +} +impl Clone for NonCopy { + fn clone(&self) -> Self { + let mut this = Self { + data: Rc::new(RefCell::new((*self.data.borrow()).clone())), + tag: Rc::new(RefCell::new((*self.tag.borrow()))), + }; + this + } +} +impl ByteRepr for NonCopy {} +pub fn main() { + std::process::exit(main_0()); +} +fn main_0() -> i32 { + let arr: Value> = Rc::new(RefCell::new( + (0..3) + .map(|_| ::default()) + .collect::>(), + )); + (*(*arr.borrow())[(0) as usize].tag.borrow_mut()) = 7; + (*(*arr.borrow())[(1) as usize].data.borrow_mut()).push(42); + assert!(((*(*arr.borrow())[(0) as usize].tag.borrow()) == 7)); + assert!(((*(*arr.borrow())[(1) as usize].data.borrow()).len() == 1_usize)); + assert!( + ((((*arr.borrow())[(1) as usize].data.as_pointer() as Ptr) + .offset(0_usize as isize) + .read()) + == 42) + ); + assert!(((*(*arr.borrow())[(2) as usize].tag.borrow()) == 0)); + assert!(((*(*arr.borrow())[(2) as usize].data.borrow()).len() == 0_usize)); + return 0; +} diff --git a/tests/unit/out/unsafe/array_of_noncopy_struct.rs b/tests/unit/out/unsafe/array_of_noncopy_struct.rs new file mode 100644 index 00000000..7fedfe8c --- /dev/null +++ b/tests/unit/out/unsafe/array_of_noncopy_struct.rs @@ -0,0 +1,30 @@ +extern crate libc; +use libc::*; +extern crate libcc2rs; +use libcc2rs::*; +use std::collections::BTreeMap; +use std::io::{Read, Seek, Write}; +use std::os::fd::{AsFd, FromRawFd, IntoRawFd}; +use std::rc::Rc; +#[repr(C)] +#[derive(Clone, Default)] +pub struct NonCopy { + pub data: Vec, + pub tag: i32, +} +pub fn main() { + unsafe { + std::process::exit(main_0() as i32); + } +} +unsafe fn main_0() -> i32 { + let mut arr: [NonCopy; 3] = std::array::from_fn::<_, 3, _>(|_| ::default()); + arr[(0) as usize].tag = 7; + arr[(1) as usize].data.push(42); + assert!(((arr[(0) as usize].tag) == (7))); + assert!(((arr[(1) as usize].data.len()) == (1_usize))); + assert!(((arr[(1) as usize].data[(0_usize)]) == (42))); + assert!(((arr[(2) as usize].tag) == (0))); + assert!(((arr[(2) as usize].data.len()) == (0_usize))); + return 0; +} diff --git a/tests/unit/out/unsafe/libc_struct_without_default.rs b/tests/unit/out/unsafe/libc_struct_without_default.rs index ede5837a..7dd5b896 100644 --- a/tests/unit/out/unsafe/libc_struct_without_default.rs +++ b/tests/unit/out/unsafe/libc_struct_without_default.rs @@ -23,12 +23,12 @@ impl Default for UserDefined { #[repr(C)] #[derive(Copy, Clone)] pub struct FieldIsLibcType { - pub addr: sockaddr, + pub addr: libc::sockaddr, } impl Default for FieldIsLibcType { fn default() -> Self { FieldIsLibcType { - addr: unsafe { std::mem::zeroed::() }, + addr: unsafe { std::mem::zeroed() }, } } } diff --git a/tests/unit/out/unsafe/socket_transparent_union.rs b/tests/unit/out/unsafe/socket_transparent_union.rs index 4dc280cf..910ec07a 100644 --- a/tests/unit/out/unsafe/socket_transparent_union.rs +++ b/tests/unit/out/unsafe/socket_transparent_union.rs @@ -18,7 +18,7 @@ unsafe fn main_0() -> i32 { assert!( ((((libc::getsockname( fd, - ((&mut ssloc as *mut sockaddr_storage) as *mut sockaddr), + ((&mut ssloc as *mut sockaddr_storage) as *mut libc::sockaddr), (&mut slen as *mut u32) )) == (-1_i32)) as i32) != 0) @@ -28,7 +28,7 @@ unsafe fn main_0() -> i32 { assert!( ((((libc::getsockname( fd, - ((&mut sin as *mut sockaddr_in) as *mut sockaddr), + ((&mut sin as *mut sockaddr_in) as *mut libc::sockaddr), (&mut inlen as *mut u32) )) == (-1_i32)) as i32) != 0)