2

I am trying to implement unit testing on my SignUp Handler and a call to database. However, it throws panic error on the database call in my SignUp Handler. It is a simple SignUp Handler that receives a JSON with username, password, and email. I will then use a SELECT statement to check if this username is duplicated inside the SignUp handler itself.

This all works when I am sending my post request to this handler. However, when I am actually doing unit testing, it doesn't work and threw me the 2 error messages. I feel that this is because the database wasn't initialized in the test environment but I am not sure how do do this without using third party frameworks to conduct a mock database.

error message

panic: runtime error: invalid memory address or nil pointer dereference [recovered]
        panic: runtime error: invalid memory address or nil pointer dereference

signup.go

package handler

type SignUpJson struct {
    Username string `json:"username"`
    Password string `json:"password"`
    Email    string `json:"email"`
}

func SignUp(w http.ResponseWriter, r *http.Request) {
    // Set Headers
    w.Header().Set("Content-Type", "application/json")
    var newUser auth_management.SignUpJson

    // Reading the request body and UnMarshal the body to the LoginJson struct
    bs, _ := io.ReadAll(req.Body)
    if err := json.Unmarshal(bs, &newUser); err != nil {
        utils.ResponseJson(w, http.StatusInternalServerError, "Internal Server Error")
        log.Println("Internal Server Error in UnMarshal JSON body in SignUp route:", err)
        return
    }

    ctx := context.Background()
    ctx, cancel = context.WithTimeout(ctx, time.Minute * 2)
    defer cancel()

    // Check if username already exists in database (duplicates not allowed)
    isExistingUsername := database.GetUsername(ctx, newUser.Username) // throws panic error here when testing
    if isExistingUsername {
        utils.ResponseJson(w, http.StatusBadRequest, "Username has already been taken. Please try again.")
        return
    }

    // other code logic...
}

sqlquery.go

package database

var SQL_SELECT_FROM_USERS = "SELECT %s FROM users WHERE %s = $1;"

func GetUsername(ctx context.Context, username string) bool {
    row := conn.QueryRow(ctx, fmt.Sprintf(SQL_SELECT_FROM_USERS, "username", "username"), username)
    return row.Scan() != pgx.ErrNoRows
}

SignUp_test.go

package handler

func Test_SignUp(t *testing.T) {

    var tests = []struct {
        name               string
        postedData         SignUpJson
        expectedStatusCode int
    }{
        {
            name: "valid login",
            postedData: SignUpJson{
                Username: "testusername",
                Password: "testpassword",
                Email:    "test@email.com",
            },
            expectedStatusCode: 200,
        },
    }

    for _, e := range tests {
        jsonStr, err := json.Marshal(e.postedData)
        if err != nil {
            t.Fatal(err)
        }

        // Setting a request for testing
        req, _ := http.NewRequest(http.MethodPost, "/signup", strings.NewReader(string(jsonStr)))
        req.Header.Set("Content-Type", "application/json")

        // Setting and recording the response
        res := httptest.NewRecorder()
        handler := http.HandlerFunc(SignUp)

        handler.ServeHTTP(res, req)

        if res.Code != e.expectedStatusCode {
            t.Errorf("%s: returned wrong status code; expected %d but got %d", e.name, e.expectedStatusCode, res.Code)
        }
    }
}

setup_test.go

func TestMain(m *testing.M) {

    os.Exit(m.Run())
}

I have seen a similar question here but not sure if that is the right approach as there was no response and the answer was confusing: How to write an unit test for a handler that invokes a function that interacts with db in Golang using pgx driver?

Jessica
  • 243
  • 1
  • 10
  • To write unit tests, global structs (such as your database connection - `conn`) should be avoided because you couldn't isolate your code with them (aka couldn't write unit tests). That guy's solution is Go's style. https://stackoverflow.com/a/73658099/8546128 – Anh Nhat Tran Mar 21 '23 at 08:23
  • @AnhNhatTran Sorry, then how should I refactor my code? – Jessica Mar 21 '23 at 08:49
  • The easiest solution is adding a `database` as a parameter in `SignUp` function. Then you can create a `mockDatabase`. – Anh Nhat Tran Mar 21 '23 at 09:09

1 Answers1

1

Let me try to help you in figuring out how to achieve these things. I refactored your code a little bit but the general idea and the tools used are still the same as yours. First, I'm gonna share the production code that is spread into two files: handlers/handlers.go and repo/repo.go.

handlers/handlers.go file

package handlers

import (
    "context"
    "database/sql"
    "encoding/json"
    "io"
    "net/http"
    "time"

    "handlertest/repo"
)

type SignUpJson struct {
    Username string `json:"username"`
    Password string `json:"password"`
    Email    string `json:"email"`
}

func SignUp(w http.ResponseWriter, r *http.Request) {
    w.Header().Set("Content-Type", "application/json")

    var newUser SignUpJson
    bs, _ := io.ReadAll(r.Body)
    if err := json.Unmarshal(bs, &newUser); err != nil {
        w.WriteHeader(http.StatusBadRequest)
        w.Write([]byte(err.Error()))
        return
    }

    ctx, cancel := context.WithTimeout(r.Context(), time.Minute*2)
    defer cancel()

    db, _ := ctx.Value("DB").(*sql.DB)
    if isExistingUserName := repo.GetUserName(ctx, db, newUser.Username); isExistingUserName {
        w.WriteHeader(http.StatusBadRequest)
        w.Write([]byte("username already present"))
        return
    }
    w.WriteHeader(http.StatusOK)
}

Here, there are two main differences:

  1. The context used. You don't have to instantiate another ctx, just use the one that is provided alongside the http.Request.
  2. The sql client used. The right way is to pass it through the context.Context. For this scenario, you don't have to build any structs or use any interface and so on. Just write a function that expects an *sql.DB as a parameter. Remember this, Functions are first-class citizens.

For sure, there is room for refactoring. The "DB" should be a constant and we've to check for the existence of this entry in the context values but, for the sake of brevity, I omitted these checks.

repo/repo.go file

package repo

import (
    "context"
    "database/sql"

    "github.com/jackc/pgx/v5"
)

func GetUserName(ctx context.Context, db *sql.DB, username string) bool {
    row := db.QueryRowContext(ctx, "SELECT username FROM users WHERE username = $1", username)
    return row.Scan() != pgx.ErrNoRows
}

Here, the code is pretty similar to yours except for these two small things:

  1. There is a dedicated method called QueryRowContext when you wish to take into consideration the context.
  2. Use the prepared statements feature when you've to build an SQL query. Don't concatenate stuff with fmt.Sprintf for two reasons: security and testability.

Now, we're gonna look at the test code.

handlers/handlers_test.go file

package handlers

import (
    "context"
    "net/http"
    "net/http/httptest"
    "strings"
    "testing"

    "github.com/DATA-DOG/go-sqlmock"
    "github.com/jackc/pgx/v5"
    "github.com/stretchr/testify/assert"
)

func TestSignUp(t *testing.T) {
    db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
    if err != nil {
        t.Fatalf("err not expected while open a mock db, %v", err)
    }
    defer db.Close()
    t.Run("NewUser", func(t *testing.T) {
        mock.ExpectQuery("SELECT username FROM users WHERE username = $1").WithArgs("john.doe@example.com").WillReturnError(pgx.ErrNoRows)

        w := httptest.NewRecorder()
        r := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(`{"username": "john.doe@example.com", "password": "1234", "email": "john.doe@example.com"}`))

        ctx := context.WithValue(r.Context(), "DB", db)
        r = r.WithContext(ctx)

        SignUp(w, r)

        assert.Equal(t, http.StatusOK, w.Code)
        if err := mock.ExpectationsWereMet(); err != nil {
            t.Errorf("not all expectations were met: %v", err)
        }
    })

    t.Run("AlreadyExistentUser", func(t *testing.T) {
        rows := sqlmock.NewRows([]string{"username"}).AddRow("john.doe@example.com")
        mock.ExpectQuery("SELECT username FROM users WHERE username = $1").WithArgs("john.doe@example.com").WillReturnRows(rows)

        w := httptest.NewRecorder()
        r := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(`{"username": "john.doe@example.com", "password": "1234", "email": "john.doe@example.com"}`))

        ctx := context.WithValue(r.Context(), "DB", db)
        r = r.WithContext(ctx)

        SignUp(w, r)

        assert.Equal(t, http.StatusBadRequest, w.Code)
        if err := mock.ExpectationsWereMet(); err != nil {
            t.Errorf("not all expectations were met: %v", err)
        }
    })
}

Here, there are a lot of changes compared to your version. Let me quickly recap them:

  • Use the sub-test feature to give a hierarchical structure to the tests.
  • Use the httptest package that provides stuff for building and asserting HTTP Requests and Responses.
  • Use the sqlmock package. The de-facto standard when it comes to mocking a database.
  • Use the context to pass the sql client alongside the http.Request.
  • Assertions have been done with the package github.com/stretchr/testify/assert.

The same applies here: there is room for refactoring (e.g. you can rework the tests by using the Table-Driven Tests feature).

Outro

This can be considered an idiomatic way to write Go code. I know this can be very challenging, especially at the beginning. If you need further details on some parts just let me know and I'll be happy to help you, thanks!

ossan
  • 1,665
  • 4
  • 10
  • Note that `context.Context` is intended for _request-scoped_ values. Using it to carry application dependencies ought therefore be considered an anti-pattern, besides it's not like there's no other way to provide non-global access to dependencies. – mkopriva Mar 21 '23 at 16:57
  • I was supposed to use `context.Context` in this way to bear request-scoped values. If I was wrong, could you indicate how to correctly pass the dependencies? Thanks in advance @mkopriva – ossan Mar 22 '23 at 07:58
  • 1
    One approach would be to use closures, example: https://stackoverflow.com/a/34056746/965900 – mkopriva Mar 22 '23 at 08:36
  • 2
    Another approach is to use structs, just like in this question's snipped: https://stackoverflow.com/questions/75138738/dependency-injection-of-non-concurrency-safe-dependencies-for-a-net-http-server – mkopriva Mar 22 '23 at 08:38
  • Thanks for your answer! What about if you want to listen for a cancelation signal (such as a timeout) and cancel the ongoing request towards the DB. How do you achieve this without passing the DB client in the `context.Context` you're gonna cancel? – ossan Mar 22 '23 at 09:09
  • Passing the DB client in the `context.Context` will NOT, by itself, cancel long running queries when the context is canceled. To cancel long running queries you first need a driver that supports that and then you need to _pass the context to the db query_ (not the other way round) (e.g., QueryRowContext, QueryContext, etc.). And transaction rollback on context cancel is supported by `database/sql` with `BeginTx`, but cancelation of long running queries, if I'm not mistaken, must be implemented at the driver level. – mkopriva Mar 22 '23 at 09:23
  • Regardless of where its implemented (stdlib or driver), you cancel db operations by passing the context to the db func, not the db to the context. – mkopriva Mar 22 '23 at 09:26
  • I mean to already pass a contextualized `db` to the function. As a guide, you can see this link: https://gorm.io/docs/context.html#Chi-Middleware-Example Here, in a middleware that is request-scoped the `db` is already contextualized. Is this a valuable approach in your opinion? If you don't mind we can have a quick chat to let me better understand. – ossan Mar 22 '23 at 11:21
  • `sql.DB` and `gorm.DB` are not one and the same thing and I do not have enough knowledge about gorm to know whether gorm.DB can be considered request-scoped, application scoped, or both. – mkopriva Mar 22 '23 at 15:48
  • 1
    @IvanPesenti Hey Ivan, I like your answer. Correct me if I am wrong. In the `GetUserName` function which is calling a SELECT query to Postgres, you are passing in the database connection. This is essential for unit testing purposes to create a mock database in the `setup_test.go` file. So what I could have improved in my code was to pass the database connection in all of my SQL query functions such as the one in this code snippet `GetUserName` function. – Jessica Mar 24 '23 at 06:11
  • 1
    @Jessica you're right! The `sql.DB` is a database handle that represents an actual underlying connection to the database. As it's very easy to mock it during unit tests, it's fine to pass it to functions that talk to the DB. – ossan Mar 24 '23 at 07:46
  • 1
    @IvanPesenti Great! If anyone is facing the same issue as me, we can pass in the database connection to the SQL Query functions as an argument or as a receiver to the functions. – Jessica Mar 24 '23 at 09:17
  • @IvanPesenti I have successfully performed a dependency injection from my database to my handlers! However, I am now facing an issue where the mock db is not being received in the handler. Would appreciate it if you can help! The post is here https://stackoverflow.com/questions/75857122/golang-dependency-injection-for-mock-database-in-unit-testing – Jessica Mar 27 '23 at 14:47