Line data Source code
1 : //
2 : // Copyright (c) 2002-2013 The ANGLE Project Authors. All rights reserved.
3 : // Use of this source code is governed by a BSD-style license that can be
4 : // found in the LICENSE file.
5 : //
6 :
7 : #include "compiler/translator/ValidateLimitations.h"
8 : #include "compiler/translator/InfoSink.h"
9 : #include "compiler/translator/InitializeParseContext.h"
10 : #include "compiler/translator/ParseContext.h"
11 : #include "angle_gl.h"
12 :
13 : namespace sh
14 : {
15 :
16 : namespace
17 : {
18 :
19 0 : int GetLoopSymbolId(TIntermLoop *loop)
20 : {
21 : // Here we assume all the operations are valid, because the loop node is
22 : // already validated before this call.
23 0 : TIntermSequence *declSeq = loop->getInit()->getAsDeclarationNode()->getSequence();
24 0 : TIntermBinary *declInit = (*declSeq)[0]->getAsBinaryNode();
25 0 : TIntermSymbol *symbol = declInit->getLeft()->getAsSymbolNode();
26 :
27 0 : return symbol->getId();
28 : }
29 :
30 : // Traverses a node to check if it represents a constant index expression.
31 : // Definition:
32 : // constant-index-expressions are a superset of constant-expressions.
33 : // Constant-index-expressions can include loop indices as defined in
34 : // GLSL ES 1.0 spec, Appendix A, section 4.
35 : // The following are constant-index-expressions:
36 : // - Constant expressions
37 : // - Loop indices as defined in section 4
38 : // - Expressions composed of both of the above
39 0 : class ValidateConstIndexExpr : public TIntermTraverser
40 : {
41 : public:
42 0 : ValidateConstIndexExpr(const std::vector<int> &loopSymbols)
43 0 : : TIntermTraverser(true, false, false), mValid(true), mLoopSymbolIds(loopSymbols)
44 : {
45 0 : }
46 :
47 : // Returns true if the parsed node represents a constant index expression.
48 0 : bool isValid() const { return mValid; }
49 :
50 0 : void visitSymbol(TIntermSymbol *symbol) override
51 : {
52 : // Only constants and loop indices are allowed in a
53 : // constant index expression.
54 0 : if (mValid)
55 : {
56 0 : bool isLoopSymbol = std::find(mLoopSymbolIds.begin(), mLoopSymbolIds.end(),
57 0 : symbol->getId()) != mLoopSymbolIds.end();
58 0 : mValid = (symbol->getQualifier() == EvqConst) || isLoopSymbol;
59 : }
60 0 : }
61 :
62 : private:
63 : bool mValid;
64 : const std::vector<int> mLoopSymbolIds;
65 : };
66 :
67 : } // namespace anonymous
68 :
69 0 : ValidateLimitations::ValidateLimitations(sh::GLenum shaderType, TInfoSinkBase *sink)
70 : : TIntermTraverser(true, false, false),
71 : mShaderType(shaderType),
72 : mSink(sink),
73 : mNumErrors(0),
74 : mValidateIndexing(true),
75 0 : mValidateInnerLoops(true)
76 : {
77 0 : }
78 :
79 : // static
80 0 : bool ValidateLimitations::IsLimitedForLoop(TIntermLoop *loop)
81 : {
82 : // The shader type doesn't matter in this case.
83 0 : ValidateLimitations validate(GL_FRAGMENT_SHADER, nullptr);
84 0 : validate.mValidateIndexing = false;
85 0 : validate.mValidateInnerLoops = false;
86 0 : if (!validate.validateLoopType(loop))
87 0 : return false;
88 0 : if (!validate.validateForLoopHeader(loop))
89 0 : return false;
90 0 : TIntermNode *body = loop->getBody();
91 0 : if (body != nullptr)
92 : {
93 0 : validate.mLoopSymbolIds.push_back(GetLoopSymbolId(loop));
94 0 : body->traverse(&validate);
95 0 : validate.mLoopSymbolIds.pop_back();
96 : }
97 0 : return (validate.mNumErrors == 0);
98 : }
99 :
100 0 : bool ValidateLimitations::visitBinary(Visit, TIntermBinary *node)
101 : {
102 : // Check if loop index is modified in the loop body.
103 0 : validateOperation(node, node->getLeft());
104 :
105 : // Check indexing.
106 0 : switch (node->getOp())
107 : {
108 : case EOpIndexDirect:
109 : case EOpIndexIndirect:
110 0 : if (mValidateIndexing)
111 0 : validateIndexing(node);
112 0 : break;
113 : default:
114 0 : break;
115 : }
116 0 : return true;
117 : }
118 :
119 0 : bool ValidateLimitations::visitUnary(Visit, TIntermUnary *node)
120 : {
121 : // Check if loop index is modified in the loop body.
122 0 : validateOperation(node, node->getOperand());
123 :
124 0 : return true;
125 : }
126 :
127 0 : bool ValidateLimitations::visitAggregate(Visit, TIntermAggregate *node)
128 : {
129 0 : switch (node->getOp()) {
130 : case EOpFunctionCall:
131 0 : validateFunctionCall(node);
132 0 : break;
133 : default:
134 0 : break;
135 : }
136 0 : return true;
137 : }
138 :
139 0 : bool ValidateLimitations::visitLoop(Visit, TIntermLoop *node)
140 : {
141 0 : if (!mValidateInnerLoops)
142 0 : return true;
143 :
144 0 : if (!validateLoopType(node))
145 0 : return false;
146 :
147 0 : if (!validateForLoopHeader(node))
148 0 : return false;
149 :
150 0 : TIntermNode *body = node->getBody();
151 0 : if (body != NULL)
152 : {
153 0 : mLoopSymbolIds.push_back(GetLoopSymbolId(node));
154 0 : body->traverse(this);
155 0 : mLoopSymbolIds.pop_back();
156 : }
157 :
158 : // The loop is fully processed - no need to visit children.
159 0 : return false;
160 : }
161 :
162 0 : void ValidateLimitations::error(TSourceLoc loc,
163 : const char *reason, const char *token)
164 : {
165 0 : if (mSink)
166 : {
167 0 : mSink->prefix(EPrefixError);
168 0 : mSink->location(loc);
169 0 : (*mSink) << "'" << token << "' : " << reason << "\n";
170 : }
171 0 : ++mNumErrors;
172 0 : }
173 :
174 0 : bool ValidateLimitations::withinLoopBody() const
175 : {
176 0 : return !mLoopSymbolIds.empty();
177 : }
178 :
179 0 : bool ValidateLimitations::isLoopIndex(TIntermSymbol *symbol)
180 : {
181 0 : return std::find(mLoopSymbolIds.begin(), mLoopSymbolIds.end(), symbol->getId()) !=
182 0 : mLoopSymbolIds.end();
183 : }
184 :
185 0 : bool ValidateLimitations::validateLoopType(TIntermLoop *node)
186 : {
187 0 : TLoopType type = node->getType();
188 0 : if (type == ELoopFor)
189 0 : return true;
190 :
191 : // Reject while and do-while loops.
192 0 : error(node->getLine(),
193 : "This type of loop is not allowed",
194 0 : type == ELoopWhile ? "while" : "do");
195 0 : return false;
196 : }
197 :
198 0 : bool ValidateLimitations::validateForLoopHeader(TIntermLoop *node)
199 : {
200 0 : ASSERT(node->getType() == ELoopFor);
201 :
202 : //
203 : // The for statement has the form:
204 : // for ( init-declaration ; condition ; expression ) statement
205 : //
206 0 : int indexSymbolId = validateForLoopInit(node);
207 0 : if (indexSymbolId < 0)
208 0 : return false;
209 0 : if (!validateForLoopCond(node, indexSymbolId))
210 0 : return false;
211 0 : if (!validateForLoopExpr(node, indexSymbolId))
212 0 : return false;
213 :
214 0 : return true;
215 : }
216 :
217 0 : int ValidateLimitations::validateForLoopInit(TIntermLoop *node)
218 : {
219 0 : TIntermNode *init = node->getInit();
220 0 : if (init == NULL)
221 : {
222 0 : error(node->getLine(), "Missing init declaration", "for");
223 0 : return -1;
224 : }
225 :
226 : //
227 : // init-declaration has the form:
228 : // type-specifier identifier = constant-expression
229 : //
230 0 : TIntermDeclaration *decl = init->getAsDeclarationNode();
231 0 : if (decl == nullptr)
232 : {
233 0 : error(init->getLine(), "Invalid init declaration", "for");
234 0 : return -1;
235 : }
236 : // To keep things simple do not allow declaration list.
237 0 : TIntermSequence *declSeq = decl->getSequence();
238 0 : if (declSeq->size() != 1)
239 : {
240 0 : error(decl->getLine(), "Invalid init declaration", "for");
241 0 : return -1;
242 : }
243 0 : TIntermBinary *declInit = (*declSeq)[0]->getAsBinaryNode();
244 0 : if ((declInit == NULL) || (declInit->getOp() != EOpInitialize))
245 : {
246 0 : error(decl->getLine(), "Invalid init declaration", "for");
247 0 : return -1;
248 : }
249 0 : TIntermSymbol *symbol = declInit->getLeft()->getAsSymbolNode();
250 0 : if (symbol == NULL)
251 : {
252 0 : error(declInit->getLine(), "Invalid init declaration", "for");
253 0 : return -1;
254 : }
255 : // The loop index has type int or float.
256 0 : TBasicType type = symbol->getBasicType();
257 0 : if ((type != EbtInt) && (type != EbtUInt) && (type != EbtFloat)) {
258 0 : error(symbol->getLine(),
259 0 : "Invalid type for loop index", getBasicString(type));
260 0 : return -1;
261 : }
262 : // The loop index is initialized with constant expression.
263 0 : if (!isConstExpr(declInit->getRight()))
264 : {
265 0 : error(declInit->getLine(),
266 : "Loop index cannot be initialized with non-constant expression",
267 0 : symbol->getSymbol().c_str());
268 0 : return -1;
269 : }
270 :
271 0 : return symbol->getId();
272 : }
273 :
274 0 : bool ValidateLimitations::validateForLoopCond(TIntermLoop *node,
275 : int indexSymbolId)
276 : {
277 0 : TIntermNode *cond = node->getCondition();
278 0 : if (cond == NULL)
279 : {
280 0 : error(node->getLine(), "Missing condition", "for");
281 0 : return false;
282 : }
283 : //
284 : // condition has the form:
285 : // loop_index relational_operator constant_expression
286 : //
287 0 : TIntermBinary *binOp = cond->getAsBinaryNode();
288 0 : if (binOp == NULL)
289 : {
290 0 : error(node->getLine(), "Invalid condition", "for");
291 0 : return false;
292 : }
293 : // Loop index should be to the left of relational operator.
294 0 : TIntermSymbol *symbol = binOp->getLeft()->getAsSymbolNode();
295 0 : if (symbol == NULL)
296 : {
297 0 : error(binOp->getLine(), "Invalid condition", "for");
298 0 : return false;
299 : }
300 0 : if (symbol->getId() != indexSymbolId)
301 : {
302 0 : error(symbol->getLine(),
303 0 : "Expected loop index", symbol->getSymbol().c_str());
304 0 : return false;
305 : }
306 : // Relational operator is one of: > >= < <= == or !=.
307 0 : switch (binOp->getOp())
308 : {
309 : case EOpEqual:
310 : case EOpNotEqual:
311 : case EOpLessThan:
312 : case EOpGreaterThan:
313 : case EOpLessThanEqual:
314 : case EOpGreaterThanEqual:
315 0 : break;
316 : default:
317 0 : error(binOp->getLine(),
318 : "Invalid relational operator",
319 0 : GetOperatorString(binOp->getOp()));
320 0 : break;
321 : }
322 : // Loop index must be compared with a constant.
323 0 : if (!isConstExpr(binOp->getRight()))
324 : {
325 0 : error(binOp->getLine(),
326 : "Loop index cannot be compared with non-constant expression",
327 0 : symbol->getSymbol().c_str());
328 0 : return false;
329 : }
330 :
331 0 : return true;
332 : }
333 :
334 0 : bool ValidateLimitations::validateForLoopExpr(TIntermLoop *node,
335 : int indexSymbolId)
336 : {
337 0 : TIntermNode *expr = node->getExpression();
338 0 : if (expr == NULL)
339 : {
340 0 : error(node->getLine(), "Missing expression", "for");
341 0 : return false;
342 : }
343 :
344 : // for expression has one of the following forms:
345 : // loop_index++
346 : // loop_index--
347 : // loop_index += constant_expression
348 : // loop_index -= constant_expression
349 : // ++loop_index
350 : // --loop_index
351 : // The last two forms are not specified in the spec, but I am assuming
352 : // its an oversight.
353 0 : TIntermUnary *unOp = expr->getAsUnaryNode();
354 0 : TIntermBinary *binOp = unOp ? NULL : expr->getAsBinaryNode();
355 :
356 0 : TOperator op = EOpNull;
357 0 : TIntermSymbol *symbol = NULL;
358 0 : if (unOp != NULL)
359 : {
360 0 : op = unOp->getOp();
361 0 : symbol = unOp->getOperand()->getAsSymbolNode();
362 : }
363 0 : else if (binOp != NULL)
364 : {
365 0 : op = binOp->getOp();
366 0 : symbol = binOp->getLeft()->getAsSymbolNode();
367 : }
368 :
369 : // The operand must be loop index.
370 0 : if (symbol == NULL)
371 : {
372 0 : error(expr->getLine(), "Invalid expression", "for");
373 0 : return false;
374 : }
375 0 : if (symbol->getId() != indexSymbolId)
376 : {
377 0 : error(symbol->getLine(),
378 0 : "Expected loop index", symbol->getSymbol().c_str());
379 0 : return false;
380 : }
381 :
382 : // The operator is one of: ++ -- += -=.
383 0 : switch (op)
384 : {
385 : case EOpPostIncrement:
386 : case EOpPostDecrement:
387 : case EOpPreIncrement:
388 : case EOpPreDecrement:
389 0 : ASSERT((unOp != NULL) && (binOp == NULL));
390 0 : break;
391 : case EOpAddAssign:
392 : case EOpSubAssign:
393 0 : ASSERT((unOp == NULL) && (binOp != NULL));
394 0 : break;
395 : default:
396 0 : error(expr->getLine(), "Invalid operator", GetOperatorString(op));
397 0 : return false;
398 : }
399 :
400 : // Loop index must be incremented/decremented with a constant.
401 0 : if (binOp != NULL)
402 : {
403 0 : if (!isConstExpr(binOp->getRight()))
404 : {
405 0 : error(binOp->getLine(),
406 : "Loop index cannot be modified by non-constant expression",
407 0 : symbol->getSymbol().c_str());
408 0 : return false;
409 : }
410 : }
411 :
412 0 : return true;
413 : }
414 :
415 0 : bool ValidateLimitations::validateFunctionCall(TIntermAggregate *node)
416 : {
417 0 : ASSERT(node->getOp() == EOpFunctionCall);
418 :
419 : // If not within loop body, there is nothing to check.
420 0 : if (!withinLoopBody())
421 0 : return true;
422 :
423 : // List of param indices for which loop indices are used as argument.
424 : typedef std::vector<size_t> ParamIndex;
425 0 : ParamIndex pIndex;
426 0 : TIntermSequence *params = node->getSequence();
427 0 : for (TIntermSequence::size_type i = 0; i < params->size(); ++i)
428 : {
429 0 : TIntermSymbol *symbol = (*params)[i]->getAsSymbolNode();
430 0 : if (symbol && isLoopIndex(symbol))
431 0 : pIndex.push_back(i);
432 : }
433 : // If none of the loop indices are used as arguments,
434 : // there is nothing to check.
435 0 : if (pIndex.empty())
436 0 : return true;
437 :
438 0 : bool valid = true;
439 0 : TSymbolTable& symbolTable = GetGlobalParseContext()->symbolTable;
440 0 : TSymbol *symbol = symbolTable.find(node->getFunctionSymbolInfo()->getName(),
441 0 : GetGlobalParseContext()->getShaderVersion());
442 0 : ASSERT(symbol && symbol->isFunction());
443 0 : TFunction *function = static_cast<TFunction *>(symbol);
444 0 : for (ParamIndex::const_iterator i = pIndex.begin();
445 0 : i != pIndex.end(); ++i)
446 : {
447 0 : const TConstParameter ¶m = function->getParam(*i);
448 0 : TQualifier qual = param.type->getQualifier();
449 0 : if ((qual == EvqOut) || (qual == EvqInOut))
450 : {
451 0 : error((*params)[*i]->getLine(),
452 : "Loop index cannot be used as argument to a function out or inout parameter",
453 0 : (*params)[*i]->getAsSymbolNode()->getSymbol().c_str());
454 0 : valid = false;
455 : }
456 : }
457 :
458 0 : return valid;
459 : }
460 :
461 0 : bool ValidateLimitations::validateOperation(TIntermOperator *node,
462 : TIntermNode* operand)
463 : {
464 : // Check if loop index is modified in the loop body.
465 0 : if (!withinLoopBody() || !node->isAssignment())
466 0 : return true;
467 :
468 0 : TIntermSymbol *symbol = operand->getAsSymbolNode();
469 0 : if (symbol && isLoopIndex(symbol))
470 : {
471 0 : error(node->getLine(),
472 : "Loop index cannot be statically assigned to within the body of the loop",
473 0 : symbol->getSymbol().c_str());
474 : }
475 0 : return true;
476 : }
477 :
478 0 : bool ValidateLimitations::isConstExpr(TIntermNode *node)
479 : {
480 0 : ASSERT(node != nullptr);
481 0 : return node->getAsConstantUnion() != nullptr && node->getAsTyped()->getQualifier() == EvqConst;
482 : }
483 :
484 0 : bool ValidateLimitations::isConstIndexExpr(TIntermNode *node)
485 : {
486 0 : ASSERT(node != NULL);
487 :
488 0 : ValidateConstIndexExpr validate(mLoopSymbolIds);
489 0 : node->traverse(&validate);
490 0 : return validate.isValid();
491 : }
492 :
493 0 : bool ValidateLimitations::validateIndexing(TIntermBinary *node)
494 : {
495 0 : ASSERT((node->getOp() == EOpIndexDirect) ||
496 : (node->getOp() == EOpIndexIndirect));
497 :
498 0 : bool valid = true;
499 0 : TIntermTyped *index = node->getRight();
500 : // The index expession must be a constant-index-expression unless
501 : // the operand is a uniform in a vertex shader.
502 0 : TIntermTyped *operand = node->getLeft();
503 0 : bool skip = (mShaderType == GL_VERTEX_SHADER) &&
504 0 : (operand->getQualifier() == EvqUniform);
505 0 : if (!skip && !isConstIndexExpr(index))
506 : {
507 0 : error(index->getLine(), "Index expression must be constant", "[]");
508 0 : valid = false;
509 : }
510 0 : return valid;
511 : }
512 :
513 : } // namespace sh
|