diff --git a/doc/syntax.md b/doc/syntax.md index b3af6a18d..2e2777026 100644 --- a/doc/syntax.md +++ b/doc/syntax.md @@ -402,13 +402,13 @@ _decl_ ::= _class-decl_ # Class declaration | _const-decl_ # Constant declaration | _global-decl_ # Global declaration -_class-decl_ ::= `class` _class-name_ _type-parameters_ _members_ `end` - | `class` _class-name_ _type-parameters_ `<` _class-name_ _type-arguments_ _members_ `end` +_class-decl_ ::= `class` _class-name_ _module-type-parameters_ _members_ `end` + | `class` _class-name_ _module-type-parameters_ `<` _class-name_ _type-arguments_ _members_ `end` -_module-decl_ ::= `module` _module-name_ _type-parameters_ _members_ `end` - | `module` _module-name_ _type-parameters_ `:` _class-name_ _type-arguments_ _members_ `end` +_module-decl_ ::= `module` _module-name_ _module-type-parameters_ _members_ `end` + | `module` _module-name_ _module-type-parameters_ `:` _class-name_ _type-arguments_ _members_ `end` -_interface-decl_ ::= `interface` _interface-name_ _type-parameters_ _interface-members_ `end` +_interface-decl_ ::= `interface` _interface-name_ _module-type-parameters_ _interface-members_ `end` _interface-members_ ::= _method-member_ # Method | _include-member_ # Mixin (include) @@ -424,6 +424,12 @@ _global-decl_ ::= _global-name_ `:` _type_ _const-name_ ::= _namespace_ /[A-Z]\w*/ _global-name_ ::= /$[a-zA-Z]\w+/ | ... + +_module-type-parameters_ ::= # Empty + | `[` _module-type-parameter_ `,` ... `]` + +_module-type-parameter_ ::= _variance_ _type-variable_ +_variance_ ::= `out` | `in` ``` ### Class declaration diff --git a/lib/ruby/signature/ast/declarations.rb b/lib/ruby/signature/ast/declarations.rb index 1886591db..1f1f207a5 100644 --- a/lib/ruby/signature/ast/declarations.rb +++ b/lib/ruby/signature/ast/declarations.rb @@ -2,6 +2,59 @@ module Ruby module Signature module AST module Declarations + class ModuleTypeParams + attr_reader :params + + TypeParam = Struct.new(:name, :variance, :skip_validation, keyword_init: true) + + def initialize() + @params = [] + end + + def add(param) + params << param + self + end + + def ==(other) + other.is_a?(ModuleTypeParams) && other.params == params + end + + def [](name) + params.find {|p| p.name == name } + end + + def to_json(*a) + { + params: params + }.to_json(*a) + end + + def each(&block) + params.each(&block) + end + + def self.empty + new + end + + def variance(name) + self[name].variance + end + + def skip_validation?(name) + self[name].skip_validation + end + + def empty? + params.empty? + end + + def size + params.size + end + end + class Class class Super attr_reader :name diff --git a/lib/ruby/signature/cli.rb b/lib/ruby/signature/cli.rb index 9c8c55e42..eda8f7f06 100644 --- a/lib/ruby/signature/cli.rb +++ b/lib/ruby/signature/cli.rb @@ -144,9 +144,9 @@ def run_ancestors(args, options) if env.class?(type_name) ancestor = case kind when :instance - definition = env.find_class(type_name) + decl = env.find_class(type_name) Definition::Ancestor::Instance.new(name: type_name, - args: Types::Variable.build(definition.type_params)) + args: Types::Variable.build(decl.type_params.each.map(&:name))) when :singleton Definition::Ancestor::Singleton.new(name: type_name) end diff --git a/lib/ruby/signature/definition.rb b/lib/ruby/signature/definition.rb index dcd12604a..e58a1e4e4 100644 --- a/lib/ruby/signature/definition.rb +++ b/lib/ruby/signature/definition.rb @@ -129,6 +129,15 @@ def type_params @self_type.args.map(&:name) end + def type_params_decl + case declaration + when AST::Declarations::Extension + nil + else + declaration.type_params + end + end + def each_type(&block) if block_given? methods.each_value do |method| diff --git a/lib/ruby/signature/definition_builder.rb b/lib/ruby/signature/definition_builder.rb index 361788b52..eeb11cd0c 100644 --- a/lib/ruby/signature/definition_builder.rb +++ b/lib/ruby/signature/definition_builder.rb @@ -23,16 +23,16 @@ def build_ancestors(self_ancestor, ancestors: [], building_ancestors: [], locati case self_ancestor when Definition::Ancestor::Instance args = self_ancestor.args - params = decl.type_params + param_names = decl.type_params.each.map(&:name) InvalidTypeApplicationError.check!( type_name: self_ancestor.name, args: args, - params: params, + params: decl.type_params, location: location || decl.location ) - sub = Substitution.build(params, args) + sub = Substitution.build(param_names, args) case decl when AST::Declarations::Class @@ -203,7 +203,7 @@ def build_instance(type_name) try_cache type_name, cache: instance_cache do decl = env.find_class(type_name) self_ancestor = Definition::Ancestor::Instance.new(name: type_name, - args: Types::Variable.build(decl.type_params)) + args: Types::Variable.build(decl.type_params.each.map(&:name))) self_type = Types::ClassInstance.new(name: type_name, args: self_ancestor.args, location: nil) case decl @@ -323,7 +323,7 @@ def build_one_instance(type_name, extension_name: nil) case decl when AST::Declarations::Class, AST::Declarations::Module self_type = Types::ClassInstance.new(name: type_name, - args: Types::Variable.build(decl.type_params), + args: Types::Variable.build(decl.type_params.each.map(&:name)), location: nil) ancestors = [Definition::Ancestor::Instance.new(name: type_name, args: self_type.args)] when AST::Declarations::Extension @@ -462,7 +462,7 @@ def build_one_instance(type_name, extension_name: nil) InvalidTypeApplicationError.check!( type_name: absolute_name, args: absolute_args, - params: interface_definition.type_params, + params: interface_definition.type_params_decl, location: member.location ) @@ -584,7 +584,7 @@ def build_one_singleton(type_name, extension_name: nil) InvalidTypeApplicationError.check!( type_name: absolute_name, args: absolute_args, - params: interface_definition.type_params, + params: interface_definition.type_params_decl, location: member.location ) @@ -705,7 +705,7 @@ def try_cache(type_name, cache:) def build_interface(type_name, declaration) self_type = Types::Interface.new( name: type_name, - args: declaration.type_params.map {|x| Types::Variable.new(name: x, location: nil) }, + args: declaration.type_params.each.map {|p| Types::Variable.new(name: p.name, location: nil) }, location: nil ) @@ -728,7 +728,7 @@ def build_interface(type_name, declaration) location: member.location ) - sub = Substitution.build(type_params, args) + sub = Substitution.build(type_params.each.map(&:name), args) mixin.methods.each do |name, method| definition.methods[name] = method.sub(sub) end diff --git a/lib/ruby/signature/errors.rb b/lib/ruby/signature/errors.rb index 57c2e1790..ad3d16801 100644 --- a/lib/ruby/signature/errors.rb +++ b/lib/ruby/signature/errors.rb @@ -11,7 +11,7 @@ def initialize(type_name:, args:, params:, location:) @args = args @params = params @location = location - super "#{Location.to_string location}: #{type_name} expects parameters [#{params.join(", ")}], but given args [#{args.join(", ")}]" + super "#{Location.to_string location}: #{type_name} expects parameters [#{params.each.map(&:name).join(", ")}], but given args [#{args.join(", ")}]" end def self.check!(type_name:, args:, params:, location:) diff --git a/lib/ruby/signature/parser.y b/lib/ruby/signature/parser.y index 510045091..14c11dffc 100644 --- a/lib/ruby/signature/parser.y +++ b/lib/ruby/signature/parser.y @@ -10,6 +10,7 @@ class Ruby::Signature::Parser kINTERFACE kEND kINCLUDE kEXTEND kATTRREADER kATTRWRITER kATTRACCESSOR tOPERATOR tQUOTEDMETHOD kPREPEND kEXTENSION kINCOMPATIBLE type_TYPE type_SIGNATURE type_METHODTYPE tEOF + kOUT kIN kUNCHECKED prechigh nonassoc kQUESTION @@ -78,13 +79,13 @@ rule extension_name: tUIDENT | tLIDENT class_decl: - annotations kCLASS start_new_scope class_name type_params super_class class_members kEND { + annotations kCLASS start_new_scope class_name module_type_params super_class class_members kEND { reset_variable_scope location = val[1].location + val[7].location result = Declarations::Class.new( name: val[3].value, - type_params: val[4]&.value || [], + type_params: val[4]&.value || Declarations::ModuleTypeParams.empty, super_class: val[5], members: val[6], annotations: val[0], @@ -105,13 +106,13 @@ rule } module_decl: - annotations kMODULE start_new_scope class_name type_params module_self_type class_members kEND { + annotations kMODULE start_new_scope class_name module_type_params module_self_type class_members kEND { reset_variable_scope location = val[1].location + val[7].location result = Declarations::Module.new( name: val[3].value, - type_params: val[4]&.value || [], + type_params: val[4]&.value || Declarations::ModuleTypeParams.empty, self_type: val[5], members: val[6], annotations: val[0], @@ -125,7 +126,7 @@ rule location = val[1].location + val[6].location result = Declarations::Module.new( name: val[3].value, - type_params: [], + type_params: Declarations::ModuleTypeParams.empty, self_type: val[4], members: val[5], annotations: val[0], @@ -272,13 +273,13 @@ rule } interface_decl: - annotations kINTERFACE start_new_scope interface_name type_params interface_members kEND { + annotations kINTERFACE start_new_scope interface_name module_type_params interface_members kEND { reset_variable_scope location = val[1].location + val[6].location result = Declarations::Interface.new( name: val[3].value, - type_params: val[4]&.value || [], + type_params: val[4]&.value || Declarations::ModuleTypeParams.empty, members: val[5], annotations: val[0], location: location, @@ -477,7 +478,7 @@ rule method_name: tOPERATOR - | kAMP | kHAT | kSTAR | kLT | kEXCLAMATION | kSTAR2 | kBAR + | kAMP | kHAT | kSTAR | kLT | kEXCLAMATION | kSTAR2 | kBAR | kOUT | kIN | method_name0 | method_name0 kQUESTION { unless val[0].location.pred?(val[1].location) @@ -504,6 +505,40 @@ rule kCLASS | kVOID | kNIL | kANY | kTOP | kBOT | kINSTANCE | kBOOL | kSINGLETON | kTYPE | kMODULE | kPRIVATE | kPUBLIC | kEND | kINCLUDE | kEXTEND | kPREPEND | kATTRREADER | kATTRACCESSOR | kATTRWRITER | kDEF | kEXTENSION | kSELF | kINCOMPATIBLE + | kUNCHECKED + + module_type_params: + { result = nil } + | kLBRACKET module_type_params0 kRBRACKET { + val[1].each {|p| insert_bound_variable(p.name) } + + result = LocatedValue.new(value: val[1], location: val[0].location + val[2].location) + } + + module_type_params0: + module_type_param { + result = Declarations::ModuleTypeParams.new() + result.add(val[0]) + } + | module_type_params0 kCOMMA module_type_param { + result = val[0].add(val[2]) + } + + module_type_param: + type_param_check type_param_variance tUIDENT { + result = Declarations::ModuleTypeParams::TypeParam.new(name: val[2].value.to_sym, + variance: val[1], + skip_validation: val[0]) + } + + type_param_variance: + { result = :invariant } + | kOUT { result = :covariant } + | kIN { result = :contravariant } + + type_param_check: + { result = false } + | kUNCHECKED { result = true } type_params: { result = nil } @@ -1126,7 +1161,10 @@ KEYWORDS = { "private" => :kPRIVATE, "alias" => :kALIAS, "extension" => :kEXTENSION, - "incompatible" => :kINCOMPATIBLE + "incompatible" => :kINCOMPATIBLE, + "unchecked" => :kUNCHECKED, + "out" => :kOUT, + "in" => :kIN, } KEYWORDS_RE = /#{Regexp.union(*KEYWORDS.keys)}\b/ diff --git a/lib/ruby/signature/scaffold/rbi.rb b/lib/ruby/signature/scaffold/rbi.rb index 877f7fc42..2ed31a5ab 100644 --- a/lib/ruby/signature/scaffold/rbi.rb +++ b/lib/ruby/signature/scaffold/rbi.rb @@ -49,7 +49,7 @@ def push_class(name, super_class, comment:) modules.push AST::Declarations::Class.new( name: nested_name(name), super_class: super_class && AST::Declarations::Class::Super.new(name: const_to_name(super_class), args: []), - type_params: [], + type_params: AST::Declarations::ModuleTypeParams.empty, members: [], annotations: [], location: nil, @@ -66,7 +66,7 @@ def push_class(name, super_class, comment:) def push_module(name, comment:) modules.push AST::Declarations::Module.new( name: nested_name(name), - type_params: [], + type_params: AST::Declarations::ModuleTypeParams.empty, members: [], annotations: [], location: nil, @@ -204,7 +204,19 @@ def process(node, outer: [], comments:) node.type == :HASH && each_arg(node.children[0]).each_slice(2).any? {|a, _| a.type == :LIT && a.children[0] == :fixed } } - current_module.type_params << node.children[0] + if (a0 = each_arg(send.children[1]).to_a[0])&.type == :LIT + variance = case a0.children[0] + when :out + :covariant + when :in + :contravariant + end + end + + current_module.type_params.add( + AST::Declarations::ModuleTypeParams::TypeParam.new(name: node.children[0], + variance: variance || :invariant, + skip_validation: false)) end else name = node.children[0].yield_self do |n| @@ -418,7 +430,7 @@ def type_of(type_node, variables:) def type_of0(type_node, variables:) case when type_node.type == :CONST - if variables.include?(type_node.children[0]) + if variables.each.include?(type_node.children[0]) Types::Variable.new(name: type_node.children[0], location: nil) else Types::ClassInstance.new(name: const_to_name(type_node), args: [], location: nil) diff --git a/lib/ruby/signature/types.rb b/lib/ruby/signature/types.rb index 059f94f84..653a134c1 100644 --- a/lib/ruby/signature/types.rb +++ b/lib/ruby/signature/types.rb @@ -126,6 +126,8 @@ def self.build(v) new(name: v, location: nil) when Array v.map {|x| new(name: x, location: nil) } + else + raise end end diff --git a/lib/ruby/signature/writer.rb b/lib/ruby/signature/writer.rb index d17b3b7b1..edecb8231 100644 --- a/lib/ruby/signature/writer.rb +++ b/lib/ruby/signature/writer.rb @@ -54,7 +54,7 @@ def write_decl(decl) end write_comment decl.comment, level: 0 write_annotation decl.annotations, level: 0 - out.puts "class #{name_and_args(decl.name, decl.type_params)}#{super_class}" + out.puts "class #{name_and_params(decl.name, decl.type_params)}#{super_class}" decl.members.each.with_index do |member, index| if index > 0 @@ -72,7 +72,7 @@ def write_decl(decl) write_comment decl.comment, level: 0 write_annotation decl.annotations, level: 0 - out.puts "module #{name_and_args(decl.name, decl.type_params)}#{self_type}" + out.puts "module #{name_and_params(decl.name, decl.type_params)}#{self_type}" decl.members.each.with_index do |member, index| if index > 0 out.puts @@ -96,7 +96,7 @@ def write_decl(decl) when AST::Declarations::Interface write_comment decl.comment, level: 0 write_annotation decl.annotations, level: 0 - out.puts "interface #{name_and_args(decl.name, decl.type_params)}" + out.puts "interface #{name_and_params(decl.name, decl.type_params)}" decl.members.each.with_index do |member, index| if index > 0 out.puts @@ -119,6 +119,30 @@ def write_decl(decl) end end + def name_and_params(name, params) + if params.empty? + "#{name}" + else + ps = params.each.map do |param| + s = "" + if param.skip_validation + s << "unchecked " + end + case param.variance + when :invariant + # nop + when :covariant + s << "out " + when :contravariant + s << "in " + end + s + param.name.to_s + end + + "#{name}[#{ps.join(", ")}]" + end + end + def name_and_args(name, args) if name && args if args.empty? diff --git a/test/ruby/signature/rbi_scaffold_test.rb b/test/ruby/signature/rbi_scaffold_test.rb index ba183a060..1201d64e2 100644 --- a/test/ruby/signature/rbi_scaffold_test.rb +++ b/test/ruby/signature/rbi_scaffold_test.rb @@ -307,7 +307,7 @@ class Array EOF assert_write parser.decls, <<-EOF -class Array[Elem] +class Array[out Elem] include Enumerable end EOF @@ -406,6 +406,28 @@ class Dir assert_write parser.decls, <<-EOF class Dir include Enumerable +end + EOF + end + + def test_parameter_type_member_variance + parser = RBI.new + + parser.parse <<-EOF +class Dir + extend T::Generic + + X = type_member(:out) + Y = type_member(:in) + Z = type_member() + + include Enumerable +end + EOF + + assert_write parser.decls, <<-EOF +class Dir[out X, in Y, Z] + include Enumerable end EOF end diff --git a/test/ruby/signature/signature_parsing_test.rb b/test/ruby/signature/signature_parsing_test.rb index ed601a342..2b2fdab6d 100644 --- a/test/ruby/signature/signature_parsing_test.rb +++ b/test/ruby/signature/signature_parsing_test.rb @@ -92,7 +92,7 @@ def test_interface assert_instance_of Declarations::Interface, interface_decl assert_equal TypeName.new(name: :_Each, namespace: Namespace.empty), interface_decl.name - assert_equal [:A, :B], interface_decl.type_params + assert_equal [:A, :B], interface_decl.type_params.each.map(&:name) assert_equal [], interface_decl.members assert_equal "interface _Each[A, B] end", interface_decl.location.source end @@ -116,7 +116,7 @@ def count: -> Integer assert_instance_of Declarations::Interface, interface_decl assert_equal TypeName.new(name: :_Each, namespace: Namespace.empty), interface_decl.name - assert_equal [:A, :B], interface_decl.type_params + assert_equal [:A, :B], interface_decl.type_params.each.map(&:name) assert_equal 2, interface_decl.members.size interface_decl.members[0].yield_self do |def_member| @@ -174,7 +174,7 @@ def test_module assert_instance_of Declarations::Module, module_decl assert_equal TypeName.new(name: :Enumerable, namespace: Namespace.empty), module_decl.name - assert_equal [:A, :B], module_decl.type_params + assert_equal [:A, :B], module_decl.type_params.each.map(&:name) assert_equal parse_type("_Each"), module_decl.self_type assert_equal [], module_decl.members assert_equal "module Enumerable[A, B] : _Each end", module_decl.location.source @@ -363,7 +363,7 @@ def test_class decls[0].yield_self do |class_decl| assert_instance_of Declarations::Class, class_decl assert_equal TypeName.new(name: :Array, namespace: Namespace.empty), class_decl.name - assert_equal [:A], class_decl.type_params + assert_equal [:A], class_decl.type_params.each.map(&:name) assert_nil class_decl.super_class end end @@ -374,7 +374,7 @@ def test_class decls[0].yield_self do |class_decl| assert_instance_of Declarations::Class, class_decl assert_equal TypeName.new(name: :Array, namespace: Namespace.root), class_decl.name - assert_equal [:A], class_decl.type_params + assert_equal [:A], class_decl.type_params.each.map(&:name) assert_instance_of Declarations::Class::Super, class_decl.super_class assert_equal TypeName.new(name: :Object, namespace: Namespace.empty), class_decl.super_class.name @@ -892,4 +892,30 @@ def binding: () -> Binding EOF end end + + def test_module_type_param_variance + Parser.parse_signature("interface _Each[A, out B, unchecked in C] end").yield_self do |decls| + assert_equal 1, decls.size + + interface_decl = decls[0] + + assert_instance_of Declarations::Interface, interface_decl + a, b, c = interface_decl.type_params.each.to_a + + assert_instance_of Declarations::ModuleTypeParams::TypeParam, a + assert_equal :A, a.name + assert_equal :invariant, a.variance + refute a.skip_validation + + assert_instance_of Declarations::ModuleTypeParams::TypeParam, b + assert_equal :B, b.name + assert_equal :covariant, b.variance + refute b.skip_validation + + assert_instance_of Declarations::ModuleTypeParams::TypeParam, c + assert_equal :C, c.name + assert_equal :contravariant, c.variance + assert c.skip_validation + end + end end diff --git a/test/ruby/signature/writer_test.rb b/test/ruby/signature/writer_test.rb index cfb81e8e7..68ab72617 100644 --- a/test/ruby/signature/writer_test.rb +++ b/test/ruby/signature/writer_test.rb @@ -126,6 +126,13 @@ def initialize: () -> void class Bar def self.new: (String) -> Bar +end + SIG + end + + def test_variance + assert_writer <<-SIG +class Foo[out A, unchecked B, in C] < Bar[A, C, B] end SIG end