APEX-specific symbols can be tagged as # apex

If a symbol is tagged as # apex, then it is exported when gen_stub_libs
is invoked with --apex.

Bug: 120638081
Test: python3 cc/test_gen_stub_libs.py

Change-Id: I190bca35d1a4fb422b37d1be41a34de1ad64de6b
diff --git a/cc/gen_stub_libs.py b/cc/gen_stub_libs.py
index c49d197..4906ea2 100755
--- a/cc/gen_stub_libs.py
+++ b/cc/gen_stub_libs.py
@@ -108,7 +108,7 @@
     return version.endswith('_PRIVATE') or version.endswith('_PLATFORM')
 
 
-def should_omit_version(version, arch, api, vndk):
+def should_omit_version(version, arch, api, vndk, apex):
     """Returns True if the version section should be ommitted.
 
     We want to omit any sections that do not have any symbols we'll have in the
@@ -121,6 +121,8 @@
         return True
     if 'vndk' in version.tags and not vndk:
         return True
+    if 'apex' in version.tags and not apex:
+        return True
     if not symbol_in_arch(version.tags, arch):
         return True
     if not symbol_in_api(version.tags, arch, api):
@@ -128,10 +130,12 @@
     return False
 
 
-def should_omit_symbol(symbol, arch, api, vndk):
+def should_omit_symbol(symbol, arch, api, vndk, apex):
     """Returns True if the symbol should be omitted."""
     if not vndk and 'vndk' in symbol.tags:
         return True
+    if not apex and 'apex' in symbol.tags:
+        return True
     if not symbol_in_arch(symbol.tags, arch):
         return True
     if not symbol_in_api(symbol.tags, arch, api):
@@ -239,15 +243,15 @@
     def __eq__(self, other):
         return self.name == other.name and set(self.tags) == set(other.tags)
 
-
 class SymbolFileParser(object):
     """Parses NDK symbol files."""
-    def __init__(self, input_file, api_map, arch, api, vndk):
+    def __init__(self, input_file, api_map, arch, api, vndk, apex):
         self.input_file = input_file
         self.api_map = api_map
         self.arch = arch
         self.api = api
         self.vndk = vndk
+        self.apex = apex
         self.current_line = None
 
     def parse(self):
@@ -275,11 +279,11 @@
         symbol_names = set()
         multiply_defined_symbols = set()
         for version in versions:
-            if should_omit_version(version, self.arch, self.api, self.vndk):
+            if should_omit_version(version, self.arch, self.api, self.vndk, self.apex):
                 continue
 
             for symbol in version.symbols:
-                if should_omit_symbol(symbol, self.arch, self.api, self.vndk):
+                if should_omit_symbol(symbol, self.arch, self.api, self.vndk, self.apex):
                     continue
 
                 if symbol.name in symbol_names:
@@ -363,12 +367,13 @@
 
 class Generator(object):
     """Output generator that writes stub source files and version scripts."""
-    def __init__(self, src_file, version_script, arch, api, vndk):
+    def __init__(self, src_file, version_script, arch, api, vndk, apex):
         self.src_file = src_file
         self.version_script = version_script
         self.arch = arch
         self.api = api
         self.vndk = vndk
+        self.apex = apex
 
     def write(self, versions):
         """Writes all symbol data to the output files."""
@@ -377,14 +382,14 @@
 
     def write_version(self, version):
         """Writes a single version block's data to the output files."""
-        if should_omit_version(version, self.arch, self.api, self.vndk):
+        if should_omit_version(version, self.arch, self.api, self.vndk, self.apex):
             return
 
         section_versioned = symbol_versioned_in_api(version.tags, self.api)
         version_empty = True
         pruned_symbols = []
         for symbol in version.symbols:
-            if should_omit_symbol(symbol, self.arch, self.api, self.vndk):
+            if should_omit_symbol(symbol, self.arch, self.api, self.vndk, self.apex):
                 continue
 
             if symbol_versioned_in_api(symbol.tags, self.api):
@@ -447,6 +452,8 @@
         help='Architecture being targeted.')
     parser.add_argument(
         '--vndk', action='store_true', help='Use the VNDK variant.')
+    parser.add_argument(
+        '--apex', action='store_true', help='Use the APEX variant.')
 
     parser.add_argument(
         '--api-map', type=os.path.realpath, required=True,
@@ -481,14 +488,14 @@
     with open(args.symbol_file) as symbol_file:
         try:
             versions = SymbolFileParser(symbol_file, api_map, args.arch, api,
-                                        args.vndk).parse()
+                                        args.vndk, args.apex).parse()
         except MultiplyDefinedSymbolError as ex:
             sys.exit('{}: error: {}'.format(args.symbol_file, ex))
 
     with open(args.stub_src, 'w') as src_file:
         with open(args.version_script, 'w') as version_file:
             generator = Generator(src_file, version_file, args.arch, api,
-                                  args.vndk)
+                                  args.vndk, args.apex)
             generator.write(versions)
 
 
diff --git a/cc/test_gen_stub_libs.py b/cc/test_gen_stub_libs.py
index 3b5585a..594c1bc 100755
--- a/cc/test_gen_stub_libs.py
+++ b/cc/test_gen_stub_libs.py
@@ -165,92 +165,115 @@
     def test_omit_private(self):
         self.assertFalse(
             gsl.should_omit_version(
-                gsl.Version('foo', None, [], []), 'arm', 9, False))
+                gsl.Version('foo', None, [], []), 'arm', 9, False, False))
 
         self.assertTrue(
             gsl.should_omit_version(
-                gsl.Version('foo_PRIVATE', None, [], []), 'arm', 9, False))
+                gsl.Version('foo_PRIVATE', None, [], []), 'arm', 9, False, False))
         self.assertTrue(
             gsl.should_omit_version(
-                gsl.Version('foo_PLATFORM', None, [], []), 'arm', 9, False))
+                gsl.Version('foo_PLATFORM', None, [], []), 'arm', 9, False, False))
 
         self.assertTrue(
             gsl.should_omit_version(
                 gsl.Version('foo', None, ['platform-only'], []), 'arm', 9,
-                False))
+                False, False))
 
     def test_omit_vndk(self):
         self.assertTrue(
             gsl.should_omit_version(
-                gsl.Version('foo', None, ['vndk'], []), 'arm', 9, False))
+                gsl.Version('foo', None, ['vndk'], []), 'arm', 9, False, False))
 
         self.assertFalse(
             gsl.should_omit_version(
-                gsl.Version('foo', None, [], []), 'arm', 9, True))
+                gsl.Version('foo', None, [], []), 'arm', 9, True, False))
         self.assertFalse(
             gsl.should_omit_version(
-                gsl.Version('foo', None, ['vndk'], []), 'arm', 9, True))
+                gsl.Version('foo', None, ['vndk'], []), 'arm', 9, True, False))
+
+    def test_omit_apex(self):
+        self.assertTrue(
+            gsl.should_omit_version(
+                gsl.Version('foo', None, ['apex'], []), 'arm', 9, False, False))
+
+        self.assertFalse(
+            gsl.should_omit_version(
+                gsl.Version('foo', None, [], []), 'arm', 9, False, True))
+        self.assertFalse(
+            gsl.should_omit_version(
+                gsl.Version('foo', None, ['apex'], []), 'arm', 9, False, True))
 
     def test_omit_arch(self):
         self.assertFalse(
             gsl.should_omit_version(
-                gsl.Version('foo', None, [], []), 'arm', 9, False))
+                gsl.Version('foo', None, [], []), 'arm', 9, False, False))
         self.assertFalse(
             gsl.should_omit_version(
-                gsl.Version('foo', None, ['arm'], []), 'arm', 9, False))
+                gsl.Version('foo', None, ['arm'], []), 'arm', 9, False, False))
 
         self.assertTrue(
             gsl.should_omit_version(
-                gsl.Version('foo', None, ['x86'], []), 'arm', 9, False))
+                gsl.Version('foo', None, ['x86'], []), 'arm', 9, False, False))
 
     def test_omit_api(self):
         self.assertFalse(
             gsl.should_omit_version(
-                gsl.Version('foo', None, [], []), 'arm', 9, False))
+                gsl.Version('foo', None, [], []), 'arm', 9, False, False))
         self.assertFalse(
             gsl.should_omit_version(
                 gsl.Version('foo', None, ['introduced=9'], []), 'arm', 9,
-                False))
+                False, False))
 
         self.assertTrue(
             gsl.should_omit_version(
                 gsl.Version('foo', None, ['introduced=14'], []), 'arm', 9,
-                False))
+                False, False))
 
 
 class OmitSymbolTest(unittest.TestCase):
     def test_omit_vndk(self):
         self.assertTrue(
             gsl.should_omit_symbol(
-                gsl.Symbol('foo', ['vndk']), 'arm', 9, False))
+                gsl.Symbol('foo', ['vndk']), 'arm', 9, False, False))
 
         self.assertFalse(
-            gsl.should_omit_symbol(gsl.Symbol('foo', []), 'arm', 9, True))
+            gsl.should_omit_symbol(gsl.Symbol('foo', []), 'arm', 9, True, False))
         self.assertFalse(
             gsl.should_omit_symbol(
-                gsl.Symbol('foo', ['vndk']), 'arm', 9, True))
+                gsl.Symbol('foo', ['vndk']), 'arm', 9, True, False))
+
+    def test_omit_apex(self):
+        self.assertTrue(
+            gsl.should_omit_symbol(
+                gsl.Symbol('foo', ['apex']), 'arm', 9, False, False))
+
+        self.assertFalse(
+            gsl.should_omit_symbol(gsl.Symbol('foo', []), 'arm', 9, False, True))
+        self.assertFalse(
+            gsl.should_omit_symbol(
+                gsl.Symbol('foo', ['apex']), 'arm', 9, False, True))
 
     def test_omit_arch(self):
         self.assertFalse(
-            gsl.should_omit_symbol(gsl.Symbol('foo', []), 'arm', 9, False))
+            gsl.should_omit_symbol(gsl.Symbol('foo', []), 'arm', 9, False, False))
         self.assertFalse(
             gsl.should_omit_symbol(
-                gsl.Symbol('foo', ['arm']), 'arm', 9, False))
+                gsl.Symbol('foo', ['arm']), 'arm', 9, False, False))
 
         self.assertTrue(
             gsl.should_omit_symbol(
-                gsl.Symbol('foo', ['x86']), 'arm', 9, False))
+                gsl.Symbol('foo', ['x86']), 'arm', 9, False, False))
 
     def test_omit_api(self):
         self.assertFalse(
-            gsl.should_omit_symbol(gsl.Symbol('foo', []), 'arm', 9, False))
+            gsl.should_omit_symbol(gsl.Symbol('foo', []), 'arm', 9, False, False))
         self.assertFalse(
             gsl.should_omit_symbol(
-                gsl.Symbol('foo', ['introduced=9']), 'arm', 9, False))
+                gsl.Symbol('foo', ['introduced=9']), 'arm', 9, False, False))
 
         self.assertTrue(
             gsl.should_omit_symbol(
-                gsl.Symbol('foo', ['introduced=14']), 'arm', 9, False))
+                gsl.Symbol('foo', ['introduced=14']), 'arm', 9, False, False))
 
 
 class SymbolFileParseTest(unittest.TestCase):
@@ -262,7 +285,7 @@
             # baz
             qux
         """))
-        parser = gsl.SymbolFileParser(input_file, {}, 'arm', 16, False)
+        parser = gsl.SymbolFileParser(input_file, {}, 'arm', 16, False, False)
         self.assertIsNone(parser.current_line)
 
         self.assertEqual('foo', parser.next_line().strip())
@@ -287,7 +310,7 @@
             VERSION_2 {
             } VERSION_1; # asdf
         """))
-        parser = gsl.SymbolFileParser(input_file, {}, 'arm', 16, False)
+        parser = gsl.SymbolFileParser(input_file, {}, 'arm', 16, False, False)
 
         parser.next_line()
         version = parser.parse_version()
@@ -311,7 +334,7 @@
         input_file = io.StringIO(textwrap.dedent("""\
             VERSION_1 {
         """))
-        parser = gsl.SymbolFileParser(input_file, {}, 'arm', 16, False)
+        parser = gsl.SymbolFileParser(input_file, {}, 'arm', 16, False, False)
         parser.next_line()
         with self.assertRaises(gsl.ParseError):
             parser.parse_version()
@@ -322,7 +345,7 @@
                 foo:
             }
         """))
-        parser = gsl.SymbolFileParser(input_file, {}, 'arm', 16, False)
+        parser = gsl.SymbolFileParser(input_file, {}, 'arm', 16, False, False)
         parser.next_line()
         with self.assertRaises(gsl.ParseError):
             parser.parse_version()
@@ -332,7 +355,7 @@
             foo;
             bar; # baz qux
         """))
-        parser = gsl.SymbolFileParser(input_file, {}, 'arm', 16, False)
+        parser = gsl.SymbolFileParser(input_file, {}, 'arm', 16, False, False)
 
         parser.next_line()
         symbol = parser.parse_symbol()
@@ -350,7 +373,7 @@
                 *;
             };
         """))
-        parser = gsl.SymbolFileParser(input_file, {}, 'arm', 16, False)
+        parser = gsl.SymbolFileParser(input_file, {}, 'arm', 16, False, False)
         parser.next_line()
         with self.assertRaises(gsl.ParseError):
             parser.parse_version()
@@ -362,7 +385,7 @@
                     *;
             };
         """))
-        parser = gsl.SymbolFileParser(input_file, {}, 'arm', 16, False)
+        parser = gsl.SymbolFileParser(input_file, {}, 'arm', 16, False, False)
         parser.next_line()
         version = parser.parse_version()
         self.assertEqual([], version.symbols)
@@ -373,7 +396,7 @@
                 foo
             };
         """))
-        parser = gsl.SymbolFileParser(input_file, {}, 'arm', 16, False)
+        parser = gsl.SymbolFileParser(input_file, {}, 'arm', 16, False, False)
         parser.next_line()
         with self.assertRaises(gsl.ParseError):
             parser.parse_version()
@@ -381,7 +404,7 @@
     def test_parse_fails_invalid_input(self):
         with self.assertRaises(gsl.ParseError):
             input_file = io.StringIO('foo')
-            parser = gsl.SymbolFileParser(input_file, {}, 'arm', 16, False)
+            parser = gsl.SymbolFileParser(input_file, {}, 'arm', 16, False, False)
             parser.parse()
 
     def test_parse(self):
@@ -402,7 +425,7 @@
                     qwerty;
             } VERSION_1;
         """))
-        parser = gsl.SymbolFileParser(input_file, {}, 'arm', 16, False)
+        parser = gsl.SymbolFileParser(input_file, {}, 'arm', 16, False, False)
         versions = parser.parse()
 
         expected = [
@@ -418,6 +441,30 @@
 
         self.assertEqual(expected, versions)
 
+    def test_parse_vndk_apex_symbol(self):
+        input_file = io.StringIO(textwrap.dedent("""\
+            VERSION_1 {
+                foo;
+                bar; # vndk
+                baz; # vndk apex
+                qux; # apex
+            };
+        """))
+        parser = gsl.SymbolFileParser(input_file, {}, 'arm', 16, False, True)
+
+        parser.next_line()
+        version = parser.parse_version()
+        self.assertEqual('VERSION_1', version.name)
+        self.assertIsNone(version.base)
+
+        expected_symbols = [
+            gsl.Symbol('foo', []),
+            gsl.Symbol('bar', ['vndk']),
+            gsl.Symbol('baz', ['vndk', 'apex']),
+            gsl.Symbol('qux', ['apex']),
+        ]
+        self.assertEqual(expected_symbols, version.symbols)
+
 
 class GeneratorTest(unittest.TestCase):
     def test_omit_version(self):
@@ -425,7 +472,7 @@
         # OmitVersionTest, PrivateVersionTest, and SymbolPresenceTest.
         src_file = io.StringIO()
         version_file = io.StringIO()
-        generator = gsl.Generator(src_file, version_file, 'arm', 9, False)
+        generator = gsl.Generator(src_file, version_file, 'arm', 9, False, False)
 
         version = gsl.Version('VERSION_PRIVATE', None, [], [
             gsl.Symbol('foo', []),
@@ -453,7 +500,7 @@
         # SymbolPresenceTest.
         src_file = io.StringIO()
         version_file = io.StringIO()
-        generator = gsl.Generator(src_file, version_file, 'arm', 9, False)
+        generator = gsl.Generator(src_file, version_file, 'arm', 9, False, False)
 
         version = gsl.Version('VERSION_1', None, [], [
             gsl.Symbol('foo', ['x86']),
@@ -476,10 +523,17 @@
         self.assertEqual('', src_file.getvalue())
         self.assertEqual('', version_file.getvalue())
 
+        version = gsl.Version('VERSION_1', None, [], [
+            gsl.Symbol('foo', ['apex']),
+        ])
+        generator.write_version(version)
+        self.assertEqual('', src_file.getvalue())
+        self.assertEqual('', version_file.getvalue())
+
     def test_write(self):
         src_file = io.StringIO()
         version_file = io.StringIO()
-        generator = gsl.Generator(src_file, version_file, 'arm', 9, False)
+        generator = gsl.Generator(src_file, version_file, 'arm', 9, False, False)
 
         versions = [
             gsl.Version('VERSION_1', None, [], [
@@ -554,18 +608,19 @@
             VERSION_4 { # versioned=9
                 wibble;
                 wizzes; # vndk
+                waggle; # apex
             } VERSION_2;
 
             VERSION_5 { # versioned=14
                 wobble;
             } VERSION_4;
         """))
-        parser = gsl.SymbolFileParser(input_file, api_map, 'arm', 9, False)
+        parser = gsl.SymbolFileParser(input_file, api_map, 'arm', 9, False, False)
         versions = parser.parse()
 
         src_file = io.StringIO()
         version_file = io.StringIO()
-        generator = gsl.Generator(src_file, version_file, 'arm', 9, False)
+        generator = gsl.Generator(src_file, version_file, 'arm', 9, False, False)
         generator.write(versions)
 
         expected_src = textwrap.dedent("""\
@@ -610,12 +665,12 @@
                     *;
             };
         """))
-        parser = gsl.SymbolFileParser(input_file, api_map, 'arm', 9001, False)
+        parser = gsl.SymbolFileParser(input_file, api_map, 'arm', 9001, False, False)
         versions = parser.parse()
 
         src_file = io.StringIO()
         version_file = io.StringIO()
-        generator = gsl.Generator(src_file, version_file, 'arm', 9001, False)
+        generator = gsl.Generator(src_file, version_file, 'arm', 9001, False, False)
         generator.write(versions)
 
         expected_src = textwrap.dedent("""\
@@ -658,13 +713,84 @@
             } VERSION_2;
 
         """))
-        parser = gsl.SymbolFileParser(input_file, {}, 'arm', 16, False)
+        parser = gsl.SymbolFileParser(input_file, {}, 'arm', 16, False, False)
 
         with self.assertRaises(gsl.MultiplyDefinedSymbolError) as cm:
             parser.parse()
         self.assertEquals(['bar', 'foo'],
                           cm.exception.multiply_defined_symbols)
 
+    def test_integration_with_apex(self):
+        api_map = {
+            'O': 9000,
+            'P': 9001,
+        }
+
+        input_file = io.StringIO(textwrap.dedent("""\
+            VERSION_1 {
+                global:
+                    foo; # var
+                    bar; # x86
+                    fizz; # introduced=O
+                    buzz; # introduced=P
+                local:
+                    *;
+            };
+
+            VERSION_2 { # arm
+                baz; # introduced=9
+                qux; # versioned=14
+            } VERSION_1;
+
+            VERSION_3 { # introduced=14
+                woodly;
+                doodly; # var
+            } VERSION_2;
+
+            VERSION_4 { # versioned=9
+                wibble;
+                wizzes; # vndk
+                waggle; # apex
+            } VERSION_2;
+
+            VERSION_5 { # versioned=14
+                wobble;
+            } VERSION_4;
+        """))
+        parser = gsl.SymbolFileParser(input_file, api_map, 'arm', 9, False, True)
+        versions = parser.parse()
+
+        src_file = io.StringIO()
+        version_file = io.StringIO()
+        generator = gsl.Generator(src_file, version_file, 'arm', 9, False, True)
+        generator.write(versions)
+
+        expected_src = textwrap.dedent("""\
+            int foo = 0;
+            void baz() {}
+            void qux() {}
+            void wibble() {}
+            void waggle() {}
+            void wobble() {}
+        """)
+        self.assertEqual(expected_src, src_file.getvalue())
+
+        expected_version = textwrap.dedent("""\
+            VERSION_1 {
+                global:
+                    foo;
+            };
+            VERSION_2 {
+                global:
+                    baz;
+            } VERSION_1;
+            VERSION_4 {
+                global:
+                    wibble;
+                    waggle;
+            } VERSION_2;
+        """)
+        self.assertEqual(expected_version, version_file.getvalue())
 
 def main():
     suite = unittest.TestLoader().loadTestsFromName(__name__)