You mentioned using the existing train_test_split
routing from scikit-learn. If this is the only thing you're using scikit-learn for, it would be overkill. But if you're already using it for other parts of your task you might as well. Astropy Tables are already backed by Numpy arrays to begin with, so you don't need to "convert your data back and forth".
Since the 'ID'
column of your table indexes rows in your table, it would be useful to formally set it as an index of your table, so that ID values can be used to index rows in the table (independently of their actual positional index). For example:
>>> from astropy.table import Table
>>> import numpy as np
>>> t = Table({
... 'ID': [1, 3, 5, 6, 7, 9],
... 'a': np.random.random(6),
... 'b': np.random.random(6)
... })
>>> t
<Table length=6>
ID a b
int64 float64 float64
----- ------------------- -------------------
1 0.7285295918917892 0.6180944983953155
3 0.9273855839237182 0.28085439237508925
5 0.8677312765220222 0.5996267567496841
6 0.06182255608446752 0.6604620336092745
7 0.21450048405835265 0.5351066893214822
9 0.928930682667869 0.8178640424254757
Then set 'ID'
as the table's index:
>>> t.add_index('ID')
Use train_test_split
to partition the IDs however you want:
>>> train_ids, test_ids = train_test_split(t['ID'], test_size=0.2)
>>> train_ids
<Column name='ID' dtype='int64' length=4>
7
9
5
1
>>> test_ids
<Column name='ID' dtype='int64' length=2>
6
3
>>> train_set = t.loc[train_ids]
>>> test_set = t.loc[test_ids]
>>> train_set
<Table length=4>
ID a b
int64 float64 float64
----- ------------------- ------------------
7 0.21450048405835265 0.5351066893214822
9 0.928930682667869 0.8178640424254757
5 0.8677312765220222 0.5996267567496841
1 0.7285295918917892 0.6180944983953155
>>> test_set
<Table length=2>
ID a b
int64 float64 float64
----- ------------------- -------------------
6 0.06182255608446752 0.6604620336092745
3 0.9273855839237182 0.28085439237508925
(Note:
>>> isinstance(t['ID'], np.ndarray)
True
>>> type(t['ID']).__mro__
(astropy.table.column.Column,
astropy.table.column.BaseColumn,
astropy.table._column_mixins._ColumnGetitemShim,
numpy.ndarray,
object)
)
For what it's worth, as it might help you find answers to problems like this more easily in the future, it would help to consider what you're trying to do more abstractly (it seems you already are doing this, but phrasing of your question suggests otherwise): The columns in your table are just Numpy arrays--once it's in that form it's irrelevant that they were read from FITS files. What you're doing has nothing directly at that point to do with Astropy either. The question just becomes how to randomly partition a Numpy array.
You can find generic answers to this problem, for example, in this question. But it's also nice to use an existing, special-purpose utility like train_test_split
if you have it.