■ ■ ■ ■ ■ ■
pkg/ast/languages/ruby/patterns/patterns.go
1 | 1 | | package patterns |
2 | 2 | | |
3 | 3 | | import ( |
| 4 | + | "errors" |
4 | 5 | | "fmt" |
5 | 6 | | "log" |
6 | 7 | | |
7 | 8 | | sitter "github.com/smacker/go-tree-sitter" |
| 9 | + | "golang.org/x/exp/slices" |
8 | 10 | | |
| 11 | + | "github.com/bearer/bearer/new/language/implementation" |
9 | 12 | | builderinput "github.com/bearer/bearer/new/language/patternquery/builder/input" |
10 | 13 | | querytypes "github.com/bearer/bearer/new/language/patternquery/types" |
11 | 14 | | "github.com/bearer/bearer/pkg/ast/idgenerator" |
| skipped 33 lines |
45 | 48 | | |
46 | 49 | | type patternWriter struct { |
47 | 50 | | *filewriter.Writer |
48 | | - | inputParams *builderinput.InputParams |
49 | | - | input []byte |
50 | | - | literals []writerbase.Literal |
51 | | - | childIndex uint32 |
52 | | - | rootElement writerbase.LiteralElement |
53 | | - | parentElement writerbase.LiteralElement |
54 | | - | nodeVariableGenerator *nodeVariableGenerator |
55 | | - | tempIdGenerator *idgenerator.Generator |
56 | | - | handled set.Set[*sitter.Node] |
57 | | - | variableNodes map[string][]writerbase.Identifier |
| 51 | + | inputParams *builderinput.InputParams |
| 52 | + | input []byte |
| 53 | + | literals []writerbase.Literal |
| 54 | + | childIndex uint32 |
| 55 | + | rootElement writerbase.LiteralElement |
| 56 | + | parentElement writerbase.LiteralElement |
| 57 | + | nodeVariableGenerator *nodeVariableGenerator |
| 58 | + | tempIdGenerator *idgenerator.Generator |
| 59 | + | handled set.Set[*sitter.Node] |
| 60 | + | variableNodes map[string][]writerbase.Identifier |
| 61 | + | langImplementation implementation.Implementation |
| 62 | + | lastChildIndexVariable writerbase.LiteralElement |
58 | 63 | | } |
| 64 | + | |
| 65 | + | var Skipped = errors.New("skipped") |
59 | 66 | | |
60 | 67 | | func CompileRule( |
61 | 68 | | walker *walker.Walker, |
62 | 69 | | inputParams *builderinput.InputParams, |
| 70 | + | langImplementation implementation.Implementation, |
63 | 71 | | patternId string, |
64 | 72 | | input []byte, |
65 | 73 | | rootNode *sitter.Node, |
| skipped 7 lines |
73 | 81 | | tempIdGenerator: idgenerator.NewGenerator(), |
74 | 82 | | handled: set.New[*sitter.Node](), |
75 | 83 | | variableNodes: make(map[string][]writerbase.Identifier), |
| 84 | + | langImplementation: langImplementation, |
76 | 85 | | } |
77 | 86 | | |
| 87 | + | matchNode := findMatchNode( |
| 88 | + | walker, |
| 89 | + | inputParams.MatchNodeOffset, |
| 90 | + | langImplementation.PatternMatchNodeContainerTypes(), |
| 91 | + | rootNode, |
| 92 | + | ) |
| 93 | + | |
78 | 94 | | err := walker.Walk(rootNode, w.visitNode) |
79 | 95 | | if err != nil { |
80 | 96 | | return err |
| skipped 24 lines |
105 | 121 | | } |
106 | 122 | | |
107 | 123 | | if len(w.literals) > 20 { |
108 | | - | log.Printf("rule too large, skipping") |
109 | | - | return nil |
| 124 | + | log.Printf("rule too large %d, skipping %s", len(w.literals), patternId) |
| 125 | + | return Skipped |
110 | 126 | | } |
111 | | - | log.Printf("#literals: %d", len(w.literals)) |
| 127 | + | // log.Printf("#literals: %d", len(w.literals)) |
112 | 128 | | |
113 | 129 | | writer.WriteRelation( |
114 | 130 | | fmt.Sprintf("Rule_Match_%s", patternId), |
| skipped 4 lines |
119 | 135 | | if err := writer.WriteRule( |
120 | 136 | | []writerbase.Predicate{writer.Predicate( |
121 | 137 | | fmt.Sprintf("Rule_Match_%s", patternId), |
122 | | - | append([]writerbase.LiteralElement{w.rootElement}, variableElements...)..., |
| 138 | + | append( |
| 139 | + | []writerbase.LiteralElement{w.Identifier(w.nodeVariableGenerator.Get(matchNode))}, |
| 140 | + | variableElements..., |
| 141 | + | )..., |
123 | 142 | | )}, |
124 | 143 | | append(w.literals, variableConstraints...), |
125 | 144 | | ); err != nil { |
126 | 145 | | return err |
| 146 | + | } |
| 147 | + | |
| 148 | + | if patternId == "blowfish_init_0" { |
| 149 | + | log.Printf("RULE: %s", matchNode.String()) |
127 | 150 | | } |
128 | 151 | | |
129 | 152 | | return nil |
| skipped 25 lines |
155 | 178 | | writer.Predicate("AST_NodeField", writer.parentElement, nodeElement, writer.Symbol(fname)), |
156 | 179 | | ) |
157 | 180 | | } else { |
| 181 | + | childIndexVariable := writer.Identifier(fmt.Sprintf("tmp%d", writer.tempIdGenerator.Get())) |
| 182 | + | nodeAnchoredBefore, _ := writer.langImplementation.PatternIsAnchored(node) |
| 183 | + | |
| 184 | + | // Anchored before |
| 185 | + | if node.IsNamed() && nodeAnchoredBefore && !slices.Contains(writer.inputParams.UnanchoredOffsets, int(node.StartByte())) { |
| 186 | + | if writer.lastChildIndexVariable != nil { |
| 187 | + | writer.literals = append( |
| 188 | + | writer.literals, |
| 189 | + | // FIXME: constraint hack! |
| 190 | + | writer.Constraint(childIndexVariable, "= 1+", writer.lastChildIndexVariable), |
| 191 | + | ) |
| 192 | + | } else { |
| 193 | + | writer.literals = append( |
| 194 | + | writer.literals, |
| 195 | + | writer.Constraint(childIndexVariable, "=", writer.Unsigned(writer.childIndex)), |
| 196 | + | ) |
| 197 | + | |
| 198 | + | } |
| 199 | + | } else { |
| 200 | + | if writer.lastChildIndexVariable != nil { |
| 201 | + | writer.literals = append( |
| 202 | + | writer.literals, |
| 203 | + | writer.Constraint(childIndexVariable, ">", writer.lastChildIndexVariable), |
| 204 | + | ) |
| 205 | + | } else { |
| 206 | + | writer.literals = append( |
| 207 | + | writer.literals, |
| 208 | + | writer.Constraint(childIndexVariable, ">=", writer.Unsigned(0)), |
| 209 | + | ) |
| 210 | + | } |
| 211 | + | } |
| 212 | + | |
| 213 | + | // FIXME: end anchoring |
| 214 | + | // if node.IsNamed() && isLastChild && nodeAnchoredAfter && !slices.Contains(writer.inputParams.UnanchoredOffsets, int(node.EndByte())) { |
| 215 | + | // // Last anchored |
| 216 | + | // } |
| 217 | + | |
158 | 218 | | writer.literals = append( |
159 | 219 | | writer.literals, |
160 | | - | writer.Predicate("AST_ParentChild", writer.parentElement, writer.Unsigned(writer.childIndex), nodeElement), |
| 220 | + | writer.Predicate("AST_ParentChild", writer.parentElement, childIndexVariable, nodeElement), |
161 | 221 | | ) |
162 | 222 | | |
163 | 223 | | writer.childIndex++ |
| 224 | + | writer.lastChildIndexVariable = childIndexVariable |
164 | 225 | | } |
165 | 226 | | } |
166 | 227 | | |
| skipped 60 lines |
227 | 288 | | oldParentElement := writer.parentElement |
228 | 289 | | oldChildIndex := writer.childIndex |
229 | 290 | | writer.childIndex = 0 |
| 291 | + | writer.lastChildIndexVariable = nil |
230 | 292 | | writer.parentElement = nodeElement |
231 | 293 | | err := visitChildren() |
232 | 294 | | writer.childIndex = oldChildIndex |
| skipped 61 lines |
294 | 356 | | return nil |
295 | 357 | | } |
296 | 358 | | |
| 359 | + | func findMatchNode( |
| 360 | + | walker *walker.Walker, |
| 361 | + | offset int, |
| 362 | + | containerTypes []string, |
| 363 | + | rootNode *sitter.Node, |
| 364 | + | ) (matchNode *sitter.Node) { |
| 365 | + | err := walker.Walk(rootNode, func(node *sitter.Node, visitChildren func() error) error { |
| 366 | + | // FIXME: do this generically! |
| 367 | + | if node.Type() != "program" { |
| 368 | + | if node.StartByte() == uint32(offset) && !slices.Contains(containerTypes, node.Type()) { |
| 369 | + | matchNode = node |
| 370 | + | return nil |
| 371 | + | } |
| 372 | + | } |
| 373 | + | |
| 374 | + | return visitChildren() |
| 375 | + | }) |
| 376 | + | |
| 377 | + | // walk itself shouldn't trigger an error, and we aren't creating any |
| 378 | + | if err != nil { |
| 379 | + | panic(err) |
| 380 | + | } |
| 381 | + | |
| 382 | + | return |
| 383 | + | } |
| 384 | + | |