0

I am working on a k-NN model, trying to compare different models and different distance measurements. I am trying to add 18 subplots into one plot. Unfortunately, it starts only from the second plot.

Picture for example:

enter image description here

I tried to change the position of the function call (plt.subplot), with no success:

types_of_data = ["separated", "mixed", "random"]
K_options = [1,3,9]                      
options_list = options_generator(types_of_data, K_options)
N = 70
graph_num = 1
rcParams['figure.figsize'] = 48, 30  # width, height

for option in options_list:

    X_train, y_train = generate_data(N, option[0])
    if option[2] == "L1":
        knn = KNeighborsClassifier(n_neighbors = option[1], metric = "manhattan")
    else:
        knn = KNeighborsClassifier(n_neighbors = option[1])
    knn.fit(X_train,y_train)
    y_pred = knn.predict(X_test)
    

    #plotting test and points:
    for i in range(3):
        a_trn = X_train[np.where(y_train ==i)][:,0]  # grab X_train points for which label is i. grab the x coordinate
        b_trn = X_train[np.where(y_train ==i)][:,1]  # grab X_train points for which label is i. grab the y coordinate
        a_tst = X_test[np.where(y_pred ==i)][:,0]    # grab X_train points for which label is i. grab the x coordinate
        b_tst = X_test[np.where(y_pred ==i)][:,1]    # grab X_train points for which label is i. grab the y coordinate


        plt.scatter(a_tst, b_tst, color=test_colors[i])
        plt.scatter(a_trn, b_trn, color=train_colors[i], label = i)
        
    
    tmp_title = "Graph #{0} \n Scatter type = {1}, K = {2}, Distance measure = {3}".format(graph_num, option[0], option[1], option[2])
    plt.title(tmp_title)
    plt.subplot(3,6,graph_num)
    graph_num += 1

Any ideas on how to solve this issue?

Thanks a lot

Mr. T
  • 11,960
  • 10
  • 32
  • 54
airline33
  • 21
  • 2
  • There is something wrong with your indentation here. The first for-loop does not include anything. Also, what does graph_num start as? – eandklahn Dec 25 '21 at 12:14
  • Hi, I edited the code. Hope it makes more sense – airline33 Dec 25 '21 at 12:20
  • 1
    It does. However, if you are not modifying graph_num in any of the functions, then I simply can't tell how it would be 2 in the first graph. Are you sure that picture was made with the code that you have posted above? Or better yet, could you provide a minimal, reproducible example (https://stackoverflow.com/help/minimal-reproducible-example)? – eandklahn Dec 25 '21 at 12:31
  • 3
    Don’t use the pyplot interface this way. Use `plt.subplots` to layout your figure and subplots (Axes) ahead of time, and act on each Axes object directly (e.g., `axes[0].scatter`) – Paul H Dec 25 '21 at 13:42
  • You could move the call to `plt.subplot(3,6,graph_num)` to the start of your loop. Note that `plt.subplot()` starts a new subplot, it doesn't move the already created subplot. However, @PaulH's advice is a much better approach. You start with `fig, axes = plt.subplots(ncols=3, nrows=6, figsize=(...))` followed by `for option, ax in zip(option_list, axes.flatten()):` and finally use `ax.scatter(....)` – JohanC Dec 25 '21 at 17:26

0 Answers0