Line data Source code
1 : //
2 : // Copyright (c) 2016 The ANGLE Project Authors. All rights reserved.
3 : // Use of this source code is governed by a BSD-style license that can be
4 : // found in the LICENSE file.
5 : //
6 : // SimplifyLoopConditions is an AST traverser that converts loop conditions and loop expressions
7 : // to regular statements inside the loop. This way further transformations that generate statements
8 : // from loop conditions and loop expressions work correctly.
9 : //
10 :
11 : #include "compiler/translator/SimplifyLoopConditions.h"
12 :
13 : #include "compiler/translator/IntermNode.h"
14 : #include "compiler/translator/IntermNodePatternMatcher.h"
15 :
16 : namespace sh
17 : {
18 :
19 : namespace
20 : {
21 :
22 0 : TIntermConstantUnion *CreateBoolConstantNode(bool value)
23 : {
24 0 : TConstantUnion *u = new TConstantUnion;
25 0 : u->setBConst(value);
26 : TIntermConstantUnion *node =
27 0 : new TIntermConstantUnion(u, TType(EbtBool, EbpUndefined, EvqConst, 1));
28 0 : return node;
29 : }
30 :
31 0 : class SimplifyLoopConditionsTraverser : public TLValueTrackingTraverser
32 : {
33 : public:
34 : SimplifyLoopConditionsTraverser(unsigned int conditionsToSimplifyMask,
35 : const TSymbolTable &symbolTable,
36 : int shaderVersion);
37 :
38 : void traverseLoop(TIntermLoop *node) override;
39 :
40 : bool visitBinary(Visit visit, TIntermBinary *node) override;
41 : bool visitAggregate(Visit visit, TIntermAggregate *node) override;
42 : bool visitTernary(Visit visit, TIntermTernary *node) override;
43 : bool visitDeclaration(Visit visit, TIntermDeclaration *node) override;
44 :
45 : void nextIteration();
46 0 : bool foundLoopToChange() const { return mFoundLoopToChange; }
47 :
48 : protected:
49 : // Marked to true once an operation that needs to be hoisted out of the expression has been
50 : // found. After that, no more AST updates are performed on that traversal.
51 : bool mFoundLoopToChange;
52 : bool mInsideLoopInitConditionOrExpression;
53 : IntermNodePatternMatcher mConditionsToSimplify;
54 : };
55 :
56 0 : SimplifyLoopConditionsTraverser::SimplifyLoopConditionsTraverser(
57 : unsigned int conditionsToSimplifyMask,
58 : const TSymbolTable &symbolTable,
59 0 : int shaderVersion)
60 : : TLValueTrackingTraverser(true, false, false, symbolTable, shaderVersion),
61 : mFoundLoopToChange(false),
62 : mInsideLoopInitConditionOrExpression(false),
63 0 : mConditionsToSimplify(conditionsToSimplifyMask)
64 : {
65 0 : }
66 :
67 0 : void SimplifyLoopConditionsTraverser::nextIteration()
68 : {
69 0 : mFoundLoopToChange = false;
70 0 : mInsideLoopInitConditionOrExpression = false;
71 0 : nextTemporaryIndex();
72 0 : }
73 :
74 : // The visit functions operate in three modes:
75 : // 1. If a matching expression has already been found, we return early since only one loop can
76 : // be transformed on one traversal.
77 : // 2. We try to find loops. In case a node is not inside a loop and can not contain loops, we
78 : // stop traversing the subtree.
79 : // 3. If we're inside a loop initialization, condition or expression, we check for expressions
80 : // that should be moved out of the loop condition or expression. If one is found, the loop
81 : // is processed.
82 0 : bool SimplifyLoopConditionsTraverser::visitBinary(Visit visit, TIntermBinary *node)
83 : {
84 :
85 0 : if (mFoundLoopToChange)
86 0 : return false;
87 :
88 0 : if (!mInsideLoopInitConditionOrExpression)
89 0 : return false;
90 :
91 0 : mFoundLoopToChange = mConditionsToSimplify.match(node, getParentNode(), isLValueRequiredHere());
92 0 : return !mFoundLoopToChange;
93 : }
94 :
95 0 : bool SimplifyLoopConditionsTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
96 : {
97 0 : if (mFoundLoopToChange)
98 0 : return false;
99 :
100 0 : if (!mInsideLoopInitConditionOrExpression)
101 0 : return false;
102 :
103 0 : mFoundLoopToChange = mConditionsToSimplify.match(node, getParentNode());
104 0 : return !mFoundLoopToChange;
105 : }
106 :
107 0 : bool SimplifyLoopConditionsTraverser::visitTernary(Visit visit, TIntermTernary *node)
108 : {
109 0 : if (mFoundLoopToChange)
110 0 : return false;
111 :
112 0 : if (!mInsideLoopInitConditionOrExpression)
113 0 : return false;
114 :
115 0 : mFoundLoopToChange = mConditionsToSimplify.match(node);
116 0 : return !mFoundLoopToChange;
117 : }
118 :
119 0 : bool SimplifyLoopConditionsTraverser::visitDeclaration(Visit visit, TIntermDeclaration *node)
120 : {
121 0 : if (mFoundLoopToChange)
122 0 : return false;
123 :
124 0 : if (!mInsideLoopInitConditionOrExpression)
125 0 : return false;
126 :
127 0 : mFoundLoopToChange = mConditionsToSimplify.match(node);
128 0 : return !mFoundLoopToChange;
129 : }
130 :
131 0 : void SimplifyLoopConditionsTraverser::traverseLoop(TIntermLoop *node)
132 : {
133 0 : if (mFoundLoopToChange)
134 0 : return;
135 :
136 : // Mark that we're inside a loop condition or expression, and transform the loop if needed.
137 :
138 0 : incrementDepth(node);
139 :
140 : // Note: No need to traverse the loop init node.
141 :
142 0 : mInsideLoopInitConditionOrExpression = true;
143 0 : TLoopType loopType = node->getType();
144 :
145 0 : if (!mFoundLoopToChange && node->getInit())
146 : {
147 0 : node->getInit()->traverse(this);
148 : }
149 :
150 0 : if (!mFoundLoopToChange && node->getCondition())
151 : {
152 0 : node->getCondition()->traverse(this);
153 : }
154 :
155 0 : if (!mFoundLoopToChange && node->getExpression())
156 : {
157 0 : node->getExpression()->traverse(this);
158 : }
159 :
160 0 : if (mFoundLoopToChange)
161 : {
162 : // Replace the loop condition with a boolean variable that's updated on each iteration.
163 0 : if (loopType == ELoopWhile)
164 : {
165 : // Transform:
166 : // while (expr) { body; }
167 : // into
168 : // bool s0 = expr;
169 : // while (s0) { { body; } s0 = expr; }
170 0 : TIntermSequence tempInitSeq;
171 0 : tempInitSeq.push_back(createTempInitDeclaration(node->getCondition()->deepCopy()));
172 0 : insertStatementsInParentBlock(tempInitSeq);
173 :
174 0 : TIntermBlock *newBody = new TIntermBlock();
175 0 : if (node->getBody())
176 : {
177 0 : newBody->getSequence()->push_back(node->getBody());
178 : }
179 0 : newBody->getSequence()->push_back(
180 0 : createTempAssignment(node->getCondition()->deepCopy()));
181 :
182 : // Can't use queueReplacement to replace old body, since it may have been nullptr.
183 : // It's safe to do the replacements in place here - this node won't be traversed
184 : // further.
185 0 : node->setBody(newBody);
186 0 : node->setCondition(createTempSymbol(node->getCondition()->getType()));
187 : }
188 0 : else if (loopType == ELoopDoWhile)
189 : {
190 : // Transform:
191 : // do {
192 : // body;
193 : // } while (expr);
194 : // into
195 : // bool s0 = true;
196 : // do {
197 : // { body; }
198 : // s0 = expr;
199 : // } while (s0);
200 0 : TIntermSequence tempInitSeq;
201 0 : tempInitSeq.push_back(createTempInitDeclaration(CreateBoolConstantNode(true)));
202 0 : insertStatementsInParentBlock(tempInitSeq);
203 :
204 0 : TIntermBlock *newBody = new TIntermBlock();
205 0 : if (node->getBody())
206 : {
207 0 : newBody->getSequence()->push_back(node->getBody());
208 : }
209 0 : newBody->getSequence()->push_back(
210 0 : createTempAssignment(node->getCondition()->deepCopy()));
211 :
212 : // Can't use queueReplacement to replace old body, since it may have been nullptr.
213 : // It's safe to do the replacements in place here - this node won't be traversed
214 : // further.
215 0 : node->setBody(newBody);
216 0 : node->setCondition(createTempSymbol(node->getCondition()->getType()));
217 : }
218 0 : else if (loopType == ELoopFor)
219 : {
220 : // Move the loop condition inside the loop.
221 : // Transform:
222 : // for (init; expr; exprB) { body; }
223 : // into
224 : // {
225 : // init;
226 : // bool s0 = expr;
227 : // while (s0) { { body; } exprB; s0 = expr; }
228 : // }
229 0 : TIntermBlock *loopScope = new TIntermBlock();
230 0 : if (node->getInit())
231 : {
232 0 : loopScope->getSequence()->push_back(node->getInit());
233 : }
234 0 : loopScope->getSequence()->push_back(
235 0 : createTempInitDeclaration(node->getCondition()->deepCopy()));
236 :
237 0 : TIntermBlock *whileLoopBody = new TIntermBlock();
238 0 : if (node->getBody())
239 : {
240 0 : whileLoopBody->getSequence()->push_back(node->getBody());
241 : }
242 0 : if (node->getExpression())
243 : {
244 0 : whileLoopBody->getSequence()->push_back(node->getExpression());
245 : }
246 0 : whileLoopBody->getSequence()->push_back(
247 0 : createTempAssignment(node->getCondition()->deepCopy()));
248 : TIntermLoop *whileLoop = new TIntermLoop(
249 0 : ELoopWhile, nullptr, createTempSymbol(node->getCondition()->getType()), nullptr,
250 0 : whileLoopBody);
251 0 : loopScope->getSequence()->push_back(whileLoop);
252 0 : queueReplacementWithParent(getAncestorNode(1), node, loopScope,
253 0 : OriginalNode::IS_DROPPED);
254 : }
255 : }
256 :
257 0 : mInsideLoopInitConditionOrExpression = false;
258 :
259 0 : if (!mFoundLoopToChange && node->getBody())
260 0 : node->getBody()->traverse(this);
261 :
262 0 : decrementDepth();
263 : }
264 :
265 : } // namespace
266 :
267 0 : void SimplifyLoopConditions(TIntermNode *root,
268 : unsigned int conditionsToSimplifyMask,
269 : unsigned int *temporaryIndex,
270 : const TSymbolTable &symbolTable,
271 : int shaderVersion)
272 : {
273 0 : SimplifyLoopConditionsTraverser traverser(conditionsToSimplifyMask, symbolTable, shaderVersion);
274 0 : ASSERT(temporaryIndex != nullptr);
275 0 : traverser.useTemporaryIndex(temporaryIndex);
276 : // Process one loop at a time, and reset the traverser between iterations.
277 0 : do
278 : {
279 0 : traverser.nextIteration();
280 0 : root->traverse(&traverser);
281 0 : if (traverser.foundLoopToChange())
282 0 : traverser.updateTree();
283 : } while (traverser.foundLoopToChange());
284 0 : }
285 :
286 : } // namespace sh
|