1

I'm using binary trees to create a simple computation graph. I understand that linked lists are a pain in Rust, but it's a very convenient data structure for what I'm doing. I tried using Box and Rc<RefCell> for the children nodes, but it didn't work out how I wanted, so I used unsafe:

use std::ops::{Add, Mul};

#[derive(Debug, Copy, Clone)]
struct MyStruct {
    value: i32,
    lchild: Option<*mut MyStruct>,
    rchild: Option<*mut MyStruct>,
}

impl MyStruct {
    unsafe fn print_tree(&mut self, set_to_zero: bool) {
        if set_to_zero {
            self.value = 0;
        }
        println!("{:?}", self);
    
        let mut nodes = vec![self.lchild, self.rchild];
        while nodes.len() > 0 {
            let child;
            match nodes.pop() {
                Some(popped_child) => child = popped_child.unwrap(),
                None => continue,
            }

            if set_to_zero {
                (*child).value = 0;
            }
            println!("{:?}", *child);
            
            if !(*child).lchild.is_none() {
                nodes.push((*child).lchild);
            }
            if !(*child).rchild.is_none() {
                nodes.push((*child).rchild);
            }
        }
        
        println!("");
    }
}

impl Add for MyStruct {
    type Output = Self;
    fn add(self, other: Self) -> MyStruct {
        MyStruct{
            value: self.value + other.value,
            lchild: Some(&self as *const _ as *mut _),
            rchild: Some(&other as *const _ as *mut _),
        }
    }
}

impl Mul for MyStruct {
   type Output = Self;
   fn mul(self, other: Self) -> Self {
        MyStruct{
            value: self.value * other.value,
            lchild: Some(&self as *const _ as *mut _),
            rchild: Some(&other as *const _ as *mut _),
        }
   }
}

fn main() {
    let mut tree: MyStruct;
    
    {
        let a = MyStruct{ value: 10, lchild: None, rchild: None };
        let b = MyStruct{ value: 20, lchild: None, rchild: None };
        
        let c = a + b;
        println!("c.value: {}", c.value); // 30
        
        let mut d = a + b;
        println!("d.value: {}", d.value); // 30
        
        d.value = 40;
        println!("d.value: {}", d.value); // 40
        
        let mut e = c * d;
        println!("e.value: {}", e.value); // 1200
        
        unsafe {
            e.print_tree(false); // correct values
            e.print_tree(true); // all zeros
            e.print_tree(false); // all zeros, everything is set correctly
        }
        
        tree = e;
    }
    
    unsafe { tree.print_tree(false); } // same here, only zeros
}

Link to the playground

I honestly don't mind that much using unsafe, but is there a safe way doing it? How bad is the use of unsafe here?

  • What you have looks very unsound to me. You should just be able to do `Option>` for the two children. – PitaJ Nov 17 '22 at 19:23
  • This is wildly incorrect. You're dropping every node but one while keeping and dereferencing pointers to them. – isaactfa Nov 17 '22 at 19:35

1 Answers1

3

You can just box both of the children, since you have a unidirectional tree:

use std::ops::{Add, Mul};
use std::fmt;

#[derive(Clone)]
struct MyStruct {
    value: i32,
    lchild: Option<Box<MyStruct>>,
    rchild: Option<Box<MyStruct>>,
}

impl fmt::Debug for MyStruct {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
        f.debug_struct("MyStruct")
            .field("value", &self.value)
            .field("lchild", &self.lchild.as_deref())
            .field("rchild", &self.rchild.as_deref())
            .finish()
    }
}

impl MyStruct {
    fn print_tree(&mut self, set_to_zero: bool) {
        if set_to_zero {
            self.value = 0;
        }

        println!("MyStruct {{ value: {:?}, lchild: {:?}, rchild: {:?} }}", self.value, &self.lchild as *const _, &self.rchild as *const _);

        if let Some(child) = &mut self.lchild {
            child.print_tree(set_to_zero);
        }

        if let Some(child) = &mut self.rchild {
            child.print_tree(set_to_zero);
        }
    }
}

impl Add for MyStruct {
    type Output = Self;
    fn add(self, other: Self) -> MyStruct {
        MyStruct {
            value: self.value + other.value,
            lchild: Some(Box::new(self)),
            rchild: Some(Box::new(other)),
        }
    }
}

impl Mul for MyStruct {
    type Output = Self;
    fn mul(self, other: Self) -> Self {
        MyStruct {
            value: self.value * other.value,
            lchild: Some(Box::new(self)),
            rchild: Some(Box::new(other)),
        }
    }
}

fn main() {
    let tree = {
        let a = MyStruct {
            value: 10,
            lchild: None,
            rchild: None,
        };
        let b = MyStruct {
            value: 20,
            lchild: None,
            rchild: None,
        };

        let c = a.clone() + b.clone();
        println!("c.value: {}", c.value); // 30

        let mut d = a.clone() + b.clone();
        println!("d.value: {}", d.value); // 30

        d.value = 40;
        println!("d.value: {}", d.value); // 40

        let mut e = c * d;
        println!("e.value: {}", e.value); // 1200
        
        println!("");

        e.print_tree(false); // correct values
        println!("");
        e.print_tree(true); // all zeros
        println!("");
        e.print_tree(false); // all zeros, everything is set correctly
        println!("");

        e
    };

    dbg!(tree);
}

I implemented Debug manually and reimplemented print_tree recursively. I don't know if there is a way to implement print_tree as mutable like that without recursion, but it's certainly possible if you take &self instead (removing the set_to_zero stuff).

playground

Edit: Turns out it is possible to mutably iterate over the tree values without recursion. The following code is derived from the playground in this comment by @Shepmaster.

impl MyStruct {
    fn zero_tree(&mut self) {
        let mut node_stack = vec![self];
        let mut value_stack = vec![];

        // collect mutable references to each value
        while let Some(MyStruct { value, lchild, rchild }) = node_stack.pop() {
            value_stack.push(value);

            if let Some(child) = lchild {
                node_stack.push(child);
            }
            if let Some(child) = rchild {
                node_stack.push(child);
            }
        }

        // iterate over mutable references to values
        for value in value_stack {
            *value = 0;
        }
    }
}
PitaJ
  • 12,969
  • 6
  • 36
  • 55