I just read up about smart pointers so i wanted to do a real demo example, hence i created the DLL code below, the problem is the nodes are placed properly and all, but nodes memory are not getting freed, not sure what i am doing wrong. as far as my understanding , when the scope runs out, the nodes must be deleted automatically. please correct me if i am wrong.
Original Code:
#include <iostream>
#include <memory>
using namespace std;
template <typename T>
class DLL {
class Node {
public:
T key;
std::shared_ptr<Node> next;
std::shared_ptr<Node> prev;
Node():key(),next(),prev(){}
Node(T data):key(data),next(),prev(){}
~Node(){
cout << "node deleted \n";
}
};
std::shared_ptr<Node> m_head;
std::shared_ptr<Node> m_tail;
std::size_t length;
public:
DLL() : m_head() ,m_tail() , length(0){
}
virtual ~DLL(){
}
void add_back(T data){
std::shared_ptr< Node > node = std::make_shared<Node>(data);
if(!m_tail){
m_tail = std::move(node);
m_head = m_tail;
}
else{
m_tail->next = std::move(node);
m_tail->next->prev = m_tail;
m_tail = m_tail->next;
}
length++;
}
void add_front(T data){
std::shared_ptr< Node > node = std::make_shared<Node>(data);
if(!m_head){
m_head = std::move(node);
m_tail = m_head;
}
else{
m_head->prev = std::move(node);
m_head->prev->next = m_head;
m_head = m_head->prev;
}
length++;
}
void printNodes(void){
for(std::shared_ptr< Node > temp = m_head; temp ; temp = temp->next) {
cout << temp->key << '\n';
}
}
void addAtPosition(T data , std::size_t pos){
if(pos < 0 || pos >= length) {
throw("Invalid position");
}
if(pos == 0){
add_front(data);
}
else if(pos == length - 1){
add_back(data);
}
else{
std::shared_ptr< Node > temp = m_head;
for(; temp && pos ; temp = temp->next) {
pos--;
}
std::shared_ptr< Node > node = std::make_shared<Node>(data);
std::shared_ptr< Node > prev = temp->prev;
temp->prev = std::move(node);
temp->prev->next = temp;
temp->prev->prev = prev;
prev->next = temp->prev;
length++;
}
}
};
int main(int argc , char** argv){
std::unique_ptr<DLL<int>> m_list = std::make_unique<DLL<int>>();
m_list->add_front(3);
m_list->add_front(2);
m_list->add_front(1);
m_list->add_back(4);
m_list->add_back(5);
m_list->add_back(6);
m_list->addAtPosition(7,0);
m_list->addAtPosition(7,4);
m_list->addAtPosition(7,7);
m_list->printNodes();
return 0;
}
Modified Code:
#include <iostream>
#include <memory>
using namespace std;
template <typename T>
class DLL {
class Node {
public:
T key;
std::shared_ptr<Node> next;
std::weak_ptr<Node> prev;
Node():key(),next(),prev() {}
Node(T data):key(data),next(),prev() {}
~Node(){cout << "deleted \n";}
};
std::shared_ptr<Node> m_head;
std::weak_ptr<Node> m_tail;
std::size_t length;
public:
DLL():m_head(),m_tail(),length(0){}
void addFront(T data){
std::shared_ptr< Node > node = std::make_shared<Node>(data);
if(length == 0){
m_head = std::move(node);
m_tail = m_head;
}
else{
node->next = m_head;
m_head->prev = node;
m_head = std::move(node);
}
length++;
}
void addBack(T data){
std::shared_ptr< Node > node = std::make_shared<Node>(data);
if(length == 0){
m_head = std::move(node);
m_tail = m_head;
}
else{
node->prev = m_tail.lock();
node->prev.lock()->next = std::move(node);
m_tail = m_tail.lock()->next;
}
length++;
}
void addAtPosition(T data , std::size_t pos){
if(pos == 0){
addFront(data);
}
else if(pos == length){
addBack(data);
}
else if(pos < 0 || pos >= length) {
throw("Invalid position");
}
else{
std::shared_ptr< Node > node = std::make_shared<Node>(data);
std::weak_ptr<Node> temp = m_head;
for(int cnt = 0; cnt < pos ; cnt++){
temp = temp.lock()->next;
}
node->next = temp.lock();
node->prev = node->next->prev;
node->prev = std::move(node);
length++;
}
}
void printNodes(void){
std::weak_ptr<Node> wp = m_head;
for(int i = 0; i < length; i++) {
auto& sp = *(wp.lock());
cout << sp.key;
wp = sp.next;
}
}
};
int main(){
std::unique_ptr<DLL<int>> m_list = std::make_unique<DLL<int>>();
for(int i = 0; i < 10 ; i++)
{
try{
m_list->addAtPosition(i,i);
}
catch(const char* mess){
cout << i <<' '<<mess << '\n';
}
}
m_list->printNodes();
return 0;
}
PS: Based on the input i have edited my code and its now working, but still i feel my methods are doing too much work and there is scope of optimization. can someone help me in optimizing my code using smart pointers. Also i am not trying to implement DLL, i just wrote enough code to get a hands-on feel using the new smart pointers.