Skip to content

Commit 059ef86

Browse files
committed
Address second iteration of comments.
1 parent 5f38e93 commit 059ef86

File tree

7 files changed

+35
-28
lines changed

7 files changed

+35
-28
lines changed

gcp_variant_transforms/options/variant_transform_options.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,8 @@ def add_arguments(self, parser):
195195
parser.add_argument(
196196
'--num_bigquery_write_shards',
197197
type=int, default=1,
198-
help=('This flag is deprecated and may be removed in future releases.'))
198+
help=('This flag is deprecated and will be removed in future '
199+
'releases.'))
199200
parser.add_argument(
200201
'--null_numeric_value_replacement',
201202
type=int,

gcp_variant_transforms/pipeline_common.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,9 @@ def parse_args(argv, command_line_options):
7171
known_args, pipeline_args = parser.parse_known_args(argv)
7272
for transform_options in options:
7373
transform_options.validate(known_args)
74-
_raise_error_on_invalid_flags(pipeline_args)
74+
_raise_error_on_invalid_flags(
75+
pipeline_args,
76+
known_args.output_table if hasattr(known_args, 'output_table') else None)
7577
if hasattr(known_args, 'input_pattern') or hasattr(known_args, 'input_file'):
7678
known_args.all_patterns = _get_all_patterns(
7779
known_args.input_pattern, known_args.input_file)
@@ -301,8 +303,8 @@ def write_headers(merged_header, file_path):
301303
vcf_header_io.WriteVcfHeaders(file_path))
302304

303305

304-
def _raise_error_on_invalid_flags(pipeline_args):
305-
# type: (List[str]) -> None
306+
def _raise_error_on_invalid_flags(pipeline_args, output_table):
307+
# type: (List[str], Any) -> None
306308
"""Raises an error if there are unrecognized flags."""
307309
parser = argparse.ArgumentParser()
308310
for cls in pipeline_options.PipelineOptions.__subclasses__():
@@ -315,6 +317,14 @@ def _raise_error_on_invalid_flags(pipeline_args):
315317
not known_pipeline_args.setup_file):
316318
raise ValueError('The --setup_file flag is required for DataflowRunner. '
317319
'Please provide a path to the setup.py file.')
320+
if output_table:
321+
if (not hasattr(known_pipeline_args, 'temp_location') or
322+
not known_pipeline_args.temp_location):
323+
raise ValueError('--temp_location is required for BigQuery imports.')
324+
if not known_pipeline_args.temp_location.startswith('gs://'):
325+
raise ValueError(
326+
'--temp_location must be valid GCS location for BigQuery imports')
327+
318328

319329

320330
def is_pipeline_direct_runner(pipeline):

gcp_variant_transforms/pipeline_common_test.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,21 +95,31 @@ def test_fail_on_invalid_flags(self):
9595
'gcp-variant-transforms-test',
9696
'--staging_location',
9797
'gs://integration_test_runs/staging']
98-
pipeline_common._raise_error_on_invalid_flags(pipeline_args)
98+
pipeline_common._raise_error_on_invalid_flags(pipeline_args, None)
9999

100100
# Add Dataflow runner (requires --setup_file).
101101
pipeline_args.extend(['--runner', 'DataflowRunner'])
102102
with self.assertRaisesRegexp(ValueError, 'setup_file'):
103-
pipeline_common._raise_error_on_invalid_flags(pipeline_args)
103+
pipeline_common._raise_error_on_invalid_flags(pipeline_args, None)
104104

105105
# Add setup.py (required for Variant Transforms run). This is now valid.
106106
pipeline_args.extend(['--setup_file', 'setup.py'])
107-
pipeline_common._raise_error_on_invalid_flags(pipeline_args)
107+
pipeline_common._raise_error_on_invalid_flags(pipeline_args, None)
108+
109+
with self.assertRaisesRegexp(ValueError, '--temp_location is required*'):
110+
pipeline_common._raise_error_on_invalid_flags(pipeline_args, 'output')
111+
112+
pipeline_args.extend(['--temp_location', 'wrong_gcs'])
113+
with self.assertRaisesRegexp(ValueError, '--temp_location must be valid*'):
114+
pipeline_common._raise_error_on_invalid_flags(pipeline_args, 'output')
115+
116+
pipeline_args = pipeline_args[:-1] + ['gs://valid_bucket/temp']
117+
pipeline_common._raise_error_on_invalid_flags(pipeline_args, 'output')
108118

109119
# Add an unknown flag.
110120
pipeline_args.extend(['--unknown_flag', 'somevalue'])
111121
with self.assertRaisesRegexp(ValueError, 'Unrecognized.*unknown_flag'):
112-
pipeline_common._raise_error_on_invalid_flags(pipeline_args)
122+
pipeline_common._raise_error_on_invalid_flags(pipeline_args, 'output')
113123

114124
def test_get_compression_type(self):
115125
vcf_metadata_list = [filesystem.FileMetadata(path, size) for

gcp_variant_transforms/transforms/sample_info_to_bigquery.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def process(self, vcf_header):
5050
class SampleInfoToBigQuery(beam.PTransform):
5151
"""Writes sample info to BigQuery."""
5252

53-
def __init__(self, output_table_prefix, temp_location, append=False,
53+
def __init__(self, output_table_prefix, append=False,
5454
samples_span_multiple_files=False):
5555
# type: (str, Dict[str, str], bool, bool) -> None
5656
"""Initializes the transform.
@@ -67,7 +67,6 @@ def __init__(self, output_table_prefix, temp_location, append=False,
6767
self._append = append
6868
self._samples_span_multiple_files = samples_span_multiple_files
6969
self._schema = sample_info_table_schema_generator.generate_schema()
70-
self._temp_location = temp_location
7170

7271
def expand(self, pcoll):
7372
return (pcoll
@@ -82,5 +81,4 @@ def expand(self, pcoll):
8281
beam.io.BigQueryDisposition.WRITE_APPEND
8382
if self._append
8483
else beam.io.BigQueryDisposition.WRITE_TRUNCATE),
85-
method=beam.io.WriteToBigQuery.Method.FILE_LOADS,
86-
custom_gcs_temp_location=self._temp_location))
84+
method=beam.io.WriteToBigQuery.Method.FILE_LOADS))

gcp_variant_transforms/transforms/sample_info_to_bigquery_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def test_convert_sample_info_to_row(self):
5353
| transforms.Create([vcf_header_1, vcf_header_2])
5454
| 'ConvertToRow'
5555
>> transforms.ParDo(sample_info_to_bigquery.ConvertSampleInfoToRow(
56-
), False))
56+
False), ))
5757

5858
assert_that(bigquery_rows, equal_to(expected_rows))
5959
pipeline.run()
@@ -83,7 +83,7 @@ def test_convert_sample_info_to_row_without_file_in_hash(self):
8383
| transforms.Create([vcf_header_1, vcf_header_2])
8484
| 'ConvertToRow'
8585
>> transforms.ParDo(sample_info_to_bigquery.ConvertSampleInfoToRow(
86-
), True))
86+
True), ))
8787

8888
assert_that(bigquery_rows, equal_to(expected_rows))
8989
pipeline.run()

gcp_variant_transforms/transforms/variant_to_bigquery.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def __init__(
5959
self,
6060
output_table, # type: str
6161
header_fields, # type: vcf_header_io.VcfHeader
62-
temp_location, # type: str
6362
variant_merger=None, # type: variant_merge_strategy.VariantMergeStrategy
6463
proc_var_factory=None, # type: processed_variant.ProcessedVariantFactory
6564
# TODO(bashir2): proc_var_factory is a required argument and if `None` is
@@ -99,7 +98,6 @@ def __init__(
9998
"""
10099
self._output_table = output_table
101100
self._header_fields = header_fields
102-
self._temp_location = temp_location
103101
self._variant_merger = variant_merger
104102
self._proc_var_factory = proc_var_factory
105103
self._append = append
@@ -137,5 +135,4 @@ def expand(self, pcoll):
137135
beam.io.BigQueryDisposition.WRITE_APPEND
138136
if self._append
139137
else beam.io.BigQueryDisposition.WRITE_TRUNCATE),
140-
method=beam.io.WriteToBigQuery.Method.FILE_LOADS,
141-
custom_gcs_temp_location=self._temp_location))
138+
method=beam.io.WriteToBigQuery.Method.FILE_LOADS))

gcp_variant_transforms/vcf_to_bq.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,6 @@ def _run_annotation_pipeline(known_args, pipeline_args):
384384
def _create_sample_info_table(pipeline, # type: beam.Pipeline
385385
pipeline_mode, # type: PipelineModes
386386
known_args, # type: argparse.Namespace,
387-
temp_directory, # str
388387
):
389388
# type: (...) -> None
390389
headers = pipeline_common.read_headers(
@@ -395,7 +394,6 @@ def _create_sample_info_table(pipeline, # type: beam.Pipeline
395394
_ = (headers | 'SampleInfoToBigQuery' >>
396395
sample_info_to_bigquery.SampleInfoToBigQuery(
397396
known_args.output_table,
398-
temp_directory,
399397
known_args.append,
400398
known_args.samples_span_multiple_files))
401399

@@ -406,8 +404,6 @@ def run(argv=None):
406404
logging.info('Command: %s', ' '.join(argv or sys.argv))
407405
known_args, pipeline_args = pipeline_common.parse_args(argv,
408406
_COMMAND_LINE_OPTIONS)
409-
if known_args.output_table and '--temp_location' not in pipeline_args:
410-
raise ValueError('--temp_location is required for BigQuery imports.')
411407
if known_args.auto_flags_experiment:
412408
_get_input_dimensions(known_args, pipeline_args)
413409

@@ -483,10 +479,6 @@ def run(argv=None):
483479
num_partitions = 1
484480

485481
if known_args.output_table:
486-
temp_directory = pipeline_options.PipelineOptions(pipeline_args).view_as(
487-
pipeline_options.GoogleCloudOptions).temp_location
488-
if not temp_directory:
489-
raise ValueError('--temp_location must be set when writing to BigQuery.')
490482
for i in range(num_partitions):
491483
table_suffix = ''
492484
if partitioner and partitioner.get_partition_name(i):
@@ -496,7 +488,6 @@ def run(argv=None):
496488
variant_to_bigquery.VariantToBigQuery(
497489
table_name,
498490
header_fields,
499-
temp_directory,
500491
variant_merger,
501492
processed_variant_factory,
502493
append=known_args.append,
@@ -507,7 +498,7 @@ def run(argv=None):
507498
known_args.null_numeric_value_replacement)))
508499
if known_args.generate_sample_info_table:
509500
_create_sample_info_table(
510-
pipeline, pipeline_mode, known_args, temp_directory)
501+
pipeline, pipeline_mode, known_args)
511502

512503
if known_args.output_avro_path:
513504
# TODO(bashir2): Add an integration test that outputs to Avro files and

0 commit comments

Comments
 (0)