0

In order to challenge myself to learn Rust I'm trying to write something reasonably complex: an MNIST classifier. There are tons of variables holding references to other variables so it seems like a good candidate to understand lifetimes... well, I don't.

Here is a minimal example where I'm just trying to build a simple Network structure:

struct Connection<'a> {
    weight: f32,
    neuron: &'a Neuron<'a>,
}

struct Neuron<'a> {
    bias: Option<f32>,
    inputs: Vec<Connection<'a>>,
}

struct Layer<'a> {
    neurons: Vec<Neuron<'a>>,
}

struct Network<'a> {
    layers: Vec<Layer<'a>>,
}

impl <'a> Network<'a> {
    pub fn new() -> Network<'a> {
        Network {
            layers: Vec::new(),
        }
    }

    pub fn add_layer(
        &'a mut self,
        neurons_count: usize,
    ) -> &'a mut Network<'a> {
        let mut layer = Layer {
            neurons: Vec::new(),
        };
        for _ in 0..neurons_count {
            let neuron = Neuron {
                bias: None,
                inputs: Vec::new(),
            };
            layer.neurons.push(neuron);
        }

        if let Some(previous_layer) = self.layers.last_mut() {
            for neuron in &mut layer.neurons {
                for previous_neuron in &previous_layer.neurons {
                    let connection = Connection {
                        weight: 0.0,
                        neuron: &previous_neuron,
                    };
                    neuron.inputs.push(connection);
                }
            }
        }

        self.layers.push(layer); // cannot borrow `self.layers` as mutable more than once at a time
        self // cannot borrow `*self` as mutable more than once at a time
    }
}

pub fn create_network<'a>() -> Network<'a> {
    let mut network = Network::new();
    network
        .add_layer(2)
        .add_layer(2)
        .add_layer(1)
    ;

    network // cannot return value referencing local variable `network`
}

And here is a link to the corresponding playground.

In the add_layer method, I do not understand how I can release the borrow that happens on line 41 such that the self.layers.push(layer) is valid.

Same problem at line 60 in create_network.

Am I approaching the issue in a totally wrong way, or am I just missing something completely obvious?

Herohtar
  • 5,347
  • 4
  • 31
  • 41
djfm
  • 2,317
  • 1
  • 18
  • 34
  • In typical Rust, keeping a reference to an object will prevent you from mutating or moving it. This is what happens in a *self-referential* struct except the borrow is being made internally, but the rules are the same. For this reason, references don't really work for expressing graph-like data structures. – kmdreko Apr 19 '22 at 16:46
  • @kmdreko thanks, yes, I wouldn't say it answers it yet, but it helps :) – djfm Apr 19 '22 at 17:04
  • OK so I'll try to find a way to do without the references – djfm Apr 19 '22 at 17:04
  • 1
    Yup, I've replied to my own question, thanks again: without references it compiles very fine. The answer you provided made me understand why. – djfm Apr 19 '22 at 17:21

1 Answers1

1

As helpfully pointed out by @kmdreko, the issue came from the self-referential data structures.

I rewrote the thing without references:

struct Connection {
    weight: f32,
    layer_depth: usize,
    pos_in_layer: usize,
}

struct Neuron {
    bias: Option<f32>,
    inputs: Vec<Connection>,
    output: f32,
}

impl Neuron {
    fn compute_input(&mut self, network: &Network) -> &mut Neuron {
        let mut sum = 0.0;
        for input in &self.inputs {
            let input_value = network.layers[input.layer_depth].neurons[input.pos_in_layer].output;
            sum += input_value * input.weight;
        }
        self.output = sum;
        self
    }
}

struct Layer {
    neurons: Vec<Neuron>,
}

pub struct Network {
    layers: Vec<Layer>,
}

impl Network {
    pub fn new() -> Network {
        Network {
            layers: Vec::new(),
        }
    }

    pub fn add_layer(
        &mut self,
        neurons_count: usize,
    ) -> &mut Network {
        let mut layer = Layer {
            neurons: Vec::new(),
        };
        for _ in 0..neurons_count {
            let neuron = Neuron {
                bias: None,
                inputs: Vec::new(),
                output: 0.0f32,
            };
            layer.neurons.push(neuron);
        }

        if let Some(previous_layer) = self.layers.last() {
            for neuron in &mut layer.neurons {
                for (pos_in_layer, _) in previous_layer.neurons.iter().enumerate() {
                    let connection = Connection {
                        weight: 0.0,
                        layer_depth: self.layers.len() - 1,
                        pos_in_layer,
                    };
                    neuron.inputs.push(connection);
                }
            }
        }

        self.layers.push(layer);
        self
    }
}

pub fn create_network() -> Network {
    let mut network = Network::new();
    network
        .add_layer(2)
        .add_layer(2)
        .add_layer(1)
    ;

    network
}

And it compiles fine. I have still a lot to learn :)

djfm
  • 2,317
  • 1
  • 18
  • 34