diff --git a/CHANGES.md b/CHANGES.md index e2dcf6e0f2ca..afaf6a896cc4 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -69,6 +69,7 @@ ## New Features / Improvements * (Python) Added exception chaining to preserve error context in CloudSQLEnrichmentHandler, processes utilities, and core transforms ([#37422](https://github.com/apache/beam/issues/37422)). +* (Python) Added `take(n)` convenience for PCollection: `beam.take(n)` and `pcoll.take(n)` to get the first N elements deterministically without Top.Of + FlatMap ([#X](https://github.com/apache/beam/issues/37429)). * X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). ## Breaking Changes diff --git a/sdks/python/apache_beam/pvalue.py b/sdks/python/apache_beam/pvalue.py index ca9a662d399e..6621d96127d4 100644 --- a/sdks/python/apache_beam/pvalue.py +++ b/sdks/python/apache_beam/pvalue.py @@ -176,6 +176,25 @@ def from_(pcoll: PValue, is_bounded: Optional[bool] = None) -> 'PCollection': is_bounded = pcoll.is_bounded return PCollection(pcoll.pipeline, is_bounded=is_bounded) + def take(self, n: int) -> 'PCollection[T]': + """Takes the first N elements from this PCollection. + + This is a convenience method that returns a new PCollection containing + at most N elements from this PCollection. The elements are taken + deterministically (not randomly sampled). + + Args: + n: Number of elements to take. Must be a positive integer. + + Returns: + A new PCollection containing at most N elements. + + Example:: + first_10 = pcoll.take(10) + """ + from apache_beam.transforms import util + return self | util.take(n) + def to_runner_api( self, context: 'PipelineContext') -> beam_runner_api_pb2.PCollection: return beam_runner_api_pb2.PCollection( diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py index fbaab6b4ebbb..dd14bd8f57bd 100644 --- a/sdks/python/apache_beam/transforms/util.py +++ b/sdks/python/apache_beam/transforms/util.py @@ -54,6 +54,7 @@ from apache_beam.pvalue import PCollection from apache_beam.transforms import window from apache_beam.transforms.combiners import CountCombineFn +from apache_beam.transforms.combiners import Top from apache_beam.transforms.core import CombinePerKey from apache_beam.transforms.core import Create from apache_beam.transforms.core import DoFn @@ -105,11 +106,13 @@ 'Reshuffle', 'Secret', 'ToString', + 'Take', 'Tee', 'Values', 'WithKeys', 'GroupIntoBatches', - 'WaitOn' + 'WaitOn', + 'take', ] K = TypeVar('K') @@ -1967,6 +1970,75 @@ def expand(self, input): )) +@typehints.with_input_types(T) +@typehints.with_output_types(T) +class Take(PTransform): + """Takes the first N elements from a PCollection. + + This transform returns a PCollection containing at most N elements from the + input PCollection. The elements are taken deterministically (not randomly + sampled). + + Args: + n: Number of elements to take. Must be a positive integer. + + Returns: + A PCollection containing at most N elements. + + Example:: + # Take first 10 elements + first_10 = pcoll | beam.take(10) + + # Or as a method + first_10 = pcoll.take(10) + """ + def __init__(self, n): + """Initializes Take transform. + + Args: + n: Number of elements to take. Must be positive. + """ + if n <= 0: + raise ValueError('n must be positive, got %d' % n) + self._n = n + + def expand(self, pcoll): + """Expands the Take transform. + + Args: + pcoll: Input PCollection. + + Returns: + A PCollection containing at most N elements. + """ + # Use Top.Of with a constant key to get first N elements deterministically. + # Top.Of returns a list, so we flatten it to get individual elements. + return ( + pcoll + | Top.Of(self._n, key=lambda x: 0).without_defaults() + | FlatMap(lambda elements: elements)) + + def default_label(self): + return 'Take(%d)' % self._n + + +def take(n): + """Convenience function for Take transform. + + Takes the first N elements from a PCollection. + + Args: + n: Number of elements to take. Must be positive. + + Returns: + A Take transform instance. + + Example:: + first_10 = pcoll | beam.take(10) + """ + return Take(n) + + class Reify(object): """PTransforms for converting between explicit and implicit form of various Beam values.""" diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py index 7389568691cd..448ba8a7ad9d 100644 --- a/sdks/python/apache_beam/transforms/util_test.py +++ b/sdks/python/apache_beam/transforms/util_test.py @@ -1934,6 +1934,45 @@ def test_tostring_kvs_empty_delimeter(self): assert_that(result, equal_to(["one1", "two2"])) +class TakeTest(unittest.TestCase): + def test_take_function_syntax(self): + with TestPipeline() as p: + result = p | beam.Create([1, 2, 3, 4, 5]) | util.take(3) + assert_that(result, equal_to([1, 2, 3])) + + def test_take_method_syntax(self): + with TestPipeline() as p: + pcoll = p | beam.Create([10, 20, 30, 40, 50]) + result = pcoll.take(2) + assert_that(result, equal_to([10, 20])) + + def test_take_more_than_available(self): + with TestPipeline() as p: + result = p | beam.Create([1, 2, 3]) | util.take(10) + assert_that(result, equal_to([1, 2, 3])) + + def test_take_single_element(self): + with TestPipeline() as p: + result = p | beam.Create([100, 200, 300]) | util.take(1) + assert_that(result, equal_to([100])) + + def test_take_all_elements(self): + with TestPipeline() as p: + data = [1, 2, 3, 4, 5] + result = p | beam.Create(data) | util.take(len(data)) + assert_that(result, equal_to(data)) + + def test_take_invalid_n_zero(self): + with self.assertRaises(ValueError) as ctx: + util.Take(0) + self.assertIn('n must be positive', str(ctx.exception)) + + def test_take_invalid_n_negative(self): + with self.assertRaises(ValueError) as ctx: + util.Take(-1) + self.assertIn('n must be positive', str(ctx.exception)) + + class LogElementsTest(unittest.TestCase): @pytest.fixture(scope="function") def _capture_stdout_log(request, capsys):