I think that the easiest solution, taking into account the code in the previous question would be to run the queries inside the AddFilenamesFn
ParDo within the for loop. Keep in mind that beam.io.Read(beam.io.BigQuerySource(query=bqquery))
is used to read rows as source and not in an intermediate step. So, in the case I propose, you can use the Python Client Library directly (google-cloud-bigquery>0.27.0
):
class AddFilenamesFn(beam.DoFn):
"""ParDo to output a dict with file id (retrieved from BigQuery) and row"""
def process(self, element, file_path):
from google.cloud import bigquery
client = bigquery.Client()
file_name = file_path.split("/")[-1]
query_job = client.query("""
SELECT FILE_ID
FROM test.file_mapping
WHERE FILENAME = '{0}'
LIMIT 1""".format(file_name))
results = query_job.result()
for row in results:
file_id = row.FILE_ID
yield {'filename':file_id, 'row':element}
This would be the most straight-forward solution to implement but it might arise an issue. Instead of running all ~20 possible queries at the start of the pipeline we are running a query for each line/record. For example, if we have 3,000 elements in a single file the same query will be launched 3,000 times. However, each different query should be actually run only once and subsequent query "repeats" will hit the cache. Also note that cached queries do not contribute towards the interactive query limit.
I used the same files of my previous answer:
$ gsutil cat gs://$BUCKET/countries1.csv
id,country
1,sweden
2,spain
gsutil cat gs://$BUCKET/countries2.csv
id,country
3,italy
4,france
and added a new table:
bq mk test.file_mapping FILENAME:STRING,FILE_ID:STRING
bq query --use_legacy_sql=false 'INSERT INTO test.file_mapping (FILENAME, FILE_ID) values ("countries1.csv", "COUNTRIES ONE"), ("countries2.csv", "COUNTRIES TWO")'

and the output is:
INFO:root:{'filename': u'COUNTRIES ONE', 'row': u'id,country'}
INFO:root:{'filename': u'COUNTRIES ONE', 'row': u'1,sweden'}
INFO:root:{'filename': u'COUNTRIES ONE', 'row': u'2,spain'}
INFO:root:{'filename': u'COUNTRIES TWO', 'row': u'id,country'}
INFO:root:{'filename': u'COUNTRIES TWO', 'row': u'3,italy'}
INFO:root:{'filename': u'COUNTRIES TWO', 'row': u'4,france'}
Another solution would be to load all the table and materialize it as a side input (depending on size this can be problematic of course) with beam.io.BigQuerySource()
or, as you say, break it down into N queries and save each one into a different side input. Then you could select the appropriate one for each record and pass it as an additional input to AddFilenamesFn
. It would be interesting to try to write that one, too.
Full code for my first proposed solution:
import argparse, logging
from operator import add
import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.io import ReadFromText
from apache_beam.io.filesystem import FileMetadata
from apache_beam.io.filesystem import FileSystem
from apache_beam.io.gcp.gcsfilesystem import GCSFileSystem
class GCSFileReader:
"""Helper class to read gcs files"""
def __init__(self, gcs):
self.gcs = gcs
class AddFilenamesFn(beam.DoFn):
"""ParDo to output a dict with file id (retrieved from BigQuery) and row"""
def process(self, element, file_path):
from google.cloud import bigquery
client = bigquery.Client()
file_name = file_path.split("/")[-1]
query_job = client.query("""
SELECT FILE_ID
FROM test.file_mapping
WHERE FILENAME = '{0}'
LIMIT 1""".format(file_name))
results = query_job.result()
for row in results:
file_id = row.FILE_ID
yield {'filename':file_id, 'row':element}
# just logging output to visualize results
def write_res(element):
logging.info(element)
return element
def run(argv=None):
parser = argparse.ArgumentParser()
known_args, pipeline_args = parser.parse_known_args(argv)
p = beam.Pipeline(options=PipelineOptions(pipeline_args))
gcs = GCSFileSystem(PipelineOptions(pipeline_args))
gcs_reader = GCSFileReader(gcs)
# in my case I am looking for files that start with 'countries'
BUCKET='BUCKET_NAME'
result = [m.metadata_list for m in gcs.match(['gs://{}/countries*'.format(BUCKET)])]
result = reduce(add, result)
# create each input PCollection name and unique step labels
variables = ['p{}'.format(i) for i in range(len(result))]
read_labels = ['Read file {}'.format(i) for i in range(len(result))]
add_filename_labels = ['Add filename {}'.format(i) for i in range(len(result))]
# load each input file into a separate PCollection and add filename to each row
for i in range(len(result)):
globals()[variables[i]] = p | read_labels[i] >> ReadFromText(result[i].path) | add_filename_labels[i] >> beam.ParDo(AddFilenamesFn(), result[i].path)
# flatten all PCollections into a single one
merged = [globals()[variables[i]] for i in range(len(result))] | 'Flatten PCollections' >> beam.Flatten() | 'Write results' >> beam.Map(write_res)
p.run()
if __name__ == '__main__':
run()